<a href="https://colab.research.google.com/github/TheS1n233/Distributed-Learning-Project5/blob/experiments/LocalSGD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
!pip install --upgrade torch




# Imports

In [10]:
import torch
import argparse
from torch.optim.optimizer import Optimizer, required
import torch
import torchvision
import torchvision.transforms as transforms
import random
import numpy as np
import json
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, SequentialLR, LinearLR
import matplotlib.pyplot as plt
import copy
import time
from torch.amp import GradScaler, autocast
import os
from google.colab import drive

In [11]:
drive.mount('/content/drive')
if not os.path.exists('/content/drive/MyDrive'):
    raise RuntimeError("Google Drive not mounted correctly!")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Model

In [12]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5)

        # Layer fully connected
        self.fc1 = nn.Linear(64 * 5 * 5, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 100)

    def forward(self, x):
        # Layer convolutivi con ReLU e max-pooling
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        # Flatten per i layer fully connected
        x = torch.flatten(x, 1)

        # Layer fully connected con ReLU
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        x = self.fc3(x)
        x = F.log_softmax(x, dim=1)

        return x

## Early stopping

In [13]:
# Early Stopping Class
class EarlyStopping:
    def __init__(self, patience=10, delta=0, path='/content/drive/MyDrive/Early2checkpoint.pt', verbose=False):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path
        self.verbose = verbose

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, optimizer=None):
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...")
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict() if optimizer else None,
            'val_loss_min': val_loss
        }
        torch.save(checkpoint, self.path)
        self.val_loss_min = val_loss



## Function to get train, test and val dataset

In [14]:
def get_dataset(batch_size):
    batch_size=batch_size
    print("batch_size", batch_size)
    # Define the transform to only convert the images to tensors (without normalization yet)
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    # Load the CIFAR-100 training dataset
    train_dataset = torchvision.datasets.CIFAR100(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )


    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=False, num_workers=2)

    # Initialize sums for calculating mean and std
    mean = torch.zeros(3)
    std = torch.zeros(3)

    for images, _ in train_loader:
        # Compute mean and std for each channel
        mean += images.mean(dim=[0, 2, 3])
        std += images.std(dim=[0, 2, 3])

    mean /= len(train_loader)
    std /= len(train_loader)

    print("Mean: ", mean)
    print("Std: ", std)

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[mean[0].item(), mean[1].item(), mean[2].item()],
                            std=[std[0].item(), std[1].item(), std[2].item()])
    ])

    transform_test = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[mean[0].item(), mean[1].item(), mean[2].item()],
                            std=[std[0].item(), std[1].item(), std[2].item()])
    ])

    # Load CIFAR-100 dataset
    start_time = time.time()
    train_dataset = torchvision.datasets.CIFAR100(
        root='./data',
        train=True,
        download=True,
        transform=transform_train
    )
    test_dataset = torchvision.datasets.CIFAR100(
        root='./data',
        train=False,
        download=True,
        transform=transform_test
    )
    print(f"Dataset loading time: {time.time() - start_time:.2f} seconds")

    # Split training and validation sets
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

    # Data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Debugging: Check DataLoader outputs
    for i, (inputs, labels) in enumerate(train_loader):
        print(f"Batch {i}: inputs shape: {inputs.shape}, labels shape: {labels.shape}")
        if i == 10:  # Test first 10 batches
            break
    print(f"Data loading for 10 batches completed.")
    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")

    return train_loader, val_loader, test_loader

# LocalSGD

