<a href="https://colab.research.google.com/github/GitOfTheseus/MoE_ViT/blob/main/Moe_ViT_on_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Implementation and Training of a MoE-ViT**

This code implements and train a Vision Transformer (ViT) endowed with Sparse Mixture of Expert (MoE) for image classification on the CIFAR10 Dataset.

In [None]:
# Mount Drive to save results, figures and checkpoints
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
pip install -q transformers datasets triton # installing libraries not included in colab

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.5/209.5 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.3/179.3 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incomp

In [None]:
import os
import random
import numpy as np
import pandas as pd
from datetime import datetime
import pickle
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from datasets import load_dataset, DatasetDict
from datasets import Dataset as Ds
from triton.language import trans

from transformers import ViTImageProcessor
from transformers import ViTForImageClassification

In [None]:
# Setting first parameters and variables
random_seed = 0
torch.manual_seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset_name = 'cifar10'
model_name = 'MoE_ViT_on_' + dataset_name

Design of the Sparse MoE with MLPs

In [None]:
class SparseMoE(nn.Module):
    """Custom Class for Sparse MoE with MLP"""

    def __init__(self, dim, hidden_dim, num_experts=32, top_k=2, experts=None):
        super(SparseMoE, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.dim = dim

        # Defining the experts
        self.experts = nn.ModuleList([nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim)
        ) for _ in range(num_experts)])

        # Gating network
        self.gate = nn.Linear(dim, num_experts)

    def forward(self, x):
        # x shape: (batch_size, seq_len, dim)
        batch_size, seq_len, dim = x.shape

        # Calculating gating scores and selecting top-k experts
        gate_logits = self.gate(x)  # Shape: (batch_size, seq_len, num_experts)
        topk_values, topk_indices = torch.topk(gate_logits, self.top_k, dim=-1)  # Get top-k experts

        # Normalizing gating values
        topk_values = torch.softmax(topk_values, dim=-1)

        # Flattening the batch and sequence dimensions for processing
        x_flat = x.view(-1, dim)  # Shape: (batch_size * seq_len, dim)
        topk_indices_flat = topk_indices.view(-1, self.top_k)  # Shape: (batch_size * seq_len, top_k)
        topk_values_flat = topk_values.view(-1, self.top_k)  # Shape: (batch_size * seq_len, top_k)

        # Initializing output tensor
        output = torch.zeros_like(x_flat)

        # Applying each expert and aggregating outputs
        for i in range(self.top_k):
            expert_idx = topk_indices_flat[:, i]  # Indices of selected experts for each token
            expert_weight = topk_values_flat[:, i].unsqueeze(1)  # Gating values for each token

            # Gathering expert outputs using advanced indexing
            expert_outputs = torch.stack([self.experts[expert_idx[j]](x_flat[j].unsqueeze(0)) for j in range(x_flat.shape[0])], dim=0).squeeze(1)

            # Weighted sum of expert outputs
            output += expert_outputs * expert_weight

        # Reshaping back to (batch_size, seq_len, dim)
        output = output.view(batch_size, seq_len, dim)

        return output

class CustomOutput(nn.Module):
    """Custom Class instead of output layer when replacing the Linear layer with the MoE"""
    def __init__(self):
        super(CustomOutput, self).__init__()

    def forward(self, hidden_states, input_tensor):
        # Bypassing transformation, but keeping residual connection
        return hidden_states + input_tensor  # Residual connection remains intact


def model_creation(num_classes=10):
    """Function to implement the ViT with MoE"""

    # Loading Google ViT from Hugging Face
    model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
    print("\ndense model:", model)

    # Replacing Intermediate Linear Layers with Custom Class MoE in the last two transformer block
    for n, layer in enumerate(model.vit.encoder.layer[-2:]):

        intermediate_size = layer.intermediate.dense.out_features # Get the size of the intermediate layer
        hidden_size = intermediate_size // 2
        dim = layer.intermediate.dense.weight.shape[1]  # Input size: 768
        hidden_dim = layer.intermediate.dense.weight.shape[0]
        moe_layer = SparseMoE(dim, hidden_dim, num_experts=8, top_k=2)
        layer.intermediate = moe_layer
        layer.output = CustomOutput()

    # Replacing the output of the ViT with the number of classes of our dataset
    model.classifier = torch.nn.Linear(in_features=model.classifier.in_features, out_features=num_classes)

    print("\nMoE model:", model)

    return model

