# Notebook: Tea_Grading_Training.ipynb
# Folder: notebooks/
# Purpose: Interactive training and evaluation for Tea Grading AI Model

# ---------- imports ----------

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt

# ---------- load dataset ----------

In [None]:
class TeaDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        self.grade_map = {"OP": 0, "OP1": 1, "OPA": 2}

        for grade_name in os.listdir(root_dir):
            grade_path = os.path.join(root_dir, grade_name)
            if not os.path.isdir(grade_path) or grade_name not in self.grade_map:
                continue

            grade_label = self.grade_map[grade_name]

            for quality_folder in os.listdir(grade_path):
                if not quality_folder.startswith("quality_"):
                    continue

                quality_path = os.path.join(grade_path, quality_folder)
                quality_num = int(quality_folder.split('_')[1]) - 1

                for img_name in os.listdir(quality_path):
                    if img_name.lower().endswith((".jpg", ".jpeg", ".png")):
                        img_path = os.path.join(quality_path, img_name)
                        self.samples.append((img_path, grade_label, quality_num))

        print(f"total samples found: {len(self.samples)}")


    def __len__(self):
        return len(self.samples)


    def __getitem__(self, idx):
        img_path, grade_label, quality_label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
            
        return image, grade_label, quality_label

# ---------- transforms and dataloader ----------

In [4]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
])

DATASET_PATH = "../dataset/images"

dataset = TeaDataset(DATASET_PATH, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

Total samples found: 6113


# ---------- model definition ----------

In [5]:
class TeaNet(nn.Module):
    def __init__(self):
        super(TeaNet, self).__init__()
        self.backbone = models.resnet18(pretrained=True)
        num_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.grade_head = nn.Linear(num_features, 3) # OP, OP1, OPA
        self.quality_head = nn.Linear(num_features, 10) # quality 1-10

    def forward(self, x):
        features = self.backbone(x)
        grade_out = self.grade_head(features)
        quality_out = self.quality_head(features)
        
        return grade_out, quality_out

# ---------- training steps ----------

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TeaNet().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

epochs = 10




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\User/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth


100.0%


# ---------- training loop ----------

In [None]:
for epoch in range(epochs):
    model.train()
    total_loss = 0

    for images, grade_labels, quality_labels in train_loader:
        images, grade_labels, quality_labels = images.to(device), grade_labels.to(device), quality_labels.to(device)
        optimizer.zero_grad()
        grade_preds, quality_preds = model(images)
        loss = criterion(grade_preds, grade_labels) + criterion(quality_preds, quality_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"epoch {epoch+1}/{epochs} - training loss: {total_loss:.4f}")

Epoch 1/10 - Training Loss: 177.6365
Epoch 2/10 - Training Loss: 13.8094
Epoch 3/10 - Training Loss: 7.0935
Epoch 4/10 - Training Loss: 6.9107
Epoch 5/10 - Training Loss: 10.2302
Epoch 6/10 - Training Loss: 4.2895
Epoch 7/10 - Training Loss: 2.1046
Epoch 8/10 - Training Loss: 1.7329
Epoch 9/10 - Training Loss: 1.3256
Epoch 10/10 - Training Loss: 1.5078


# ---------- validation ----------

In [None]:
model.eval()

correct_grade = 0
total_grade = 0
correct_quality = 0
total_quality = 0
with torch.no_grad():
    for images, grade_labels, quality_labels in val_loader:
        images, grade_labels, quality_labels = images.to(device), grade_labels.to(device), quality_labels.to(device)
        grade_preds, quality_preds = model(images)
        _, grade_predicted = torch.max(grade_preds,1)
        _, quality_predicted = torch.max(quality_preds,1)
        total_grade += grade_labels.size(0)
        correct_grade += (grade_predicted==grade_labels).sum().item()
        total_quality += quality_labels.size(0)
        correct_quality += (quality_predicted==quality_labels).sum().item()
        
print(f"grade accuracy: {100 * correct_grade / total_grade} %",)
print(f"quality accuracy: {100 * correct_quality / total_quality} %")

Grade Accuracy: 99.83646770237122 %
Quality Accuracy: 99.75470155355683 %


# ---------- save trained model ----------

In [9]:
os.makedirs("../saved_models", exist_ok=True)
torch.save(model.state_dict(), "../saved_models/tea_grading_model.pth")

print("Model saved.")

Model saved.


# ---------- test prediction ----------

In [12]:
def predict(image_path, model):
    model.eval()
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        grade_out, quality_out = model(image)
        grade = torch.argmax(grade_out,1).item()
        quality = torch.argmax(quality_out,1).item() + 1
    grade_map = {0:"OP", 1:"OP1", 2:"OPA"}

    return grade_map[grade], quality

#predict("../dataset/images/OP/quality_1/DSC00145.JPG", model)
# predict("../dataset/images/OPA/quality_7/DSC03637.JPG", model)
predict("../dataset/images/OP1/quality_4/DSC08756.JPG", model)

('OP1', 4)