## Notebook: Tea_Grading_Training.ipynb
## Folder: notebooks/
## Purpose: Interactive training and evaluation for Tea Grading AI Model

### imports

In [10]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt

### load dataset

In [12]:
class TeaDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # map tea grades for classification
        self.grade_map = {
            "OP": 0, 
            "OP1": 1, 
            "OPA": 2,
            "NOT_TEA": 3
        }

        # go through all main grade folders in the dataset  
        for grade_name in os.listdir(root_dir):
            grade_path = os.path.join(root_dir, grade_name)

            # ignore unknown folder names
            if not os.path.isdir(grade_path) or grade_name not in self.grade_map:
                continue

            grade_label = self.grade_map[grade_name]

            # go through subfolders (quality) in the main folder
            for quality_folder in os.listdir(grade_path):
                if not quality_folder.startswith("quality_"):
                    continue

                # ignore unknown folder names
                quality_path = os.path.join(grade_path, quality_folder)

                # convert subfolder names to indexes
                quality_num = int(quality_folder.split('_')[1]) - 1

                # go through image files in the subfolder
                for img_name in os.listdir(quality_path):
                    if img_name.lower().endswith((".jpg", ".jpeg", ".png")):
                        img_path = os.path.join(quality_path, img_name)

                        # add image_path, grade_label and quality_label to samples
                        self.samples.append((img_path, grade_label, quality_num))

        print(f"total samples found: {len(self.samples)}")

    # get total number of images in the dataset    
    def __len__(self):
        return len(self.samples)

    # ready images for modal input 
    def __getitem__(self, idx):
        img_path, grade_label, quality_label = self.samples[idx]

        # load and convert image into RGB
        image = Image.open(img_path).convert("RGB")

        # apply transformations
        if self.transform:
            image = self.transform(image)
            
        return image, grade_label, quality_label

### transforms and dataloader

In [13]:
# define image transformation for model inputs
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

# dataset path
DATASET_PATH = "../dataset/images"

# create dataset to train
dataset = TeaDataset(DATASET_PATH, transform=transform)

# split dataset into train and validation sets (80% train, 20% validation)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# collect grade labels to handle imbalances
train_labels = []
for idx in train_dataset.indices:
    _, grade_label, _ = dataset[idx]
    train_labels.append(grade_label)

# convert labels to a tensor
train_labels = torch.tensor(train_labels)

# boost weight on tea classes
class_sample_weights = torch.tensor([3.0, 3.0, 3.0, 1.0])  # OP, OP1, OPA, NOT_TEA

