## Declare the TrAct Implementation
---

In [1]:
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.linalg import inv

_DEBUG = False  # Class-level debug control. Setting True will verify correct gradient implementation

def unfold3d(x, kernel_size, padding=0, stride=1, dilation=1):
    """
    Perform a 3D unfold operation on a 5D tensor.

    Args:
        x: A 5D tensor of shape (batch_size, channels, depth, height, width).
        kernel_size: A tuple of 3 integers representing the kernel size in each dimension.
        padding: A tuple of 3 integers representing the padding in each dimension.
        stride: A tuple of 3 integers representing the stride in each dimension.
        dilation: A tuple of 3 integers representing the dilation in each dimension.
    """

    # Extract dimensions
    batch_size, channels, depth, height, width = x.size()

    # Apply padding
    if padding:
        x = nn.functional.pad(x, (padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]))

    # Unfold in the depth dimension
    x = x.unfold(2, kernel_size[0], stride[0])
    x = x.unfold(3, kernel_size[1], stride[1])
    x = x.unfold(4, kernel_size[2], stride[2])

    # Permute dimensions to arrange the kernel elements in the channel dimension
    # New shape: (B, C, out_d, out_h, out_w, kD, kH, kW)
    x = x.permute(0, 1, 5, 6, 7, 2, 3, 4).contiguous()

    # Reshape to combine kernel elements into the channel dimension
    # New shape: (B, C * kD * kH * kW, out_d * out_h * out_w)
    x = x.view(batch_size, channels * kernel_size[0] * kernel_size[1] * kernel_size[2], -1)

    return x

class TrActFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias, lambda_, is_conv, conv_params):
        """
        Custom forward pass for TrACT.

        Args:
            input (torch.Tensor): Input tensor.
            weight (torch.Tensor): Weight parameter.
            bias (torch.Tensor): Bias parameter.
            lambda_ (float): Regularization parameter.
            is_conv (bool): Whether the layer is a convolutional layer.
            conv_params (dict): Convolutional parameters (stride, padding, dilation, groups).

        Returns:
            torch.Tensor: The output of the layer.
        """
        ctx.save_for_backward(input, weight, bias)
        ctx.lambda_ = lambda_
        ctx.is_conv = is_conv
        ctx.conv_params = conv_params

        if is_conv:
            stride, padding, dilation, groups, dim = conv_params
            if dim == 1:
                output = torch.nn.functional.conv1d(input, weight, bias, stride, padding, dilation, groups)
            elif dim == 2:
                output = torch.nn.functional.conv2d(input, weight, bias, stride, padding, dilation, groups)
            elif dim == 3:
                output = torch.nn.functional.conv3d(input, weight, bias, stride, padding, dilation, groups)
            else:
                raise ValueError(f"Unsupported convolution dimension: {dim}")
        else:
            output = input @ weight.T
            if bias is not None:
                output += bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        is_conv = ctx.is_conv
        conv_params = ctx.conv_params

        if is_conv:
            # Unpack convolutional parameters
            stride, padding, dilation, groups, dim = conv_params
            kernel_size = weight.shape[2:]  # Kernel shape (kW) or (kH, kW) or (kD, kH, kW)

            if dim == 1:
                input_unfolded = torch.nn.functional.unfold(input.unsqueeze(-1), kernel_size=(kernel_size[0], 1),
                                                            dilation=(dilation[0], 1), padding=(padding[0], 0),
                                                            stride=(stride[0], 1)).squeeze(-1)
            elif dim == 2:
                input_unfolded = torch.nn.functional.unfold(input, kernel_size, dilation, padding, stride)
            elif dim == 3:
                input_unfolded = unfold3d(input, kernel_size=kernel_size, dilation=dilation,
                                                            padding=padding, stride=stride)
            else:
                raise ValueError(f"Unsupported convolution dimension: {dim}")

            # Flatten grad_output for weight gradient computation
            grad_output_unfolded = grad_output.permute(0, *range(2, 2 + dim), 1).reshape(-1, grad_output.shape[1])

            # Prepare input_unfolded for TrACT adjustment
            input_unfolded_flat = input_unfolded.permute(0, 2, 1).reshape(-1, input_unfolded.shape[1])

            # TrAct adjustment
            b, n = input_unfolded_flat.shape
            reg_term = ctx.lambda_ * torch.eye(n, device=input.device)
            xTx = input_unfolded_flat.T @ input_unfolded_flat / b
            inv_term = torch.linalg.inv(xTx + reg_term)

            # Compute TrAct-adjusted weight gradient
            if _DEBUG:
                grad_weight = grad_output_unfolded.T @ input_unfolded_flat
            else:
                grad_weight = grad_output_unfolded.T @ input_unfolded_flat @ inv_term
            grad_weight = grad_weight.view(weight.shape)  # Reshape back to original weight shape
            
            # Compute bias gradient
            grad_bias = grad_output.sum(dim=(0, *range(2, 2 + dim))) if bias is not None else None

        else:
            # Handle B, *, C for Linear
            input_flat = input.view(-1, input.shape[-1])  # Flatten to (B*, C)
            grad_output_flat = grad_output.view(-1, grad_output.shape[-1])  # Flatten to (B*, M)

            b, n = input_flat.shape  # Batch size and input features
            reg_term = ctx.lambda_ * torch.eye(n, device=input.device)
            xTx = input_flat.T @ input_flat / b
            inv_term = torch.linalg.inv(xTx + reg_term)

            if _DEBUG:
                grad_weight = grad_output_flat.T @ input_flat
            else:
                grad_weight = grad_output_flat.T @ input_flat @ inv_term
                
            grad_bias = grad_output_flat.sum(0) if bias is not None else None

        # First layer, no need to propagate grad_input
        grad_input = None
        return grad_input, grad_weight, grad_bias, None, None, None


