In [None]:
import os
import math
import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.transforms import Compose, Resize, ToTensor
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from PIL import Image
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import torch.nn as nn

class Food101Dataset(Dataset):
    """
    Custom dataset class for handling the Food-101 dataset
    """
    def __init__(self, data_dir, num_samples_per_split=30000):
        """
        Initialize the Food-101 dataset.

        Args:
            data_dir (str): Path to the root directory of the Food-101 dataset.
            num_samples_per_split (int): Number of samples to load for each split (training and validation).
        """
        self.data_dir = data_dir
        self.num_samples_per_split = num_samples_per_split

        # Create a list of image file paths and their corresponding labels
        self.image_paths, self.labels = self._load_data()

        print('Number of images found: ', len(self.image_paths))
        print('Number of labels found: ', len(self.labels))

        self.transform = Compose([
            Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        category_name_to_index = {}
        for index, category in enumerate(os.listdir(os.path.join(self.data_dir, 'images'))):
            category_name_to_index[category] = index

        self.category_name_to_index = category_name_to_index

    def _load_data(self):
        """
        Load image file paths and labels from the dataset.

        Returns:
            list, list: A list of image file paths and a list of corresponding labels.
        """
        image_paths = []
        labels = []

        # only include images that are in the training set
        # I can tell if an image is in the training set by looking at the file name
        # and seeing if it is in /meta/train.txt

        # Load the train.txt file
        with open(os.path.join(self.data_dir, 'meta/train.txt'), 'r') as f:
            train_files = f.readlines()

        # Extract the image file name from each line
        train_files = [file.split('/')[1].strip() + '.jpg' for file in train_files]

        # Traverse through each food category
        for category in os.listdir(os.path.join(self.data_dir, 'images')):
            category_path = os.path.join(self.data_dir, 'images', category)

            # Ensure it's a directory
            if os.path.isdir(category_path):
                for image_filename in os.listdir(category_path):
                    if image_filename.endswith('.jpg') and image_filename in train_files:
                        image_path = os.path.join(category_path, image_filename)
                        image_paths.append(image_path)
                        labels.append(category)

                    # Stop loading once we reach the desired number of samples per split
                    if len(image_paths) >= self.num_samples_per_split:
                        break

                # Additional check to ensure we have exactly the desired number of samples
                if len(image_paths) >= self.num_samples_per_split:
                    break

        return image_paths, labels

    def __len__(self):
        """
        Returns the length of the dataset.
        """
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)  # Apply the transform here

        label = self.labels[idx]
        category_index = torch.tensor(self.category_name_to_index[label])

        return {
            'image': image,
            'label': category_index,
            'filepath': image_path,
        }

    def get_splits(self):
        """
        Generate training and validation indices for training.
        """
        # Create a list of indices for the dataset split
        indices = list(range(len(self.image_paths)))
        train_indices, val_indices = train_test_split(indices, test_size=0.2, random_state=42)

        # Create the train and validation subsets
        train_subset = Subset(self, train_indices)
        val_subset = Subset(self, val_indices)

        return train_subset, val_subset

class Attention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        B, N, C = x.shape
        
        # Split the input into self.num_heads different heads
        query, key, value = self.query(x), self.key(x), self.value(x)
        
        query = query.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        key = key.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 3, 1)
        value = value.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        
        # Calculate scaled dot-product attention
        scores = torch.matmul(query, key) * self.scale
        attention = torch.nn.functional.softmax(scores, dim=-1)
        
        x = torch.matmul(attention, value)
        x = x.permute(0, 2, 1, 3).contiguous().view(B, N, C)
        
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout_rate=0.1):
        super(MultiHeadAttention, self).__init__()
        self.attention_heads = nn.ModuleList([Attention(embed_dim, num_heads) for _ in range(num_heads)])
        self.fc_out = nn.Linear(embed_dim * num_heads, embed_dim)  # Adjust in_features and out_features
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        attention_outputs = [head(x) for head in self.attention_heads]
        concat_attention = torch.cat(attention_outputs, dim=2)
        out = self.fc_out(concat_attention)
        out = self.dropout(out)
        return out


