In [1]:
import os
import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torchvision.models as models

class Food101Dataset(Dataset):
    def __init__(self, data_dir, train_txt_path, test_txt_path, val_split_ratio=0.2):
        self.data_dir = data_dir
        self.train_txt_path = train_txt_path
        self.test_txt_path = test_txt_path
        self.val_split_ratio = val_split_ratio
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
        ])

        self.train_data, self.val_data, self.test_data, self.label_to_int = self._load_data(train_txt_path, test_txt_path)

    def _load_data(self, train_txt_path, test_txt_path):
        train_data = []
        test_data = []

        label_to_int = {}
        int_label = 0

        with open(os.path.join(self.data_dir, train_txt_path), 'r') as f:
            lines = f.readlines()
        for line in lines:
            filename = line.strip() + '.jpg'
            label = filename.split('/')[0]
            if label not in label_to_int:
                label_to_int[label] = int_label
                int_label += 1
            image_path = os.path.join(self.data_dir, 'images', filename)
            train_data.append((image_path, label))

        with open(os.path.join(self.data_dir, test_txt_path), 'r') as f:
            lines = f.readlines()
        for line in lines:
            filename = line.strip() + '.jpg'
            label = filename.split('/')[0]
            image_path = os.path.join(self.data_dir, 'images', filename)
            test_data.append((image_path, label))

        # Split train_data into train and validation sets using train_test_split
        train_data, val_data = train_test_split(train_data, test_size=self.val_split_ratio, random_state=42)

        return train_data, val_data, test_data, label_to_int

    def __len__(self):
        return len(self.train_data) + len(self.val_data) + len(self.test_data)

    def __getitem__(self, idx):
        if idx < len(self.train_data):
            data_source = self.train_data
        elif idx < len(self.train_data) + len(self.val_data):
            data_source = self.val_data
            idx -= len(self.train_data)
        else:
            data_source = self.test_data
            idx -= (len(self.train_data) + len(self.val_data))

        image_path, label = data_source[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        # Convert label to integer using label_to_int mapping
        label = torch.tensor(self.label_to_int[label], dtype=torch.int64)

        return {
            'image': image,
            'label': label,
        }

    def get_splits(self):
        train_subset = Subset(self, list(range(len(self.train_data))))
        val_subset = Subset(self, list(range(len(self.train_data), len(self.train_data) + len(self.val_data))))
        test_subset = Subset(self, list(range(len(self.train_data) + len(self.val_data), len(self.train_data) + len(self.val_data) + len(self.test_data))))
        return train_subset, val_subset, test_subset
    
class Config:
    num_classes = 101
    learning_rate = 1e-4
    num_epochs = 256
    batch_size = 128

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the pre-trained ViT model
model = models.vit_b_16(pretrained=True)

in_features = model.heads.head.in_features
classifier = nn.Linear(in_features=in_features, out_features=Config.num_classes)
model.heads.head = classifier

for param in model.parameters():
    param.requires_grad = False
model.heads.head.weight.requires_grad = True
model.heads.head.bias.requires_grad = True

model.to(device)

criterion = nn.CrossEntropyLoss()

data_handler = Food101Dataset(data_dir='food101', train_txt_path='meta/train.txt', test_txt_path='meta/test.txt')

train_data, val_data, test_data = data_handler.get_splits()

# Create data loaders using the batch_size from the Config class
train_loader = DataLoader(train_data, batch_size=Config.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=Config.batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=Config.batch_size, shuffle=False)

optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)



In [2]:
from collections import deque

# Initialize a deque to keep track of the last two validation accuracies
val_loss_history = deque(maxlen=2)

# Define a variable to track the number of consecutive times validation accuracy drops
consecutive_drops = 0

for epoch in range(Config.num_epochs):
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    model.train()
    total_loss = 0.0

    for batch_idx, batch in progress_bar:
        # Retrieve features and labels from the current batch
        features = batch['image'].to(device)
        labels = batch['label'].to(device)

        outputs = model(features)
        predictions = torch.argmax(outputs, dim=1)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        progress_bar.set_description(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

    progress_bar.close()
    average_loss = total_loss / len(train_loader)

    model.eval()
    val_total_loss = 0.0
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch in val_loader:
            features = batch['image'].to(device)
            labels = batch['label'].to(device)

            outputs = model(features)

            val_loss = criterion(outputs, labels)
            val_total_loss += val_loss.item()

            predictions = torch.argmax(outputs, dim=1)

            all_labels.extend(labels.tolist())
            all_predictions.extend(predictions.tolist())

    average_val_loss = val_total_loss / len(val_loader)

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f'Epoch [{epoch + 1}/{Config.num_epochs}] - Loss: {average_loss:.4f} - Validation Loss: {average_val_loss:.4f} - Validation Accuracy: {accuracy:.4f}')

    # Append the validation loss to the history
    val_loss_history.append(average_val_loss)

    # Check if validation loss has increased twice in a row
    if len(val_loss_history) == 2 and val_loss_history[0] < val_loss_history[1]:
        consecutive_increases += 1
        if consecutive_increases >= 2:
            print('Validation loss has increased twice in a row. Stopping training.')
            break
    else:
        consecutive_increases = 0  # Reset counter

