# Testing Jacobian gradient descent

This tests the aggregation methods using per-sample gradient computation.

Caveat: custom aggregation methods consume a lot of GPU memory due to the overhead of explicitly creating and manipulating per-sample gradients.

More memory optimizations will be needed for memory constrained environments. Furthermore, per-param group aggregation methods could be explored.

In [1]:
from __future__ import annotations

from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import pandas as pd
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss, CrossEntropyLoss
from torch.optim import Optimizer, SGD, Adam, AdamW
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.functional.image import total_variation
import torchinfo
from image_classification.utils import trange

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

print(device)

cuda:0


## ResNet-18

For quick prototyping, also consider using `ConvNet16`, a 1.6M-parameter model that is much smaller than `ResNet18`.

In [3]:
from image_classification.models import ResNet18, ConvNet16
from image_classification.datasets import cifar10_train_test, cifar100_train_test
from image_classification.nn import train_epoch, train_loop, train_val_loop, test_epoch

Using cuda device


In [4]:
# Set to 10 for CIFAR-10, 100 for CIFAR-100
num_classes = 10

# The images are already normalized by theses datasets
if num_classes == 10:
    get_train_test = cifar10_train_test
elif num_classes == 100:
    get_train_test = cifar100_train_test
else:
    raise ValueError(f"Can't find CIFAR dataset with {num_classes} classes")
print(f"Loading CIFAR-{num_classes}")

training_data, test_data = get_train_test(root='data')
N_test = len(test_data)
N_val = len(training_data) // 10
N_aux = N_val
N = len(training_data) - N_val - N_aux
# This works since training data is already shuffled
training_data, val_data, aux_data = training_data.split([N, N_val, N_aux])

batch_size = 100
N, N_val, N_aux, N_test

Loading CIFAR-10


(40000, 5000, 5000, 10000)

In [5]:
train_loader = DataLoader(training_data, batch_size, drop_last=True)
val_loader = DataLoader(val_data, batch_size, drop_last=True)
aux_loader = DataLoader(aux_data, batch_size, drop_last=True)

### Hyperparameters

In [6]:
lr = 1e-3
weight_decay = 5e-4
# For learning rate scheduling
max_lr = 0.1

epochs = 6
steps_per_epoch = N // batch_size

lr_sched_params = dict(max_lr=max_lr, epochs=epochs, steps_per_epoch=steps_per_epoch)

criterion = CrossEntropyLoss()

top_k = {10: 1, 100: 5}[num_classes]

metric = MulticlassAccuracy(num_classes=num_classes, top_k=top_k)

### Optimizer

In [7]:
def make_optimizer(model: nn.Module, opt_name='adamw', lr=lr, weight_decay=weight_decay, **kwargs) -> Optimizer:
    cls = {'sgd': SGD, 'adam': Adam, 'adamw': AdamW}[opt_name]
    return cls(model.parameters(), lr=lr, weight_decay=weight_decay, **kwargs)

In [8]:
from importlib import reload
import federated as fed
reload(fed)
reload(fed.utils)

<module 'federated.utils' from '/home/lvt/dev/python/ml/psc/playground/implementations/projects/poisoning/federated/utils.py'>

In [9]:
from image_classification.nn import Metric, MetricLogger, _detect_device
import federated as fed
from federated import Aggregator, Mean, Krum
from federated.utils import convert_bn_modules_to_gn

def train_epoch_jd(
        model: nn.Module,
        dataloader: DataLoader,
        criterion: _Loss,
        optimizer: Optimizer,
        aggregator: Aggregator,
        keep_pbars=True,
        metric: Metric = None,
    ):
    device = _detect_device(model)
    model.train()

    logger = MetricLogger(
        metric,
        desc='Train loop', total=len(dataloader.dataset), keep_pbars=keep_pbars,
    )

    for step, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction and loss
        logits = model(X).detach()
        mean_loss = criterion(logits, y).item()
        
        optimizer.zero_grad()
        fed.backpropagate_grads(model, X, y, criterion, aggregator)
        optimizer.step()
        optimizer.zero_grad()

        logger.compute_metrics(X, y, logits, mean_loss)
        del X, y
    
    logger.finish()
    return logger

In [10]:
from torch.optim.lr_scheduler import LRScheduler