class TrAct(nn.Module):

    def __init__(self, module, lambda_=0.1):
        """
        Wraps a given nn.Linear or nn.Conv* module and modifies its backward pass using TrACT.

        Args:
            module (nn.Module): The module to wrap (must be nn.Linear or nn.Conv*).
            lambda_ (float): The regularization parameter for TrACT.
        """
        super().__init__()
        if not isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
            raise TypeError("TrAct only supports nn.Linear or convolutional layers.")

        self.lambda_ = lambda_

        # Transfer weight and bias to the TrACT wrapper directly
        self.weight = module.weight
        self.bias = module.bias if hasattr(module, "bias") else None

        # Handle convolution parameters for Conv layers
        self.is_conv = isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d))
        if self.is_conv:
            self.stride = module.stride
            self.padding = module.padding
            self.dilation = module.dilation
            self.groups = module.groups
            self.dim = len(module.weight.shape) - 2  # Determine dimension (1D, 2D, or 3D)

    def forward(self, x):
        if self.is_conv:
            conv_params = (self.stride, self.padding, self.dilation, self.groups, self.dim)
            output = TrActFunction.apply(x, self.weight, self.bias, self.lambda_, True, conv_params)
        else:
            output = TrActFunction.apply(x, self.weight, self.bias, self.lambda_, False, None)
        return output


## Test Gradient Implementation
---

In [None]:
import copy

# Example: Wrapping a linear layer
linear_layer = nn.Linear(10, 5)
linear_layer1 = copy.deepcopy(linear_layer)
tract_layer = TrAct(linear_layer, lambda_=0.01)

x = torch.randn(8, 15, 10)
x1 = x.clone()

# Use the registered backward function
tract_layer(x).mean().backward()
linear_layer1(x1).mean().backward()

print("Weight gradients match:", torch.allclose(tract_layer.weight.grad, linear_layer1.weight.grad))
print("Bias gradients match:", torch.allclose(tract_layer.bias.grad, linear_layer1.bias.grad))

print("Gradient W (TrAct):", tract_layer.weight.grad)
print("Gradient W1 (TrAct):", linear_layer1.weight.grad)

print("Gradient B (TrAct):", tract_layer.bias.grad)
print("Gradient B1 (TrAct):", linear_layer1.bias.grad)

In [None]:
import copy

# Example: Wrapping a linear layer
linear_layer = nn.Conv1d(10, 5, 3)
linear_layer1 = copy.deepcopy(linear_layer)
tract_layer = TrAct(linear_layer, lambda_=0.1)

x = torch.randn(2, 10, 7)
x1 = x.clone()

# Use the registered backward function
tract_layer(x).mean().backward()
linear_layer1(x1).mean().backward()

print("Weight gradients match:", torch.allclose(tract_layer.weight.grad, linear_layer1.weight.grad))
print("Bias gradients match:", torch.allclose(tract_layer.bias.grad, linear_layer1.bias.grad))

print("Gradient W (TrAct):", tract_layer.weight.grad)
print("Gradient W1 (TrAct):", linear_layer1.weight.grad)

print("Gradient B (TrAct):", tract_layer.bias.grad)
print("Gradient B1 (TrAct):", linear_layer1.bias.grad)

In [None]:
import copy