Batch 473/474, Loss: 3.3518: 100%|██████████| 474/474 [07:34<00:00,  1.04it/s]


Epoch [1/256] - Loss: 3.8601 - Validation Loss: 3.2599 - Validation Accuracy: 0.3630


Batch 473/474, Loss: 2.6555: 100%|██████████| 474/474 [07:40<00:00,  1.03it/s]


Epoch [2/256] - Loss: 2.9163 - Validation Loss: 2.6806 - Validation Accuracy: 0.4314


Batch 473/474, Loss: 2.6231: 100%|██████████| 474/474 [07:51<00:00,  1.00it/s]


Epoch [3/256] - Loss: 2.4982 - Validation Loss: 2.4002 - Validation Accuracy: 0.4692


Batch 473/474, Loss: 2.1252: 100%|██████████| 474/474 [07:39<00:00,  1.03it/s]


Epoch [4/256] - Loss: 2.2671 - Validation Loss: 2.2219 - Validation Accuracy: 0.4933


Batch 473/474, Loss: 1.9708: 100%|██████████| 474/474 [07:43<00:00,  1.02it/s]


Epoch [5/256] - Loss: 2.1144 - Validation Loss: 2.0989 - Validation Accuracy: 0.5100


Batch 473/474, Loss: 2.1129: 100%|██████████| 474/474 [07:35<00:00,  1.04it/s]


Epoch [6/256] - Loss: 2.0065 - Validation Loss: 2.0137 - Validation Accuracy: 0.5242


Batch 473/474, Loss: 1.5661: 100%|██████████| 474/474 [07:47<00:00,  1.01it/s]


Epoch [7/256] - Loss: 1.9257 - Validation Loss: 1.9511 - Validation Accuracy: 0.5301


Batch 473/474, Loss: 1.7125: 100%|██████████| 474/474 [07:42<00:00,  1.03it/s]


Epoch [8/256] - Loss: 1.8568 - Validation Loss: 1.8894 - Validation Accuracy: 0.5430


Batch 473/474, Loss: 2.0777: 100%|██████████| 474/474 [07:37<00:00,  1.04it/s]


Epoch [9/256] - Loss: 1.8014 - Validation Loss: 1.8444 - Validation Accuracy: 0.5516


Batch 473/474, Loss: 1.8777: 100%|██████████| 474/474 [07:43<00:00,  1.02it/s]


Epoch [10/256] - Loss: 1.7597 - Validation Loss: 1.8043 - Validation Accuracy: 0.5570


Batch 473/474, Loss: 1.3515: 100%|██████████| 474/474 [07:45<00:00,  1.02it/s]


Epoch [11/256] - Loss: 1.7153 - Validation Loss: 1.7706 - Validation Accuracy: 0.5622


Batch 473/474, Loss: 1.7074: 100%|██████████| 474/474 [07:45<00:00,  1.02it/s]


Epoch [12/256] - Loss: 1.6806 - Validation Loss: 1.7392 - Validation Accuracy: 0.5682


Batch 473/474, Loss: 2.0426: 100%|██████████| 474/474 [07:53<00:00,  1.00it/s]


Epoch [13/256] - Loss: 1.6490 - Validation Loss: 1.7225 - Validation Accuracy: 0.5710


Batch 473/474, Loss: 1.2076: 100%|██████████| 474/474 [07:36<00:00,  1.04it/s]


Epoch [14/256] - Loss: 1.6199 - Validation Loss: 1.7003 - Validation Accuracy: 0.5760


Batch 473/474, Loss: 1.7592: 100%|██████████| 474/474 [07:40<00:00,  1.03it/s]


Epoch [15/256] - Loss: 1.5978 - Validation Loss: 1.6788 - Validation Accuracy: 0.5825