def train_val_loop_jd(
        model: nn.Module,
        train_dataloader: DataLoader,
        val_dataloader: DataLoader,
        criterion: _Loss,
        optimizer: Optimizer,
        aggregator: Aggregator,
        epochs: int,
        *,
        lr_scheduler: LRScheduler = None,
        keep_pbars=True,
        metric: Metric = None,
        validate_every: int = 2,
        early_stopping = True,
    ):
    """
    Run the training loop on the model with periodic validation.

    If `val_dataloader` is `None`, no validation is performed.

    If `early_stopping` is True, the training loop exits when validation loss starts decreasing.
    """
    val_loss = float('inf')
    for epoch in trange(epochs, desc='Train epochs', unit='epoch', leave=keep_pbars):
        train_epoch_jd(
            model, train_dataloader, criterion, optimizer, aggregator,
            keep_pbars=keep_pbars, metric=metric,
        )
        if lr_scheduler is not None:
            lr_scheduler.step()
        if val_dataloader is not None and epoch % validate_every == 0:
            logger = test_epoch(
                model, val_dataloader, criterion,
                keep_pbars=keep_pbars, metric=metric,
            )
            next_val_loss = logger.avg_loss.compute()
            if early_stopping and next_val_loss > val_loss:
                print(f"Epoch {epoch}: validation loss stopped improving, exiting train loop.")
                break
            val_loss = next_val_loss

In [11]:
class SimpleCNNBN(nn.Module):
    def __init__(self):
        super(SimpleCNNBN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(12544, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.bn1(self.conv1(x))
        x = F.relu(x)
        x = self.bn2(self.conv2(x))
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [12]:
#net = ResNet18(num_classes=num_classes).to(device)
net = convert_bn_modules_to_gn(SimpleCNNBN()).to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)
aggregator = Mean()

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop_jd(
    net, train_loader, val_loader,
    criterion, opt, aggregator,
    epochs,
    metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4750.75it/s, MulticlassAccuracy=0.618, avg_loss=1.53]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 19638.35it/s, MulticlassAccuracy=0.6, avg_loss=1.22]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:07<00:00, 5215.71it/s, MulticlassAccuracy=0.765, avg_loss=1.02]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:07<00:00, 5186.87it/s, MulticlassAccuracy=0.755, avg_loss=0.878]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 21419.33it/s, MulticlassAccuracy=0.609, avg_loss=0.959]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:07<00:00, 5033.01it/s, MulticlassAccuracy=0.83, avg_loss=0.796]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:07<00:00, 5084.69it/s, MulticlassAccuracy=0.818, avg_loss=0.731]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 21686.77it/s, MulticlassAccuracy=0.592, avg_loss=0.935]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:07<00:00, 5070.55it/s, MulticlassAccuracy=0.821, avg_loss=0.674]


In [13]:
net = convert_bn_modules_to_gn(SimpleCNNBN()).to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop(
    net, train_loader, val_loader,
    criterion, opt,
    epochs,
    metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:03<00:00, 13248.02it/s, MulticlassAccuracy=0.598, avg_loss=1.55]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 23126.77it/s, MulticlassAccuracy=0.531, avg_loss=1.24]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:03<00:00, 13277.38it/s, MulticlassAccuracy=0.714, avg_loss=1.03]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:03<00:00, 13267.40it/s, MulticlassAccuracy=0.769, avg_loss=0.889]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 19934.39it/s, MulticlassAccuracy=0.651, avg_loss=0.996]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:02<00:00, 13463.46it/s, MulticlassAccuracy=0.749, avg_loss=0.807]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:03<00:00, 13243.49it/s, MulticlassAccuracy=0.807, avg_loss=0.741]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 20102.14it/s, MulticlassAccuracy=0.647, avg_loss=0.988]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:03<00:00, 13331.30it/s, MulticlassAccuracy=0.829, avg_loss=0.685]


In [15]:
#net = ResNet18(num_classes=num_classes).to(device)
net = convert_bn_modules_to_gn(SimpleCNNBN()).to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)
aggregator = Krum(num_byzantine=batch_size//6, num_selected=batch_size//2)

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop_jd(
    net, train_loader, val_loader,
    criterion, opt, aggregator,
    epochs,
    metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [02:23<00:00, 278.92it/s, MulticlassAccuracy=0.492, avg_loss=3.3]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 18503.08it/s, MulticlassAccuracy=0.401, avg_loss=3.12]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [02:21<00:00, 281.71it/s, MulticlassAccuracy=0.541, avg_loss=3.62]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [02:23<00:00, 277.88it/s, MulticlassAccuracy=0.575, avg_loss=4.04]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 17817.06it/s, MulticlassAccuracy=0.525, avg_loss=4.17]
