# **Setup**

### Install dependencies

In [87]:
# Install missing dependencies
!pip install -q torchinfo torchmetrics wandb
!pip install torch
!pip install torchvision
!pip install numpy
!pip install matplotlib



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip

### Import important libraries

In [88]:
# Import required libraries/code
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

from torch import nn
from torchvision import transforms

from torchinfo import summary

from torchvision.datasets import CIFAR100
from torch.utils.data import random_split, DataLoader, Subset
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer

import os
import glob
import wandb
import pickle

import warnings
from typing import Optional, Union, Callable, List
import scipy.stats as stats
import math

import copy
from itertools import combinations
import time
from datetime import timedelta


In [89]:
# from google.colab import drive
# drive.mount('/content/gdrive')

In [90]:
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### Data Preprocessing

In [91]:

# Define a class for handling CIFAR100 data
class CIFAR100Data:
    """
    A class used to represent the CIFAR100 dataset.

    ...

    Attributes
    ----------
    batch_size : int
        the number of samples that will be propagated through the network simultaneously
    original_train_set : torchvision.datasets.CIFAR100
        the original training set downloaded from CIFAR100
    original_test_set : torchvision.datasets.CIFAR100
        the original test set downloaded from CIFAR100
    train_set : torch.utils.data.Subset
        the training set after splitting the original training set
    validation_set : torch.utils.data.Subset
        the validation set after splitting the original training set
    test_set : torchvision.datasets.CIFAR100
        the test set, same as the original test set
    original_train_loader : torch.utils.data.DataLoader
        data loader for the original training set
    original_test_loader : torch.utils.data.DataLoader
        data loader for the original test set
    train_loader : torch.utils.data.DataLoader
        data loader for the training set
    validation_loader : torch.utils.data.DataLoader
        data loader for the validation set
    test_loader : torch.utils.data.DataLoader
        data loader for the test set

    Methods
    -------
    compute_mean_std(loader)
        Computes the mean and standard deviation of the images in the loader.
    download_data()
        Downloads the CIFAR100 dataset.
    split_data(original_train_set, validation_ratio=0.2)
        Splits the original training set into a training set and a validation set.
    compute_statistics(train_set)
        Computes the mean and standard deviation of the training set.
    apply_transforms(train_mean, train_std, is_validation_set_available = False)
        Defines and applies the transformations for the training set, validation set, and test set.
    save_data(data_loader, data_set, file_name: str)
        Saves the data loader to Google Drive.
    load_data(file_name: str)
        Loads the data loader from Google Drive.
    create_and_save_data_loaders(train_set, test_set, validation_set=None)
        Creates data loaders for the training, validation, and test sets and saves them to Google Drive.
    prepare_data(validation_ratio = None)
        Prepares the data by downloading it, splitting it, computing statistics, applying transforms, and creating and saving data loaders.
    train_valid_test(validation_ratio=0.2)
        Loads or prepares the data loaders for the training, validation, and test sets and returns them.
    train_test()
        Loads or prepares the data loaders for the original training and test sets and returns them.
    iid_shards(num_shards=2)
        Loads or prepares the data loaders for the shards of the original training set and returns them.
    """
    def __init__(self, batch_size=64):
        """
        Initialize the CIFAR100Data object with the given batch size.

        Parameters:
        batch_size (int): The size of the batches for the data loaders.
        """
        self.batch_size = batch_size
        self.original_train_set = None
        self.original_test_set = None
        self.train_set = None
        self.validation_set = None
        self.test_set = None

        self.original_train_loader = None
        self.original_test_loader = None
        self.train_loader = None
        self.validation_loader = None
        self.test_loader = None

    def compute_mean_std(self, loader):
        """
        Compute the mean and standard deviation of the images in the loader.

        Parameters:
        loader (DataLoader): The DataLoader object containing the image data.

        Returns:
        mean (Tensor): The mean of the images.
        std (Tensor): The standard deviation of the images.
        """
        channels_sum, channels_squared_sum, num_batches = 0, 0, 0
        for data, _ in loader:
            channels_sum += torch.mean(data, dim=[0, 2, 3])
            channels_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
            num_batches += 1
        mean = channels_sum / num_batches
        std = torch.sqrt((channels_squared_sum / num_batches) - mean**2)
        return mean, std

    def download_data(self):
        """
        Download the CIFAR100 dataset and store it in instance variables.
        """
        self.original_train_set = CIFAR100(root='./data', train=True, download=True, transform=transforms.ToTensor())
        self.original_test_set = CIFAR100(root='./data', train=False, download=True, transform=transforms.ToTensor())

    def split_data(self, original_train_set, validation_ratio=0.2):
        """
        Split the original training set into a training set and a validation set.

        Parameters:
        original_train_set (Dataset): The original training set.
        validation_ratio (float): The ratio of the original training set to use for validation.

        Returns:
        train_set (Subset): The new training set.
        validation_set (Subset): The new validation set.
        """
        train_len = int(len(original_train_set) * (1 - validation_ratio))
        val_len = len(original_train_set) - train_len
        train_set, validation_set = random_split(original_train_set, [train_len, val_len])

        return train_set, validation_set

    def compute_statistics(self, train_set):
        """
        Compute the mean and standard deviation of the train set.

        Parameters:
        train_set (Dataset/Subset): The training set.

        Returns:
        train_mean (Tensor): The mean of the training set.
        train_std (Tensor): The standard deviation of the training set.
        """
        trainloader_tmp = DataLoader(train_set, batch_size=self.batch_size, shuffle=True)
        train_mean, train_std = self.compute_mean_std(trainloader_tmp)

        return train_mean, train_std

    def apply_transforms(self, train_mean, train_std, is_validation_set_available = False):
        """
        Define the transformations for the training set, validation set, and test set
        and apply them to the datasets.

        Parameters:
        train_mean (Tensor): The mean of the training set.
        train_std (Tensor): The standard deviation of the training set.
        is_validation_set_available (bool): Whether a validation set is available.
        """

        # Transformations for the training set
        train_transforms = transforms.Compose([
            transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(train_mean.tolist(), train_std.tolist())
        ])

        # Transformations for the validation and test sets
        test_val_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(train_mean.tolist(), train_std.tolist())
        ])

        # Apply the transformations to the datasets
        if is_validation_set_available:
            self.train_set.transform = train_transforms
            self.validation_set.transform = test_val_transforms
            self.test_set.transform = test_val_transforms
        else:
            self.original_train_set.transform = train_transforms
            self.original_test_set.transform = test_val_transforms

    def save_data(self, data_loader, file_name: str):
        """
        Save the given data loader and data set to Google Drive.

        Parameters:
        data_loader (DataLoader): The data loader to save.
        file_name (str): The name of the file to save the data loader and data set to.
        """
        # Check if the directory exists, if not, create it
        # if not os.path.exists('/content/gdrive/MyDrive/data/data_loaders/'):
        #     os.makedirs('/content/gdrive/MyDrive/data/data_loaders/')

        # Open each file in write-binary mode on Google Drive and dump (pickle) the data loader into it
        # with open(f'/content/gdrive/MyDrive/data/data_loaders/{self.batch_size}_{file_name}_loader.pkl', 'wb') as f:
        #     pickle.dump(data_loader, f)
        if not os.path.exists(f'./data/data_loaders/{self.batch_size}/'):
            os.makedirs(f'./data/data_loaders/{self.batch_size}/')
        
        open(f'./data/data_loaders/{self.batch_size}/{file_name}_loader.pkl', 'wb').write(pickle.dumps(data_loader))

    def load_data(self, file_name: str):
        """
        Load a data loader from Google Drive.

        Parameters:
        file_name (str): The name of the file to load the data loader from.

        Returns:
        data_loader (DataLoader): The loaded data loader, or None if the file does not exist.
        """
        # Check if the file exists
        # if os.path.exists(f'/content/gdrive/MyDrive/data/data_loaders/{self.batch_size}_{file_name}_loader.pkl'):
        #     # If it exists, open the file in read-binary mode and load (unpickle) the data loader from it
        #     with open(f'/content/gdrive/MyDrive/data/data_loaders/{self.batch_size}_{file_name}_loader.pkl', 'rb') as f:
        #         return pickle.load(f)
        # else:
        #     return None
        if os.path.exists(f'./data/data_loaders/{self.batch_size}/{file_name}_loader.pkl'):
            return pickle.loads(open(f'./data/data_loaders/{self.batch_size}/{file_name}_loader.pkl', 'rb').read())
        else:
            return None

    def create_and_save_data_loaders(self, train_set, test_set, train_name: str, test_name: str, validation_set=None):
        """
        Create data loaders for the training, validation, and test sets and save them to Google Drive.

        Parameters:
        train_set (Subset): The training set.
        test_set (Subset): The test set.
        validation_set (Subset, optional): The validation set.

        Returns:
        train_loader (DataLoader): The data loader for the training set.
        validation_loader (DataLoader): The data loader for the validation set, if it exists.
        test_loader (DataLoader): The data loader for the test set.
        """
        train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True, num_workers =8)
        test_loader = DataLoader(test_set, batch_size=self.batch_size, shuffle=False, num_workers=8)

        # Save the newly created data loaders to Google Drive
        self.save_data(train_loader, train_name)
        self.save_data(test_loader, test_name)

        if validation_set is not None:
            validation_loader = DataLoader(validation_set, batch_size=self.batch_size, shuffle=False, num_workers=8)
            self.save_data(validation_loader, 'validation')
            return train_loader, validation_loader, test_loader

        return train_loader, test_loader

    def prepare_data(self, validation_ratio = None):
        """
        Prepare the data by downloading it, splitting it into training, validation, and test sets,
        computing statistics, applying transformations, and creating and saving data loaders.

        Parameters:
        validation_ratio (float, optional): The ratio of the original training set to use for validation.
        """
        try:
            self.download_data()
        except IOError:
            print("Error downloading data")
            return

        if validation_ratio is not None:
            self.train_set, self.validation_set = self.split_data(self.original_train_set, validation_ratio)
            if self.validation_set is None:
                print("Validation set is not available")
                return
            self.test_set = self.original_test_set
            train_mean, train_std = self.compute_statistics(self.train_set)
            self.apply_transforms(train_mean, train_std, is_validation_set_available = True)

            # Create and save data loaders
            self.train_loader, self.validation_loader, self.test_loader = self.create_and_save_data_loaders(self.train_set, self.test_set, 'train', 'test', self.validation_set)

        else:
            train_mean, train_std = self.compute_statistics(self.original_train_set)
            self.apply_transforms(train_mean, train_std)

            # Create and save data loaders
            self.original_train_loader, self.original_test_loader = self.create_and_save_data_loaders(self.original_train_set, self.original_test_set, 'original_train', 'original_test')

    def train_valid_test(self, validation_ratio=0.2):
        """
        Load the training, validation, and test data loaders from Google Drive, or prepare the data if they do not exist.

        Parameters:
        validation_ratio (float): The ratio of the original training set to use for validation.

        Returns:
        train_loader (DataLoader): The data loader for the training set.
        validation_loader (DataLoader): The data loader for the validation set.
        test_loader (DataLoader): The data loader for the test set.
        """
        self.train_loader = self.load_data('train')
        self.validation_loader = self.load_data('validation')
        self.test_loader = self.load_data('test')

        if self.train_loader is None or self.validation_loader is None or self.test_loader is None:
            self.prepare_data(validation_ratio)

        # Return the data loaders
        return self.train_loader, self.validation_loader, self.test_loader

    def train_test(self):
        """
        Load the original training and test data loaders from Google Drive, or prepare the data if they do not exist.

        Returns:
        original_train_loader (DataLoader): The data loader for the original training set.
        original_test_loader (DataLoader): The data loader for the original test set.
        """
        self.original_train_loader = self.load_data('original_train')
        self.original_test_loader = self.load_data('original_test')

        if self.original_train_loader is None or self.original_test_loader is None:
            self.prepare_data()

        # Return the data loaders
        return self.original_train_loader, self.original_test_loader

    def iid_shards(self, num_shards=2):
        """
        Create or load independent and identically distributed (IID) shards of the original training set.

        Parameters:
        num_shards (int): The number of shards to create.

        Returns:
        shard_loaders (list of DataLoader): The data loaders for the shards.
        """
        # Try to load the shard datasets and their corresponding data loaders from Google Drive
        shard_loaders = []
        for i in range(num_shards):
            shard_loader = self.load_data(f'iid_sharding/{num_shards}_chunk_{i+1}')
            if shard_loader is None:
                break
            shard_loaders.append(shard_loader)

        # If all shard data loaders were successfully loaded, return them
        if len(shard_loaders) == num_shards:
            return shard_loaders

        # If not all shard data loaders were successfully loaded, create them
        if self.original_train_set is None:
            self.download_data()
            train_mean, train_std = self.compute_statistics(self.original_train_set)
            self.apply_transforms(train_mean, train_std)

        # Shuffle the indices
        indices = torch.randperm(len(self.original_train_set))

        # Split the indices into K chunks
        shard_size = len(indices) // num_shards
        shards = [indices[i*shard_size:(i+1)*shard_size] for i in range(num_shards)]

        # Create subsets for each shard
        shard_datasets = [Subset(self.original_train_set, shard) for shard in shards]

        # Create data loaders for each shard
        shard_loaders = [DataLoader(shard_dataset, batch_size=self.batch_size, shuffle=True) for shard_dataset in shard_datasets]

        # Save each shard dataset and its corresponding data loader
        for i, shard_loader in enumerate(shard_loaders):
            self.save_data(shard_loader, f'iid_sharding_{num_shards}_chunk_{i+1}')

        return shard_loaders


In [92]:
data = CIFAR100Data()
train_loader, validation_loader, test_loader = data.train_valid_test(validation_ratio=0.2)
original_train_loader, original_test_loader = data.train_test()

### Define the model architecture

In [93]:
# Define the model architecture
# Check if it is LeNet-5 or similar to LeNet-5 we want similar.
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=0) # 28x28
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, padding=0) # 10x10
        self.pool = nn.MaxPool2d(2, 2) # 14x14 for conv1 and 5x5 for conv2
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(5 * 5 * 64, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x, indexing=None):
        intermediate_outputs = {}

        # Convolutional and pooling layers
        x = F.relu(self.conv1(x))
        if indexing == 'conv1':
            return x
        intermediate_outputs['conv1'] = x

        x = self.pool(x)
        if indexing == 'pool1':
            return x
        intermediate_outputs['pool1'] = x

        x = F.relu(self.conv2(x))
        if indexing == 'conv2':
            return x
        intermediate_outputs['conv2'] = x

        x = self.pool(x)
        if indexing == 'pool2':
            return x
        intermediate_outputs['pool2'] = x

        # Flatten layer
        x = self.flatten(x)
        if indexing == 'flatten':
            return x
        intermediate_outputs['flatten'] = x

        # Fully connected layers
        x = F.relu(self.fc1(x))
        if indexing == 'fc1':
            return x
        intermediate_outputs['fc1'] = x

        x = F.relu(self.fc2(x))
        if indexing == 'fc2':
            return x
        intermediate_outputs['fc2'] = x

        x = self.fc3(x)
        if indexing == 'fc3':
            return x
        intermediate_outputs['fc3'] = x

        # If no indexing, return final output
        return x if indexing is None else intermediate_outputs.get(indexing, x)

In [94]:
# Define loss function
criterion = nn.CrossEntropyLoss()

### Define some basic functions for train and test the model

In [95]:
def train(model, dataloader, optimizer, loss_fn, accumulation_steps=1, device=device, is_wandb=False):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    optimizer.zero_grad()  # Initialize gradients to zero at the start of each epoch

    for i, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

        loss.backward()  # Backpropagation

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        # Gradient accumulation
        if (i+1) % accumulation_steps == 0 or i+1 == len(dataloader):
            optimizer.step()  # Update model parameters
            optimizer.zero_grad()  # Reset gradients to zero

    train_loss = running_loss / len(dataloader)
    train_accuracy = 100. * correct / total

    if is_wandb:
        wandb.log({"Train Loss": train_loss, "Train Accuracy": train_accuracy})

    return train_loss, train_accuracy

In [96]:
def test(model, dataloader, loss_fn, device = device, is_wandb= False):

    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if is_wandb:
              # Log the loss and accuracy values at each step
              wandb.log({
                  'Test Loss': test_loss / (batch_idx + 1),
                  'Test Accuracy': 100 * correct / total
              })

    test_loss = test_loss / len(dataloader)
    test_accuracy = 100. * correct / total

    return test_loss, test_accuracy

In [97]:
def save_checkpoint(state, epoch, batch_size, optimizer_name, hyperparameters):
    # dir_path = f"/content/gdrive/MyDrive/{optimizer_name}/{batch_size}/"
    dir_path = f"./train/{optimizer_name}/{batch_size}/"
    for key, value in hyperparameters.items():
        dir_path = os.path.join(dir_path, f"{key}_{value}/")
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
    path = os.path.join(dir_path, f"epoch_{epoch:03}.pt")
    torch.save(state, path)

    # Get list of all files
    list_of_files = glob.glob(os.path.join(dir_path, f"epoch_*.pt"))
    # Sort files by creation time
    list_of_files.sort(key=os.path.getctime)
    # If there are more than 2 files, delete the second last one
    if len(list_of_files) > 1:
        os.remove(list_of_files[-2])