Batch 473/474, Loss: 1.8603: 100%|██████████| 474/474 [07:40<00:00,  1.03it/s]


Epoch [16/256] - Loss: 1.5709 - Validation Loss: 1.6548 - Validation Accuracy: 0.5871


Batch 473/474, Loss: 1.4865: 100%|██████████| 474/474 [07:41<00:00,  1.03it/s]


Epoch [17/256] - Loss: 1.5513 - Validation Loss: 1.6354 - Validation Accuracy: 0.5882


Batch 473/474, Loss: 1.3304: 100%|██████████| 474/474 [07:44<00:00,  1.02it/s]


Epoch [18/256] - Loss: 1.5324 - Validation Loss: 1.6302 - Validation Accuracy: 0.5888


Batch 473/474, Loss: 1.5023: 100%|██████████| 474/474 [07:50<00:00,  1.01it/s]


Epoch [19/256] - Loss: 1.5159 - Validation Loss: 1.6166 - Validation Accuracy: 0.5967


Batch 473/474, Loss: 1.2919: 100%|██████████| 474/474 [07:35<00:00,  1.04it/s]


Epoch [20/256] - Loss: 1.4985 - Validation Loss: 1.5965 - Validation Accuracy: 0.5974


Batch 473/474, Loss: 1.6231: 100%|██████████| 474/474 [07:54<00:00,  1.00s/it]


Epoch [21/256] - Loss: 1.4823 - Validation Loss: 1.5882 - Validation Accuracy: 0.5981


Batch 473/474, Loss: 1.0847: 100%|██████████| 474/474 [07:44<00:00,  1.02it/s]


Epoch [22/256] - Loss: 1.4672 - Validation Loss: 1.5816 - Validation Accuracy: 0.5970


Batch 473/474, Loss: 1.8866: 100%|██████████| 474/474 [07:40<00:00,  1.03it/s]


Epoch [23/256] - Loss: 1.4524 - Validation Loss: 1.5634 - Validation Accuracy: 0.6030


Batch 473/474, Loss: 1.4973: 100%|██████████| 474/474 [07:49<00:00,  1.01it/s]


Epoch [24/256] - Loss: 1.4385 - Validation Loss: 1.5612 - Validation Accuracy: 0.6030


Batch 473/474, Loss: 1.3206: 100%|██████████| 474/474 [07:42<00:00,  1.03it/s]


Epoch [25/256] - Loss: 1.4296 - Validation Loss: 1.5486 - Validation Accuracy: 0.6069


Batch 473/474, Loss: 1.3887: 100%|██████████| 474/474 [07:40<00:00,  1.03it/s]


Epoch [26/256] - Loss: 1.4142 - Validation Loss: 1.5326 - Validation Accuracy: 0.6110


Batch 473/474, Loss: 1.3913: 100%|██████████| 474/474 [07:38<00:00,  1.03it/s]


Epoch [27/256] - Loss: 1.4046 - Validation Loss: 1.5250 - Validation Accuracy: 0.6104


Batch 473/474, Loss: 1.5247: 100%|██████████| 474/474 [07:35<00:00,  1.04it/s]


Epoch [28/256] - Loss: 1.3921 - Validation Loss: 1.5192 - Validation Accuracy: 0.6157


Batch 473/474, Loss: 1.3588: 100%|██████████| 474/474 [07:43<00:00,  1.02it/s]


Epoch [29/256] - Loss: 1.3828 - Validation Loss: 1.5149 - Validation Accuracy: 0.6141


Batch 473/474, Loss: 1.1048: 100%|██████████| 474/474 [07:41<00:00,  1.03it/s]


Epoch [30/256] - Loss: 1.3741 - Validation Loss: 1.5073 - Validation Accuracy: 0.6145


Batch 473/474, Loss: 1.2477: 100%|██████████| 474/474 [07:42<00:00,  1.02it/s]


Epoch [31/256] - Loss: 1.3646 - Validation Loss: 1.4985 - Validation Accuracy: 0.6180


Batch 473/474, Loss: 1.0374: 100%|██████████| 474/474 [07:19<00:00,  1.08it/s]


Epoch [32/256] - Loss: 1.3532 - Validation Loss: 1.4938 - Validation Accuracy: 0.6224


Batch 473/474, Loss: 1.3463: 100%|██████████| 474/474 [07:29<00:00,  1.05it/s]


Epoch [33/256] - Loss: 1.3474 - Validation Loss: 1.4909 - Validation Accuracy: 0.6193