Class for Custom Preprocessing of the dataset

In [None]:
class CustomImageDataset(Dataset):
    """PyTorch Dataset class for customized processing and transformation of inputs"""

    def __init__(self, ds, transform=None, target_transform=None):
        self.ds = ds
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        image = self.ds[idx]['img']
        label = self.ds[idx]['label']
        if image.shape[0] == 1:
            image = image.repeat(3, 1, 1)

        if self.transform:
            inputs = self.transform(images=image, return_tensors="pt")
            pixel_values = inputs.pixel_values
        if self.target_transform:
            label = torch.tensor(label).clone().detach()
        return inputs.pixel_values, label

Classes and Functions to Plot and Analyze performance

In [None]:
class AverageMeter(object):
    """Class to easily computing and storing the average and current value
    class taken from https://github.com/pranoyr/cnn-lstm/blob/7062a1214ca0dbb5ba07d8405f9fbcd133b1575e/utils.py#L52"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count  #

def plot_performances(epoch, performances, model_name):
    epochs = range(1, epoch + 1)

    fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1)

    # Set the font size for various elements
    fontsize = 14  # Adjust this value as needed

    # First plot (losses)
    ax1.xaxis.set_label_coords(0.1, -0.2)
    ax1.plot(np.array(epochs) - 0.5, performances['train_loss'], 'g', label='Training loss')
    ax1.plot(np.array(epochs) - 0.5, performances['eval_loss'], 'b', label='Validation loss')
    ax1.set_title('Training and Validation loss', fontsize=fontsize)
    ax1.set_xlabel('Epochs', fontsize=fontsize)
    ax1.set_ylabel('Loss', fontsize=fontsize)
    ax1.set_xticks(np.arange(0.5, epoch + 0.5, 1))
    ax1.set_xticklabels([str(int(tick)) for tick in range(1, epoch + 1, 1)], fontsize=fontsize)
    ax1.legend(fontsize=fontsize)

    # Second plot (accuracies)
    ax2.plot(np.array(epochs) - 0.5, performances['train_acc'], 'g', label='Training accuracy')
    ax2.plot(np.array(epochs) - 0.5, performances['eval_acc'], 'b', label='Validation accuracy')
    ax2.set_title('Training and Validation accuracy', fontsize=fontsize)
    ax2.set_xlabel('Epochs', fontsize=fontsize)
    ax2.set_ylabel('Loss', fontsize=fontsize)
    ax2.set_xticks(np.arange(0.5, epoch + 0.5, 1))
    ax2.set_xticklabels([str(int(tick)) for tick in range(1, epoch + 1, 1)], fontsize=fontsize)
    ax2.set_ylim(max(0, min(min(min(performances['train_acc']), min(performances['eval_acc'])) - 0.05, 1)))
    ax2.legend(fontsize=fontsize)

    # Adjust layout and save figure
    plt.tight_layout()
    fig_dir = os.path.join(os.getcwd(), 'figures', model_name + '_training.png')
    os.makedirs(os.path.dirname(fig_dir), exist_ok=True)
    plt.savefig(fig_dir)
    plt.close('all')

    return


def calculate_performance(confusion_matrix, class_names, model_name):
    """Function to compute the model's performance on the test set"""

    # Inizializing the confusion matrix
    confusion_matrix = confusion_matrix
    total_predicted = confusion_matrix.sum(0)
    total_actual = confusion_matrix.sum(1)
    totals = confusion_matrix.sum(1).sum(0)

    df = pd.DataFrame(confusion_matrix, index=class_names, columns=class_names)
    df_pred = pd.DataFrame(total_predicted, index=class_names)
    df_actual = pd.DataFrame(total_actual, index=class_names)

    total_acc = 0
    recall = pd.DataFrame(np.zeros(len(class_names)), index=class_names)
    precision = pd.DataFrame(np.zeros(len(class_names)), index=class_names)
    overall_perf = pd.DataFrame(np.zeros(1), index=['accuracy'])
    for name in class_names:
        total_acc += df.loc[name, name]
        recall.loc[name] = df.loc[name, name] / df_actual.loc[name]
        precision.loc[name] = df.loc[name, name] / df_pred.loc[name]

    accuracy = total_acc / totals * 100
    error_rate = 100 - accuracy
    overall_perf.loc['accuracy', 0] = accuracy

    performance_test = {'accuracy': accuracy, 'error_rate': error_rate, 'recall': recall,
                        'precision': precision, 'overall_acc': overall_perf}

    print(f'RESULTS OF TEST {model_name}')
    print("the accuracy of the model is {} %".format(performance_test['accuracy']))
    print("the error rate of the model is {} %".format(performance_test['error_rate']))
    print("recall : \n{}".format(performance_test['recall']))
    print("precision : \n{}".format(performance_test['precision']))

    results_dir = os.path.join('/content/drive/MyDrive/indeep/results', f'results_{model_name}.pkl')
    confusion_matrix_dir = os.path.join('/content/drive/MyDrive/indeep/results', f'conf_matrix_{model_name}.pkl')
    with open(results_dir, 'wb') as file:
        pickle.dump(performance_test, file)
    with open(confusion_matrix_dir, 'wb') as file:
        pickle.dump(confusion_matrix, file)

    return performance_test