class VisionTransformer(nn.Module):
    def __init__(self, num_classes, num_patches, embed_dim, num_heads, num_layers, dropout_rate=0.1):
        super(VisionTransformer, self).__init__()
        self.num_patches = num_patches
        self.embed_dim = embed_dim
        self.patch_size = int(math.sqrt(224 * 224 / num_patches))  # Correctly calculate the patch size
        self.num_patches_h = int(224 / self.patch_size)
        self.num_patches_w = int(224 / self.patch_size)
        self.patch_embedding = nn.Linear(3 * self.patch_size * self.patch_size, embed_dim)
        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        self.transformer_layers = nn.ModuleList([MultiHeadAttention(embed_dim, num_heads) for _ in range(num_layers)])
        self.fc = nn.Linear(embed_dim, num_classes)
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        B, C, H, W = x.shape  # Get batch size, channels, height, and width
        x = x.permute(0, 2, 3, 1)  # Permute to (Batch Size, Height, Width, Channels)
        
        # Calculate the size of each patch
        self.num_patches_h = int(H / self.patch_size)
        self.num_patches_w = int(W / self.patch_size)
        
        # Split the image into patches
        x = x.unfold(1, self.patch_size, self.patch_size).unfold(2, self.patch_size, self.patch_size)
        
        # Reshape to (Batch Size, num_patches_h * num_patches_w, 3 * self.patch_size * self.patch_size)
        x = x.contiguous().view(B, self.num_patches, 3 * self.patch_size * self.patch_size)
        
        x = self.patch_embedding(x)
        x = x + self.positional_embedding

        for layer in self.transformer_layers:
            x = layer(x)
            x = self.dropout(x)

        x = x.mean(dim=1)  # Global average pooling
        x = self.fc(x)
        return x


    
class Config:
    num_classes = 101
    num_patches = 256
    
    embed_dim = 64
    num_heads = 8
    num_layers = 12
    learning_rate = 0.0001
    num_epochs = 64
    batch_size = 32

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
criterion = nn.CrossEntropyLoss()

data_handler = Food101Dataset(data_dir='food101', num_samples_per_split=30000)

train_data, val_data = data_handler.get_splits()

# Create data loaders using the batch_size from the Config class
train_loader = DataLoader(train_data, batch_size=Config.batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=Config.batch_size, shuffle=True)

model = VisionTransformer(Config.num_classes, Config.num_patches, Config.embed_dim, Config.num_heads, Config.num_layers)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=Config.learning_rate)


cuda
Number of images found:  30000
Number of labels found:  30000


In [2]:
for epoch in range(Config.num_epochs):
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    model.train()
    total_loss = 0.0

    for batch_idx, batch in progress_bar:
        # Retrieve features and labels from the current batch
        features = batch['image'].to(device)
        labels = batch['label'].to(device)

        # Forward pass
        outputs = model(features)

        predictions = torch.argmax(outputs, dim=1)

        # Calculate the loss
        loss = criterion(outputs, labels)

        # Backpropagation and optimization with gradient clipping
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        progress_bar.set_description(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

    average_loss = total_loss / len(train_loader)

    model.eval()
    val_total_loss = 0.0
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch in val_loader:
            features = batch['image'].to(device)
            labels = batch['label'].to(device)

            outputs = model(features)

            val_loss = criterion(outputs, labels)
            val_total_loss += val_loss.item()

            predictions = torch.argmax(outputs, dim=1)

            all_labels.extend(labels.tolist())
            all_predictions.extend(predictions.tolist())

    average_val_loss = val_total_loss / len(val_loader)

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f'Epoch [{epoch + 1}/{Config.num_epochs}] - Loss: {average_loss:.4f} - Validation Loss: {average_val_loss:.4f} - Validation Accuracy: {accuracy:.4f}')

Batch 749/750, Loss: 3.7960: 100%|██████████| 750/750 [04:27<00:00,  2.81it/s]


Epoch [1/64] - Loss: 3.7833 - Validation Loss: 3.7072 - Validation Accuracy: 0.0233


Batch 749/750, Loss: 3.6887: 100%|██████████| 750/750 [04:14<00:00,  2.94it/s]


Epoch [2/64] - Loss: 3.7037 - Validation Loss: 3.6970 - Validation Accuracy: 0.0258


Batch 749/750, Loss: 3.4776: 100%|██████████| 750/750 [04:21<00:00,  2.86it/s]


Epoch [3/64] - Loss: 3.6545 - Validation Loss: 3.5710 - Validation Accuracy: 0.0630


Batch 749/750, Loss: 3.5551: 100%|██████████| 750/750 [04:21<00:00,  2.86it/s]


Epoch [4/64] - Loss: 3.5451 - Validation Loss: 3.5213 - Validation Accuracy: 0.0588


Batch 749/750, Loss: 3.6010: 100%|██████████| 750/750 [04:23<00:00,  2.85it/s]


