In [1]:
!pip install torchvision



In [2]:
import torchvision
import torch
import torch.nn as nn

In [3]:
torch.cuda.is_available()

True

# ResNet-18

This experiment will train a ResNet-18 model on the CIFAR-10 dataset using the Prox Skip algorithm (Stochastic setting) and Local SGD.

## Data Preparation

The first thing to do is to pull and prepare CIFAR-10 dataset for the model to be trained on

In [4]:
# Get CIFAR10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
from proxskip.types import Vector
from proxskip.data import DataLoader

class StochasticCIFARBatchLoader(DataLoader):
    def __init__(self, dataset, batch_size):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def get(self) -> tuple[Vector, Vector]:
        random_batch = torch.randint(0, len(self.dataset), (self.batch_size,))
        images = self.dataset.data[random_batch]
        images = torch.stack([self.transforms(image) for image in images])
        
        return images, torch.tensor([self.dataset.targets[int(i)] for i in random_batch], dtype=torch.long)
    
    def total_size(self) -> int:
        return len(self.dataset)
    
    def get_data(self, left: int, size: int) -> tuple[Vector, Vector]:
        images = self.dataset.data[left:left+size]
        images = torch.stack([self.transforms(image) for image in images])
        targets = torch.tensor(self.dataset.targets[left:(left+size)], dtype=torch.long)
        return images, targets

## Loading Model

In [6]:
from proxskip.types import Vector
from proxskip.model import Model

class CIFARModel(Model):
    def __init__(self) -> None:
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.resnet = torchvision.models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(512, 10)
        
        self.resnet = self.resnet.to(self.device)
        
        for param in self.resnet.parameters():
            param.requires_grad = True
            
        self._params = [key for key, _ in self.resnet.named_parameters()]
        self._params.sort()
        
        
    def forward(self, x: Vector) -> Vector:
        if isinstance(x, torch.Tensor):
            x = x.to(self.device).to(torch.float32)
        else:
            x = torch.tensor(x, device=self.device)
        return self.resnet(x).detach().cpu().numpy()
    
    def backward(self, x: Vector, upstream: Vector) -> Vector:
        for param in self.resnet.parameters():
            param.grad = None

        y = self.resnet(torch.tensor(x, device=self.device))
        y.backward(torch.tensor(upstream, device=self.device))
            
        # Concatenate all gradients into a single vector 
        return torch.cat([self.resnet.get_parameter(p).grad.flatten() for p in self._params]).detach().cpu().numpy()
    
    def update(self, params: Vector) -> None:
        # The params vector contains the gradients for each parameter
        # flattened into a single vector. We need to unflatten it
        # and update the parameters accordingly.
        shift = 0
        for p in self._params:
            param = self.resnet.get_parameter(p)
            size = param.grad.numel()
            param = torch.tensor(
                params[shift:shift+size].reshape(param.grad.shape), 
                device=self.device,
                dtype=torch.float32
            )
            shift += size
            
            self.resnet.get_parameter(p).data = param
            
    def params(self) -> Vector:
        return torch.cat([self.resnet.get_parameter(p).flatten() for p in self._params]).detach().cpu().numpy()
            

In [7]:
from proxskip.model import Model
from proxskip.types import Vector
from proxskip.loss import LossFunction

class CrossEntropyLoss(LossFunction):
    def __init__(self) -> None:
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    def loss(self, m: Model, X: Vector, y: Vector) -> Vector:
        return self.loss_fn(
            torch.tensor(m.forward(X), device=self.device), 
            torch.tensor(y, device=self.device)
        )
    
    
    def upstream_gradient(self, m: Model, X: Vector, y: Vector) -> Vector:
        y_pred = m.forward(torch.tensor(X, device=self.device))
        y_pred = torch.tensor(y_pred, device=self.device)
        y_pred.requires_grad = True
        
        loss = self.loss_fn(y_pred, torch.tensor(y, device=self.device))
        loss.backward()
        
        return y_pred.grad.cpu().numpy()

In [8]:
from proxskip.consensus import ConsesusProx
from proxskip.optimizer import StochasticProxSkip

models = [CIFARModel()]
loss_fn = CrossEntropyLoss()
dataloaders = [StochasticCIFARBatchLoader(trainset, batch_size=32)]
num_iterations = 1000
communication_rounds = 1000


ps_optimizer = StochasticProxSkip(
    models=models,
    dataloaders=dataloaders,
    loss=loss_fn,
    prox=ConsesusProx(),
    num_iterations=num_iterations ** 2, # to have `comminication_rounds` of loss values
    learning_rate=1E-4,
    p=1
)



In [9]:
dataloaders[0].total_size()

50000

In [10]:
from tqdm import trange

progress = trange(communication_rounds)
c = 0
while True:
    on_prox = ps_optimizer.step()
    if on_prox is None or c == communication_rounds:
        break
    if on_prox:
        progress.set_description(f"loss={ps_optimizer._step['loss'][-1].item():.4f}")
        progress.update(1)
        c += 1

  y_pred = m.forward(torch.tensor(X, device=self.device))
  loss = self.loss_fn(y_pred, torch.tensor(y, device=self.device))
  y = self.resnet(torch.tensor(x, device=self.device))
  torch.tensor(y, device=self.device)
loss=2.4370:   2%|▏         | 15/1000 [20:01<21:55:10, 80.11s/it]

KeyboardInterrupt: 