def plot_confusion_matrix(performance_test, confusion_matrix, class_names):
    fig, ([ax1, ax2], [ax3, ax4]) = plt.subplots(nrows=2, ncols=2,
                                                 gridspec_kw={'width_ratios': [7, 3], 'height_ratios': [8, 2]})

    # plot 1 confusion matrix
    im = ax1.imshow(confusion_matrix)  # x axis = real class, y axis = predicted class
    ax1.set_xticks(np.arange(len(class_names)))
    ax1.set_yticks(np.arange(len(class_names)))
    ax1.set_ylabel('Real Class', fontweight='bold', fontsize=9)
    ax1.set_xlabel('Predicted Class', fontweight='bold', fontsize=9)

    # Position the x-axis label at the top
    ax1.xaxis.set_label_coords(.5, 1.15)  # Keep this for label positioning
    ax1.xaxis.tick_top()  # Place ticks at the top
    ax1.xaxis.set_label_position('top')  # Ensure the label is at the top

    # Align the x-tick labels properly; "right" aligns the text to the right of the tick
    ax1.tick_params(axis='x', which='both', top=True, labeltop=True, labelsize=10, pad=10)  # Adjust padding
    ax1.tick_params(axis='y', which='both', labelsize=10)
    ax1.set_xticklabels(class_names, rotation=30, ha="left")
    ax1.set_yticklabels(class_names, rotation=60, ha="right")

    # Loop over data dimensions and create text annotations.
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            text = ax1.text(j, i, int(confusion_matrix[i, j]),
                            ha="center", va="center", color="w")
    ## plot 2 recall
    ax2.set_title("Recall", fontweight ='bold', fontsize=12)
    im = ax2.imshow(performance_test['recall'])
    for i, c in enumerate(class_names):
        text = ax2.text(0, i, str(round(performance_test['recall'].iloc[i][0]*100, 1)), ha="center", va="center", color="w")
    ax2.tick_params(top=False, bottom=False, left=False, right=False)
    plt.setp(ax2.get_xticklabels(), visible=False)
    plt.setp(ax2.get_yticklabels(), visible=False)

    ## plot 3 precision
    ax3.set_title("Precision", fontweight ='bold', fontsize=12)
    im = ax3.imshow(performance_test['precision'].T)
    for i, c in enumerate(class_names):
        text = ax3.text(i, 0, str(round(performance_test['precision'].iloc[i][0]*100, 1)), ha="center", va="center",
                        color="w")
    ax3.tick_params(top=False, bottom=False, left=False, right=False)
    plt.setp(ax3.get_xticklabels(), visible=False)
    plt.setp(ax3.get_yticklabels(), visible=False)

    # plot 4 macro average F1-score
    ax4.set_title("Macro Avg F1-score", fontweight ='bold', fontsize=12)
    f1_scores = 2 * (performance_test['precision'].values * performance_test['recall'].values) / (
                performance_test['precision'].values + performance_test['recall'].values)
    macro_avg_f1 = f1_scores.mean()
    print(macro_avg_f1)
    im = ax4.imshow(pd.DataFrame([macro_avg_f1], columns=['Value']))
    text = ax4.text(0, 0, str(round(macro_avg_f1*100, 2)), ha="center", va="center",
                    color="w")

    ax4.tick_params(top=False, bottom=False, left=False, right=False)
    plt.setp(ax4.get_xticklabels(), visible=False)
    plt.setp(ax4.get_yticklabels(), visible=False)

    fig_dir = os.path.join(os.getcwd(), 'figures', 'conf_mtrx.png')
    os.makedirs(os.path.dirname(fig_dir), exist_ok=True)
    plt.savefig(fig_dir)
    plt.show()

    return