def load_checkpoint(optimizer_name, batch_size, hyperparameters):
    # dir_path = f"/content/gdrive/MyDrive/{optimizer_name}/{batch_size}/"
    dir_path = f"./train/{optimizer_name}/{batch_size}/"
    for key, value in hyperparameters.items():
        dir_path = os.path.join(dir_path, f"{key}_{value}/")
    list_of_files = glob.glob(os.path.join(dir_path, f"epoch_*.pt")) # * means all if need specific format then *.csv
    if not list_of_files:  # I'm using glob which can return an empty list
        return None
    latest_file = max(list_of_files, key=os.path.getctime)
    if os.path.isfile(latest_file):
        return torch.load(latest_file)
    return None

In [98]:
def run_training(
    num_epochs,
    model,
    trainloader,
    validationloader,
    testloader,
    optimizer,
    scheduler,
    loss_fn,
    device,
    optimizer_name: str,
    accumulation_steps=1,
    hyperparameters=None,
    is_wandb = False,
    n_epochs_stop = None,
    warmup_ratio = 0
  ):

  best_accuracy = 0
  epochs_no_improve = 0
  warmup_steps = int(warmup_ratio * num_epochs)
  lr = optimizer.param_groups[0]['lr']
  
  start_epoch = 0
  run_id = None
  run_name = None
  # Load checkpoint if available
  checkpoint = load_checkpoint(optimizer_name, trainloader.batch_size, hyperparameters)
  if checkpoint is not None:
      model.load_state_dict(checkpoint['model_state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
      start_epoch = checkpoint['epoch'] + 1
      run_id = checkpoint.get('wandb_run_id', None)
      run_name = checkpoint.get('wandb_run_name', None)

  else:
    run_name = " ".join([f"{key}={value}" for key, value in hyperparameters.items()])

  if is_wandb:
    # Initialize a wandb run with the given hyperparameters
    wandb.init(id=run_id, name=run_name, project=f'cifar100-training-mldl2024-baseline-{optimizer_name}',
                   config=hyperparameters if hyperparameters is not None else {},
                   resume="allow", reinit=True)

  for epoch in range(start_epoch, num_epochs):
      if epoch < warmup_steps:
        optimizer.param_groups[0]['lr'] = lr * epoch / warmup_steps
      elif epoch == warmup_steps:
        optimizer.param_groups[0]['lr'] = lr
          
      # Call the training function for each epoch
      train_loss, train_acc = train(model, trainloader, optimizer, loss_fn, accumulation_steps, device, is_wandb=is_wandb)
      print(f'[{epoch+1}/{num_epochs}]: Training Loss: {train_loss}, Training Accuracy: {train_acc}')
      scheduler.step() # Update learning rate based on scheduler

      # Save checkpoint
      save_checkpoint({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'loss': criterion,
                    'wandb_run_id': wandb.run.id if is_wandb else None,
                    'wandb_run_name': wandb.run.name if is_wandb else None,
                    }, epoch, trainloader.batch_size, optimizer_name, hyperparameters)

      if validationloader is not None:
        val_loss, val_acc = test(model, validationloader, criterion)
        print(f'Validation Loss: {val_loss}, Validation Accuracy: {val_acc}')


        if n_epochs_stop is not None:
          if val_acc > best_accuracy:
                best_accuracy = val_acc
                epochs_no_improve = 0
          else:
              epochs_no_improve += 1
              if epochs_no_improve == n_epochs_stop:
                  print('Early stopping!')
                  break

  print('*'*70)
  test_loss, test_acc = test(model, testloader, criterion, is_wandb = is_wandb)
  print(f'Test Loss: {test_loss}, Test Accuracy: {test_acc}')


  # Finish the wandb run after all epochs
  wandb.finish()

# **Centeralised baseline**

In [70]:
learning_rates = [1e-03, 1e-02]
weight_decays = [1e-04, 1e-03, 4e-04]


#### SGDM (Stochastic Gradient Descent with Momentum)

##### Hyperparameter Tuning

In [60]:
num_epochs = 150

for lr in learning_rates:
  for wd in weight_decays:

    print('='*50)
    print(f'Hyperparameter with lr:{lr} and wd:{wd}')
    print('='*50)

    hyperparameters = {'learning_rate': lr,
                       'weight_decay' : wd
                       }
    # Load the model
    model = LeNet5().to(device)

    # Optimizer and scheduler setup
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


    # Training
    run_training(num_epochs, model, train_loader, validation_loader, test_loader, optimizer, scheduler, criterion, device, 'SGDM-HyperParameterTuning', hyperparameters=hyperparameters, is_wandb = True, n_epochs_stop = 10)



Hyperparameter with lr:0.001 and wd:0.0001


[10/150]: Training Loss: 3.7955315692901612, Training Accuracy: 12.3825
Validation Loss: 3.7704276780413974, Validation Accuracy: 12.97
[11/150]: Training Loss: 3.693796951675415, Training Accuracy: 14.2675
Validation Loss: 3.6661312990127857, Validation Accuracy: 14.76
[12/150]: Training Loss: 3.5919596023559572, Training Accuracy: 16.0125
Validation Loss: 3.5912161389733575, Validation Accuracy: 16.26
[13/150]: Training Loss: 3.5116733703613283, Training Accuracy: 17.3775
Validation Loss: 3.5196323546634356, Validation Accuracy: 16.95
[14/150]: Training Loss: 3.4402521675109865, Training Accuracy: 18.835
Validation Loss: 3.4998342930131656, Validation Accuracy: 17.78
[15/150]: Training Loss: 3.374294019317627, Training Accuracy: 19.835
Validation Loss: 3.412280123704558, Validation Accuracy: 19.05
[16/150]: Training Loss: 3.3126377914428713, Training Accuracy: 20.9225
Validation Loss: 3.3567575497232425, Validation Accuracy: 20.21
[17/150]: Training Loss: 3.264142190551758, Training 

0,1
Test Accuracy,██▁▂▂▂▂▁▁▁▁▂▂▂▃▃▃▃▃▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
Test Loss,▁▄▆▇▇▇███▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
Train Accuracy,▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
Train Loss,██▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
Test Accuracy,12.53
Test Loss,12.75723
Train Accuracy,59.32
Train Loss,1.48869


Hyperparameter with lr:0.001 and wd:0.001


[1/150]: Training Loss: 4.605354942321777, Training Accuracy: 1.0075
Validation Loss: 4.604883777108162, Validation Accuracy: 0.9
[2/150]: Training Loss: 4.602973918151855, Training Accuracy: 1.0475
Validation Loss: 4.6016397293965525, Validation Accuracy: 0.9
[3/150]: Training Loss: 4.596875128936768, Training Accuracy: 1.3425
Validation Loss: 4.590152442834939, Validation Accuracy: 2.52
[4/150]: Training Loss: 4.556022624969483, Training Accuracy: 2.9175
Validation Loss: 4.472016604842653, Validation Accuracy: 3.55
[5/150]: Training Loss: 4.280242394256592, Training Accuracy: 4.6625
Validation Loss: 4.20257385673037, Validation Accuracy: 5.95
[6/150]: Training Loss: 4.122820203781128, Training Accuracy: 6.5925
Validation Loss: 4.098185727550725, Validation Accuracy: 6.72
[7/150]: Training Loss: 4.0448949375152585, Training Accuracy: 7.655
Validation Loss: 4.033748802865387, Validation Accuracy: 8.54
[8/150]: Training Loss: 3.977578921508789, Training Accuracy: 9.06
Validation Loss: 3

0,1
Test Accuracy,▄█▂▁▂▁▂▂▁▁▁▂▂▂▂▂▃▂▃▂▂▃▃▃▃▃▃▄▄▃▄▄▃▃▄▃▃▃▄▃
Test Loss,▃▁▄▄▆▇▇█▇▇▇▇▇▇█▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
Train Accuracy,▁▁▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
Train Loss,███▇▇▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁

0,1
Test Accuracy,12.36
Test Loss,15.7624
Train Accuracy,67.3425
Train Loss,1.17824


Hyperparameter with lr:0.001 and wd:0.0004


[1/150]: Training Loss: 4.605505518341064, Training Accuracy: 1.0675
Validation Loss: 4.60469646514601, Validation Accuracy: 0.9
[2/150]: Training Loss: 4.603742761993408, Training Accuracy: 1.0775
Validation Loss: 4.602513228252435, Validation Accuracy: 0.97
[3/150]: Training Loss: 4.60020831451416, Training Accuracy: 1.42
Validation Loss: 4.596951505940432, Validation Accuracy: 1.39
[4/150]: Training Loss: 4.590583048248291, Training Accuracy: 1.345
Validation Loss: 4.579934718502555, Validation Accuracy: 1.81
[5/150]: Training Loss: 4.522595316314697, Training Accuracy: 2.9
Validation Loss: 4.390095902096694, Validation Accuracy: 3.43
[6/150]: Training Loss: 4.228593465805054, Training Accuracy: 5.2625
Validation Loss: 4.173950313762495, Validation Accuracy: 5.48
[7/150]: Training Loss: 4.1090137573242185, Training Accuracy: 6.88
Validation Loss: 4.087003700292794, Validation Accuracy: 7.44
[8/150]: Training Loss: 4.035317427825928, Training Accuracy: 8.06
Validation Loss: 4.0155219

0,1
Test Accuracy,█▆▁▃▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▂▃▃▃▃▃▃▃▃▃▄▄▃▃▄▄▄▄▄▄
Test Loss,▂▁▅▄▆▆▆▇▇▆▆▇█▇██▇█████▇▇▇███████████████
Train Accuracy,▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
Train Loss,███▇▇▇▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁

0,1
Test Accuracy,11.15
Test Loss,17.28562
Train Accuracy,66.87
Train Loss,1.17442


Hyperparameter with lr:0.01 and wd:0.0001


[1/150]: Training Loss: 4.394053651046753, Training Accuracy: 3.235
Validation Loss: 4.1179144488778086, Validation Accuracy: 6.51
[2/150]: Training Loss: 3.940752759170532, Training Accuracy: 9.1825
Validation Loss: 3.793641860318032, Validation Accuracy: 10.96
[3/150]: Training Loss: 3.607163610458374, Training Accuracy: 14.85
Validation Loss: 3.445422892357893, Validation Accuracy: 18.4
[4/150]: Training Loss: 3.32619926071167, Training Accuracy: 19.315
Validation Loss: 3.202780070578217, Validation Accuracy: 21.55
[5/150]: Training Loss: 3.0974891372680666, Training Accuracy: 23.9725
Validation Loss: 3.0388081802684033, Validation Accuracy: 24.97
[6/150]: Training Loss: 2.901975968170166, Training Accuracy: 27.42
Validation Loss: 3.0498315498327755, Validation Accuracy: 25.89
[7/150]: Training Loss: 2.72861491394043, Training Accuracy: 31.1425
Validation Loss: 2.817266490049423, Validation Accuracy: 30.26
[8/150]: Training Loss: 2.565713974761963, Training Accuracy: 33.8575
Validat

0,1
Test Accuracy,▁▂▃█▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███████
Test Loss,▁▅█▇█▇██▇▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇███▇▇▇▇▇▇▇▇▇▇▇▇
Train Accuracy,▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇███
Train Loss,█▇▇▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁

0,1
Test Accuracy,14.4
Test Loss,31.79171
Train Accuracy,81.7975
Train Loss,0.58307


Hyperparameter with lr:0.01 and wd:0.001


[1/150]: Training Loss: 4.452998904800415, Training Accuracy: 2.615
Validation Loss: 4.1346032179085315, Validation Accuracy: 5.28
[2/150]: Training Loss: 3.9527258255004885, Training Accuracy: 8.88
Validation Loss: 3.7435999432946465, Validation Accuracy: 11.55
[3/150]: Training Loss: 3.6381911922454835, Training Accuracy: 14.3725
Validation Loss: 3.50478922181828, Validation Accuracy: 15.58
[4/150]: Training Loss: 3.3785652015686036, Training Accuracy: 18.865
Validation Loss: 3.3451961089091697, Validation Accuracy: 19.0
[5/150]: Training Loss: 3.17810930519104, Training Accuracy: 22.375
Validation Loss: 3.1223374703887163, Validation Accuracy: 23.6
[6/150]: Training Loss: 3.0013810291290284, Training Accuracy: 25.565
Validation Loss: 3.015145599462424, Validation Accuracy: 25.56
[7/150]: Training Loss: 2.8544215950012206, Training Accuracy: 28.7175
Validation Loss: 2.917889818264421, Validation Accuracy: 28.04
[8/150]: Training Loss: 2.6971351528167724, Training Accuracy: 31.3825
Va

0,1
Test Accuracy,▄█▄▅▃▄▄▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Test Loss,▁▂▆▄▃▅▆▇▇▆▆▇▇▇███▇▇▇██▇▇█▇██████████████
Train Accuracy,▁▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇███
Train Loss,█▇▆▆▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▁▁▁▁

0,1
Test Accuracy,18.45
Test Loss,17.25387
Train Accuracy,77.925
Train Loss,0.72145


Hyperparameter with lr:0.01 and wd:0.0004


[1/150]: Training Loss: 4.436210793685913, Training Accuracy: 2.8025
Validation Loss: 4.1512272403498365, Validation Accuracy: 5.38
[2/150]: Training Loss: 3.943741687011719, Training Accuracy: 9.175
Validation Loss: 3.791590502307673, Validation Accuracy: 11.54
[3/150]: Training Loss: 3.6135140613555907, Training Accuracy: 14.57
Validation Loss: 3.473675949558331, Validation Accuracy: 17.5
[4/150]: Training Loss: 3.3465395004272462, Training Accuracy: 19.22
Validation Loss: 3.3301091406755385, Validation Accuracy: 19.19
[5/150]: Training Loss: 3.1310076709747316, Training Accuracy: 22.94
Validation Loss: 3.0678081117617855, Validation Accuracy: 24.65
[6/150]: Training Loss: 2.9402838905334474, Training Accuracy: 26.4875
Validation Loss: 2.9766739310732313, Validation Accuracy: 25.68
[7/150]: Training Loss: 2.7704892574310302, Training Accuracy: 29.855
Validation Loss: 2.8771827342403924, Validation Accuracy: 28.79
[8/150]: Training Loss: 2.6034393648147582, Training Accuracy: 33.84
Va

0,1
Test Accuracy,█▆▁▅▆▇█▆▆▆▇▇▆▇▇▇▇▇████████▇██▇▇▇▇███████
Test Loss,▁▄▆▅▅▅▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████
Train Accuracy,▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇███
Train Loss,█▇▆▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁

0,1
Test Accuracy,17.1
Test Loss,24.73122
Train Accuracy,80.77
Train Loss,0.6179


##### Training Using Best Hyperparameters Set

In [99]:
num_epochs = 150
lr = 1e-02
wd = 1e-03

hyperparameters = {'learning_rate': lr,
                  'weight_decay' : wd}

# Load the model
model_0 = LeNet5().to(device)

# Optimizer and scheduler setup
optimizer_0 = torch.optim.SGD(model_0.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_0, T_max=num_epochs)


# Training
run_training(num_epochs, model_0, original_train_loader, original_test_loader, original_test_loader, optimizer_0, scheduler, criterion, device, optimizer_name='SGDM', hyperparameters=hyperparameters, is_wandb=True)


**********************************************************************
Test Loss: 1.8986814637093028, Test Accuracy: 57.55


0,1
Test Accuracy,▆█▄▃▄▃▂▁▂▂▂▂▁▁▁▁▂▂▂▁▁▁▁▂▁▂▁▂▂▂▂▁▂▂▂▂▂▂▂▂
Test Loss,█▁▄▅▅▅▆▇▅▆▅▄▅▅▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▄

0,1
Test Accuracy,57.55
Test Loss,1.89868
Train Accuracy,87.314
Train Loss,0.44531


#### AdamW (Adam with Weight Decay)

##### Hyperparameter Tuning

In [71]:
num_epochs = 150

for lr in learning_rates:
  for wd in weight_decays:

    print('='*50)
    print(f'Hyperparameter with lr:{lr} and wd:{wd}')
    print('='*50)

    hyperparameters = {'learning_rate': lr,
                       'weight_decay' : wd
                       }
    # Load the model
    model = LeNet5().to(device)

    # Optimizer and scheduler setup
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


    # Training
    run_training(num_epochs, model, train_loader, validation_loader, test_loader, optimizer, scheduler, criterion, device, 'AdamW-HyperParameterTuning', hyperparameters=hyperparameters, is_wandb = True, n_epochs_stop = 10)



Hyperparameter with lr:0.001 and wd:0.0001


0,1
Train Accuracy,▁
Train Loss,▁

0,1
Train Accuracy,7.7625
Train Loss,4.01706


[1/150]: Training Loss: 4.032126424407959, Training Accuracy: 7.645
Validation Loss: 3.633875617555752, Validation Accuracy: 13.81
[2/150]: Training Loss: 3.4117616390228274, Training Accuracy: 17.9375
Validation Loss: 3.291252329091358, Validation Accuracy: 19.57
[3/150]: Training Loss: 3.0980306770324706, Training Accuracy: 23.7325
Validation Loss: 3.0133402590539045, Validation Accuracy: 25.49
[4/150]: Training Loss: 2.875119793319702, Training Accuracy: 27.905
Validation Loss: 2.9483105938905365, Validation Accuracy: 26.77
[5/150]: Training Loss: 2.6911334384918213, Training Accuracy: 31.615
Validation Loss: 2.851942961383018, Validation Accuracy: 28.77
[6/150]: Training Loss: 2.5383362628936768, Training Accuracy: 34.6425
Validation Loss: 2.8010447116414454, Validation Accuracy: 30.38
[7/150]: Training Loss: 2.3951899276733397, Training Accuracy: 37.7375
Validation Loss: 2.772052654035532, Validation Accuracy: 31.08
[8/150]: Training Loss: 2.2677825828552245, Training Accuracy: 40

0,1
Test Accuracy,█▇▃▃▁▁▂▂▂▂▃▂▂▂▂▂▂▂▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
Test Loss,▁▆▇▅████████▇█▇▇██▇█▇███▇███▇▇█▇███▇▇▇▇▇
Train Accuracy,▁▂▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇███
Train Loss,█▇▆▅▅▄▄▄▄▃▃▃▂▂▂▂▂▁▁▁

0,1
Test Accuracy,3.39
Test Loss,198.79228
Train Accuracy,68.56
Train Loss,1.0573


Hyperparameter with lr:0.001 and wd:0.001


[1/150]: Training Loss: 4.108658624267578, Training Accuracy: 6.6075
Validation Loss: 3.7040131304674087, Validation Accuracy: 12.1
[2/150]: Training Loss: 3.5060665573120118, Training Accuracy: 16.015
Validation Loss: 3.3598362321306947, Validation Accuracy: 18.42
[3/150]: Training Loss: 3.1865245040893555, Training Accuracy: 21.7275
Validation Loss: 3.128495198146553, Validation Accuracy: 22.84
[4/150]: Training Loss: 2.9679495239257814, Training Accuracy: 25.8525
Validation Loss: 2.9615708293428846, Validation Accuracy: 26.83
[5/150]: Training Loss: 2.797335900115967, Training Accuracy: 29.515
Validation Loss: 2.9264626108157406, Validation Accuracy: 26.96
[6/150]: Training Loss: 2.6617756172180176, Training Accuracy: 32.13
Validation Loss: 2.8755724490827816, Validation Accuracy: 28.35
[7/150]: Training Loss: 2.54268462638855, Training Accuracy: 34.435
Validation Loss: 2.7718013547788, Validation Accuracy: 30.25
[8/150]: Training Loss: 2.424749851608276, Training Accuracy: 37.1025


0,1
Test Accuracy,█▅▁▂▃▂▃▃▃▃▃▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
Test Loss,▇▇▁▂▁▄▄▇█▆▃▄▅▆▅▅▆▆▆▆▅▇▇▆▆▇▆▆▅▆▆▆▅▆▅▅▆▆▅▅
Train Accuracy,▁▂▃▃▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇▇███
Train Loss,█▇▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
Test Accuracy,3.33
Test Loss,156.7261
Train Accuracy,64.2825
Train Loss,1.2339


Hyperparameter with lr:0.001 and wd:0.0004


[1/150]: Training Loss: 4.053023485183716, Training Accuracy: 7.3
Validation Loss: 3.71551227873298, Validation Accuracy: 12.6
[2/150]: Training Loss: 3.4803551902770997, Training Accuracy: 16.4825
Validation Loss: 3.3104240119836894, Validation Accuracy: 19.6
[3/150]: Training Loss: 3.1457452793121337, Training Accuracy: 22.71
Validation Loss: 3.1647931253834134, Validation Accuracy: 22.36
[4/150]: Training Loss: 2.9213612785339356, Training Accuracy: 26.8375
Validation Loss: 2.967867093481076, Validation Accuracy: 26.13
[5/150]: Training Loss: 2.7346919048309326, Training Accuracy: 30.745
Validation Loss: 2.870804744161618, Validation Accuracy: 28.43
[6/150]: Training Loss: 2.580346668434143, Training Accuracy: 33.8975
Validation Loss: 2.813369401700937, Validation Accuracy: 29.97
[7/150]: Training Loss: 2.4525743492126466, Training Accuracy: 36.445
Validation Loss: 2.8111886962963517, Validation Accuracy: 31.76
[8/150]: Training Loss: 2.3342738655090334, Training Accuracy: 38.8
Vali

0,1
Test Accuracy,█▁▂▇▃▁▂▁▂▂▂▃▃▃▄▄▄▄▃▃▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
Test Loss,▁█▅▃▃▃▃▄▅▄▃▂▂▃▃▂▄▄▄▄▃▃▃▃▃▃▃▃▃▂▃▃▃▃▃▃▃▂▂▂
Train Accuracy,▁▂▃▃▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
Train Loss,█▇▆▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁

0,1
Test Accuracy,5.39
Test Loss,67.07529
Train Accuracy,63.035
Train Loss,1.27527


Hyperparameter with lr:0.01 and wd:0.0001


[1/150]: Training Loss: 4.622148120880127, Training Accuracy: 0.935
Validation Loss: 4.609954472560032, Validation Accuracy: 0.99
[2/150]: Training Loss: 4.609106838226318, Training Accuracy: 0.9475
Validation Loss: 4.611106529357327, Validation Accuracy: 0.9
[3/150]: Training Loss: 4.608888108062744, Training Accuracy: 0.9925
Validation Loss: 4.608767968074531, Validation Accuracy: 0.84
[4/150]: Training Loss: 4.608837638854981, Training Accuracy: 0.935
Validation Loss: 4.610516976399027, Validation Accuracy: 0.89
[5/150]: Training Loss: 4.608827792358398, Training Accuracy: 0.9775
Validation Loss: 4.610510003035236, Validation Accuracy: 0.9
[6/150]: Training Loss: 4.608845700073243, Training Accuracy: 0.935
Validation Loss: 4.609803257474474, Validation Accuracy: 0.83
[7/150]: Training Loss: 4.608741146087646, Training Accuracy: 0.9625
Validation Loss: 4.60897787665106, Validation Accuracy: 0.89
[8/150]: Training Loss: 4.608720026397705, Training Accuracy: 0.9775
Validation Loss: 4.6

0,1
Test Accuracy,▁▃▃▆▆▆▆▇▇▇▇▇▇▇▇█▇▇▇██▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
Test Loss,▅▄█▅▅▄▂▄▅▃▃▃▃▄▃▃▃▃▃▃▂▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁
Train Accuracy,▅▅▇▅▇▅▆▇█▄▁
Train Loss,█▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,1.01
Test Loss,19.34399
Train Accuracy,0.8625
Train Loss,4.60907


Hyperparameter with lr:0.01 and wd:0.001


[1/150]: Training Loss: 4.625637398529053, Training Accuracy: 0.885
Validation Loss: 4.60909015509733, Validation Accuracy: 1.06
[2/150]: Training Loss: 4.608977200317383, Training Accuracy: 0.94
Validation Loss: 4.609363568056921, Validation Accuracy: 0.92
[3/150]: Training Loss: 4.608765270996094, Training Accuracy: 0.9275
Validation Loss: 4.610460181145152, Validation Accuracy: 0.9
[4/150]: Training Loss: 4.609055269622803, Training Accuracy: 0.855
Validation Loss: 4.609649321076217, Validation Accuracy: 0.91
[5/150]: Training Loss: 4.608893507385254, Training Accuracy: 0.965
Validation Loss: 4.610334041012321, Validation Accuracy: 1.15
[6/150]: Training Loss: 4.609225297546387, Training Accuracy: 1.0125
Validation Loss: 4.608022522774472, Validation Accuracy: 1.16
[7/150]: Training Loss: 4.609016616821289, Training Accuracy: 0.9875
Validation Loss: 4.609603547746209, Validation Accuracy: 0.82
[8/150]: Training Loss: 4.609039859008789, Training Accuracy: 1.0025
Validation Loss: 4.60

0,1
Test Accuracy,▁▆▅▇▇███▇▇▇▆▆▆▆▆▇▇▇▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
Test Loss,▁▇█▆▅▆▆▇█▇▇▇▇▇▇▇▇▇▇█▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
Train Accuracy,▂▄▄▁▅▇▆▇▆▅▅█▄▂▄▆
Train Loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,1.0
Test Loss,51.00808
Train Accuracy,0.9875
Train Loss,4.60864


Hyperparameter with lr:0.01 and wd:0.0004


[1/150]: Training Loss: 4.626314550018311, Training Accuracy: 0.9725
Validation Loss: 4.610879454643103, Validation Accuracy: 0.88
[2/150]: Training Loss: 4.609021481323242, Training Accuracy: 0.9225
Validation Loss: 4.6096283493527945, Validation Accuracy: 1.0
[3/150]: Training Loss: 4.608779692077637, Training Accuracy: 0.9575
Validation Loss: 4.610723838684665, Validation Accuracy: 1.07
[4/150]: Training Loss: 4.609119167327881, Training Accuracy: 0.9125
Validation Loss: 4.60863508236636, Validation Accuracy: 0.89
[5/150]: Training Loss: 4.608814185333252, Training Accuracy: 0.95
Validation Loss: 4.611634175488903, Validation Accuracy: 0.93
[6/150]: Training Loss: 4.6092648048400875, Training Accuracy: 0.8575
Validation Loss: 4.610189516832874, Validation Accuracy: 0.92
[7/150]: Training Loss: 4.6090187705993655, Training Accuracy: 1.005
Validation Loss: 4.610130288798338, Validation Accuracy: 0.88
[8/150]: Training Loss: 4.608886881256104, Training Accuracy: 1.01
Validation Loss: 4

0,1
Test Accuracy,▃█▃▁▁▁▃▄▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
Test Loss,▁▅▇█▆▆▆▆▇▆▆▆▆▆▆▆▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇
Train Accuracy,▆▄▅▃▅▁▇▇▇▇█▃▄
Train Loss,█▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,1.08
Test Loss,7.14208
Train Accuracy,0.9225
Train Loss,4.60871


##### Training Using Best Hyperparameters Set

In [100]:
num_epochs = 150
lr = 1e-03
wd = 4e-04

hyperparameters = {'learning_rate': lr,
                    'weight_decay' : wd
                  }

# Load the model
model_1 = LeNet5().to(device)

# Optimizer and scheduler setup
optimizer_1 = torch.optim.AdamW(model_1.parameters(), lr=lr, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_1, T_max=num_epochs)


# Training
run_training(num_epochs, model_1, original_train_loader, original_test_loader, original_test_loader, optimizer_1, scheduler, criterion, device, optimizer_name='AdamW', hyperparameters=hyperparameters, is_wandb = True)

**********************************************************************
Test Loss: 2.113396340115055, Test Accuracy: 49.46


0,1
Test Accuracy,█▅▃▂▁▂▁▁▂▁▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁
Test Loss,▁▃▆▃█▆▅▇▅▅▃▃▃▄▄▃▃▃▄▅▆▅▅▄▄▅▅▅▄▄▄▄▄▄▄▄▄▄▃▃

0,1
Test Accuracy,49.46
Test Loss,2.1134
Train Accuracy,65.458
Train Loss,1.1963


# **Large Batch Optimizers**


#### LARS

In [13]:



class LARS(Optimizer):
    """
    Implements LARS (Layer-wise Adaptive Rate Scaling).

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
        lr (float): learning rate (default: 1e-3)
        momentum (float, optional): momentum factor (default: 0)
        trust_coef (float, optional): LARS coefficient as used in the paper (default: 1e-3)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)
        epsilon (float, optional): epsilon to prevent zero division (default: 0)
    """

    def __init__(
            self,
            params,
            lr: float = 1e-3,
            momentum: float = 0,
            trust_coef: float = 1e-3,
            dampening: float = 0,
            weight_decay: float = 0,
            nesterov=False,
            epsilon: float = 1e-9
    ):
        """
        Initializes a new instance of the LARS optimizer.

        Args:
            params: iterable of parameters to optimize or dicts defining
            lr: learning rate
            momentum: momentum factor
            trust_coef: LARS coefficient as used in the paper
            weight_decay: weight decay (L2 penalty)
            dampening: dampening for momentum
            nesterov: enables Nesterov momentum
            epsilon: epsilon to prevent zero division
        """

        if lr <= 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(
            lr=lr,
            momentum=momentum,
            trust_coef=trust_coef,
            dampening=dampening,
            weight_decay=weight_decay,
            nesterov=nesterov,
            epsilon=epsilon)

        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(LARS, self).__init__(params, defaults)

    def __setstate__(self, state):
        """
        Sets the state of the optimizer.

        Args:
            state: The state to set the optimizer to.
        """
        super(LARS, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def _compute_local_lr(self, p, weight_decay, trust_coef, epsilon):
        """
        Computes the local learning rate for a given parameter.

        Args:
            p: The parameter to compute the local learning rate for.
            weight_decay: The weight decay factor.
            trust_coef: The trust coefficient.
            epsilon: A small constant for numerical stability.

        Returns:
            float: The computed local learning rate.
        """
        w_norm = torch.norm(p.data)
        g_norm = torch.norm(p.grad.data)
        if w_norm * g_norm > 0:
            return trust_coef * w_norm / (g_norm + weight_decay * w_norm + epsilon)
        else:
            return 1

    def _update_params(self, p, d_p, local_lr, lr, momentum, buf,
                       dampening, nesterov, weight_decay):
        """
        Updates the parameters with the computed update.

        Args:
            p: The parameter to be updated.
            d_p: The computed update for the parameter.
            local_lr: The local learning rate.
            lr: The global learning rate.
            momentum: The momentum factor.
            buf: The buffer for the momentum.
            dampening: The dampening for the momentum.
            nesterov: A flag indicating whether to use Nesterov momentum.
            weight_decay: The weight decay factor.
        """
        if weight_decay != 0:
            d_p.add_(weight_decay, p.data)
        if momentum != 0:
            param_state = self.state[p]
            if 'momentum_buffer' not in param_state:
                buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
            else:
                buf = param_state['momentum_buffer']
            buf.mul_(momentum).add_(1 - dampening, d_p)
            if nesterov:
                d_p = d_p.add(momentum, buf)
            else:
                d_p = buf

        p.data.add_(-local_lr * lr, d_p)

    def step(self):
        """Performs a single optimization step.
        """

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            trust_coef = group['trust_coef']
            dampening = group['dampening']
            nesterov = group['nesterov']
            epsilon = group['epsilon']

            for p in group['params']:
                if p.grad is None:
                    continue
                local_lr = self._compute_local_lr(p, weight_decay, trust_coef, epsilon)
                self._update_params(p, p.grad.data, local_lr, group['lr'], momentum, None, dampening, nesterov, weight_decay)

        return None

In [78]:
learning_rates = [1e-02, 5e-02, 1e-01, 5e-01, 1, 1.5, 2]
wd = 1e-03

##### LARS Hyperparameter Tunning

In [67]:
num_epochs = 150

for lr in learning_rates:

  print('='*50)
  print(f'Hyperparameter with lr:{lr} and wd:{wd}')
  print('='*50)

  hyperparameters = {'learning_rate': lr,
                      'weight_decay' : wd
                      }
  # Load the model
  model = LeNet5().to(device)

  # Optimizer and scheduler setup
  optimizer = LARS(model.parameters(), lr=lr, weight_decay=wd, momentum=0.9)
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


  # Training
  run_training(num_epochs,
                model,
                train_loader,
                validation_loader,
                test_loader,
                optimizer,
                scheduler,
                criterion,
                device,
                'LARS-HyperParameterTuning',
                hyperparameters=hyperparameters,
                is_wandb = True,
                n_epochs_stop = 10
  )




Hyperparameter with lr:0.01 and wd:0.001


[1/150]: Training Loss: 4.605613048553467, Training Accuracy: 1.0325
Validation Loss: 4.605481047539195, Validation Accuracy: 1.01
[2/150]: Training Loss: 4.603893094635009, Training Accuracy: 1.2825
Validation Loss: 4.603514659176966, Validation Accuracy: 1.37
[3/150]: Training Loss: 4.601452378845215, Training Accuracy: 1.7625
Validation Loss: 4.600479241389378, Validation Accuracy: 1.96
[4/150]: Training Loss: 4.597398109436035, Training Accuracy: 1.815
Validation Loss: 4.595512341541849, Validation Accuracy: 1.41
[5/150]: Training Loss: 4.590549831390381, Training Accuracy: 1.5175
Validation Loss: 4.587287993947411, Validation Accuracy: 1.68
[6/150]: Training Loss: 4.579249038696289, Training Accuracy: 1.93
Validation Loss: 4.573939116897097, Validation Accuracy: 2.37
[7/150]: Training Loss: 4.560530513763427, Training Accuracy: 2.69
Validation Loss: 4.5522112117451465, Validation Accuracy: 2.64
[8/150]: Training Loss: 4.5309873321533205, Training Accuracy: 2.92
Validation Loss: 4.

0,1
Test Accuracy,█▇▂▂▂▂▂▁▁▁▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
Test Loss,▁▁▇██▇▇█▇▇▆▇▇▇▇▆▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇
Train Accuracy,▁▁▂▂▃▃▄▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████████
Train Loss,████▇▆▅▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,5.7
Test Loss,4.83727
Train Accuracy,9.7625
Train Loss,3.99084


Hyperparameter with lr:0.05 and wd:0.001


[1/150]: Training Loss: 4.600790646362305, Training Accuracy: 1.1675
Validation Loss: 4.590667490746565, Validation Accuracy: 1.26
[2/150]: Training Loss: 4.559383809661865, Training Accuracy: 2.1575
Validation Loss: 4.506907958133965, Validation Accuracy: 2.54
[3/150]: Training Loss: 4.4036731437683105, Training Accuracy: 3.8525
Validation Loss: 4.321598778864381, Validation Accuracy: 4.84
[4/150]: Training Loss: 4.246265041351318, Training Accuracy: 5.2175
Validation Loss: 4.227905522486207, Validation Accuracy: 6.02
[5/150]: Training Loss: 4.176178216934204, Training Accuracy: 6.1675
Validation Loss: 4.179626170237353, Validation Accuracy: 6.38
[6/150]: Training Loss: 4.135982835769654, Training Accuracy: 6.6675
Validation Loss: 4.148335417364813, Validation Accuracy: 6.64
[7/150]: Training Loss: 4.108450821304321, Training Accuracy: 7.2275
Validation Loss: 4.130975275282648, Validation Accuracy: 6.89
[8/150]: Training Loss: 4.0844139430999755, Training Accuracy: 7.6
Validation Loss

0,1
Test Accuracy,▆█▂▁▂▂▁▁▂▂▃▃▂▂▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂
Test Loss,▃▁▄▆▆▇██▇▇▆▇▇▇▇▇▇▇▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
Train Accuracy,▁▂▃▃▃▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇████████████████
Train Loss,█▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,10.88
Test Loss,5.46211
Train Accuracy,24.6675
Train Loss,3.16033


Hyperparameter with lr:0.1 and wd:0.001


[1/150]: Training Loss: 4.554393922424317, Training Accuracy: 2.2
Validation Loss: 4.421840394378468, Validation Accuracy: 3.74
[2/150]: Training Loss: 4.284681567382813, Training Accuracy: 4.52
Validation Loss: 4.239770348664302, Validation Accuracy: 4.95
[3/150]: Training Loss: 4.181965134429932, Training Accuracy: 5.585
Validation Loss: 4.184003517126581, Validation Accuracy: 5.53
[4/150]: Training Loss: 4.125413941192627, Training Accuracy: 6.6775
Validation Loss: 4.123335443484556, Validation Accuracy: 6.52
[5/150]: Training Loss: 4.084817845535278, Training Accuracy: 7.4525
Validation Loss: 4.096477279237881, Validation Accuracy: 6.87
[6/150]: Training Loss: 4.048351853179931, Training Accuracy: 8.1125
Validation Loss: 4.069233142646255, Validation Accuracy: 7.8
[7/150]: Training Loss: 4.018422792816162, Training Accuracy: 8.56
Validation Loss: 4.0364193445558, Validation Accuracy: 8.44
[8/150]: Training Loss: 3.989247378540039, Training Accuracy: 8.9625
Validation Loss: 4.010960

0,1
Test Accuracy,▄█▆▄▅▃▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂
Test Loss,▅▁▆▆▇███▇▇▆▇█▇▇▇▇▇▇▇▇▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
Train Accuracy,▁▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█████████
Train Loss,█▇▆▆▆▅▅▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,12.53
Test Loss,6.14717
Train Accuracy,32.1475
Train Loss,2.75643


Hyperparameter with lr:0.5 and wd:0.001


[1/150]: Training Loss: 4.377298123168945, Training Accuracy: 3.2725
Validation Loss: 4.2101050941807445, Validation Accuracy: 5.2
[2/150]: Training Loss: 4.07323479423523, Training Accuracy: 7.05
Validation Loss: 3.997956647994412, Validation Accuracy: 8.69
[3/150]: Training Loss: 3.9166299026489257, Training Accuracy: 9.98
Validation Loss: 3.8700176090191882, Validation Accuracy: 10.85
[4/150]: Training Loss: 3.76255333404541, Training Accuracy: 12.6475
Validation Loss: 3.7763638131937403, Validation Accuracy: 12.03
[5/150]: Training Loss: 3.6385873168945313, Training Accuracy: 14.8
Validation Loss: 3.5995885247637514, Validation Accuracy: 15.3
[6/150]: Training Loss: 3.5313883621215822, Training Accuracy: 16.6625
Validation Loss: 3.547201700271315, Validation Accuracy: 16.33
[7/150]: Training Loss: 3.4311797870635985, Training Accuracy: 18.0425
Validation Loss: 3.447270431336324, Validation Accuracy: 18.49
[8/150]: Training Loss: 3.3406018447875976, Training Accuracy: 20.2325
Valida

0,1
Test Accuracy,█▄▂▂▂▁▂▁▁▁▁▁▁▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
Test Loss,█▁▅▄▄▅▆▇▅▄▃▄▅▄▅▅▄▅▄▅▅▄▄▄▄▄▄▄▄▅▅▅▅▅▄▅▅▅▄▄
Train Accuracy,▁▁▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇████
Train Loss,█▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁

0,1
Test Accuracy,13.22
Test Loss,16.03359
Train Accuracy,67.9275
Train Loss,1.13058


Hyperparameter with lr:1 and wd:0.001


[1/150]: Training Loss: 4.390103232955933, Training Accuracy: 2.935
Validation Loss: 4.160131826522244, Validation Accuracy: 5.54
[2/150]: Training Loss: 4.031327721786499, Training Accuracy: 7.5025
Validation Loss: 3.9958771960750505, Validation Accuracy: 8.6
[3/150]: Training Loss: 3.800543330383301, Training Accuracy: 11.665
Validation Loss: 3.7379826664165328, Validation Accuracy: 12.06
[4/150]: Training Loss: 3.6242603660583494, Training Accuracy: 14.345
Validation Loss: 3.559469526740396, Validation Accuracy: 15.28
[5/150]: Training Loss: 3.48455090675354, Training Accuracy: 16.9325
Validation Loss: 3.491581111956554, Validation Accuracy: 16.56
[6/150]: Training Loss: 3.3448009601593016, Training Accuracy: 19.435
Validation Loss: 3.3586603653658726, Validation Accuracy: 19.41
[7/150]: Training Loss: 3.229361641693115, Training Accuracy: 21.4375
Validation Loss: 3.2139996130754995, Validation Accuracy: 21.78
[8/150]: Training Loss: 3.1254781677246095, Training Accuracy: 23.4825
Va

0,1
Test Accuracy,█▅▁▂▁▁▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▂▂▂▂▂▂▂▁▂▂▁▂▂▂▂▂▂▂
Test Loss,▁▂▅▆▇▆▇█▇▇▇▇▇██▇███████▇▇███████████████
Train Accuracy,▁▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Train Loss,█▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
Test Accuracy,13.77
Test Loss,18.36643
Train Accuracy,65.5075
Train Loss,1.19387


Hyperparameter with lr:1.5 and wd:0.001


[1/150]: Training Loss: 4.426865048217773, Training Accuracy: 2.45
Validation Loss: 4.219578746018136, Validation Accuracy: 4.19
[2/150]: Training Loss: 4.060279036712647, Training Accuracy: 6.7225
Validation Loss: 3.8940811293899635, Validation Accuracy: 9.43
[3/150]: Training Loss: 3.7707011901855467, Training Accuracy: 11.6925
Validation Loss: 3.8030508642743346, Validation Accuracy: 11.0
[4/150]: Training Loss: 3.58028080368042, Training Accuracy: 15.05
Validation Loss: 3.516739834645751, Validation Accuracy: 16.1
[5/150]: Training Loss: 3.4152388591766356, Training Accuracy: 17.82
Validation Loss: 3.411277610025588, Validation Accuracy: 17.62
[6/150]: Training Loss: 3.2448327434539794, Training Accuracy: 20.7725
Validation Loss: 3.2892891753251385, Validation Accuracy: 19.86
[7/150]: Training Loss: 3.1129616333007815, Training Accuracy: 23.41
Validation Loss: 3.155244686041668, Validation Accuracy: 22.75
[8/150]: Training Loss: 2.9936462223052978, Training Accuracy: 25.73
Validati

0,1
Test Accuracy,█▆▁▄▃▄▄▄▄▄▄▄▄▄▄▅▅▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
Test Loss,▁▅▇▇▇▆▇██▇▇█████████████████████████████
Train Accuracy,▁▁▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
Train Loss,█▇▇▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
Test Accuracy,15.72
Test Loss,20.29688
Train Accuracy,69.42
Train Loss,1.01537


Hyperparameter with lr:2 and wd:0.001


[1/150]: Training Loss: 4.6014143745422365, Training Accuracy: 0.98
Validation Loss: 4.606205284215842, Validation Accuracy: 0.91
[2/150]: Training Loss: 4.605726162719726, Training Accuracy: 0.93
Validation Loss: 4.606422876856129, Validation Accuracy: 0.91
[3/150]: Training Loss: 4.605443458557129, Training Accuracy: 0.9375
Validation Loss: 4.606579792727331, Validation Accuracy: 0.82
[4/150]: Training Loss: 4.6053140991210935, Training Accuracy: 1.015
Validation Loss: 4.606754381945179, Validation Accuracy: 0.82
[5/150]: Training Loss: 4.6052259376525875, Training Accuracy: 1.0275
Validation Loss: 4.60688726765335, Validation Accuracy: 0.82
[6/150]: Training Loss: 4.605191477966309, Training Accuracy: 1.0175
Validation Loss: 4.607019655264107, Validation Accuracy: 0.82
[7/150]: Training Loss: 4.6051658882141115, Training Accuracy: 1.015
Validation Loss: 4.607097844409335, Validation Accuracy: 0.82
[8/150]: Training Loss: 4.605149390411377, Training Accuracy: 0.9775
Validation Loss: 

0,1
Test Accuracy,▁▆███▆▇▇▇▇▇▇▇▇▆▇███▇████████████████████
Test Loss,█▅▁▁▃▃▅▄▄▄▄▄▄▄▄▅▅▅▅▅▆▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▅▅▅
Train Accuracy,▅▁▂▇█▇▇▄▂▇▃
Train Loss,▁██▇▇▇▇▇▇▇▇

0,1
Test Accuracy,1.0
Test Loss,4.60539
Train Accuracy,0.9625
Train Loss,4.60514


##### LARS BaseLine B-Size 64 

In [26]:
num_epochs = 150
lr = 1.5
wd = 1e-03

hyperparameters = {'learning_rate': lr,
                    'weight_decay' : wd
                  }

# Load the model
model = LeNet5().to(device)

# Optimizer and scheduler setup
optimizer = LARS(model.parameters(), lr=lr, weight_decay=wd, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


# Training
run_training(
    num_epochs,
    model,
    original_train_loader,
    original_test_loader,
    original_test_loader,
    optimizer,
    scheduler,
    criterion,
    device,
    optimizer_name='LARS',
    hyperparameters=hyperparameters,
    is_wandb = True
)

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at ../torch/csrc/utils/python_arg_parser.cpp:1578.)
  d_p.add_(weight_decay, p.data)


[1/150]: Training Loss: 4.004822687724667, Training Accuracy: 8.156
Validation Loss: 3.6169370329304105, Validation Accuracy: 15.1
[2/150]: Training Loss: 3.489779775099986, Training Accuracy: 16.288
Validation Loss: 3.195649910884298, Validation Accuracy: 21.71
[3/150]: Training Loss: 3.188335987003258, Training Accuracy: 21.346
Validation Loss: 2.9737777011409685, Validation Accuracy: 26.15
[4/150]: Training Loss: 3.0118914528576006, Training Accuracy: 24.982
Validation Loss: 2.7995891707717995, Validation Accuracy: 29.11
[5/150]: Training Loss: 2.8688232484071152, Training Accuracy: 28.08
Validation Loss: 2.7174989600090464, Validation Accuracy: 30.23
[6/150]: Training Loss: 2.7214920926276984, Training Accuracy: 30.662
Validation Loss: 2.5305145515757763, Validation Accuracy: 34.82
[7/150]: Training Loss: 2.6200312084858983, Training Accuracy: 32.576
Validation Loss: 2.474041129373441, Validation Accuracy: 35.97
[8/150]: Training Loss: 2.53857664821093, Training Accuracy: 34.358
Va

0,1
Test Accuracy,▆█▃▄▃▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Test Loss,▁▃▆▅██▇█▆▇▆▅▆▇█▇▆▇▆▇▇▇▇▇▇▇█▇▇▇▇▇▇▇▆▆▇▆▆▆
Train Accuracy,▁▂▃▃▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇█████████
Train Loss,█▇▆▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,56.63
Test Loss,1.94981
Train Accuracy,85.464
Train Loss,0.49865


##### LARS Test Large Batches

In [28]:
num_epochs = 150
lr = 1.5
wd = 1e-03
batch_sizes = [512, 1024, 2048, 4096, 8192, 16384 , 32768]
learning_rates = [lr * (batch_size / 64.0) ** 0.5 for batch_size in batch_sizes] # Root square scale-up of learning rate


for i, batch_size in enumerate(batch_sizes):

  print('='*50)
  print(f'Batch size: {batch_size}, Weight decay: {wd}')
  print('='*50)

  hyperparameters = {
    'batch_size': batch_size,
    'learning_rate': learning_rates[i],
    'weight_decay' : wd
  }
  if batch_size <= 4096:
  
    # load data
    data = CIFAR100Data(batch_size= batch_size)
    original_train_loader_large_batch, original_test_loader_large_batch = data.train_test()
    # Load the model
    model = LeNet5().to(device)
    # Optimizer and scheduler setup
    optimizer = LARS(model.parameters(), lr= learning_rates[i], weight_decay=wd, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


    # Training
    run_training(
        num_epochs,
        model,
        original_train_loader_large_batch,
        original_test_loader_large_batch,
        original_test_loader_large_batch,
        optimizer,
        scheduler,
        criterion,
        device,
        optimizer_name='LARS_Large_Batches',
        hyperparameters=hyperparameters,
        is_wandb = True
    )
  else:
    # load data
    data = CIFAR100Data(batch_size= 4096)
    original_train_loader_large_batch, original_test_loader_large_batch = data.train_test()
    # Load the model
    model = LeNet5().to(device)
    # Optimizer and scheduler setup
    optimizer = LARS(model.parameters(), lr= learning_rates[i], weight_decay=wd, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    accumulation_steps = batch_size // 4096

    # Training
    run_training(
        num_epochs,
        model,
        original_train_loader_large_batch,
        original_test_loader_large_batch,
        original_test_loader_large_batch,
        optimizer,
        scheduler,
        criterion,
        device,
        optimizer_name='LARS_Large_Batches',
        accumulation_steps=accumulation_steps,
        hyperparameters=hyperparameters,
        is_wandb = True
    )

Batch size: 512, Learning rate: 4.242640687119286, Weight decay: 0.001


[1/150]: Training Loss: 4.194531238808924, Training Accuracy: 5.802
Validation Loss: 3.8219263911247254, Validation Accuracy: 11.06
[2/150]: Training Loss: 3.737521487839368, Training Accuracy: 12.162
Validation Loss: 3.5326468467712404, Validation Accuracy: 15.5
[3/150]: Training Loss: 3.469568571265863, Training Accuracy: 16.654
Validation Loss: 3.2325081706047056, Validation Accuracy: 21.14
[4/150]: Training Loss: 3.2623822543085836, Training Accuracy: 20.338
Validation Loss: 3.041664385795593, Validation Accuracy: 24.34
[5/150]: Training Loss: 3.1202199824002324, Training Accuracy: 23.134
Validation Loss: 2.9670259952545166, Validation Accuracy: 26.23
[6/150]: Training Loss: 2.973516508024566, Training Accuracy: 25.864
Validation Loss: 2.835892844200134, Validation Accuracy: 28.97
[7/150]: Training Loss: 2.8618772808386357, Training Accuracy: 27.73
Validation Loss: 2.731530475616455, Validation Accuracy: 31.08
[8/150]: Training Loss: 2.7622045083921782, Training Accuracy: 29.856
Va

0,1
Test Accuracy,█▆▁▅▄▅▃▆▄▃▃▅▅▅▅▅▆▆▆▆
Test Loss,▆▆█▄▃▃▄▂▂▃▂▃▃▂▃▃▁▂▁▁
Train Accuracy,▁▂▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇█████████
Train Loss,█▇▆▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,56.18
Test Loss,1.97009
Train Accuracy,84.57
Train Loss,0.5342


Batch size: 1024, Learning rate: 6.0, Weight decay: 0.001


[1/150]: Training Loss: 4.271560571631607, Training Accuracy: 4.982
Validation Loss: 3.9391962051391602, Validation Accuracy: 8.84
[2/150]: Training Loss: 3.82728639914065, Training Accuracy: 10.948
Validation Loss: 3.6827336311340333, Validation Accuracy: 13.37
[3/150]: Training Loss: 3.6269987845907408, Training Accuracy: 14.39
Validation Loss: 3.4175909042358397, Validation Accuracy: 18.27
[4/150]: Training Loss: 3.4358463238696664, Training Accuracy: 17.696
Validation Loss: 3.2815504550933836, Validation Accuracy: 20.38
[5/150]: Training Loss: 3.3039202203555984, Training Accuracy: 19.798
Validation Loss: 3.171510195732117, Validation Accuracy: 23.09
[6/150]: Training Loss: 3.1788000671231016, Training Accuracy: 22.106
Validation Loss: 3.083156633377075, Validation Accuracy: 24.25
[7/150]: Training Loss: 3.0781905553778826, Training Accuracy: 24.04
Validation Loss: 2.9470993995666506, Validation Accuracy: 26.99
[8/150]: Training Loss: 2.9905799846259917, Training Accuracy: 25.462
V

0,1
Test Accuracy,█▁▁▃▂▂▅▄▃▄
Test Loss,█▂▂▁▄▄▃▄▃▂
Train Accuracy,▁▂▂▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇█████████
Train Loss,█▇▆▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,55.35
Test Loss,2.00062
Train Accuracy,83.206
Train Loss,0.58266


Batch size: 2048, Learning rate: 8.485281374238571, Weight decay: 0.001


[1/150]: Training Loss: 4.36701738357544, Training Accuracy: 3.886
Validation Loss: 4.163058471679688, Validation Accuracy: 5.67
[2/150]: Training Loss: 4.113679485321045, Training Accuracy: 6.492
Validation Loss: 3.9081267356872558, Validation Accuracy: 9.58
[3/150]: Training Loss: 3.8604202461242676, Training Accuracy: 10.016
Validation Loss: 3.6436347484588625, Validation Accuracy: 14.06
[4/150]: Training Loss: 3.7227595615386964, Training Accuracy: 12.524
Validation Loss: 3.58444561958313, Validation Accuracy: 15.49
[5/150]: Training Loss: 3.5834127140045164, Training Accuracy: 14.702
Validation Loss: 3.4575596332550047, Validation Accuracy: 17.38
[6/150]: Training Loss: 3.4569161987304686, Training Accuracy: 16.904
Validation Loss: 3.262946367263794, Validation Accuracy: 20.16
[7/150]: Training Loss: 3.3158952045440673, Training Accuracy: 19.406
Validation Loss: 3.2002731800079345, Validation Accuracy: 22.26
[8/150]: Training Loss: 3.1832161426544188, Training Accuracy: 22.068
Val

0,1
Test Accuracy,▁█▆▃▇
Test Loss,▆▄▆█▁
Train Accuracy,▁▂▂▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇█████████
Train Loss,█▇▆▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,54.15
Test Loss,1.9843
Train Accuracy,80.532
Train Loss,0.67401


Batch size: 4096, Learning rate: 12.0, Weight decay: 0.001


[1/150]: Training Loss: 4.448218419001653, Training Accuracy: 3.104
Validation Loss: 4.19553820292155, Validation Accuracy: 5.83
[2/150]: Training Loss: 4.162900081047645, Training Accuracy: 5.992
Validation Loss: 4.001244783401489, Validation Accuracy: 8.19
[3/150]: Training Loss: 3.9526728116548977, Training Accuracy: 9.02
Validation Loss: 3.849735975265503, Validation Accuracy: 10.99
[4/150]: Training Loss: 3.9088880098783054, Training Accuracy: 9.748
Validation Loss: 3.7720150152842202, Validation Accuracy: 11.82
[5/150]: Training Loss: 3.7579474999354434, Training Accuracy: 11.674
Validation Loss: 3.6346964836120605, Validation Accuracy: 13.58
[6/150]: Training Loss: 3.8317384169651914, Training Accuracy: 10.848
Validation Loss: 3.6928911209106445, Validation Accuracy: 13.8
[7/150]: Training Loss: 3.6279670825371375, Training Accuracy: 13.764
Validation Loss: 3.426232655843099, Validation Accuracy: 17.9
[8/150]: Training Loss: 3.559702047934899, Training Accuracy: 15.198
Validatio

0,1
Test Accuracy,▂▁█
Test Loss,▇█▁
Train Accuracy,▁▁▂▂▃▄▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▆▇▇▇▇▇▇▇█████████
Train Loss,█▇▇▆▅▅▅▅▄▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,54.48
Test Loss,1.81654
Train Accuracy,71.694
Train Loss,0.98847


Batch size: 8192, Learning rate: 16.970562748477143, Weight decay: 0.001


[1/150]: Training Loss: 4.543815356034499, Training Accuracy: 1.934
Validation Loss: 4.557299613952637, Validation Accuracy: 4.11
[2/150]: Training Loss: 4.4831575613755446, Training Accuracy: 3.478
Validation Loss: 4.407679557800293, Validation Accuracy: 3.73
[3/150]: Training Loss: 4.383443318880522, Training Accuracy: 4.272
Validation Loss: 4.315660317738851, Validation Accuracy: 4.56
[4/150]: Training Loss: 4.753879473759578, Training Accuracy: 2.854
Validation Loss: 4.617130438486735, Validation Accuracy: 1.55
[5/150]: Training Loss: 4.834007776700533, Training Accuracy: 1.152
Validation Loss: 4.623517354329427, Validation Accuracy: 1.21
[6/150]: Training Loss: 4.748236729548528, Training Accuracy: 1.046
Validation Loss: 4.616909344991048, Validation Accuracy: 1.0
[7/150]: Training Loss: 4.6143631201524, Training Accuracy: 1.022
Validation Loss: 4.6133114496866865, Validation Accuracy: 0.95
[8/150]: Training Loss: 86586261001215.23, Training Accuracy: 1.076
Validation Loss: 199153

0,1
Test Accuracy,▁██
Test Loss,█▅▁
Train Accuracy,█▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train Loss,▁▁▁ █ █ █ ████ ██ ██ █ █

0,1
Test Accuracy,1.04
Test Loss,7.343644848636065e+34
Train Accuracy,1.01
Train Loss,inf


Batch size: 16384, Learning rate: 24.0, Weight decay: 0.001


[1/150]: Training Loss: 4.5802256144010105, Training Accuracy: 1.524
Validation Loss: 4.505903720855713, Validation Accuracy: 3.24
[2/150]: Training Loss: 4.487952305720403, Training Accuracy: 2.782
Validation Loss: 4.534458001454671, Validation Accuracy: 1.82
[3/150]: Training Loss: 4.498784615443303, Training Accuracy: 2.232
Validation Loss: 4.477282365163167, Validation Accuracy: 2.08
[4/150]: Training Loss: 4.4941191306481, Training Accuracy: 2.178
Validation Loss: 4.612959861755371, Validation Accuracy: 1.29
[5/150]: Training Loss: 4.5912819642287035, Training Accuracy: 1.504
Validation Loss: 5.598563989003499, Validation Accuracy: 2.19
[6/150]: Training Loss: 5.028410141284649, Training Accuracy: 1.632
Validation Loss: 7.806542873382568, Validation Accuracy: 1.07
[7/150]: Training Loss: 14089.29747926272, Training Accuracy: 1.044
Validation Loss: 7.282219409942627, Validation Accuracy: 0.93
[8/150]: Training Loss: 461533512.8043683, Training Accuracy: 0.978
Validation Loss: 14518

0,1
Test Accuracy,█▃▁
Train Accuracy,█▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train Loss,▁▁▁████

0,1
Test Accuracy,1.0
Test Loss,
Train Accuracy,1.0
Train Loss,


Batch size: 32768, Learning rate: 33.941125496954285, Weight decay: 0.001


[1/150]: Training Loss: 4.600700745215783, Training Accuracy: 1.18
Validation Loss: 4.5345486005147295, Validation Accuracy: 2.25
[2/150]: Training Loss: 4.525542259216309, Training Accuracy: 2.194
Validation Loss: 4.480770270029704, Validation Accuracy: 3.63
[3/150]: Training Loss: 4.467245468726525, Training Accuracy: 3.348
Validation Loss: 4.682004292805989, Validation Accuracy: 2.02
[4/150]: Training Loss: 4.584587500645564, Training Accuracy: 2.316
Validation Loss: 4.589324951171875, Validation Accuracy: 1.07
[5/150]: Training Loss: 4.593533919407771, Training Accuracy: 1.128
Validation Loss: 4.565906047821045, Validation Accuracy: 2.24
[6/150]: Training Loss: 4.610266098609338, Training Accuracy: 1.65
Validation Loss: 5.121236960093181, Validation Accuracy: 1.01
[7/150]: Training Loss: 5.261244590465839, Training Accuracy: 1.226
Validation Loss: 7.05544392267863, Validation Accuracy: 1.19
[8/150]: Training Loss: 6.352289456587571, Training Accuracy: 1.176
Validation Loss: 22.8200

0,1
Test Accuracy,█▃▁
Train Accuracy,▇█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Train Loss,▁▁▁▁█

0,1
Test Accuracy,1.0
Test Loss,
Train Accuracy,1.0
Train Loss,


#### LAMB

In [101]:
class LAMB(Optimizer):
    """
    Implements LAMB (Layer-wise Adaptive Moments) optimizer.

    Args:
        params (iterable): Iterable of parameters to optimize or dicts defining parameter groups.
        learning_rate (Union[float, Callable], optional): The learning rate. Default is 0.001.
        beta_1 (float, optional): The exponential decay rate for the 1st moment estimates. Default is 0.9.
        beta_2 (float, optional): The exponential decay rate for the 2nd moment estimates. Default is 0.999.
        epsilon (float, optional): A small constant for numerical stability. Default is 1e-6.
        weight_decay (float, optional): Weight decay. Default is 0.0.
        exclude_from_weight_decay (Optional[List[str]], optional): List of regex patterns of variables excluded from weight decay. Variables whose name contain a substring matching the pattern will be excluded. Default is None.
        exclude_from_layer_adaptation (Optional[List[str]], optional): List of regex patterns of variables excluded from layer adaptation. Variables whose name contain a substring matching the pattern will be excluded. Default is None.
        name (str, optional): Optional name for the operations created when applying gradients. Defaults to "LAMB".
        **kwargs: Keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to allow time inverse decay of learning rate. `lr` is included for backward compatibility, recommended to use `learning_rate` instead.

    Note:
        - If "weight_decay_rate" is found in kwargs, it will be renamed to "weight_decay", and will be deprecated in Addons 0.18.
        - If exclude_from_layer_adaptation is None, it will be set to exclude_from_weight_decay.
    """

    def __init__(
        self,
        params,
        lr: Union[float, Callable] = 0.001,
        beta_1: float = 0.9,
        beta_2: float = 0.999,
        epsilon: float = 1e-6,
        wd: float = 0.0,
        exclude_from_weight_decay: Optional[List[str]] = None,
        exclude_from_layer_adaptation: Optional[List[str]] = None,
        name: str = "LAMB",
        **kwargs,
    ):
        """
        Initializes a new instance of the LAMB optimizer.
        """

        defaults = dict(
            lr=lr,
            betas=(beta_1, beta_2),
            eps=epsilon,
            wd=wd,
            **kwargs)
        super().__init__(params, defaults)

        self.exclude_from_weight_decay = exclude_from_weight_decay
        # exclude_from_layer_adaptation is set to exclude_from_weight_decay if
        # the arg is None.
        if exclude_from_layer_adaptation:
            self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
        else:
            self.exclude_from_layer_adaptation = exclude_from_weight_decay

    def _compute_update(self, p, grad, state, group):
        """
        Computes the update for a given parameter.

        Args:
            p (Tensor): The parameter to be updated.
            grad (Tensor): The gradient of the parameter.
            state (dict): A dictionary containing information about the optimization state.
            group (dict): A dictionary containing the optimization parameters.

        Returns:
            Tensor: The computed update for the parameter.
        """
        # State initialization
        if len(state) == 0:
            state['step'] = 0
            # Exponential moving average of gradient values
            state['exp_avg'] = torch.zeros_like(p.data)
            # Exponential moving average of squared gradient values
            state['exp_avg_sq'] = torch.zeros_like(p.data)

        exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
        beta1, beta2 = group['betas']

        state['step'] += 1

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(1 - beta1, grad)
        exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

        denom = exp_avg_sq.sqrt().add_(group['eps'])

        update = exp_avg / denom

        # LAMB layer-wise adaptation
        r1 = p.data.pow(2).sum().sqrt()
        r2 = update.pow(2).sum().sqrt()
        r = torch.where(r1 == 0, torch.zeros_like(r1), r1 / r2)

        return r * update

    def _update_params(self, p, update, step_size, weight_decay):
        """
        Updates the parameters with the computed update.

        Args:
            p (Tensor): The parameter to be updated.
            update (Tensor): The computed update for the parameter.
            step_size (float): The step size for the update.
            weight_decay (float): The weight decay factor.
        """
        if weight_decay != 0:
            p.data.add_(-weight_decay * step_size, p.data)

        p.data.add_(-step_size, update)

    def step(self, closure=None):
        """
        Performs a single optimization step.

        Args:
            closure (callable, optional): A closure that reevaluates the model and returns the loss. Default is None.
        """
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('LAMB does not support sparse gradients.')

                state = self.state[p]

                update = self._compute_update(p, grad, state, group)
                self._update_params(p, update, group['lr'], group['weight_decay'])


##### LAMB Hyperparameter Tunning

In [30]:
num_epochs = 150
learning_rates = [9e-04, 95e-05, 1e-03, 15e-04, 2e-03]
wd = 4e-04
for lr in learning_rates:

  print('='*50)
  print(f'Hyperparameter with lr:{lr} and wd:{wd}')
  print('='*50)

  hyperparameters = {'learning_rate': lr,
                      'weight_decay' : wd
                      }
  # Load the model
  model = LeNet5().to(device)

  # Optimizer and scheduler setup
  optimizer = LAMB(model.parameters(), lr=lr, weight_decay=wd)
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


  # Training
  run_training(num_epochs,
                model,
                train_loader,
                validation_loader,
                test_loader,
                optimizer,
                scheduler,
                criterion,
                device,
                'LAMB-HyperParameterTuning',
                hyperparameters=hyperparameters,
                is_wandb = True,
                n_epochs_stop = 10
  )




Hyperparameter with lr:0.0009 and wd:0.0004


[1/150]: Training Loss: 4.366185345458985, Training Accuracy: 3.775
Validation Loss: 4.140061999582182, Validation Accuracy: 6.19
[2/150]: Training Loss: 3.9872212955474855, Training Accuracy: 9.2825
Validation Loss: 3.9175391182018693, Validation Accuracy: 10.13
[3/150]: Training Loss: 3.7751855819702147, Training Accuracy: 12.73
Validation Loss: 3.754828929901123, Validation Accuracy: 13.44
[4/150]: Training Loss: 3.6031822479248046, Training Accuracy: 15.7675
Validation Loss: 3.5609892857302525, Validation Accuracy: 16.33
[5/150]: Training Loss: 3.4478089206695555, Training Accuracy: 18.66
Validation Loss: 3.4119288389849816, Validation Accuracy: 19.45
[6/150]: Training Loss: 3.3203197284698485, Training Accuracy: 20.7075
Validation Loss: 3.349654642639646, Validation Accuracy: 20.47
[7/150]: Training Loss: 3.20732957611084, Training Accuracy: 22.7475
Validation Loss: 3.2386614349996967, Validation Accuracy: 22.79
[8/150]: Training Loss: 3.115479539489746, Training Accuracy: 24.5875

0,1
Test Accuracy,▃█▁▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▂▃▃▃▃▃▃▃▂▃▃
Test Loss,▆▁▂▂▄▄▆█▇▇▆▇█▇▇▇▆▇▇▇▇▇▆▆▆▆▆▆▆▇▇▆▆▆▆▆▆▆▆▆
Train Accuracy,▁▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇███
Train Loss,█▇▇▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁

0,1
Test Accuracy,14.49
Test Loss,17.67243
Train Accuracy,70.73
Train Loss,1.0265


Hyperparameter with lr:0.00095 and wd:0.0004


[1/150]: Training Loss: 4.36243955039978, Training Accuracy: 3.8475
Validation Loss: 4.152171432592307, Validation Accuracy: 6.23
[2/150]: Training Loss: 3.988539717102051, Training Accuracy: 9.035
Validation Loss: 3.8899211443153914, Validation Accuracy: 10.42
[3/150]: Training Loss: 3.784552033996582, Training Accuracy: 12.44
Validation Loss: 3.716481655266634, Validation Accuracy: 13.63
[4/150]: Training Loss: 3.6115343948364256, Training Accuracy: 15.38
Validation Loss: 3.575136796684022, Validation Accuracy: 15.25
[5/150]: Training Loss: 3.464410584640503, Training Accuracy: 18.1625
Validation Loss: 3.448981456695848, Validation Accuracy: 17.92
[6/150]: Training Loss: 3.340861389923096, Training Accuracy: 20.2875
Validation Loss: 3.3443257702384024, Validation Accuracy: 19.79
[7/150]: Training Loss: 3.240240854263306, Training Accuracy: 22.0975
Validation Loss: 3.2861741910314866, Validation Accuracy: 21.07
[8/150]: Training Loss: 3.1493673683166503, Training Accuracy: 23.8725
Val

0,1
Test Accuracy,██▁▂▂▁▁▁▂▃▃▄▃▃▃▄▄▄▅▅▄▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
Test Loss,▄▁▄▇▆▇█▇▄▂▂▂▄▃▄▃▃▄▃▃▃▂▂▁▂▂▂▃▃▃▃▂▃▃▂▂▂▂▂▂
Train Accuracy,▁▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇████
Train Loss,█▇▇▇▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁

0,1
Test Accuracy,13.44
Test Loss,24.40566
Train Accuracy,76.875
Train Loss,0.78436


Hyperparameter with lr:0.001 and wd:0.0004


[1/150]: Training Loss: 4.351185768127442, Training Accuracy: 3.94
Validation Loss: 4.125218055810139, Validation Accuracy: 6.94
[2/150]: Training Loss: 3.934326037979126, Training Accuracy: 9.675
Validation Loss: 3.8352680054439863, Validation Accuracy: 11.53
[3/150]: Training Loss: 3.6920593730926514, Training Accuracy: 14.25
Validation Loss: 3.6194094821905636, Validation Accuracy: 15.04
[4/150]: Training Loss: 3.5119912174224854, Training Accuracy: 17.235
Validation Loss: 3.498675950773203, Validation Accuracy: 17.26
[5/150]: Training Loss: 3.3595370391845703, Training Accuracy: 19.9025
Validation Loss: 3.3320715184424334, Validation Accuracy: 20.65
[6/150]: Training Loss: 3.2359037544250486, Training Accuracy: 22.135
Validation Loss: 3.2282794827868226, Validation Accuracy: 21.83
[7/150]: Training Loss: 3.1236792266845703, Training Accuracy: 24.1875
Validation Loss: 3.1484436852157494, Validation Accuracy: 23.69
[8/150]: Training Loss: 3.0264555549621583, Training Accuracy: 25.99


0,1
Test Accuracy,██▂▃▁▁▂▁▁▁▂▂▁▁▂▁▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃
Test Loss,▂▁▂▃▅▅▆▇▇▆▆▆█▇██▇█▇▇▇▇▇▇▇▇▇▇▇██▇▇▇▇▇▇▇▇▇
Train Accuracy,▁▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
Train Loss,█▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁

0,1
Test Accuracy,12.24
Test Loss,19.86775
Train Accuracy,68.255
Train Loss,1.11956


Hyperparameter with lr:0.0015 and wd:0.0004


[1/150]: Training Loss: 4.262849729156494, Training Accuracy: 4.9725
Validation Loss: 3.9988777105975304, Validation Accuracy: 9.32
[2/150]: Training Loss: 3.787733755874634, Training Accuracy: 12.5275
Validation Loss: 3.666901140455987, Validation Accuracy: 14.13
[3/150]: Training Loss: 3.5222632232666014, Training Accuracy: 16.975
Validation Loss: 3.430799526773441, Validation Accuracy: 18.28
[4/150]: Training Loss: 3.3261954177856445, Training Accuracy: 20.0275
Validation Loss: 3.268884590476941, Validation Accuracy: 20.25
[5/150]: Training Loss: 3.1651099575042725, Training Accuracy: 22.825
Validation Loss: 3.140599861266507, Validation Accuracy: 23.63
[6/150]: Training Loss: 3.0220409996032713, Training Accuracy: 25.57
Validation Loss: 3.062152488975768, Validation Accuracy: 24.92
[7/150]: Training Loss: 2.9046311878204345, Training Accuracy: 27.855
Validation Loss: 2.9628464735237654, Validation Accuracy: 26.95
[8/150]: Training Loss: 2.7910710647583006, Training Accuracy: 30.235

0,1
Test Accuracy,▅█▂▅▄▂▁▁▁▁▂▃▃▂▂▃▃▃▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▄▄▄▄▄▄
Test Loss,▁▄▇▆▇▇███▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆
Train Accuracy,▁▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
Train Loss,█▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
Test Accuracy,15.7
Test Loss,32.80274
Train Accuracy,82.41
Train Loss,0.55922


Hyperparameter with lr:0.002 and wd:0.0004


[1/150]: Training Loss: 4.222224911499024, Training Accuracy: 5.3975
Validation Loss: 3.937301506662065, Validation Accuracy: 9.22
[2/150]: Training Loss: 3.7261308227539063, Training Accuracy: 13.09
Validation Loss: 3.593827167134376, Validation Accuracy: 15.34
[3/150]: Training Loss: 3.4419474758148194, Training Accuracy: 18.2775
Validation Loss: 3.3379262875599465, Validation Accuracy: 19.89
[4/150]: Training Loss: 3.242010182952881, Training Accuracy: 21.865
Validation Loss: 3.2397833553848754, Validation Accuracy: 21.54
[5/150]: Training Loss: 3.074250465774536, Training Accuracy: 24.8525
Validation Loss: 3.2218855049959414, Validation Accuracy: 22.29
[6/150]: Training Loss: 2.9360343181610107, Training Accuracy: 27.3175
Validation Loss: 2.9865669110778033, Validation Accuracy: 26.78
[7/150]: Training Loss: 2.801424619293213, Training Accuracy: 30.0175
Validation Loss: 2.881270451150882, Validation Accuracy: 28.82
[8/150]: Training Loss: 2.688298546409607, Training Accuracy: 32.05

0,1
Test Accuracy,█▅▁▂▁▁▁▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▂▂▂▂▂▂▂▂▂
Test Loss,▁▃▆▅▆▇▇▇▇▇▇▇▇███████████████████████████
Train Accuracy,▁▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇▇███
Train Loss,█▇▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▁▁▁▁

0,1
Test Accuracy,11.85
Test Loss,35.52837
Train Accuracy,78.7575
Train Loss,0.68238


##### LAMB BaseLine B-Size = 64

In [32]:
num_epochs = 150
lr = 4.8/(2**5 *1e02) # approximately 15e-04
wd = 4e-04

hyperparameters = {'learning_rate': lr,
                    'weight_decay' : wd
                  }

# Load the model
model = LeNet5().to(device)

# Optimizer and scheduler setup
optimizer = LAMB(model.parameters(), learning_rate=lr, weight_decay=wd)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


# Training
run_training(num_epochs, model, original_train_loader, original_test_loader, original_test_loader, optimizer, scheduler, criterion, device, optimizer_name='LAMB', hyperparameters=hyperparameters, is_wandb = True)

0,1
Train Accuracy,▁▄▆▇█
Train Loss,█▅▃▂▁

0,1
Train Accuracy,22.668
Train Loss,3.18342


[1/150]: Training Loss: 4.170614353226274, Training Accuracy: 6.554
Validation Loss: 3.8419871087286883, Validation Accuracy: 11.55
[2/150]: Training Loss: 3.73113738697813, Training Accuracy: 13.326
Validation Loss: 3.5141605601948536, Validation Accuracy: 17.24
[3/150]: Training Loss: 3.4963422099037853, Training Accuracy: 17.148
Validation Loss: 3.329459586720558, Validation Accuracy: 21.14
[4/150]: Training Loss: 3.32537763838268, Training Accuracy: 20.17
Validation Loss: 3.1594036445496188, Validation Accuracy: 23.38
[5/150]: Training Loss: 3.1821046982274948, Training Accuracy: 22.724
Validation Loss: 3.024734082495331, Validation Accuracy: 25.42
[6/150]: Training Loss: 3.0576654736648132, Training Accuracy: 24.838
Validation Loss: 2.917888655024729, Validation Accuracy: 27.48
[7/150]: Training Loss: 2.9461568444586166, Training Accuracy: 26.996
Validation Loss: 2.8047209317517128, Validation Accuracy: 29.77
[8/150]: Training Loss: 2.860208951298843, Training Accuracy: 28.704
Val

0,1
Test Accuracy,▆█▁▁▁▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂
Test Loss,▂▁█▆▇▅▆▇▄▅▅▅▅▅▆▅▅▅▅▆▆▆▅▅▅▅▆▆▅▅▅▅▅▅▄▅▅▅▄▄
Train Accuracy,▁▂▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇████████████
Train Loss,█▇▆▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Test Accuracy,53.01
Test Loss,2.21681
Train Accuracy,78.148
Train Loss,0.74618


##### LAMB Test Large Batches

In [102]:
# without warmup
num_epochs = 150
wd = 1e-04
batch_sizes = [512, 1024, 2048, 4096, 8192, 16384 , 32768]
lr = 4.8/(2**5 *1e02) # approximately 15e-04
learning_rates = [lr * batch_size / 64.0 for batch_size in batch_sizes] # linear scale-up of learning rate
# warmup_ratio = [1/320, 1/160, 1/80, 1/40, 1/20, 1/10, 1/5]

for i, batch_size in enumerate(batch_sizes):
  print('='*50)
  print(f'Batch size: {batch_size}')
  print('='*50)

  hyperparameters = {
    'batch_size': batch_size,
    'learning_rate': learning_rates[i],
    'weight_decay' : wd
  }

  if batch_size <= 4096:
    # load data
    data = CIFAR100Data(batch_size= batch_size)
    original_train_loader_large_batch, original_test_loader_large_batch = data.train_test()
    # Load the model
    model = LeNet5().to(device)
    # Optimizer and scheduler setup
    optimizer = LAMB(model.parameters(), learning_rate=lr, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


    # Training
    run_training(
        num_epochs,
        model,
        original_train_loader_large_batch,
        original_test_loader_large_batch,
        original_test_loader_large_batch,
        optimizer,
        scheduler,
        criterion,
        device,
        optimizer_name='LAMB_Large_Batches_without_warmup',
        hyperparameters=hyperparameters,
        is_wandb = True
    )
  else:
    accumulation_steps = batch_size // 4096
    # load data
    data = CIFAR100Data(batch_size= 4096)
    original_train_loader_large_batch, original_test_loader_large_batch = data.train_test()
    # Load the model
    model = LeNet5().to(device)
    # Optimizer and scheduler setup
    optimizer = LAMB(model.parameters(), learning_rate=lr, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


    # Training
    run_training(
        num_epochs,
        model,
        original_train_loader_large_batch,
        original_test_loader_large_batch,
        original_test_loader_large_batch,
        optimizer,
        scheduler,
        criterion,
        device,
        optimizer_name='LAMB_Large_Batches_without_warmup',
        accumulation_steps=accumulation_steps,
        hyperparameters=hyperparameters,
        is_wandb = True
    )

Batch size: 512


**********************************************************************
Test Loss: 2.032588768005371, Test Accuracy: 47.49


0,1
Test Accuracy,▅▄▁▃▅▄▄▅▆▆▆▆▇▇▇▇███▇
Test Loss,██▇▃▃▃▃▁▁▂▂▁▂▁▂▂▁▁▁▁

0,1
Test Accuracy,47.49
Test Loss,2.03259


Batch size: 1024


**********************************************************************
Test Loss: 2.2387969970703123, Test Accuracy: 42.42


0,1
Test Accuracy,▅▁▂█▇▇▆▄▅▄
Test Loss,█▄▃▁▃▂▃▄▃▂

0,1
Test Accuracy,42.42
Test Loss,2.2388
Train Accuracy,42.978
Train Loss,2.20884


Batch size: 2048


**********************************************************************
Test Loss: 2.52529616355896, Test Accuracy: 36.42


0,1
Test Accuracy,▂█▆▁▄
Test Loss,█▁▃▆▃

0,1
Test Accuracy,36.42
Test Loss,2.5253
Train Accuracy,35.39
Train Loss,2.57354


Batch size: 4096


**********************************************************************
Test Loss: 2.9222216606140137, Test Accuracy: 29.16


0,1
Test Accuracy,█▁▄
Test Loss,▃█▁

0,1
Test Accuracy,29.16
Test Loss,2.92222
Train Accuracy,26.796
Train Loss,3.01636


Batch size: 8192


**********************************************************************
Test Loss: 3.3192431131998696, Test Accuracy: 21.19


0,1
Test Accuracy,▁▁█
Test Loss,▃█▁

0,1
Test Accuracy,21.19
Test Loss,3.31924
Train Accuracy,19.494
Train Loss,3.4046


Batch size: 16384


**********************************************************************
Test Loss: 3.6979848543802896, Test Accuracy: 14.82


0,1
Test Accuracy,█▁▃
Test Loss,▁█▃

0,1
Test Accuracy,14.82
Test Loss,3.69798
Train Accuracy,13.586
Train Loss,3.75874


Batch size: 32768


**********************************************************************
Test Loss: 4.131630261739095, Test Accuracy: 7.73


0,1
Test Accuracy,█▁▂
Test Loss,▁█▁

0,1
Test Accuracy,7.73
Test Loss,4.13163
Train Accuracy,7.258
Train Loss,4.1532


# **Distributed Approaches**

In [67]:
# Initialize a model and save its initial parameters
initial_model = LeNet5()
initial_state_dict = initial_model.state_dict()

### Local SGD

In [68]:
class LocalSGDOptimizer(Optimizer):
    def __init__(self, global_model, lr=0.01):
        self.global_model = global_model
        self.lr = lr
        params = list(global_model.parameters())
        super(LocalSGDOptimizer, self).__init__(params, {'lr': lr})

    def step(self, local_models):
        # Get global model parameters
        global_params = list(self.global_model.parameters())
        
        # Initialize deltas for each parameter, same size as global parameters
        deltas = [torch.zeros_like(param) for param in global_params]
        
        # Sum up differences between global model and local models
        for local_model in local_models:
            local_params = list(local_model.parameters())
            for i, param in enumerate(local_params):
                deltas[i] += (global_params[i] - param)
        
        # Average the delta over the number of local models
        num_models = len(local_models)
        for i, delta in enumerate(deltas):
            deltas[i] /= self.lr
            deltas[i] /= num_models
        
        # Update global model parameters
        with torch.no_grad():
            for i, param in enumerate(global_params):
                param.copy_(param - self.lr * deltas[i])
        
        return None


In [69]:
def local_SGD(shard_loaders, loss_fn, k, j, lr, wd, initial_state_dict, num_epochs, is_wandb=False):
      
  total_start_time = time.time()

  iterations = num_epochs // j
  # Initialize a model with same value of param for each chunk
  local_models = [LeNet5().to(device) for _ in range(k)]
  for model in local_models:
    model.load_state_dict(initial_state_dict)
  #Initialize optimizers for each chunk
  local_optimizers = [torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd) for model in local_models]
  # Initialize the global model
  global_model = LeNet5().to(device)
  global_model.load_state_dict(initial_state_dict)
  
  # Initialize the global optimizer with LocalSGDOptimizer
  global_optimizer = LocalSGDOptimizer(global_model, lr)
  
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(global_optimizer, T_max=iterations)
  
  # Load checkpoint
  checkpoint = load_checkpoint('local_sgd', 64, {'k': k, 'j': j})
  if checkpoint is not None:
      global_model.load_state_dict(checkpoint['model_state_dict'])
      global_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
      
      test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn)
      print(f'Global Update: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
      
      return None

  # Training
  for iteration in range(iterations):
    for worker, shard_loader in enumerate(shard_loaders):
      train_start_time = time.time()
      for loca_step in range(j):
        train_loss, train_accuracy = train(local_models[worker], shard_loader, local_optimizers[worker], loss_fn, device = device,  is_wandb=False)
        print(f'Worker {worker+1}, [{loca_step+1:02}/{j:02}]: Training Loss: {train_loss:.9f}, Training Accuracy: {train_accuracy:.3f}')
      train_end_time = time.time()
      print(f'Time taken for training worker {worker+1}: {str(timedelta(seconds=train_end_time - train_start_time))}')
      print('-'*50)
    sync_start_time = time.time()
    
    # Synchronize local models with global model
    global_optimizer.step(local_models)

    scheduler.step()

    # Update local models learning rate to global learning rate after scheduler step
    for local_optimizer in local_optimizers:
        local_optimizer.param_groups[0]['lr'] = global_optimizer.param_groups[0]['lr']

    # Load global model to local models
    for local_model in local_models:
      local_model.load_state_dict(global_model.state_dict())
    
    sync_end_time = time.time()

    print('*'*50)
    print(f'Time taken for synchronization: {str(timedelta(seconds=sync_end_time - sync_start_time))}')
    
    test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn, is_wandb = False)
    
    print(f'Global Update {iteration+1:02}: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
    print('*'*50)
    
  
  total_end_time = time.time()
  
  # Save checkpoint
  save_checkpoint({
              'epoch': num_epochs,
              'model_state_dict': global_model.state_dict(),
              'optimizer_state_dict': global_optimizer.state_dict(),
              'scheduler_state_dict': scheduler.state_dict(),
              'loss': loss_fn
              }, 149, 64, 'local_sgd', {'k': k, 'j': j})
 
  print('/'*50)
  print(f'Total time taken for local_SGD: {str(timedelta(seconds=total_end_time - total_start_time))}')
  print('/'*50)

In [70]:
lr = 1e-02
wd = 1e-03
K = [2, 4, 8]
J = [4, 8, 16, 32, 64]
num_epochs = 150
data = CIFAR100Data()
loss_fn = nn.CrossEntropyLoss()

for k in K: # Number of workers
  shard_loaders = data.iid_shards(num_shards=k)
  for j in J:
    print('='*50)
    print(f'Number of Workers:{k}, Number of Local Steps:{j}')
    print('='*50)
    local_SGD(shard_loaders, loss_fn, k, j, lr, wd, initial_state_dict, num_epochs)

Files already downloaded and verified
Files already downloaded and verified
Number of Workers:2, Number of Local Steps:4
Global Update: Test Loss: 1.891312885, Test Accuracy: 56.100
Number of Workers:2, Number of Local Steps:8
Global Update: Test Loss: 1.937131206, Test Accuracy: 55.350
Number of Workers:2, Number of Local Steps:16
Global Update: Test Loss: 1.966259042, Test Accuracy: 55.450
Number of Workers:2, Number of Local Steps:32
Global Update: Test Loss: 2.047621591, Test Accuracy: 53.230
Number of Workers:2, Number of Local Steps:64
Global Update: Test Loss: 2.210712717, Test Accuracy: 42.830
Number of Workers:4, Number of Local Steps:4
Global Update: Test Loss: 1.875574264, Test Accuracy: 54.280
Number of Workers:4, Number of Local Steps:8
Global Update: Test Loss: 1.986871252, Test Accuracy: 52.610
Number of Workers:4, Number of Local Steps:16
Global Update: Test Loss: 2.156295303, Test Accuracy: 51.480
Number of Workers:4, Number of Local Steps:32
Global Update: Test Loss: 

### SlowMo

In [71]:
class SlowMoOptimizer(Optimizer):
    def __init__(self, global_model, momentum, lr=0.01, beta=0.5, alpha=1.0):
        self.global_model = global_model
        self.lr = lr
        self.beta = beta
        self.alpha = alpha
        self.momentum = momentum
        params = list(global_model.parameters())
        super(SlowMoOptimizer, self).__init__(params, {'lr': lr})
        
    def step(self, local_models):
        # Calculate exact average of local models
        avg_state_dict = {key: torch.zeros_like(value) for key, value in local_models[0].state_dict().items()}
        for model in local_models:
            for key, param in model.state_dict().items():
                avg_state_dict[key] += param
        
        # Averaging the models
        for key in avg_state_dict:
            avg_state_dict[key] /= len(local_models)
        
        # Perform SlowMo momentum update
        for key in self.global_model.state_dict().keys():
            self.momentum[key] = self.beta * self.momentum[key] + (1.0 / self.lr) * (self.global_model.state_dict()[key] - avg_state_dict[key])
        
        # Update global model parameters with outer update
        with torch.no_grad():
            for key, param in self.global_model.state_dict().items():
                param.copy_(param - self.alpha * self.lr * self.momentum[key])

        return None


In [72]:
def slowmo(shard_loaders, loss_fn, k, j, lr, wd, initial_state_dict, num_epochs, beta, alpha):

  total_start_time = time.time()

  iterations = num_epochs // j
  # Initialize a model with same value of param for each chunk
  local_models = [LeNet5().to(device) for _ in range(k)]
  for model in local_models:
    model.load_state_dict(initial_state_dict)
  #Initialize optimizers for each chunk
  local_optimizers = [torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd) for model in local_models]
  # Initialize the global model
  global_model = LeNet5().to(device)
  global_model.load_state_dict(initial_state_dict)
  
  momentum = {key: torch.zeros_like(value) for key, value in global_model.state_dict().items()}
  
  global_optimizer = SlowMoOptimizer(global_model, momentum , lr, beta, alpha)
  
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(global_optimizer, T_max=iterations)

  checkpoint = load_checkpoint('slowmo', 64, {'k': k, 'j': j})
  if checkpoint is not None:
      global_model.load_state_dict(checkpoint['model_state_dict'])
      global_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
      
      test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn)
      print(f'Global Update: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
      
      return None

  for iteration in range(iterations):
    for worker, shard_loader in enumerate(shard_loaders):
      train_start_time = time.time()
      for loca_step in range(j):
        train_loss, train_accuracy = train(local_models[worker], shard_loader, local_optimizers[worker], loss_fn, device = device,  is_wandb=False)
        print(f'Worker {worker+1}, [{loca_step+1:02}/{j:02}]: Training Loss: {train_loss:.9f}, Training Accuracy: {train_accuracy:.3f}')
      train_end_time = time.time()
      print(f'Time taken for training worker {worker+1}: {str(timedelta(seconds=train_end_time - train_start_time))}')
      print('-'*50)
    sync_start_time = time.time()

    global_optimizer.step(local_models)

    scheduler.step()

    for local_optimizer in local_optimizers:
        local_optimizer.param_groups[0]['lr'] = global_optimizer.param_groups[0]['lr']    

    for local_model in local_models:
       local_model.load_state_dict(global_model.state_dict())
    sync_end_time = time.time()
    print('*'*50)
    print(f'Time taken for synchronization: {str(timedelta(seconds=sync_end_time - sync_start_time))}')
    test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn, is_wandb = False)
    print(f'Global Update {iteration+1:02}: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
    print('*'*50)
    
  # Save checkpoint
  save_checkpoint({
              'epoch': num_epochs,
              'model_state_dict': global_model.state_dict(),
              'optimizer_state_dict': global_optimizer.state_dict(),
              'scheduler_state_dict': scheduler.state_dict(),
              'loss': loss_fn
              }, 149, 64, 'slowmo', {'k': k, 'j': j})

  total_end_time = time.time()
  print('/'*50)
  print(f'Total time taken for local_SGD: {str(timedelta(seconds=total_end_time - total_start_time))}')
  print('/'*50)


In [73]:
lr = 1e-02
wd = 1e-03
beta = 0.4
alpha = 1
parameters = {'lr': lr, 'wd': wd, 'beta': beta}
K = [2, 4, 8]
J = [4, 8, 16, 32 , 64]
num_epochs = 150
data = CIFAR100Data()
loss_fn = nn.CrossEntropyLoss()

# Initialize slow model
slow_model = LeNet5()
slow_model.load_state_dict(initial_state_dict)

for k in K:
  shard_loaders = data.iid_shards(num_shards=k)
  for j in J:
    print('='*50)
    print(f'Number of Workers:{k}, Number of Local Steps:{j}')
    print('='*50)
    slowmo(shard_loaders, loss_fn, k, j, lr, wd, initial_state_dict, num_epochs, beta, alpha)

Files already downloaded and verified
Files already downloaded and verified
Number of Workers:2, Number of Local Steps:4
Global Update: Test Loss: 1.981766467, Test Accuracy: 56.000
Number of Workers:2, Number of Local Steps:8
Global Update: Test Loss: 2.081315363, Test Accuracy: 54.610
Number of Workers:2, Number of Local Steps:16
Global Update: Test Loss: 2.200608175, Test Accuracy: 53.200
Number of Workers:2, Number of Local Steps:32
Global Update: Test Loss: 2.824673044, Test Accuracy: 48.240
Number of Workers:2, Number of Local Steps:64
Global Update: Test Loss: 3.877965771, Test Accuracy: 40.550
Number of Workers:4, Number of Local Steps:4
Global Update: Test Loss: 2.035068179, Test Accuracy: 53.420
Number of Workers:4, Number of Local Steps:8
Global Update: Test Loss: 2.146553544, Test Accuracy: 52.300
Number of Workers:4, Number of Local Steps:16
Global Update: Test Loss: 2.383769921, Test Accuracy: 49.880
Number of Workers:4, Number of Local Steps:32
Global Update: Test Loss: 

# **Personal Contribution**

### SHAT (Asyncronous Approach)

In [74]:
class SHAT_PS_Optimizer(Optimizer):
    def __init__(self, global_model, lr=0.01):
        self.global_model = global_model
        self.lr = lr
        params = list(global_model.parameters())
        super(SHAT_PS_Optimizer, self).__init__(params, {'lr': lr})
        
    def step(self, gradients_model):
        # Update global model parameters with gradients 
        with torch.no_grad():
            for key, param in self.global_model.state_dict().items():
                param.copy_(param -  self.lr * gradients_model.state_dict()[key])

        return None

def generate_computation_latency_sequence(K, each_worker_iteration):
    """
    Simulates a sequence of operations for K workers with random computation latencies.

    The function generates a list of operations, each performed by a worker, sorted by
    their latencies. Each worker performs exactly t operations, and the latencies are
    scaled so that the fastest worker's latency is normalized to 1.

    Parameters:
    K (int): The number of workers in the simulation.
    t (int): The number of operations each computer will perform.

    Returns:
    tuple: A tuple containing:
        - result (list of dicts): A list of dictionaries where each dictionary represents
          an operation performed by a computer. The keys in the dictionary are:
              - "total_iterations" (int): The total number of iterations performed.
              - "computer" (int): The ID of the computer performing the operation.
              - "value" (int): The latency value associated with that operation.
        - computation_powers (list of ints): The original random computation latencies for each computer.
        - scaled_powers (list of floats): The computation latencies scaled such that the fastest
          worker's latency is 1.
    """
    # Generate random computation powers
    computation_powers = [5370, 9830]

    # Find the minimum computation power
    min_power = min(computation_powers)

    # Scale the computation powers so that the fastest (smallest value) is equal to 1
    scaled_powers = [power / min_power for power in computation_powers]

    # Initialize a list to store the dictionaries
    result = []

    # Store the number of turns taken by each worker
    turns_taken = [0] * K
    total_number_of_iterations = 0
    # Generate the dictionaries until all computers have 15 turns
    while any(turn < each_worker_iteration for turn in turns_taken):
        total_number_of_iterations += 1
        # Generate the next possible dictionaries for each worker
        possible_entries = []
        for index in range(K):
            if turns_taken[index] < each_worker_iteration:  # Only consider computers with less than 15 turns
                operation_value = computation_powers[index] * (turns_taken[index] + 1)
                possible_entries.append({"total_iterations": total_number_of_iterations,"worker": index, "value": operation_value})

        # Sort the possible dictionaries by their operation value
        possible_entries.sort(key=lambda x: x["value"])

        # Add the dictionary with the smallest value to the result
        chosen_entry = possible_entries[0]
        result.append(chosen_entry)

        # Update the turn count for the chosen worker
        chosen_index = chosen_entry["worker"] 
        turns_taken[chosen_index] += 1

    return result, computation_powers, scaled_powers

def calculate_gradients_model(global_model, local_model, lr):
    gradients_dict = {key: torch.zeros_like(value) for key, value in local_model.state_dict().items()}
    
    for key, value in local_model.state_dict().items():
        gradients_dict[key] += ((global_model.state_dict()[key] - value)/lr)
    
    return gradients_dict

def calculate_s_i(k,ci):
    return float(k/ci)

def calculate_alpha_i(si, k):
    alpha = 1 - (si / math.log(k))
    return alpha

def updatel_local_model(global_model, local_model, alpha_i):
    new_weights = {}
    for key, value in local_model.state_dict().items():
        new_weights[key] = (1 - alpha_i) * value + alpha_i * global_model.state_dict()[key]
    
    return new_weights



In [75]:


def SHAT_PS(shard_loaders, loss_fn, k, lr, wd, initial_state_dict, num_epochs):  
    # SHAT Parameter Server to manage workers
    total_start_time = time.time()

    # Initialize a model with same value of param for each chunk
    local_models = [LeNet5().to(device) for _ in range(k)]
    for model in local_models:
      model.load_state_dict(initial_state_dict)

    # Initialize the global model
    global_model = LeNet5().to(device)
    global_model.load_state_dict(initial_state_dict)

    global_optimizer = SHAT_PS_Optimizer(global_model, lr)
  
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(global_optimizer, T_max=num_epochs)
    
    checkpoint = load_checkpoint('shat', 64, {'k': k})
    if checkpoint is not None:
      global_model.load_state_dict(checkpoint['model_state_dict'])
      global_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
      
      test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn)
      print(f'Global Update: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
      
      return None
    # Generate a sequence of computation latency to simulate the difference of computation latency (Lower computation Latency means higher computation power)
    computation_latency_sequence, computation_latency, scaled_computation_latency = generate_computation_latency_sequence(k, num_epochs)

    # Print the original and scaled computation latency
    print("Original Computation Latency:", computation_latency)
    print("Scaled Computation Latency:", scaled_computation_latency)

    # print sequence of workers based on their computation latency
    print(f'workers simulated orders based on computation latency:{[entry["worker"]+1 for entry in computation_latency_sequence]}')

    C = [1 for _ in range(k)] # Staleness counter
    local_optimizers = [torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd) for model in local_models]
    
    #This specifies turn of the model
    for iteration_index, worker in enumerate([entry['worker'] for entry in computation_latency_sequence]):
      iteration_start_time = time.time()
      print('*'*50)

      train_loss, train_accuracy = train(local_models[worker], shard_loaders[worker], local_optimizers[worker], loss_fn, device = device,  is_wandb=False)
      print(f'Worker {worker+1}, [{iteration_index+1:02}]: Training Loss: {train_loss:.9f}, Training Accuracy: {train_accuracy:.3f}')
      print('*'*50)
      
      '''PS server: receive model from the worker and calculate diff model (gradient)'''
      gradients_model = LeNet5().to(device)
      gradients_model.load_state_dict(calculate_gradients_model(global_model, local_models[worker], lr))
      
      # Computing the staleness of each worker
      for i in C :
        if i != worker:
          i += 1
          
      '''PS Server update global model'''
      global_optimizer.step(gradients_model)
      
      '''send updated model to the worker'''
      '''calucale the staleness of the worker αi ← si − logn ,    s ← n/ci'''
      s_i = calculate_s_i(k,C[worker])
      alpha_i = calculate_alpha_i(s_i, k)
      '''update worker local model ba w ← (1 − α )w + α w̃'''
      local_models[worker].load_state_dict(updatel_local_model(global_model, local_models[worker], alpha_i))

      '''continue outer loop in PS'''
      '''ci = 0 ya 1'''
      C[worker] = 1
      
      iteration_end_time = time.time()
     
      print(f'Time taken for worker {worker+1} : {str(timedelta(seconds=iteration_end_time - iteration_start_time))}')
      print('-'*50)
      print(f'Iteration: {iteration_index+1:02}/{k*num_epochs:02} - True epochs: {num_epochs}')
      test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn, is_wandb = False)
      print(f'Global Update {iteration_index+1:02}: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
      print('-'*50)
    
    # Save checkpoint
    save_checkpoint({
              'epoch': num_epochs,
              'model_state_dict': global_model.state_dict(),
              'optimizer_state_dict': global_optimizer.state_dict(),
              'scheduler_state_dict': scheduler.state_dict(),
              'loss': loss_fn
              }, 149, 64, 'shat', {'k': k})
    
    total_end_time = time.time()
    print('/'*50)
    print(f'Total time taken for SHAT: {str(timedelta(seconds=total_end_time - total_start_time))}')
    print('/'*50)
      

In [76]:
lr = 1e-02
wd = 1e-03 
K = [2]
num_epochs = 150
data = CIFAR100Data()
loss_fn = nn.CrossEntropyLoss()

for k in K: # number of workers
  shard_loaders = data.iid_shards(num_shards=k)
  print('='*50)
  print(f'Number of Workers:{k}')
  print('='*50)
  SHAT_PS(shard_loaders, loss_fn, k, lr, wd, initial_state_dict, num_epochs)


Files already downloaded and verified
Files already downloaded and verified
Number of Workers:2
Global Update: Test Loss: 2.560581903, Test Accuracy: 42.860


### Our Approach 1: Layerwise Masking Approach

In [77]:
def freeze_layers(model, layers_to_freeze):
    """
    Freezes the specified layers in the model, skipping non-trainable layers.

    Parameters:
    - model (nn.Module): The model whose layers you want to freeze.
    - layers_to_freeze (list of nn.Module): The layers to freeze.
    """

    for idx, layer in enumerate(model.children()):
        if idx in layers_to_freeze and any(p.requires_grad for p in layer.parameters()):
            for param in layer.parameters():
                param.requires_grad = False

def build_model_list(model, num_layers_to_freeze):
    """
    Builds a list of models with different combinations of frozen layers.

    Parameters:
    - model (nn.Module): The model to use as a base.
    - num_layers_to_freeze (int): The number of layers to freeze in each model.

    Returns:
    - list of nn.Module: A list of models with different combinations of frozen layers.
    """
    model_list = []
    trainable_layers = [idx for idx, layer in enumerate(model.children()) if any(p.requires_grad for p in layer.parameters())]

    # Generate all possible combinations of trainable layers to freeze
    layer_combinations = list(combinations(trainable_layers, num_layers_to_freeze))

    for layers in layer_combinations:
        new_model = copy.deepcopy(model)
        freeze_layers(new_model, layers)
        model_list.append(new_model)

    return model_list

def train_select_best_model(model, num_layers_to_freeze, num_epochs, train_loader, test_loader, loss_fn, optimizer, device):
    """
    Trains the model with different combinations of frozen layers and selects the best one based on test accuracy.
    """
    model_list = build_model_list(model, num_layers_to_freeze)
    best_accuracy = 0.0
    best_layers = []

    for candidate_model in model_list:
        candidate_model.to(device)
        optimizer_copy = copy.deepcopy(optimizer)

        # Update only the parameters that require gradients
        for param_group in optimizer_copy.param_groups:
            param_group['params'] = [param for param in candidate_model.parameters() if param.requires_grad]

        candidate_model.train()

        for _ in range(num_epochs):
            train(candidate_model, train_loader, optimizer_copy, loss_fn)

        test_loss, test_accuracy = test(candidate_model, test_loader, loss_fn, is_wandb=False)

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            best_layers = [name for idx, (name, layer) in enumerate(candidate_model.named_children()) 
                           if not any(p.requires_grad for p in layer.parameters()) and any(p.requires_grad for p in model.get_submodule(name).parameters())]

    best_layer_indices = [idx for idx, (name, layer) in enumerate(model.named_children()) if name in best_layers]
    freeze_layers(model, best_layer_indices)

    return {tuple(best_layers): copy.deepcopy(model)}


In [78]:
class PerDLMaskOptimizer(Optimizer):
    def __init__(self, global_model, lr=0.01):
        self.global_model = global_model
        self.lr = lr
        params = list(global_model.parameters())
        super(PerDLMaskOptimizer, self).__init__(params, {'lr': lr})

    def expand_mask(self, model, layer_mask):
        """
        Expands the layer-based mask to match the number of parameters.

        Args:
        - model (nn.Module): The model whose parameters are being masked.
        - layer_mask (list of bool): The mask corresponding to each layer (True for trainable, False for frozen).

        Returns:
        - list of bool: Expanded mask corresponding to each parameter in the model.
        """
        expanded_mask = []
        for is_trainable, layer in zip(layer_mask, model.children()):
            # Append True/False for each parameter in the layer
            expanded_mask.extend([is_trainable] * len(list(layer.parameters())))
        return expanded_mask

    def step(self, local_models, masks):
        """
        Perform a step of Local SGD optimization using the mask for each worker.

        Args:
        - local_models (list of nn.Module): The list of local models (from different workers).
        - masks (list of list of bool): The mask for each worker indicating whether each layer is trainable or not.
        """
        # Get global model parameters
        global_params = list(self.global_model.parameters())
        
        # Initialize deltas for each parameter, same size as global parameters
        deltas = [torch.zeros_like(param) for param in global_params]

        # Iterate over local models and accumulate deltas based on the expanded mask
        for worker_idx, local_model in enumerate(local_models):
            local_params = list(local_model.parameters())
            
            # Expand the mask to match the number of parameters in the model
            expanded_mask = self.expand_mask(local_model, masks[worker_idx])

            # Ensure mask length matches local_params length
            assert len(expanded_mask) == len(local_params), f"Expanded mask size {len(expanded_mask)} doesn't match number of parameters {len(local_params)}"

            for i, param in enumerate(local_params):
                if expanded_mask[i]:  # Only consider non-frozen parameters
                    deltas[i] += (global_params[i] - param)

        # Adjust deltas based on the number of workers that updated each layer
        for i, delta in enumerate(deltas):
            num_active_workers = sum(expanded_mask[i] for expanded_mask in [self.expand_mask(m, masks[w_idx]) for w_idx, m in enumerate(local_models)])  # Count workers with trainable parameters
            if num_active_workers > 0:
                deltas[i] /= num_active_workers  # Average only over active workers
            deltas[i] /= self.lr  # Apply learning rate scaling
        
        # Update global model parameters
        with torch.no_grad():
            for i, param in enumerate(global_params):
                param.copy_(param - self.lr * deltas[i])
        
        return None


In [79]:
def PerDLMask(shard_loaders, loss_fn, k, j, lr, wd, initial_state_dict, num_epochs, index_weaker_worker, n_freeze_layer, is_wandb=False, fine_tuning=False):
      
  total_start_time = time.time()
  
  iterations = num_epochs // j
  # Initialize a model with same value of param for each chunk
  local_models = [LeNet5().to(device) for _ in range(k)]
  for model in local_models:
    model.load_state_dict(initial_state_dict)
  #Initialize optimizers for each chunk
  local_optimizers = [torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd) for model in local_models]

  # Initialize the global model
  global_model = LeNet5().to(device)
  global_model.load_state_dict(initial_state_dict)
  
  global_optimizer = PerDLMaskOptimizer(global_model, lr)
  
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(global_optimizer, T_max=iterations)
  
  checkpoint = load_checkpoint('PerDLMask', 64, {'k': k, 'j': j, 'weaker_worker': index_weaker_worker, 'n_freeze_layer': n_freeze_layer, 'fine_tuning': fine_tuning})
  if checkpoint is not None:
      global_model.load_state_dict(checkpoint['model_state_dict'])
      global_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
      
      test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn)
      print(f'Global Update: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
      
      return None
  
  # select to freeze layers for the weaker worker with lower bais
  freeze_model = train_select_best_model(local_models[index_weaker_worker], n_freeze_layer, 5, shard_loaders[index_weaker_worker], original_test_loader, loss_fn, local_optimizers[index_weaker_worker], device)
  print(f'The following layers of worker {index_weaker_worker+1} are frozen: {list(freeze_model.keys())[0]}')
  local_models[index_weaker_worker] = list(freeze_model.values())[0]
  
  # Initialize masks for each worker
  masks = [[True for _ in model.named_children()] for model in local_models]
  # Freeze the layers for the weaker worker in the masks
  masks[index_weaker_worker] = [name not in list(freeze_model.keys())[0] for name, layer in list(freeze_model.values())[0].named_children()]

  for iteration in range(iterations-1 if fine_tuning else iterations):
    for worker, shard_loader in enumerate(shard_loaders):
      train_start_time = time.time()
      for loca_step in range(j):
        train_loss, train_accuracy = train(local_models[worker], shard_loader, local_optimizers[worker], loss_fn, device = device,  is_wandb=False)
        print(f'Worker {worker+1}, [{loca_step+1:02}/{j:02}]: Training Loss: {train_loss:.9f}, Training Accuracy: {train_accuracy:.3f}')
      train_end_time = time.time()
      print(f'Time taken for training worker {worker+1}: {str(timedelta(seconds=train_end_time - train_start_time))}')
      print('-'*50)
    sync_start_time = time.time()

    # Synchronize local models with global model
    global_optimizer.step(local_models, masks)

    scheduler.step()

    for local_optimizer in local_optimizers:
        local_optimizer.param_groups[0]['lr'] = global_optimizer.param_groups[0]['lr']


    for local_model in local_models:
      local_model.load_state_dict(global_model.state_dict())
    sync_end_time = time.time()
    print('*'*50)
    print(f'Time taken for synchronization: {str(timedelta(seconds=sync_end_time - sync_start_time))}')
    test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn, is_wandb = False)
    print(f'Global Update {iteration+1:02}: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
    print('*'*50)
    
  if fine_tuning:
    # Unfreeze all layers for worker 2 (the weaker worker)
    for name, layer in local_models[index_weaker_worker].named_children():
        for param in layer.parameters():
            if name in list(freeze_model.keys())[0]:
                param.requires_grad = True
    masks = [[True for _ in model.named_children()] for model in local_models]  
    
    for worker in range(k):
        train_start_time = time.time()
        for _ in range(j):
            train_loss, train_accuracy = train(local_models[worker], shard_loader, local_optimizers[worker], loss_fn, device = device,  is_wandb=False)
            print(f'Worker {worker+1}, [{loca_step+1:02}/{j:02}]: Training Loss: {train_loss:.9f}, Training Accuracy: {train_accuracy:.3f}')
        total_end_time = time.time()
        print(f'Time taken for training worker {worker+1}: {str(timedelta(seconds=train_end_time - train_start_time))}')
        print('-'*50)
    sync_start_time = time.time()

    global_optimizer.step(local_models, masks)
    scheduler.step()
    for local_optimizer in local_optimizers:
        local_optimizer.param_groups[0]['lr'] = global_optimizer.param_groups[0]['lr']


    for local_model in local_models:
        local_model.load_state_dict(global_model.state_dict())
    sync_end_time = time.time()
    print('*'*50)
    print(f'Time taken for synchronization: {str(timedelta(seconds=sync_end_time - sync_start_time))}')
    test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn, is_wandb = False)
    print(f'Global Update {iteration+1:02}: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
    print('*'*50)

  
  # Save checkpoint
  save_checkpoint({
              'epoch': num_epochs,
              'model_state_dict': global_model.state_dict(),
              'optimizer_state_dict': global_optimizer.state_dict(),
              'scheduler_state_dict': scheduler.state_dict(),
              'loss': loss_fn
              }, 149, 64, 'PerDLMask', {'k': k, 'j': j, 'weaker_worker': index_weaker_worker, 'n_freeze_layer': n_freeze_layer, 'fine_tuning': fine_tuning})
 

  print('/'*50)
  print(f'Total time taken for local_SGD: {str(timedelta(seconds=total_end_time - total_start_time))}')
  print('/'*50)


In [80]:
# Without Fine-tuning
lr = 1e-02
wd = 1e-03
K = [2]
J = [4]

num_epochs = 150
data = CIFAR100Data()
loss_fn = nn.CrossEntropyLoss()

for k in K: # Number of workers
  shard_loaders = data.iid_shards(num_shards=k)
  for j in J:
    print('='*50)
    print(f'Number of Workers:{k}, Number of Local Steps:{j}')
    print('='*50)
    PerDLMask(shard_loaders, loss_fn, k, j, lr, wd, initial_state_dict, num_epochs, index_weaker_worker=1, n_freeze_layer=2)

Files already downloaded and verified
Files already downloaded and verified
Number of Workers:2, Number of Local Steps:4
Global Update: Test Loss: 2.760830911, Test Accuracy: 48.840


In [81]:
# With Fine-tuning
lr = 1e-02
wd = 1e-03
K = [2]
J = [4]

num_epochs = 150
data = CIFAR100Data()
loss_fn = nn.CrossEntropyLoss()

for k in K: # Number of workers
  shard_loaders = data.iid_shards(num_shards=k)
  for j in J:
    print('='*50)
    print(f'Number of Workers:{k}, Number of Local Steps:{j}')
    print('='*50)
    PerDLMask(shard_loaders, loss_fn, k, j, lr, wd, initial_state_dict, num_epochs, index_weaker_worker=1, n_freeze_layer=2, fine_tuning=True)

Files already downloaded and verified
Files already downloaded and verified
Number of Workers:2, Number of Local Steps:4
Global Update: Test Loss: 2.543037563, Test Accuracy: 49.050


### Our Approach 2: Confidence Interval Strategy

In [82]:
def update_CI_Acc(accuracies, C_level=0.8):
    if len(accuracies) == 1:
        return 0.0
    mean = np.mean(accuracies)
    std_dev = np.std(accuracies, ddof=1)
    t_score = stats.t.ppf(1 - (1 - C_level) / 2, len(accuracies) - 1)
    CI = t_score * (std_dev / np.sqrt(len(accuracies)))
    # Compute relative error
    RE = CI / mean
    # Compute accuracy
    accuracy = 1 - RE

    return accuracy


In [83]:
def ConfidenceInetrvalApproach(shard_loaders, loss_fn, k, lr, wd, initial_state_dict, num_epochs, ci_acc_threshold):
  total_start_time = time.time()
  
  # Initialize a model with same value of param for each chunk
  local_models = [LeNet5().to(device) for _ in range(k)]
  train_accuracy_workers = {i: [] for i in range(k)}
  steps_per_worker = {i: 0 for i in range(k)}
  accuracy_workers = {i: 0 for i in range(k)}
  for model in local_models:
    model.load_state_dict(initial_state_dict)

  #Initialize optimizers for each chunk
  local_optimizers = [torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd) for model in local_models]
  
  # Initialize the global model
  global_model = LeNet5().to(device)
  global_model.load_state_dict(initial_state_dict)
  global_optimizer = LocalSGDOptimizer(global_model, lr=lr)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(global_optimizer, T_max=num_epochs)

  checkpoint = load_checkpoint('ci-strategy', 64, {'k': k, 'threshold': ci_acc_threshold})
  if checkpoint is not None:
      global_model.load_state_dict(checkpoint['model_state_dict'])
      global_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
      
      test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn)
      print(f'Global Update: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
      
      return None
  
  total_epochs = 0
  steps_per_worker = {i: 0 for i in range(k)}
  while not all(value > num_epochs for value in steps_per_worker.values()): 
    for worker, shard_loader in enumerate(shard_loaders):
      if steps_per_worker[worker] > num_epochs:
        continue
      train_start_time = time.time()
      step_counter = 0
      # Train the model
      while not accuracy_workers[worker] > ci_acc_threshold:
        
        train_loss, train_accuracy = train(local_models[worker], shard_loader, local_optimizers[worker], loss_fn, device = device,  is_wandb=False)
        train_accuracy_workers[worker].append(train_accuracy)

        accuracy_workers[worker] = update_CI_Acc(train_accuracy_workers[worker])
        steps_per_worker[worker] += 1
        step_counter += 1
        print(accuracy_workers[worker])
        print(f'Worker {worker+1}, [epoch: {step_counter:02}]: Training Loss: {train_loss:.9f}, Training Accuracy: {train_accuracy:.3f}')
        if steps_per_worker[worker] == num_epochs:
          break
      train_accuracy_workers[worker] = []
      accuracy_workers[worker] = 0  
      train_end_time = time.time()
      print(f'Time taken for training worker {worker+1}: {str(timedelta(seconds=train_end_time - train_start_time))}')
      print('-'*50)
    sync_start_time = time.time() 

    # Synchronize local models with global model
    global_optimizer.step(local_models)

    scheduler.step()
    
    # Update the local models with the global model
    for local_model in local_models:
       local_model.load_state_dict(global_model.state_dict())
    
    # Update the learning rate of the local optimizers with the global optimizer after scheduling
    for local_optimizer in local_optimizers:
        local_optimizer.param_groups[0]['lr'] = global_optimizer.param_groups[0]['lr']

    total_epochs += step_counter
    sync_end_time = time.time()
    print('*'*50)
    print(f'Time taken for synchronization: {str(timedelta(seconds=sync_end_time - sync_start_time))}')
    test_loss, test_accuracy = test(global_model,original_test_loader, loss_fn, is_wandb = False)
    print(f'Global Model: Test Loss: {test_loss:.9f}, Test Accuracy: {test_accuracy:.3f}')
    print('*'*50)
  
  # Save checkpoint
  save_checkpoint({
              'epoch': num_epochs,
              'model_state_dict': global_model.state_dict(),
              'optimizer_state_dict': global_optimizer.state_dict(),
              'scheduler_state_dict': scheduler.state_dict(),
              'loss': loss_fn
              }, 149, 64, 'ci-strategy', {'k': k, 'threshold': ci_acc_threshold}) 
  
  total_end_time = time.time()
  print('/'*50)
  print(f'Total time taken for Confidence Interval: {str(timedelta(seconds=total_end_time - total_start_time))}')
  print('/'*50)

In [84]:
lr = 1e-02
wd = 1e-03
K = [2]
num_epochs = 150
ci_acc_thresholds= [0.8, 0.9, 0.95]
data = CIFAR100Data()
loss_fn = nn.CrossEntropyLoss()

for k in K: # Number of workers
    for ci_acc_threshold in ci_acc_thresholds:
        shard_loaders = data.iid_shards(num_shards=k)
        print('='*50)
        print(f'Number of Workers:{k}, CI Accuracy Threshold:{ci_acc_threshold}')
        print('='*50)
        ConfidenceInetrvalApproach(shard_loaders, loss_fn, k, lr, wd, initial_state_dict, num_epochs, ci_acc_threshold)

Files already downloaded and verified
Files already downloaded and verified
Number of Workers:2, CI Accuracy Threshold:0.8
Global Update: Test Loss: 1.786572354, Test Accuracy: 55.280
Number of Workers:2, CI Accuracy Threshold:0.9
Global Update: Test Loss: 1.758802217, Test Accuracy: 55.350
Number of Workers:2, CI Accuracy Threshold:0.95
Global Update: Test Loss: 1.742772651, Test Accuracy: 54.220
