# Food101 CNN Classification (Fixed)

This notebook demonstrates loading the Food101 dataset, ensuring all images are RGB, defining custom collate functions, and training a simple CNN without normalization broadcast errors.

In [3]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from PIL import Image
from tqdm import tqdm

In [5]:
## 1. Define transforms and collate functions

# Transforms
train_tf = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])
test_tf = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

# Custom collate functions that convert to RGB
def train_collate_fn(batch):
    imgs = []
    for item in batch:
        img = item['image']
        if isinstance(img, Image.Image) and img.mode != 'RGB':
            img = img.convert('RGB')
        imgs.append(train_tf(img))
    pixel_values = torch.stack(imgs)
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
    return {'pixel_values': pixel_values, 'labels': labels}

def test_collate_fn(batch):
    imgs = []
    for item in batch:
        img = item['image']
        if isinstance(img, Image.Image) and img.mode != 'RGB':
            img = img.convert('RGB')
        imgs.append(test_tf(img))
    pixel_values = torch.stack(imgs)
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
    return {'pixel_values': pixel_values, 'labels': labels}

In [7]:
## 2. Load dataset and create DataLoaders

raw_train = load_dataset('food101', split='train')
raw_test  = load_dataset('food101', split='validation')

train_loader = DataLoader(raw_train, batch_size=64, shuffle=True, collate_fn=train_collate_fn)
test_loader  = DataLoader(raw_test,  batch_size=64, shuffle=False, collate_fn=test_collate_fn)

print('Number of training batches:', len(train_loader))
print('Number of validation batches:', len(test_loader))

Number of training batches: 1184
Number of validation batches: 395


In [9]:
## 3. Define the CNN model

import torch.nn as nn
import torch.nn.functional as F

class SimpleFoodCNN(nn.Module):
    def __init__(self, num_classes=101):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool  = nn.MaxPool2d(2,2)
        self.fc1   = nn.Linear(128 * 16 * 16, 512)
        self.fc2   = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)

# Instantiate
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleFoodCNN().to(device)

In [11]:
## 4. Training and evaluation functions

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train_epoch(loader):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for batch in tqdm(loader, desc='Train'):
        imgs, labels = batch['pixel_values'].to(device), batch['labels'].to(device)
        optimizer.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    return total_loss/total, correct/total

def eval_epoch(loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in tqdm(loader, desc='Eval '):
            imgs, labels = batch['pixel_values'].to(device), batch['labels'].to(device)
            logits = model(imgs)
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return correct/total

In [13]:
## 5. Run training loop

epochs = 5
for epoch in range(1, epochs+1):
    train_loss, train_acc = train_epoch(train_loader)
    val_acc = eval_epoch(test_loader)
    print(f"Epoch {epoch}: Loss={train_loss:.4f}, Train Acc={train_acc:.3f}, Val Acc={val_acc:.3f}")

Train: 100%|███████████████████████████████████████████████████████████████████████| 1184/1184 [05:32<00:00,  3.56it/s]
Eval : 100%|█████████████████████████████████████████████████████████████████████████| 395/395 [01:42<00:00,  3.85it/s]


Epoch 1: Loss=4.3568, Train Acc=0.044, Val Acc=0.107


Train: 100%|███████████████████████████████████████████████████████████████████████| 1184/1184 [05:35<00:00,  3.53it/s]
Eval : 100%|█████████████████████████████████████████████████████████████████████████| 395/395 [01:40<00:00,  3.94it/s]


Epoch 2: Loss=3.8291, Train Acc=0.122, Val Acc=0.187


Train: 100%|███████████████████████████████████████████████████████████████████████| 1184/1184 [05:32<00:00,  3.57it/s]
Eval : 100%|█████████████████████████████████████████████████████████████████████████| 395/395 [01:38<00:00,  4.00it/s]


Epoch 3: Loss=3.5029, Train Acc=0.177, Val Acc=0.235


Train: 100%|███████████████████████████████████████████████████████████████████████| 1184/1184 [05:28<00:00,  3.60it/s]
Eval : 100%|█████████████████████████████████████████████████████████████████████████| 395/395 [01:42<00:00,  3.84it/s]


Epoch 4: Loss=3.2657, Train Acc=0.219, Val Acc=0.257


Train: 100%|███████████████████████████████████████████████████████████████████████| 1184/1184 [05:34<00:00,  3.54it/s]
Eval : 100%|█████████████████████████████████████████████████████████████████████████| 395/395 [01:35<00:00,  4.15it/s]

Epoch 5: Loss=3.0526, Train Acc=0.262, Val Acc=0.276



