In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import efficientnet_b5
import torch.optim as optim
from torch.optim import lr_scheduler
from sklearn.preprocessing import LabelEncoder
from PIL import Image
from torch.utils.checkpoint import checkpoint

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data_dir = '/kaggle/input/food-101/food-101/food-101/'

with open(os.path.join(data_dir, 'meta', 'train.txt'), 'r') as f:
    train_images = f.read().splitlines()

with open(os.path.join(data_dir, 'meta', 'test.txt'), 'r') as f:
    test_images = f.read().splitlines()

specified_classes = [
    'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio',
    'beef_tartare', 'beet_salad', 'beignets', 'bibimbap',
    'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad',
    'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche',
    'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla'
]

train_images = [os.path.join(data_dir, 'images', img + '.jpg') for img in train_images]
test_images = [os.path.join(data_dir, 'images', img + '.jpg') for img in test_images]

train_labels = [img.split('/')[0] if img.split('/')[0] in specified_classes else 'other' for img in train_images]
test_labels = [img.split('/')[0] if img.split('/')[0] in specified_classes else 'other' for img in test_images]

classes = specified_classes + ['other']

label_encoder = LabelEncoder()
label_encoder.fit(classes)
train_labels_encoded = label_encoder.transform(train_labels)
test_labels_encoded = label_encoder.transform(test_labels)

class FoodDataset(Dataset):
    def __init__(self, paths, labels, transform=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

train_transforms = transforms.Compose([
    transforms.Resize((456, 456)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transforms = transforms.Compose([
    transforms.Resize((456, 456)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = FoodDataset(train_images, train_labels_encoded, transform=train_transforms)
test_dataset = FoodDataset(test_images, test_labels_encoded, transform=test_transforms)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, pin_memory=True)

class FoodClassifier(nn.Module):
    def __init__(self, num_classes):
        super(FoodClassifier, self).__init__()
        self.model = efficientnet_b5(weights='IMAGENET1K_V1')
        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, num_classes)

    def forward(self, x):
        # Apply checkpointing to save memory
        x = checkpoint(self.model.features, x)
        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.model.classifier(x)
        return x

num_classes = len(classes)
model = FoodClassifier(num_classes=num_classes).to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

steps_per_epoch = len(train_loader)
scheduler = lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=0.001,
    steps_per_epoch=steps_per_epoch,
    epochs=10,
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=25.0
)

num_epochs = 5
best_val_acc = 0.0
best_model_path = 'best_food_classifier_model.pth'

scaler = torch.cuda.amp.GradScaler()

accumulation_steps = 4  # Number of steps to accumulate gradients

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    total_train = 0
    correct_train = 0

    print(f'Epoch [{epoch + 1}/{num_epochs}]')
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Zero the gradients every 'accumulation_steps'
        if i % accumulation_steps == 0:
            optimizer.zero_grad()

        # Forward pass with autocast for mixed-precision
        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, labels) / accumulation_steps  # Scale loss for accumulation
        
        scaler.scale(loss).backward()

        # Perform step update every 'accumulation_steps'
        if (i + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        running_loss += loss.item() * images.size(0) * accumulation_steps  # Reverse scale for reporting
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    epoch_loss = running_loss / total_train
    epoch_acc = 100 * correct_train / total_train
    print(f'Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_acc:.2f}%')

    model.eval()
    total_val = 0
    correct_val = 0
    val_running_loss = 0.0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

    val_loss = val_running_loss / total_val
    val_acc = 100 * correct_val / total_val
    print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%')

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f'Model saved with Validation Accuracy: {best_val_acc:.2f}%')

model.load_state_dict(torch.load(best_model_path))
model.eval()
total_test = 0
correct_test = 0
test_running_loss = 0.0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_test += labels.size(0)
        correct_test += (predicted == labels).sum().item()

test_loss = test_running_loss / total_test
test_acc = 100 * correct_test / total_test
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%')

Downloading: "https://download.pytorch.org/models/efficientnet_b5_lukemelas-1a07897c.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b5_lukemelas-1a07897c.pth
100%|██████████| 117M/117M [00:01<00:00, 96.2MB/s]
  scaler = torch.cuda.amp.GradScaler()


Epoch [1/5]


  with torch.cuda.amp.autocast():
  return fn(*args, **kwargs)


Train Loss: 1.6951, Train Accuracy: 84.30%
Validation Loss: 2.7395, Validation Accuracy: 59.01%
Model saved with Validation Accuracy: 59.01%
Epoch [2/5]
Train Loss: 1.2189, Train Accuracy: 92.96%
Validation Loss: 2.2582, Validation Accuracy: 71.19%
Model saved with Validation Accuracy: 71.19%
Epoch [3/5]
Train Loss: 1.0334, Train Accuracy: 96.96%
Validation Loss: 1.7739, Validation Accuracy: 84.55%
Model saved with Validation Accuracy: 84.55%
Epoch [4/5]
Train Loss: 0.9118, Train Accuracy: 98.89%
Validation Loss: 1.4450, Validation Accuracy: 92.85%
Model saved with Validation Accuracy: 92.85%
Epoch [5/5]
Train Loss: 0.8413, Train Accuracy: 99.59%
Validation Loss: 1.2137, Validation Accuracy: 97.14%
Model saved with Validation Accuracy: 97.14%


  model.load_state_dict(torch.load(best_model_path))


Test Loss: 1.2137, Test Accuracy: 97.14%
