In [1]:
!pip install torchvision

[0m

In [2]:
import torchvision
import torch
import torch.nn as nn

In [3]:
torch.cuda.is_available()

True

# ResNet-18

This experiment will train a ResNet-18 model on the CIFAR-10 dataset using the Prox Skip algorithm (Stochastic setting) and Local SGD.

## Data Preparation

The first thing to do is to pull and prepare CIFAR-10 dataset for the model to be trained on

In [4]:
# Get CIFAR10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
trainset.data = trainset.data[:100]
trainset.targets = trainset.targets[:100]

In [6]:
from proxskip.types import Vector
from proxskip.data import DataLoader

class StochasticCIFARBatchLoader(DataLoader):
    def __init__(self, dataset, batch_size):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.transforms = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        self.targets = torch.tensor(self.dataset.targets, dtype=torch.long)

    def get(self) -> tuple[Vector, Vector]:
        random_batch = torch.randint(0, len(self.dataset), (self.batch_size,))
        images = self.dataset.data[random_batch]
        images = torch.stack([self.transforms(image) for image in images])
        
        return images, self.targets[[self.dataset.targets[int(i)] for i in random_batch]]
    
    def total_size(self) -> int:
        return len(self.dataset)
    
    def get_data(self, left: int, size: int) -> tuple[Vector, Vector]:
        images = self.dataset.data[left:left+size]
        images = torch.stack([self.transforms(image) for image in images])
        targets = torch.tensor(self.dataset.targets[left:(left+size)], dtype=torch.long)
        return images, targets

## Loading Model

In [7]:
from proxskip.types import Vector
from proxskip.model import Model

class CIFARModel(Model):
    def __init__(self) -> None:
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.resnet = torchvision.models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(512, 10)
        
        self.resnet = self.resnet.to(self.device)
        
        for param in self.resnet.parameters():
            param.requires_grad = True
            
        self._params = [key for key, _ in self.resnet.named_parameters()]
        self._params.sort()
        
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.resnet(x.to(self.device))
    
    def backward(self, x: torch.Tensor, upstream: torch.Tensor) -> torch.Tensor:
        for param in self.resnet.parameters():
            param.grad = None

        y = self.resnet(x.to(self.device))
        y.backward(upstream.to(self.device))
            
        # Concatenate all gradients into a single vector 
        return torch.cat([self.resnet.get_parameter(p).grad.flatten() for p in self._params])
    
    def update(self, params: torch.Tensor) -> None:
        # The params vector contains the gradients for each parameter
        # flattened into a single vector. We need to unflatten it
        # and update the parameters accordingly.
        shift = 0
        for p in self._params:
            param = self.resnet.get_parameter(p)
            size = param.grad.numel()
            param = params[shift:shift+size].reshape(param.grad.shape).to(torch.float32)
            shift += size
            
            self.resnet.get_parameter(p).data = param
            
    def params(self) -> Vector:
        return torch.cat([self.resnet.get_parameter(p).flatten() for p in self._params])
    
    def eval(self):
        self.resnet = self.resnet.eval()
        return self
        
    def train(self):
        self.resnet = self.resnet.train()
        return self
            

In [8]:
from proxskip.model import Model
from proxskip.types import Vector
from proxskip.loss import LossFunction

class CrossEntropyLoss(LossFunction):
    def __init__(self) -> None:
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
    def loss(self, m: Model, X: Vector, y: Vector) -> Vector:
        with torch.no_grad():
            return self.loss_fn(
                m.forward(X).to(self.device), 
                y.to(self.device)
            )
    
    
    def upstream_gradient(self, m: Model, X: Vector, y: Vector) -> Vector:
        m.resnet = m.resnet.eval()
        with torch.no_grad():
            y_pred = m.forward(X.to(self.device))
            y_pred = y_pred.to(self.device)
            

        y_pred.requires_grad = True        
        loss = self.loss_fn(y_pred, y.to(self.device))
        loss.backward()
        grad_ = y_pred.grad.detach()

        torch.cuda.empty_cache()
        
        return grad_

In [9]:
from proxskip.consensus import ConsesusTorchProx
from proxskip.optimizer import StochasticTorchProxSkip

models = [CIFARModel()]
loss_fn = CrossEntropyLoss()
dataloaders = [StochasticCIFARBatchLoader(trainset, batch_size=100)]
num_iterations = 1000
communication_rounds = 1000


ps_optimizer = StochasticTorchProxSkip(
    models=models,
    dataloaders=dataloaders,
    loss=loss_fn,
    prox=ConsesusTorchProx(),
    num_iterations=num_iterations ** 2, # to have `comminication_rounds` of loss values
    learning_rate=1E-4,
    p=1,
    device='cuda',
)

  self._step[key] = [torch.tensor(x, device=device) for x in self._step[key]]


In [10]:
ps_optimizer._step['h_t'][0].device

device(type='cuda', index=0)

In [11]:
dataloaders[0].total_size()

100

In [12]:
from tqdm import trange

progress = trange(communication_rounds)
c = 0
while True:
    on_prox = ps_optimizer.step()
    if on_prox is None or c == communication_rounds:
        break
    if on_prox:
        progress.set_description(f"loss={ps_optimizer._step['loss'][-1]:.4f}")
        progress.update(1)
        c += 1

loss=2.9248:  59%|█████▉    | 591/1000 [02:55<02:08,  3.18it/s]

KeyboardInterrupt: 