# Example: Wrapping a linear layer
linear_layer = nn.Conv2d(10, 5, 3)
linear_layer1 = copy.deepcopy(linear_layer)
tract_layer = TrAct(linear_layer, lambda_=0.1)

x = torch.randn(2, 10, 7, 7)
x1 = x.clone()

# Use the registered backward function
tract_layer(x).view(2,-1).mean().backward()
linear_layer1(x1).view(2,-1).mean().backward()

print("Weight gradients match:", torch.allclose(tract_layer.weight.grad, linear_layer1.weight.grad))
print("Bias gradients match:", torch.allclose(tract_layer.bias.grad, linear_layer1.bias.grad))

print("Gradient W (TrAct):", tract_layer.weight.grad)
print("Gradient W1 (TrAct):", linear_layer1.weight.grad)

print("Gradient B (TrAct):", tract_layer.bias.grad)
print("Gradient B1 (TrAct):", linear_layer1.bias.grad)

In [None]:
import copy

# Example: Wrapping a linear layer
linear_layer = nn.Conv3d(10, 5, 3)
linear_layer1 = copy.deepcopy(linear_layer)
tract_layer = TrAct(linear_layer, lambda_=0.1)

x = torch.randn(2, 10, 7, 7, 7)
x1 = x.clone()

# Use the registered backward function
tract_layer(x).mean().backward()
linear_layer1(x1).mean().backward()

print("Weight gradients match:", torch.allclose(tract_layer.weight.grad, linear_layer1.weight.grad))
print("Bias gradients match:", torch.allclose(tract_layer.bias.grad, linear_layer1.bias.grad))

print("Gradient W (TrAct):", tract_layer.weight.grad)
print("Gradient W1 (TrAct):", linear_layer1.weight.grad)

print("Gradient B (TrAct):", tract_layer.bias.grad)
print("Gradient B1 (TrAct):", linear_layer1.bias.grad)

## Define Test Classification Models (ResNet and ViT)
---

In [2]:
import math
from timm.models.vision_transformer import Block as VitBlock # import the VitBlock to save space

class MLPPredictor(nn.Module):
    def __init__(self, in_channels, out_dim, hidden_dim=512):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, out_dim),
        )
        
    def forward(self, x):
        return self.mlp(x)

# end MLPPredictor

class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, num_layers=8, d_model=64, num_heads=2, 
                     pred_dim=512, num_classes=10,
                     train_pos_emb=True, use_cls=False, do_init=False):
        """
        A configurable ViT model.
        
        Args:
            img_size (int): Input image size to establish sequence length.
            patch_size (int): Size of non-overlapping input patches.
            num_layers (int): Number of stacked transformer blocks (L).
            d_model (int): Embedding dimmension of the model (d).
            num_heads (int): Number of heads per attention layer (h).
            pred_dim (int): Intermediate projection size for the MLP predictor.
            num_classes (int): Number of output classes.
            train_pos_emb (bool): Use random trained PE (true) or fixed sinusoidal PE (false).
            use_cls (bool): Use a class token for prediction (true) or avgPool->MLP (false).
            do_init(bool): Use truncnorm init (true) or default torch init (false).
        """
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.d_model = d_model
        self.use_cls = use_cls

        # Patch embedding
        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)

        # Positional encoding
        if train_pos_emb:
            self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, d_model))
        else:
            num_tokens = self.num_patches + 1 if use_cls else self.num_patches
            pef = self._create_pos_emb(num_tokens, d_model)
            self.pos_embed = nn.Parameter(torch.zeros_like(pef), requires_grad=False)
            self.pos_embed.data.copy_(pef)

        # output norm
        self.norm = nn.LayerNorm(d_model, eps=1e-6)
        
        if self.use_cls:
            self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
            self.head = nn.Linear(d_model, num_classes)
        else:
            self.cls_token = None
            # MLP Predictor
            self.head = MLPPredictor(d_model, num_classes, hidden_dim=pred_dim)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            #nn.TransformerEncoderLayer(d_model, num_heads, dim_feedforward=4 * d_model, activation='gelu')
            VitBlock(d_model, num_heads, mlp_ratio=4., qkv_bias=False, drop_path=0., act_layer=nn.GELU)
            for _ in range(num_layers)
        ])

        if do_init:
            self.init_weights()
    

    def init_weights(self):
        def _basic_init(module):
            # apply to nn.Linear
            if isinstance(module, nn.Linear):
                nn.init.trunc_normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

        # initialize the linear matricies with trunc norm and zero bias
        self.apply(_basic_init)

        if self.cls_token is not None:
            nn.init.trunc_normal_(self.cls_token, mean=0.0, std=0.02)

    def _create_pos_emb(self, max_len, d_model):
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        return pe
    
    def forward(self, x):
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, d_model, h_patches, w_patches)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, d_model)

        # get the shapes
        B,S,C = x.shape
        
        if self.use_cls:
            cls_token = self.cls_token.expand(B,-1,-1)
            # add the cls token to the start
            x = torch.cat([cls_token, x], dim=1)

        # Add positional encoding
        x = x + self.pos_embed

        # Transformer layers
        for block in self.blocks:
            x = block(x)

        # Global average pooling and classification
        if self.use_cls:
            x = x[:,0]
        else:
            x = x.mean(dim=1)  # Average pool over the patch dimension

        # normalize
        x = self.norm(x)
        
        # predict
        return self.head(x)