Epoch [5/64] - Loss: 3.5094 - Validation Loss: 3.5148 - Validation Accuracy: 0.0612


Batch 749/750, Loss: 3.4136: 100%|██████████| 750/750 [04:15<00:00,  2.93it/s]


Epoch [6/64] - Loss: 3.4913 - Validation Loss: 3.5072 - Validation Accuracy: 0.0615


Batch 749/750, Loss: 3.7139: 100%|██████████| 750/750 [04:21<00:00,  2.87it/s]


Epoch [7/64] - Loss: 3.4750 - Validation Loss: 3.4975 - Validation Accuracy: 0.0663


Batch 749/750, Loss: 3.5118: 100%|██████████| 750/750 [04:27<00:00,  2.80it/s]


Epoch [8/64] - Loss: 3.4594 - Validation Loss: 3.4972 - Validation Accuracy: 0.0692


Batch 749/750, Loss: 3.4047: 100%|██████████| 750/750 [04:24<00:00,  2.84it/s]


Epoch [9/64] - Loss: 3.4475 - Validation Loss: 3.4940 - Validation Accuracy: 0.0622


Batch 749/750, Loss: 3.5766: 100%|██████████| 750/750 [04:26<00:00,  2.81it/s]


Epoch [10/64] - Loss: 3.4302 - Validation Loss: 3.4742 - Validation Accuracy: 0.0643


Batch 749/750, Loss: 3.0398: 100%|██████████| 750/750 [04:30<00:00,  2.77it/s]


Epoch [11/64] - Loss: 3.4152 - Validation Loss: 3.4901 - Validation Accuracy: 0.0600


Batch 749/750, Loss: 3.4034: 100%|██████████| 750/750 [04:10<00:00,  3.00it/s]


Epoch [12/64] - Loss: 3.4021 - Validation Loss: 3.4747 - Validation Accuracy: 0.0620


Batch 749/750, Loss: 3.0986: 100%|██████████| 750/750 [04:22<00:00,  2.86it/s]


Epoch [13/64] - Loss: 3.3866 - Validation Loss: 3.4780 - Validation Accuracy: 0.0688


Batch 749/750, Loss: 3.3295: 100%|██████████| 750/750 [04:20<00:00,  2.87it/s]


Epoch [14/64] - Loss: 3.3694 - Validation Loss: 3.4724 - Validation Accuracy: 0.0717


Batch 749/750, Loss: 3.2817: 100%|██████████| 750/750 [04:09<00:00,  3.01it/s]


Epoch [15/64] - Loss: 3.3574 - Validation Loss: 3.4708 - Validation Accuracy: 0.0698


Batch 749/750, Loss: 3.1178: 100%|██████████| 750/750 [04:23<00:00,  2.85it/s]


Epoch [16/64] - Loss: 3.3380 - Validation Loss: 3.4896 - Validation Accuracy: 0.0720


Batch 749/750, Loss: 3.1708: 100%|██████████| 750/750 [04:27<00:00,  2.81it/s]


Epoch [17/64] - Loss: 3.3239 - Validation Loss: 3.4599 - Validation Accuracy: 0.0738


Batch 749/750, Loss: 3.3941: 100%|██████████| 750/750 [04:32<00:00,  2.75it/s]


Epoch [18/64] - Loss: 3.3000 - Validation Loss: 3.5253 - Validation Accuracy: 0.0763


Batch 749/750, Loss: 3.2711: 100%|██████████| 750/750 [04:31<00:00,  2.76it/s]


Epoch [19/64] - Loss: 3.2856 - Validation Loss: 3.5119 - Validation Accuracy: 0.0640


Batch 749/750, Loss: 3.1214: 100%|██████████| 750/750 [04:29<00:00,  2.78it/s]


Epoch [20/64] - Loss: 3.2689 - Validation Loss: 3.5867 - Validation Accuracy: 0.0712


Batch 749/750, Loss: 3.5373: 100%|██████████| 750/750 [04:07<00:00,  3.02it/s]


Epoch [21/64] - Loss: 3.2515 - Validation Loss: 3.5513 - Validation Accuracy: 0.0732


Batch 749/750, Loss: 3.2038: 100%|██████████| 750/750 [04:14<00:00,  2.94it/s]


Epoch [22/64] - Loss: 3.2370 - Validation Loss: 3.6880 - Validation Accuracy: 0.0700


Batch 749/750, Loss: 3.4266: 100%|██████████| 750/750 [04:10<00:00,  3.00it/s]