Epoch 2: validation loss stopped improving, exiting train loop.


In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.fc1 = nn.Linear(12544, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

In [None]:
#net = ResNet18(num_classes=num_classes).to(device)
net = SimpleCNN().to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)
aggregator = Mean()

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop_jd(
    net, train_loader, val_loader,
    criterion, opt, aggregator,
    epochs,
    metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:06<00:00, 5789.20it/s, MulticlassAccuracy=0.705, avg_loss=1.35]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 19681.70it/s, MulticlassAccuracy=0.538, avg_loss=1.08]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:06<00:00, 5785.00it/s, MulticlassAccuracy=0.74, avg_loss=0.955]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:06<00:00, 5735.72it/s, MulticlassAccuracy=0.813, avg_loss=0.802]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 18349.42it/s, MulticlassAccuracy=0.634, avg_loss=0.962]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:07<00:00, 5706.30it/s, MulticlassAccuracy=0.853, avg_loss=0.674]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:07<00:00, 5656.82it/s, MulticlassAccuracy=0.878, avg_loss=0.555]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 21140.11it/s, MulticlassAccuracy=0.635, avg_loss=0.99]
Epoch 4: validation loss stopped improving, exiting train loop.


In [None]:
net = SimpleCNN().to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop(
    net, train_loader, val_loader,
    criterion, opt,
    epochs,
    metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:02<00:00, 13994.02it/s, MulticlassAccuracy=0.624, avg_loss=1.38]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 18131.05it/s, MulticlassAccuracy=0.549, avg_loss=1.13]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:02<00:00, 13704.50it/s, MulticlassAccuracy=0.693, avg_loss=0.991]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:02<00:00, 13873.18it/s, MulticlassAccuracy=0.713, avg_loss=0.843]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 19857.16it/s, MulticlassAccuracy=0.593, avg_loss=0.982]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:02<00:00, 14271.85it/s, MulticlassAccuracy=0.785, avg_loss=0.725]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:02<00:00, 14520.87it/s, MulticlassAccuracy=0.836, avg_loss=0.623]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 20001.70it/s, MulticlassAccuracy=0.614, avg_loss=1.01]
Epoch 4: validation loss stopped improving, exiting train loop.


In [None]:
from image_classification.models import ShuffleNetV2

In [None]:
net = convert_bn_modules_to_gn(ShuffleNetV2()).to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)
aggregator = Mean()

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop_jd(
    net, train_loader, val_loader,
    criterion, opt, aggregator, epochs, metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:28<00:00, 1385.18it/s, MulticlassAccuracy=0.472, avg_loss=1.9]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 10220.29it/s, MulticlassAccuracy=0.415, avg_loss=1.63]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:29<00:00, 1357.56it/s, MulticlassAccuracy=0.528, avg_loss=1.47]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:29<00:00, 1378.38it/s, MulticlassAccuracy=0.535, avg_loss=1.27]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 10404.80it/s, MulticlassAccuracy=0.495, avg_loss=1.29]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:30<00:00, 1307.74it/s, MulticlassAccuracy=0.645, avg_loss=1.15]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:30<00:00, 1330.37it/s, MulticlassAccuracy=0.657, avg_loss=1.07]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 10274.52it/s, MulticlassAccuracy=0.536, avg_loss=1.19]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:30<00:00, 1304.93it/s, MulticlassAccuracy=0.71, avg_loss=0.989]


In [None]:
net = convert_bn_modules_to_gn(ShuffleNetV2()).to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop(
    net, train_loader, val_loader,
    criterion, opt, epochs, metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:09<00:00, 4325.88it/s, MulticlassAccuracy=0.46, avg_loss=1.91]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 10306.93it/s, MulticlassAccuracy=0.384, avg_loss=1.65]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:09<00:00, 4182.67it/s, MulticlassAccuracy=0.532, avg_loss=1.49]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:09<00:00, 4229.12it/s, MulticlassAccuracy=0.614, avg_loss=1.3]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 9364.76it/s, MulticlassAccuracy=0.533, avg_loss=1.32]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:09<00:00, 4080.83it/s, MulticlassAccuracy=0.717, avg_loss=1.18]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:09<00:00, 4150.55it/s, MulticlassAccuracy=0.746, avg_loss=1.09]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 9979.64it/s, MulticlassAccuracy=0.566, avg_loss=1.14]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4534.18it/s, MulticlassAccuracy=0.759, avg_loss=1.01]