# end ViT


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.gelu = nn.GELU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.gelu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        return self.gelu(out)

# end BasicBlock

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, channels, pred_dim=512, num_classes=10):
        """
        A configurable ResNet model.
        
        Args:
            block (nn.Module): Residual block class (e.g., `BasicBlock`).
            num_blocks (tuple): Number of blocks in each stage.
            channels (list): Output channels for each stage.
            pred_dim (int): Intermediate projection size for the MLP predictor.
            num_classes (int): Number of output classes.
        """
        super().__init__()
        assert len(num_blocks) == len(channels), "num_blocks and channels must have the same length"

        self.in_channels = channels[0]
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.gelu = nn.GELU()

        # Residual layers
        self.blocks = nn.ModuleList()
        for num_block, out_channels in zip(num_blocks, channels):
            self.blocks.append(self._make_layer(block, out_channels, num_block, stride=2 if self.in_channels != out_channels else 1))

        # Pooling
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        # output norm before the prediction
        self.out_norm = nn.LayerNorm(channels[-1], eps=1e-6)

        # MLP predictor
        self.mlp_head = MLPPredictor(channels[-1], num_classes, hidden_dim = pred_dim)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.bn1(x)
        x = self.gelu(x)

        # Pass through residual blocks
        for block in self.blocks:
            x = block(x)

        # pooling
        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        # normalize and predict
        x = self.out_norm(x)
        return self.mlp_head(x)
        
# end ResNet


In [None]:
# instantiate the models to train in parallel

vit_model = ViT(img_size=32, 
                patch_size=4,       # 8x8 tokens for 32x32 image
                num_layers=12, 
                d_model=128, 
                num_heads=4,         # d_k = 128/4 = 32
                pred_dim=512, 
                num_classes=100,     # CIFAR-100
                train_pos_emb=False, # initializing to fixed sinusoidal embeddings works better
                use_cls=False,       # learning a MLP predictor after average pooling works better than a CLS token for this scale
                do_init=False,       # default torch init works better for this scale / model
               ).cuda()

resnet_model = ResNet(BasicBlock, 
                  num_blocks=(2, 2, 2),     # 2 blocks per stage
                  channels=[64, 128, 256],  # Channels for each stage
                  pred_dim=512,   
                  num_classes=100,          # CIFAR-100
                ).cuda()

## Test Training Code
---

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from timm.data.transforms_factory import create_transform

# build an augmented transform for training
train_transform = create_transform(
        input_size = 32,
        is_training=True,
        use_prefetcher=False,
        no_aug=False,
        scale=[0.08, 1.0],
        ratio=[0.75, 1.33],
        hflip=0.5,
        vflip=0.0,
        color_jitter=0.4,
        auto_augment='rand-m9-mstd0.5-inc1',
        interpolation='random',
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5],
        crop_pct=None,
        tf_preprocessing=False,
        re_prob=0.25,
        re_mode='pixel',
        re_count=1,
        re_num_splits=0,
        separate=False,
    )

# Transform: Convert to tensor and normalize
val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


# use a batch size of 128 as per the paper
batch_size = 128

# setup the dataloaders
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
import copy
import numpy as np
from tqdm import tqdm
from functools import partial

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.functional import cross_entropy


loss_per_step = []
val_per_epoch = []