# create sampler to balance classes during training
sample_weights = class_sample_weights[train_labels]
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# use sampler in DataLoader instead of shuffle
train_loader = DataLoader(train_dataset, batch_size=16, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

total samples found: 6915


### model definition

In [14]:
# neural network for tea grading and quality prediction 
class TeaNet(nn.Module):
    def __init__(self, num_grades=4, num_qualities=10):
        super(TeaNet, self).__init__()

        # use a pretrained ResNet-18 for feature extraction (to leverage learned visual features)
        self.backbone = models.resnet18(pretrained=True)

        # get number of features
        num_features = self.backbone.fc.in_features

        # remove the original ResNet connected layer
        self.backbone.fc = nn.Identity()

        # classification head for grade and quality
        self.grade_head = nn.Linear(num_features, num_grades)   # OP, OP1, OPA, NOT_TEA
        self.quality_head = nn.Linear(num_features, num_qualities)  # quality 1-10

    # get predicted probabilities for tea grade and quality classification
    def forward(self, x):
        features = self.backbone(x)

        grade_out = self.grade_head(features)
        quality_out = self.quality_head(features)
        
        return grade_out, quality_out


### training steps

In [15]:
# detect device and use GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# initialize TeaNet model
model = TeaNet(num_grades=4, num_qualities=10).to(device)

# define weighted loss for tea grade classification
class_weights = torch.tensor([1.0, 1.0, 1.0, 0.25]).to(device)
criterion_grade = nn.CrossEntropyLoss(weight=class_weights)

# define standard cross entropy loss for tea quality prediction
criterion_quality = nn.CrossEntropyLoss()

# define optimizer to update model parameters
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# number of passes over the training dataset
epochs = 25

### training loop

In [16]:
for epoch in range(epochs):
    # set model to training mode
    model.train()

    # total losses for current epoch
    total_loss = 0

    # go through training data
    for images, grade_labels, quality_labels in train_loader:
        # move data to the device (GPU/CPU)
        images, grade_labels, quality_labels = images.to(device), grade_labels.to(device), quality_labels.to(device)

        # reset gradients from the previous step
        optimizer.zero_grad()

        # get predictions for grade and quality
        grade_preds, quality_preds = model(images)

        # calculate loss for grade classification
        loss_grade = criterion_grade(grade_preds, grade_labels)

        # calculate loss for quality classification (only for OP, OP1, OPA)
        not_tea_label = 3
        mask = (grade_labels != not_tea_label)
        if mask.sum() > 0:
            loss_quality = criterion_quality(quality_preds[mask], quality_labels[mask])
        else:
            loss_quality = torch.tensor(0.0, device=device)

        loss = loss_grade + loss_quality

        # backpropagation to calculate gradients
        loss.backward()

        # update model parameters
        optimizer.step()

        # calculate batch loss
        total_loss += loss.item()

    print(f"epoch {epoch+1}/{epochs} - training loss: {total_loss:.4f}")

epoch 1/25 - training loss: 202.9434
epoch 2/25 - training loss: 18.4998
epoch 3/25 - training loss: 8.7080
epoch 4/25 - training loss: 6.6258
epoch 5/25 - training loss: 7.2024
epoch 6/25 - training loss: 5.5314
epoch 7/25 - training loss: 4.3386
epoch 8/25 - training loss: 3.1436
epoch 9/25 - training loss: 1.7929
epoch 10/25 - training loss: 20.0953
epoch 11/25 - training loss: 3.4812
epoch 12/25 - training loss: 2.9847
epoch 13/25 - training loss: 3.2718
epoch 14/25 - training loss: 0.8392
epoch 15/25 - training loss: 1.4619
epoch 16/25 - training loss: 4.1413
epoch 17/25 - training loss: 9.6820
epoch 18/25 - training loss: 4.2149
epoch 19/25 - training loss: 4.4106
epoch 20/25 - training loss: 4.4802
epoch 21/25 - training loss: 1.4725
epoch 22/25 - training loss: 0.8164
epoch 23/25 - training loss: 0.6545
epoch 24/25 - training loss: 0.7505
epoch 25/25 - training loss: 0.3486


### validation

In [17]:
# accuracy calculation counters
correct_grade = 0
total_grade = 0
correct_quality = 0
total_quality = 0

# disable gradient computation
with torch.no_grad():
    # go through validation set
    for images, grade_labels, quality_labels in val_loader:
        # move data to the device (GPU/CPU)
        images = images.to(device)
        grade_labels = grade_labels.to(device)
        quality_labels = quality_labels.to(device)

        # pass through the model to get prediction
        grade_preds, quality_preds = model(images)

        # get predicted grade and quality
        _, grade_predicted = torch.max(grade_preds, 1)
        _, quality_predicted = torch.max(quality_preds, 1)

        # grade accuracy including NOT_TEA
        total_grade += grade_labels.size(0)
        correct_grade += (grade_predicted == grade_labels).sum().item()

        # quality accuracy for real tea images
        mask = (grade_labels != 3)
        if mask.sum() > 0:
            # count correct quality predictions (only for tea images)
            correct_quality += (quality_predicted[mask] == quality_labels[mask]).sum().item()
            total_quality += mask.sum().item()

print(f"grade accuracy: {100 * correct_grade / total_grade:.2f} %")
if total_quality > 0:
    print(f"quality accuracy: {100 * correct_quality / total_quality:.2f} %")
else:
    print("no tea samples found for quality evaluation")

grade accuracy: 99.71 %
quality accuracy: 99.92 %


### save trained model

In [18]:
# create folder to save the model if doesnt exist
os.makedirs("../saved_models", exist_ok=True)

# save trained model
# torch.save(model.state_dict(), "../saved_models/tea_grading_model.pth")
# torch.save(model.state_dict(), "../saved_models/tea_grading_model_v2.pth")
torch.save(model.state_dict(), "../saved_models/tea_grading_model_v3.pth")

print("Model saved.")

Model saved.


### test prediction

In [19]:
# predict grade and quality for a single image
def predict(image_path, model):
    # set model to evaluation mode
    model.eval()

    # load image and convert to RGB 
    image = Image.open(image_path).convert("RGB")

    # apply training transformations
    image = transform(image).unsqueeze(0).to(device)

    # disable gradient computation for inference
    with torch.no_grad():
        # pass through the model
        grade_out, quality_out = model(image)

        # get predicted grade and quality
        grade = torch.argmax(grade_out,1).item()
        quality = torch.argmax(quality_out,1).item() + 1

    # map grade index to text 
    grade_map = {0:"OP", 1:"OP1", 2:"OPA"}

    return grade_map[grade], quality

#predict("../dataset/images/OP/quality_1/DSC00145.JPG", model)
# predict("../dataset/images/OPA/quality_7/DSC03637.JPG", model)
predict("../dataset/images/OP1/quality_4/DSC08756.JPG", model)

('OP1', 4)