In [None]:
def data_loading(dataset):
    """Function to load, filter, preprocess the dataset and create Dataloader"""

    global batch_size
    batch_size = 32

    # To train the dataset on CIFAR10
    if dataset == 'cifar10':
        # Load the dataset from Hugging Face
        cifar10_dataset = load_dataset('uoft-cs/cifar10').with_format('torch')
        dataset_to_train = cifar10_dataset

        global class_names
        class_names = cifar10_dataset['train'].features['label'].names
        df_total = dataset_to_train['train'].to_pandas()

        # for a smaller train set
        #df_total = pd.concat([group.sample(frac=0.2) for _, group in df_total.groupby('label')], axis=0).reset_index(drop=True)

        df_valid = pd.concat([group.sample(frac=0.2) for _, group in df_total.groupby('label')], axis=0).reset_index(drop=True)
        df_train = df_total.loc[~df_total.index.isin(df_valid.index)].reset_index(drop=True)
        dataset_to_train['train'] = Ds.from_pandas(df_train, features=dataset_to_train['train'].features).with_format('torch')
        dataset_to_train['valid'] = Ds.from_pandas(df_valid, features=dataset_to_train['train'].features).with_format('torch')

    # To train the dataset on tiny ImageNet !!! this dataset does not have any test set
    elif dataset == 'tiny_imagenet':
        # Load the dataset from Hugging Face
        tiny_imagenet_dataset = load_dataset('Maysee/tiny-imagenet').with_format('torch')

        # Creating a subset of the dataset for faster training
        df_train = tiny_imagenet_dataset['train'].to_pandas()
        df_train_filtered = pd.concat([group.sample(frac=0.2) for _, group in df_train[df_train['label'] < 20].groupby('label')], axis=0).reset_index(drop=True)
        df_valid = tiny_imagenet_dataset['valid'].to_pandas()
        df_valid_filtered = pd.concat([group.sample(frac=0.2) for _, group in df_valid[df_valid['label'] < 20].groupby('label')], axis=0).reset_index(drop=True)
        filtered_ds_train = Ds.from_pandas(df_train_filtered, features=tiny_imagenet_dataset['train'].features).with_format('torch')
        filtered_ds_valid = Ds.from_pandas(df_valid_filtered, features=tiny_imagenet_dataset['valid'].features).with_format('torch')
        print(filtered_ds_train)
        filtered_dataset = DatasetDict({
            'train': filtered_ds_train,
            'valid': filtered_ds_valid
        })
        dataset_to_train = filtered_dataset

    # Loading the preprocessing rules directly from the model
    processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
    # Preprocessing the data
    dataset = {x: CustomImageDataset(dataset_to_train[x], transform=processor, target_transform=True) for x in
               ['train', 'valid', 'test']}
    # Creating the DataLoader for the training
    shuffle_ds = {'train': True, 'valid': False, 'test': False}
    dataloader = {x: DataLoader(dataset[x], batch_size=batch_size, shuffle=shuffle_ds[x]) for x in ['train', 'valid', 'test']}

    return dataloader