def train_and_validate(models, train_loader, val_loader, 
                       optimizers, schedules, epochs=10, 
                       lambdas=None, label_smoothing=0.1):
    
    # make these global so that training can be interrupted without looking the results
    global loss_per_step
    global val_per_epoch
    
    # apply TrAct based on the lambdas
    for model, lam in zip(models, lambdas):
        if lam is not None:
            print(f"Applying TrAct to {model.__class__.__name__}")
            model.patch_embed = TrAct(model.patch_embed, lambda_=lam)
        else:
            print(f"Skipping TrAct on {model.__class__.__name__}")
    # end init TrAct

    # setup the optimizers and LR schedulers
    optimizers = [opt(model.parameters()) for model, opt in zip(models,optimizers)]
    schedulers = [sched(opt) for opt, sched in zip(optimizers, schedules)]

    # Arrays to store results
    loss_per_step = [[] for _ in models]
    val_per_epoch = [[] for _ in models]

    # Training loop
    for epoch in range(epochs):
        # Set models to training mode
        for model in models:
            model.train()

        # Initialize rolling average loss
        rolling_loss = [0.0 for _ in models]
        step_count = 0

        with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}") as pbar:
            for inputs, labels in pbar:
                inputs, labels = inputs.cuda(), labels.cuda()

                # Zero gradients
                for optimizer in optimizers:
                    optimizer.zero_grad()

                # Forward pass, compute loss, backward pass, and optimizer step
                for i, (model, optimizer) in enumerate(zip(models, optimizers)):
                    outputs = model(inputs)
                    loss = cross_entropy(outputs, labels, label_smoothing=label_smoothing)
                    loss.backward()
                    optimizer.step()

                    # Update rolling loss
                    rolling_loss[i] = 0.99 * rolling_loss[i] + 0.01 * loss.item()

                    # Save loss every 50 steps
                    if step_count % 50 == 0:
                        loss_per_step[i].append(loss.item())

                # Update step count and tqdm
                step_count += 1
                pbar.set_postfix({f"Model {i} Loss": rolling_loss[i] for i in range(len(models))})
        # end for batch in epoch

        # Update learning rate
        for scheduler in schedulers:
            scheduler.step(epoch+1)

        # Validation loop
        for i, model in enumerate(models):
            model.eval()
            correct = 0
            total = 0

            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.cuda(), labels.cuda()
                    outputs = model(inputs)
                    _, preds = outputs.max(1)
                    correct += (preds == labels).sum().item()
                    total += labels.size(0)

            # Calculate and save accuracy
            accuracy = correct / total
            val_per_epoch[i].append(accuracy)
            print(f"Model {i} Validation Accuracy: {accuracy:.4f}")
        # end validation for model in models
    # end for each epoch
# end training and validation function

In [None]:
from timm.scheduler.cosine_lr import CosineLRScheduler
import copy

def make_schedule(num_epochs=100, warmup=5, lr_base=1e-3):
    # lr_min should be 1e-5 for a lr of 1e-3, so *1e-2
    # warmup init lr should be 1e-6 for a lr of 1e-3, so *1e-3
    return partial(CosineLRScheduler, 
                   t_initial=num_epochs, # number of total epochs
                   lr_min=lr_base*1e-2, 
                   warmup_t=warmup, # number of warmup epochs
                   warmup_lr_init=lr_base*1e-3,
                   t_in_epochs=True)
# end make_schedule

# model order is 
# - ViT
# - ResNet
# - ViT + TrAct
# - ResNet + TrAct

NUM_EPOCHS = 100

models = [
    vit_model,
    resnet_model,
    copy.deepcopy(vit_model),    # make a deep copy to avoid collisions
    copy.deepcopy(resnet_model), # make a deep copy to avoid collisions
]

optimizers = [
     partial(optim.Adam, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0), # ViT trains better with Adam
     partial(optim.SGD,  lr=0.08, momentum=0.9, weight_decay=0.0005),            # ResNet trains better with SGD
     partial(optim.Adam, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0), # ViT trains better with Adam 
     partial(optim.SGD,  lr=0.08, momentum=0.9, weight_decay=0.0005),            # ResNet trains better with SGD
]
    
schedules = [
    make_schedule(lr_base = 1e-3, num_epochs=NUM_EPOCHS), # ViT
    make_schedule(lr_base = 0.08, num_epochs=NUM_EPOCHS), # ResNet
    make_schedule(lr_base = 1e-3, num_epochs=NUM_EPOCHS), # ViT
    make_schedule(lr_base = 0.08, num_epochs=NUM_EPOCHS), # ResNet
]

lambdas = [
    None,  # ViT
    None, # ResNet
    0.1,  # ViT + TrAct. Paper suggests 0.1 for ViT
    0.05, # ResNet + TrAct. Paper suggests 0.05 for ResNet
]


# do the training
train_and_validate(models, train_loader, test_loader, 
                      optimizers, schedules, epochs=NUM_EPOCHS, 
                      lambdas=lambdas, label_smoothing=0.1)