# 🧠 ResNet18 Classification Training on Colab (6 classes)

In [None]:
!pip install mlflow torchvision


In [None]:
import mlflow, torch, glob, os
from torch import nn, optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("dispatch-pipeline")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
# Dataset với 6 class: dish/tray + empty/not_empty/kakigori
class SixClassDataset(Dataset):
    def __init__(self, root, transform=None):
        self.files = glob.glob(f"{root}/*/*/*.jpg")
        self.label_map = {
            "dish/empty": 0, "dish/not_empty": 1, "dish/kakigori": 2,
            "tray/empty": 3, "tray/not_empty": 4, "tray/kakigori": 5
        }
        self.transform = transform

    def __getitem__(self, idx):
        path = self.files[idx]
        parts = path.split("/")[-3:-1]
        label = self.label_map["/".join(parts)]
        image = Image.open(path).convert("RGB")
        if self.transform: image = self.transform(image)
        return image, label

    def __len__(self): return len(self.files)


In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])
dataset = SixClassDataset("/content/drive/MyDrive/dispatch/data/raw/Dataset/Classification", transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True)


In [None]:
# Train ResNet18 và log vào MLflow
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 6)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

with mlflow.start_run(run_name="resnet18_classifier_v1"):
    mlflow.log_params({"model": "resnet18", "classes": 6, "epochs": 5, "lr": 1e-3})
    for epoch in range(5):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            correct += (out.argmax(1) == y).sum().item()
            total += y.size(0)
        acc = correct / total
        mlflow.log_metric("loss", total_loss, step=epoch)
        mlflow.log_metric("accuracy", acc, step=epoch)

    torch.save(model.state_dict(), "resnet18_dispatch.pt")
    mlflow.log_artifact("resnet18_dispatch.pt")
