Imports and Device Setup

In [2]:
import os 
from pathlib import Path
import json

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
import matplotlib.pyplot as plt


In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("using device", device)

using device mps


Paths

In [4]:
DATA_DIR = Path("../data/raw")
TRAIN_DIR = DATA_DIR / "training"
VAL_DIR = DATA_DIR / "validation"


Image Transforms

In [5]:
IMAGE_SIZE = 224
train_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )

])
val_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


Dataset and DataLoader
- using ImageFolder: it's a PyTorch built in function which assign images labels acc to their folder names

In [6]:
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=train_transforms)
val_dataset = datasets.ImageFolder(VAL_DIR, transform=val_transforms)

class_to_idx = train_dataset.class_to_idx
idx_to_class = {v: k for k, v in class_to_idx.items()}
print("Classes:", class_to_idx)

Classes: {'Bread': 0, 'Dairy product': 1, 'Dessert': 2, 'Egg': 3, 'Fried food': 4, 'Meat': 5, 'Noodles-Pasta': 6, 'Rice': 7, 'Seafood': 8, 'Soup': 9, 'Vegetable-Fruit': 10}


In [21]:
train_dataset

Dataset ImageFolder
    Number of datapoints: 9866
    Root location: ../data/raw/training
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [7]:
with open("class_mapping.json", "w") as f:
    json.dump(idx_to_class, f)

DataLoader: how data is going to fed up in the model
- num_workers = 2, tells how many cpu processes load data in parallel

In [8]:
BATCH_SIZE = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

Load ResNet18
- we freeze some layers because originally resnet buil on 1000 classes so we will our case imp layers only

In [9]:
model = models.resnet18(pretrained=True)

#Freeze backbone
for param in model.parameters():
    param.requires_grad = False

#replace classifier
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 11)
model = model.to(device)




Loss And Optimizers

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.fc.parameters(),
    lr=1e-4
)

In [14]:
criterion

CrossEntropyLoss()

Training Loop

In [11]:
def train_one_epoch(model, loader):
    model.train()
    running_loss, correct, total = 0, 0, 0
    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss/len(loader), correct/total

Validation Loop

In [12]:
def validate(model, loader):
    model.eval()
    running_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        return running_loss/len(loader), correct/total

In [13]:
EPOCHS = 8
best_val_acc = 0

for epoch in range(EPOCHS):
    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_loss, val_acc = validate(model, val_loader)

    print(f"""
    Epoch {epoch+1}/{EPOCHS}
    Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}
    Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}
    """)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")

  0%|          | 0/309 [00:00<?, ?it/s]

100%|██████████| 309/309 [00:36<00:00,  8.50it/s]



    Epoch 1/8
    Train Loss: 2.0690 | Train Acc: 0.2878
    Val Loss:   1.7656 | Val Acc:   0.4359
    


100%|██████████| 309/309 [00:36<00:00,  8.49it/s]



    Epoch 2/8
    Train Loss: 1.5944 | Train Acc: 0.5257
    Val Loss:   1.4236 | Val Acc:   0.5956
    


100%|██████████| 309/309 [00:38<00:00,  7.95it/s]



    Epoch 3/8
    Train Loss: 1.3441 | Train Acc: 0.6243
    Val Loss:   1.2346 | Val Acc:   0.6501
    


100%|██████████| 309/309 [00:39<00:00,  7.87it/s]



    Epoch 4/8
    Train Loss: 1.1876 | Train Acc: 0.6624
    Val Loss:   1.1233 | Val Acc:   0.6738
    


100%|██████████| 309/309 [00:37<00:00,  8.23it/s]



    Epoch 5/8
    Train Loss: 1.0894 | Train Acc: 0.6914
    Val Loss:   1.0335 | Val Acc:   0.6921
    


100%|██████████| 309/309 [00:36<00:00,  8.49it/s]



    Epoch 6/8
    Train Loss: 1.0107 | Train Acc: 0.7076
    Val Loss:   0.9658 | Val Acc:   0.7111
    


100%|██████████| 309/309 [00:39<00:00,  7.81it/s]



    Epoch 7/8
    Train Loss: 0.9567 | Train Acc: 0.7149
    Val Loss:   0.9111 | Val Acc:   0.7219
    


100%|██████████| 309/309 [00:38<00:00,  7.98it/s]



    Epoch 8/8
    Train Loss: 0.9127 | Train Acc: 0.7315
    Val Loss:   0.8776 | Val Acc:   0.7315
    