In [None]:
def training(dataloader, model):
    """Function to train the model"""

    # Defining Training Params
    num_epochs = 4
    learning_rate = 2e-5
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) # Currently not used for the training because num_epochs = step_size
    training_results = {'valid_accuracy': [], 'total_loss': []}
    last_loss = 100

    # parameters for early stopping (not used in the current training because num_epochs = patience)
    patience = 5
    trigger_times = 0

    performances = {
        'best_acc': 0,
        'train_acc': [],
        'train_loss': [],
        'eval_loss': [],
        'eval_acc': [],
    }

    for epoch in range(1, num_epochs + 1):

        print('\n')

        for phase in ['train', 'valid']:

            if phase == 'train':
                print(f'Number of batches in train_loader: {len(dataloader[phase])}')
                current_batch = dataloader[phase].batch_size
                print(f'Batch size: {current_batch}')
                print(f'Number of samples in dataset: {len(dataloader[phase].dataset)}')
                train_features, train_labels = next(iter(dataloader[phase]))
                print(f'Feature batch shape: {train_features.squeeze(1).shape}')  # IF IT STOPS...
                print(f'Labels batch shape: {train_labels.shape}')
                model.train()
            else:
                print(f'\nValidation phase is starting at {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
                model.eval()

            losses = AverageMeter()
            accuracies = AverageMeter()

            for batch_idx, batch_data in enumerate(dataloader[phase]):
                batch = {'pixel_values': batch_data[0], 'label': batch_data[1]}
                images, labels = batch['pixel_values'].squeeze(1).to(device), batch['label'].to(device)
                correct, total = 0, 0
                if phase == 'train':
                    print(f'training status {round((batch_idx+(len(dataloader[phase])*(epoch-1)))/(len(dataloader[phase])*num_epochs)*100, 2)}% --> batch n: {batch_idx + 1} at {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(images).logits
                    _, predicted = torch.max(outputs, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

                    # Calculating loss and backpropagate
                    loss = criterion(outputs, labels)
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                batch_accuracy = torch.sum(predicted == labels) / current_batch
                losses.update(loss.item(), current_batch)
                accuracies.update(batch_accuracy.item(), current_batch)

            print(f'Epoch [{epoch}/{num_epochs}] {phase} set ({len(dataloader[phase].dataset)} samples): Average loss: {losses.avg:.4f}\tAcc: {accuracies.avg * 100:.4f}%')

            if phase == 'train':
                scheduler.step()
                performances['train_loss'].append(losses.avg)
                performances['train_acc'].append(accuracies.avg)
            else:
                performances['eval_loss'].append(losses.avg)
                performances['eval_acc'].append(accuracies.avg)
                if accuracies.avg > performances['best_acc']:
                    performances['best_acc'] = accuracies.avg

        # Plotting Learning curves
        plot_performances(epoch, performances, model_name)

        # Custom Early Stopping Rule
        if losses.avg > last_loss:
            trigger_times += 1
            print('Trigger Times before patience:', trigger_times)

            if trigger_times >= patience:
                print("Epoch {}\n".format(epoch))
                print('Early stopping!')

                checkpoint_dict = {'epoch': epoch,
                              # The current epoch number, which helps resume training from the correct point.
                              'state_dict': model.state_dict(),
                              # The model's weights (parameters), stored in the state_dict format.
                              'optimizer_state_dict': optimizer.state_dict(),
                              # The state of the optimizer, including momenta and other parameters.
                              'loss_function': criterion,  # The loss function (criterion) used during training.
                              'performances': performances,  # Example of additional metadata
                              'lr_scheduler_state_dict': scheduler.state_dict(),  # If using a scheduler
                              'hyperparameters': {
                                  'batch_size': batch_size,
                                  'learning_rate': learning_rate
                              },
                              'random_seed': random_seed,
                              'library_versions': {
                                  'torch': torch.__version__,
                                  'numpy': np.__version__
                              }
                              }
                checkpoint_dir = os.path.join('/content/drive/MyDrive/indeep/checkpoints', f'checkpoint_{model_name}_{num_epochs}_epoch.pth')
                print(f'\nsaving checkpoint in {checkpoint_dir}\n')
                torch.save(checkpoint_dict, checkpoint_dir)

                return performances, model


        last_loss = losses.avg
        checkpoint_dict = {'epoch': epoch,
                      # The current epoch number, which helps resume training from the correct point.
                      'state_dict': model.state_dict(),
                      # The model's weights (parameters), stored in the state_dict format.
                      'optimizer_state_dict': optimizer.state_dict(),
                      # The state of the optimizer, including momenta and other parameters.
                      'loss_function': criterion,  # The loss function (criterion) used during training.
                      'performances': performances,  # Performances of the model
                      'lr_scheduler_state_dict': None,  # If using a scheduler
                      'hyperparameters': {
                          'batch_size': batch_size,
                          'learning_rate': learning_rate
                      },
                      'random_seed': random_seed,
                      'library_versions': {
                          'torch': torch.__version__,
                          'numpy': np.__version__
                      }
                      }
        checkpoint_dir = os.path.join('/content/drive/MyDrive/indeep/checkpoints', f'checkpoint_{model_name}_{num_epochs}_epoch.pth')
        print(f'\nsaving checkpoint in {checkpoint_dir}\n')
        torch.save(checkpoint_dict, checkpoint_dir)

    return performances, model

In [None]:
def test(dataloader, model):
    """Function to compute the model performance on the test set"""

    model.eval()

    confusion_matrix = np.zeros((len(class_names), len(class_names)))
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for batch_idx, batch_data in enumerate(dataloader['test']):
            batch = {'pixel_values': batch_data[0], 'label': batch_data[1]}
            images, labels = batch['pixel_values'].squeeze(1).to(device), batch['label'].to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)

            correct += (predicted == labels).sum().item()
            for i, l in enumerate(labels):
                all_labels.append(labels[i].item())
                all_predictions.append(predicted[i].item())

                confusion_matrix[int(labels[i]), int(predicted[i])] += 1

        results = {'label': all_labels, 'predictions': all_predictions}
        results_df = pd.DataFrame(results)

        performance_test = calculate_performance(confusion_matrix, class_names, model_name)

        plot_confusion_matrix(performance_test, confusion_matrix, class_names)

        return

Now we can actually start the training, using all the functions and classes defined above

In [None]:
# Loading the dataset and creating the dataloader
dataloader = data_loading(dataset=dataset_name)

In [None]:
# Building the model
model = model_creation(num_classes=len(class_names))


dense model: ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense)

In [None]:
# Starting the training
model.to(device)
performances, model = training(dataloader, model)



Number of batches in train_loader: 1250
Batch size: 32
Number of samples in dataset: 40000
Feature batch shape: torch.Size([32, 3, 224, 224])
Labels batch shape: torch.Size([32])


  label = torch.tensor(label).clone().detach()


training status 0.0% --> batch n: 1 at 2024-11-30 11:18:22
training status 0.02% --> batch n: 2 at 2024-11-30 11:18:33
training status 0.04% --> batch n: 3 at 2024-11-30 11:18:45


KeyboardInterrupt: 

Now that the training is concluded we can assess the actual performance on a test set. If you have just concluded the training you can skip the loading of the checkpoint.
Otherwise, if you want only to test the model, build the dataloader, build the model again, define the function needed, load the checkpoint, and go to the testing!

In [None]:
filename = f'checkpoint_{model_name}_{num_epochs}_epoch.pth'
checkpoint_dir = os.path.join('/content/drive/MyDrive/indeep/checkpoints', filename)
checkpoint = torch.load(checkpoint_dir, map_location=torch.device(device))

  checkpoint = torch.load(checkpoint_dir, map_location=torch.device(device))


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/indeep/checkpoints/checkpoint_MoE_ViT_on_cifar10_5_epoch.pth'

In [None]:
# Testing the model
test(dataloader, model)

NameError: name 'trained_model' is not defined