Epoch [23/64] - Loss: 3.2211 - Validation Loss: 3.5876 - Validation Accuracy: 0.0715


Batch 749/750, Loss: 3.3423: 100%|██████████| 750/750 [04:04<00:00,  3.06it/s]


Epoch [24/64] - Loss: 3.1971 - Validation Loss: 3.6077 - Validation Accuracy: 0.0735


Batch 749/750, Loss: 3.2591: 100%|██████████| 750/750 [03:47<00:00,  3.29it/s]


Epoch [25/64] - Loss: 3.1806 - Validation Loss: 3.6435 - Validation Accuracy: 0.0737


Batch 749/750, Loss: 3.2349: 100%|██████████| 750/750 [04:11<00:00,  2.98it/s]


Epoch [26/64] - Loss: 3.1632 - Validation Loss: 3.7026 - Validation Accuracy: 0.0717


Batch 749/750, Loss: 3.0934: 100%|██████████| 750/750 [04:22<00:00,  2.86it/s]


Epoch [27/64] - Loss: 3.1429 - Validation Loss: 3.6603 - Validation Accuracy: 0.0685


Batch 749/750, Loss: 3.1862: 100%|██████████| 750/750 [04:18<00:00,  2.90it/s]


Epoch [28/64] - Loss: 3.1213 - Validation Loss: 3.6617 - Validation Accuracy: 0.0697


Batch 749/750, Loss: 3.3000: 100%|██████████| 750/750 [04:21<00:00,  2.87it/s]


Epoch [29/64] - Loss: 3.0983 - Validation Loss: 3.8713 - Validation Accuracy: 0.0745


Batch 749/750, Loss: 2.9649: 100%|██████████| 750/750 [04:19<00:00,  2.89it/s]


Epoch [30/64] - Loss: 3.0881 - Validation Loss: 3.9037 - Validation Accuracy: 0.0710


Batch 749/750, Loss: 3.1412: 100%|██████████| 750/750 [04:23<00:00,  2.84it/s]


Epoch [31/64] - Loss: 3.0627 - Validation Loss: 3.8495 - Validation Accuracy: 0.0667


Batch 749/750, Loss: 2.8567: 100%|██████████| 750/750 [04:20<00:00,  2.88it/s]


Epoch [32/64] - Loss: 3.0382 - Validation Loss: 4.0171 - Validation Accuracy: 0.0768


Batch 556/750, Loss: 3.1601:  74%|███████▍  | 557/750 [03:01<01:02,  3.11it/s]

In [None]:
def evaluate_on_test_data(model, criterion):
    # test data can be identified in /food101/meta/test.txt
    
    # Load the test.txt file
    with open(os.path.join('food101', 'meta', 'test.txt'), 'r') as f:
        test_files = f.readlines()

    # Extract the image file name from each line
    test_files = [file.split('/')[1].strip() + '.jpg' for file in test_files]

    # Create a list of image file paths and their corresponding labels
    image_paths = []
    labels = []

    # Traverse through each food category
    for category in os.listdir(os.path.join('food101', 'images')):
        category_path = os.path.join('food101', 'images', category)

        # Ensure it's a directory
        if os.path.isdir(category_path):
            for image_filename in os.listdir(category_path):
                if image_filename.endswith('.jpg') and image_filename in test_files:
                    image_path = os.path.join(category_path, image_filename)
                    image_paths.append(image_path)
                    labels.append(category)

    # process the test data the same way I did the training data
    test_data = Food101Dataset(data_dir='food101', num_samples_per_split=10000)

    # overwrite the image_paths and labels with the test data
    test_data.image_paths = image_paths
    test_data.labels = labels

    test_loader = DataLoader(test_data, batch_size=Config.batch_size, shuffle=True)

    model.eval()
    total_loss = 0.0
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch in test_loader:
            features = batch['image'].to(device)
            labels = batch['label'].to(device)

            outputs = model(features)

            loss = criterion(outputs, labels)
            total_loss += loss.item()

            predictions = torch.argmax(outputs, dim=1)

            all_labels.extend(labels.tolist())
            all_predictions.extend(predictions.tolist())
    
    average_loss = total_loss / len(test_loader)

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f'Test Loss: {average_loss:.4f} - Test Accuracy: {accuracy:.4f}')

evaluate_on_test_data(model, criterion)

Number of images found:  10000
Number of labels found:  10000
Test Loss: 9.3660 - Test Accuracy: 0.0410