In [15]:
def local_sgd(train_dataset, val_loader, test_loader, device, num_workers, local_steps, num_epoch, batch_size, hyperparams, checkpoint_path):

    # Calculate adjusted epochs based on J and K
    total_samples = len(train_dataset)
    batches_per_epoch_centralized = total_samples // batch_size
    batches_per_epoch_per_worker = total_samples // num_workers // batch_size
    batches_per_sync = local_steps * num_workers
    adjusted_epochs = num_epoch * batches_per_epoch_centralized // (batches_per_sync * batches_per_epoch_per_worker)

    print(f"Adjusted epochs for : {adjusted_epochs} K={num_workers}, J={local_steps}")

    # Initialize the global model and optimizer
    global_model = LeNet5().to(device)
    global_optimizer = optim.SGD(global_model.parameters(), lr=hyperparams['lr'], weight_decay=hyperparams['weight_decay'], momentum=hyperparams['momentum'])
    global_scheduler = CosineAnnealingLR(global_optimizer, T_max=adjusted_epochs)
    criterion = nn.CrossEntropyLoss()

    warmup_scheduler = LinearLR(global_optimizer, start_factor=0.1, total_iters=10)
    global_scheduler = SequentialLR(global_optimizer, schedulers=[warmup_scheduler, CosineAnnealingLR(global_optimizer, T_max=adjusted_epochs - 10)], milestones=[10])

    # Split the training dataset into K subsets
    split_sizes = [len(train_dataset) // num_workers] * num_workers
    split_sizes[-1] += len(train_dataset) % num_workers
    dataset_split = torch.utils.data.random_split(train_dataset, split_sizes)
    train_loaders = [torch.utils.data.DataLoader(subset, batch_size, shuffle=True, num_workers=2, pin_memory=True) for subset in dataset_split]

    # Record training progress
    train_losses, val_losses, train_accuracies, val_accuracies = [], [], [], []

    # Train each worker locally
    for epoch in range(adjusted_epochs):
        global_model.train()

        # Initialize a dictionary to store the sum of all workers' parameters
        aggregated_params = {key: torch.zeros_like(value) for key, value in global_model.state_dict().items()}

        # Train each worker locally and accumulate their parameters
        for worker_id, train_loader in enumerate(train_loaders):
            # Clone the global model for each worker
            local_model = LeNet5().to(device)
            local_model.load_state_dict(global_model.state_dict())

            # Define the optimizer for the local model
            local_optimizer = optim.SGD(local_model.parameters(), lr=hyperparams['lr'],weight_decay=hyperparams['weight_decay'], momentum=hyperparams['momentum'])

            # Train locally for `local_steps`
            for step, (inputs, labels) in enumerate(train_loader):
                if step >= local_steps:
                    break

                inputs, labels = inputs.to(device), labels.to(device)
                local_optimizer.zero_grad()
                outputs = local_model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                local_optimizer.step()

            """for key, param in local_model.named_parameters():
                noise = torch.normal(mean=0, std=0.01, size=param.size(), device=device)
                param.data.add_(noise)"""

            # Accumulate worker's parameters to the global parameter dictionary
            for key, value in local_model.state_dict().items():
                aggregated_params[key] += value

        # Average the accumulated parameters across all workers
        for key in aggregated_params:
            aggregated_params[key] /= num_workers

        # Update the global model with the averaged parameters
        global_model.load_state_dict(aggregated_params)

        # Adjust learning rate using the global scheduler
        global_scheduler.step()

        # Validate performance
        global_model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = global_model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100. * val_correct / val_total
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        print(f"Epoch {epoch + 1}/{adjusted_epochs} | Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

        # Save checkpoint
        torch.save(global_model.state_dict(), checkpoint_path)


    # Test performance
    global_model.eval()
    test_correct, test_total = 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = global_model(inputs)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += predicted.eq(labels).sum().item()

    test_acc = 100. * test_correct / test_total
    print(f"Final Test Accuracy: {test_acc:.2f}%")
    return train_losses, val_losses, val_accuracies, test_acc


if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    num_epoch=150
    batch_size=64
    num_workers_list = [2, 4, 8]
    local_steps_list = [4, 8, 16, 32, 64]
    hyperparams = {
        'lr': 0.01,
        'weight_decay': 0.005,
        'momentum': 0.9,
    }

    train_loader, val_loader, test_loader = get_dataset(batch_size)   # modify batch size
    train_dataset = train_loader.dataset

    for num_workers in num_workers_list:
        for local_steps in local_steps_list:
            print(f"Running LocalSGD with {num_workers} workers and {local_steps} local steps")
            train_losses, val_losses, val_accuracies, test_acc = local_sgd(
                train_dataset=train_dataset,
                val_loader=val_loader,
                test_loader=test_loader,
                device=device,
                num_workers=num_workers,
                local_steps=local_steps,
                num_epoch=num_epoch,
                batch_size=batch_size,
                hyperparams=hyperparams,
                checkpoint_path=f"/content/drive/MyDrive/LocalSGD_{num_workers}_{local_steps}.pth"
            )


batch_size 64
Files already downloaded and verified
Mean:  tensor([0.5070, 0.4865, 0.4408])
Std:  tensor([0.2664, 0.2555, 0.2750])
Files already downloaded and verified
Files already downloaded and verified
Dataset loading time: 1.77 seconds
Batch 0: inputs shape: torch.Size([64, 3, 32, 32]), labels shape: torch.Size([64])
Batch 1: inputs shape: torch.Size([64, 3, 32, 32]), labels shape: torch.Size([64])
Batch 2: inputs shape: torch.Size([64, 3, 32, 32]), labels shape: torch.Size([64])
Batch 3: inputs shape: torch.Size([64, 3, 32, 32]), labels shape: torch.Size([64])
Batch 4: inputs shape: torch.Size([64, 3, 32, 32]), labels shape: torch.Size([64])
Batch 5: inputs shape: torch.Size([64, 3, 32, 32]), labels shape: torch.Size([64])
Batch 6: inputs shape: torch.Size([64, 3, 32, 32]), labels shape: torch.Size([64])
Batch 7: inputs shape: torch.Size([64, 3, 32, 32]), labels shape: torch.Size([64])
Batch 8: inputs shape: torch.Size([64, 3, 32, 32]), labels shape: torch.Size([64])
Batch 9: in

KeyboardInterrupt: 