In [None]:
net = ShuffleNetV2().to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)

# Pretrain the model to make it learn the features
#mini_train_set = Subset(training_data, np.arange(N_aux))
#mini_train_loader = DataLoader(mini_train_set, batch_size=64)

train_val_loop(
    net, train_loader, val_loader,
    criterion, opt, epochs, metric=metric
);

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4778.15it/s, MulticlassAccuracy=0.44, avg_loss=1.79]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 12770.22it/s, MulticlassAccuracy=0.38, avg_loss=1.56]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4860.02it/s, MulticlassAccuracy=0.581, avg_loss=1.42]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4868.93it/s, MulticlassAccuracy=0.672, avg_loss=1.24]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 12015.54it/s, MulticlassAccuracy=0.554, avg_loss=1.23]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4914.24it/s, MulticlassAccuracy=0.722, avg_loss=1.1]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4907.02it/s, MulticlassAccuracy=0.742, avg_loss=0.983]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 12957.27it/s, MulticlassAccuracy=0.579, avg_loss=1.09]


Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 40000/40000 [00:08<00:00, 4843.80it/s, MulticlassAccuracy=0.788, avg_loss=0.89]


In [11]:
net = convert_bn_modules_to_gn(ResNet18(num_classes=num_classes)).to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)
aggregator = Mean()

mini_train_set = Subset(training_data, np.arange(5_000))
# Avoid OOM with smaller batch size
mini_train_loader = DataLoader(mini_train_set, batch_size=16)
train_val_loop_jd(
    net, train_loader, val_loader,
    criterion, opt, aggregator,
    epochs,
    metric=metric,
)

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/40000 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.16 GiB. GPU 0 has a total capacity of 7.67 GiB of which 2.26 GiB is free. Including non-PyTorch memory, this process has 5.38 GiB memory in use. Of the allocated memory 4.91 GiB is allocated by PyTorch, and 299.29 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
del net

In [None]:
# Still convert to GN since batch_size is smaller
net = convert_bn_modules_to_gn(ResNet18(num_classes=num_classes)).to(device)
opt = make_optimizer(net, opt_name='adam', lr=lr)


mini_train_set = Subset(training_data, np.arange(20_000))
mini_train_loader = DataLoader(mini_train_set, batch_size=64)
train_val_loop(
    net, mini_train_loader, val_loader,
    criterion, opt,
    epochs,
    metric=metric,
)

Train epochs:   0%|          | 0/6 [00:00<?, ?epoch/s]

Train loop:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 20000/20000 [00:12<00:00, 1591.36it/s, MulticlassAccuracy=0.317, avg_loss=2.14]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:00<00:00, 5093.80it/s, MulticlassAccuracy=0.158, avg_loss=1.94]


Train loop:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 20000/20000 [00:12<00:00, 1579.58it/s, MulticlassAccuracy=0.323, avg_loss=1.69]


Train loop:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 20000/20000 [00:12<00:00, 1563.28it/s, MulticlassAccuracy=0.41, avg_loss=1.52]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:01<00:00, 4987.08it/s, MulticlassAccuracy=0.307, avg_loss=1.66]


Train loop:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 20000/20000 [00:12<00:00, 1553.68it/s, MulticlassAccuracy=0.53, avg_loss=1.41]


Train loop:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 20000/20000 [00:12<00:00, 1551.89it/s, MulticlassAccuracy=0.6, avg_loss=1.31]


Test epoch:   0%|          | 0/5000 [00:00<?, ?it/s]

Test epoch: 100%|██████████| 5000/5000 [00:01<00:00, 4941.87it/s, MulticlassAccuracy=0.457, avg_loss=1.37]


Train loop:   0%|          | 0/20000 [00:00<?, ?it/s]

Train loop: 100%|██████████| 20000/20000 [00:12<00:00, 1542.97it/s, MulticlassAccuracy=0.63, avg_loss=1.21]