Batch 473/474, Loss: 1.3514: 100%|██████████| 474/474 [07:34<00:00,  1.04it/s]


Epoch [34/256] - Loss: 1.3398 - Validation Loss: 1.4862 - Validation Accuracy: 0.6201


Batch 473/474, Loss: 1.2551: 100%|██████████| 474/474 [07:30<00:00,  1.05it/s]


Epoch [35/256] - Loss: 1.3277 - Validation Loss: 1.4777 - Validation Accuracy: 0.6244


Batch 473/474, Loss: 1.4690: 100%|██████████| 474/474 [07:23<00:00,  1.07it/s]


Epoch [36/256] - Loss: 1.3220 - Validation Loss: 1.4782 - Validation Accuracy: 0.6236


Batch 473/474, Loss: 0.9580: 100%|██████████| 474/474 [07:29<00:00,  1.06it/s]


Epoch [37/256] - Loss: 1.3129 - Validation Loss: 1.4636 - Validation Accuracy: 0.6246


Batch 473/474, Loss: 1.5459: 100%|██████████| 474/474 [07:28<00:00,  1.06it/s]


Epoch [38/256] - Loss: 1.3073 - Validation Loss: 1.4571 - Validation Accuracy: 0.6257


Batch 473/474, Loss: 1.0964: 100%|██████████| 474/474 [57:28<00:00,  7.28s/it]    


Epoch [39/256] - Loss: 1.3019 - Validation Loss: 1.4579 - Validation Accuracy: 0.6275


Batch 473/474, Loss: 1.3492: 100%|██████████| 474/474 [07:22<00:00,  1.07it/s]


Epoch [40/256] - Loss: 1.2953 - Validation Loss: 1.4510 - Validation Accuracy: 0.6258


Batch 473/474, Loss: 1.0958: 100%|██████████| 474/474 [07:32<00:00,  1.05it/s]


Epoch [41/256] - Loss: 1.2859 - Validation Loss: 1.4406 - Validation Accuracy: 0.6316


Batch 473/474, Loss: 1.4903: 100%|██████████| 474/474 [07:25<00:00,  1.06it/s]


Epoch [42/256] - Loss: 1.2789 - Validation Loss: 1.4480 - Validation Accuracy: 0.6290


Batch 473/474, Loss: 1.4293: 100%|██████████| 474/474 [07:34<00:00,  1.04it/s]


Epoch [43/256] - Loss: 1.2751 - Validation Loss: 1.4485 - Validation Accuracy: 0.6330
Validation loss has increased twice in a row. Stopping training.


In [3]:
# save the model
torch.save(model.state_dict(), 'food101_vit.pt')

In [4]:
import os
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, accuracy_score

def evaluate_on_test_data(model, criterion, num_classes, device):
    model.eval()
    total_loss = 0.0
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch in test_loader:
            features = batch['image'].to(device)
            labels = batch['label'].to(device)

            outputs = model(features)

            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Assuming your model returns class probabilities, apply softmax
            predictions = torch.argmax(outputs, dim=1)

            all_labels.extend(labels.tolist())
            all_predictions.extend(predictions.tolist())
    
    average_loss = total_loss / len(test_loader)

    # Compute the classification report
    classification_rep = classification_report(all_labels, all_predictions, target_names=[f'Class {i}' for i in range(num_classes)])
    
    accuracy = accuracy_score(all_labels, all_predictions)
    
    print(f'Test Loss: {average_loss:.4f} - Test Accuracy: {accuracy:.4f}')
    print('Classification Report:\n', classification_rep)

evaluate_on_test_data(model, criterion, Config.num_classes, device)

Test Loss: 1.2133 - Test Accuracy: 0.6863
Classification Report:
               precision    recall  f1-score   support

     Class 0       0.43      0.35      0.39       250
     Class 1       0.63      0.74      0.68       250
     Class 2       0.73      0.67      0.70       250
     Class 3       0.71      0.72      0.72       250
     Class 4       0.55      0.57      0.56       250
     Class 5       0.56      0.60      0.58       250
     Class 6       0.78      0.78      0.78       250
     Class 7       0.81      0.83      0.82       250
     Class 8       0.44      0.44      0.44       250
     Class 9       0.61      0.63      0.62       250
    Class 10       0.60      0.56      0.58       250
    Class 11       0.69      0.76      0.73       250
    Class 12       0.63      0.70      0.66       250
    Class 13       0.65      0.70      0.67       250
    Class 14       0.69      0.66      0.67       250
    Class 15       0.52      0.46      0.49       250
    Class 16   