# Summary of the MLP Setup, Forward and Backward Passes, and Kernels

## Multilayer Perceptron (MLP) Setup
An MLP is a feedforward neural network composed of multiple layers of linear transformations followed by non-linear activation functions. In this setup:

- **Layers**: The MLP consists of $L$ layers.
- **Weights**: Each layer $l$ has a weight matrix $W^{(l)}$.
- **No Biases**: For simplicity, we assume there are no bias terms.
- **Activation Function**: A non-linear activation function $f$ is applied element-wise.

## Architecture
- **Input**: $x \in \mathbb{R}^n$
- **Layer 1**:
    - Pre-activation: $z^{(1)} = W^{(1)} x$
    - Activation: $a^{(1)} = f(z^{(1)})$
- **Layer $l$** (for $l = 2, \dots, L$):
    - Pre-activation: $z^{(l)} = W^{(l)} a^{(l-1)}$
    - Activation: $a^{(l)} = f(z^{(l)})$
- **Output**: The activations of the last layer $a^{(L)}$ represent the output of the MLP.

## Forward Pass
The forward pass computes the activations of each layer given an input $x$.
- **Layer 1**:
  - $z^{(1)} = W^{(1)} x$
  - $a^{(1)} = f(z^{(1)})$
- **Layer $l$** (for $l = 2, \dots, L$):
  - $z^{(l)} = W^{(l)} a^{(l-1)}$
  - $a^{(l)} = f(z^{(l)})$

## Backward Pass
The backward pass computes the gradients of a loss function $L$ with respect to the weights using backpropagation.
- **Initialize Gradient at Output Layer**:
  - $\delta^{(L)} = \nabla_{a^{(L)}} L \circ f'(z^{(L)})$
  - where $\circ$ denotes element-wise multiplication and $f'$ is the derivative of the activation function.

- **Backpropagate Through Layers $l = L, L-1, \dots, 1$**:
- $\delta^{(l)} = (W^{(l+1)\top} \delta^{(l+1)}) \circ f'(z^{(l)})$

- **Gradient with Respect to Weights**:
  - $\nabla_{W^{(l)}} L = \delta^{(l)} (a^{(l-1)})^\top$
  - For $l = 1$, $a^{(0)} = x$.

## Forward Kernel
The forward kernel measures the similarity between the activations of two inputs $x$ and $Y$ at each layer.

### Definition
At layer $l$, the forward kernel $K^{(l)}(x,Y)$ is defined as:
$$
K^{(l)}(x,Y) = \mathbb{E}[a^{(l)}(x)^\top a^{(l)}(Y)]
$$
The expectation $\mathbb{E}$ can be over the randomness in the weights or inputs if they are stochastic.

### Properties
- **Recursive Computation**: The forward kernel can be computed recursively using the kernels from the previous layer and the properties of the activation function.
- **Influence on Learning**: The forward kernel captures how similar the representations of different inputs are at each layer, influencing the network's ability to generalize.

## Backward Kernel
The backward kernel measures the similarity between the gradients of the loss with respect to the inputs or weights for two different inputs.

### Definition
At layer $l$, the backward kernel $B^{(l)}(x,Y)$ is defined as:
$$
B^{(l)}(x,Y) = \mathbb{E}[\delta^{(l)}(x)^\top \delta^{(l)}(Y)]
$$
where $\delta^{(l)}(x)$ is the backpropagated error at layer $l$ for input $x$.

### Properties
- **Captures Gradient Alignment**: The backward kernel indicates how aligned the gradients are for different inputs, affecting convergence during training.
- **Dependency on Activation Function**: The computation of the backward kernel depends on the activation function and its derivative.

## Kernel Equivalence Between Models
To ensure that two models $M_1$ and $M_2$ have equal forward and backward kernels:

- **Activation Functions**: Both models must use the same activation function $f$.
- **Weight Initialization**: Weights must be initialized such that the statistical properties (variances and covariances) of the pre-activations and activations match between the models.
- **Input Mapping**: If the input dimensions differ, inputs must be mapped appropriately (e.g., duplication and scaling) to preserve input variances.

## Constructing $M_2$ from $M_1$
Given $M_1$ with weight matrix $W_1 \in \mathbb{R}^{d \times d}$, construct $M_2$ with $W_2 \in \mathbb{R}^{2d \times 2d}$ as:
$$
W_2 = \frac{1}{2} \begin{bmatrix} W_1 & W_1 \\ W_1 & W_1 \end{bmatrix}
$$

### Input Mapping:
$$
X_2 = \frac{1}{2} \begin{bmatrix} x \\ x \end{bmatrix}
$$

### Activation Function Requirements:
- $f(0) = 0$
- $f$ should be homogeneous or satisfy specific symmetry properties (e.g., ReLU).

### Result
With the above construction:

- **Forward Pass**: The activations in $M_2$ corresponding to the first $d$ neurons replicate those in $M_1$.
- **Backward Pass**: The gradients with respect to the weights in $M_2$ align with those in $M_1$ for the corresponding weights.
- **Kernel Equivalence**: Both the forward and backward kernels of $M_1$ and $M_2$ are equal, ensuring similar training dynamics.

## Practical Implementation
- **PyTorch Models**: Implement $M_1$ and $M_2$ as PyTorch models without bias terms.
- **Initialization**: Initialize $M_2$ weights based on $M_1$ to satisfy the kernel equivalence conditions.
- **Verification**: Compute the forward and backward passes for a given input and verify that the activations and gradients match as expected.


# Empirical 

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


activation_functions = {
    'relu': nn.ReLU,
    'leaky_relu': nn.LeakyReLU,
    'prelu': nn.PReLU,
    'rrelu': nn.RReLU,
    'relu6': nn.ReLU6,
    'elu': nn.ELU,
    'selu': nn.SELU,
    'celu': nn.CELU,
    'gelu': nn.GELU,
    'sigmoid': nn.Sigmoid,
    'tanh': nn.Tanh,
    'softmax': nn.Softmax,
    'log_softmax': nn.LogSoftmax,
}

class ConfigurableMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, activation='relu', norm_layer='none'):
        super(ConfigurableMLP, self).__init__()
        self.layers = nn.ModuleList()
        
        # Input layer
        input_linear = nn.Linear(input_dim, hidden_dims[0])
        input_linear.name = "linear_in"  # Optional naming
        self.layers.append(input_linear)
        
        # Hidden layers
        for i in range(len(hidden_dims) - 1):
            linear = nn.Linear(hidden_dims[i], hidden_dims[i+1])
            linear.name = f"linear_{i}"  # Optional naming
            self.layers.append(linear)
            
            # Add normalization layer
            if norm_layer == 'layernorm':
                norm = nn.LayerNorm(hidden_dims[i+1])
                norm.name = f"norm_{i}"  # Optional naming
                self.layers.append(norm)
            elif norm_layer == 'batchnorm':
                norm = nn.BatchNorm1d(hidden_dims[i+1])
                norm.name = f"norm_{i}"  # Optional naming
                self.layers.append(norm)
            elif norm_layer == 'rmsnorm':
                norm = nn.GroupNorm(1, hidden_dims[i+1])  # RMSNorm is GroupNorm with num_groups=1
                norm.name = f"norm_{i}"  # Optional naming
                self.layers.append(norm)
            
            # Add activation layer
            act = activation_functions[activation]()
            act.name = f"act_{i}"  # Optional naming
            self.layers.append(act)
        
        # Output layer
        output_linear = nn.Linear(hidden_dims[-1], output_dim)
        output_linear.name = "linear_out"  # Optional naming
        self.layers.append(output_linear)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

def double_mlp_size(original_mlp, config, noise=0.01):

    _, input_dim, output_dim = get_dataset(config['dataset'],)
    # Create new MLP with doubled hidden layer sizes
    hidden_dims = [dim * 2 for dim in config['hidden_dims']]
    new_mlp = ConfigurableMLP(input_dim, hidden_dims, output_dim, config['activation'], config['norm_layer'])

    # Initialize weights by cloning
    with torch.no_grad():
        # Input layer (mx2n)
        new_mlp.layers[0].weight.data = torch.cat([
            original_mlp.layers[0].weight.data,
            original_mlp.layers[0].weight.data
        ], dim=0) 
        new_mlp.layers[0].bias.data = torch.cat([
            original_mlp.layers[0].bias.data,
            original_mlp.layers[0].bias.data,
        ],)

        # Hidden layers (2mx2n)
        for i in range(1, len(hidden_dims) + 1):
            orig_layer = [l for l in original_mlp.layers if isinstance(l, nn.Linear)][i]
            new_layer = [l for l in new_mlp.layers if isinstance(l, nn.Linear)][i]

            new_layer.weight.data = torch.cat([
                orig_layer.weight.data,
                orig_layer.weight.data
            ],dim=0) 
            new_layer.weight.data = torch.cat([
                new_layer.weight.data,
                new_layer.weight.data
            ],dim=1) * 0.5
            new_layer.bias.data = torch.cat([
                orig_layer.bias.data,
                orig_layer.bias.data
            ])

        # Output layer (2mxn)
        new_mlp.layers[-1].weight.data = torch.cat([
            original_mlp.layers[-1].weight.data,
            original_mlp.layers[-1].weight.data
        ], dim=1) * 0.5
        new_mlp.layers[-1].bias.data = original_mlp.layers[-1].bias.data.clone()
    for l in new_mlp.layers:
        if isinstance(l,nn.Linear):
            sd = torch.std(l.weight.data)
            l.weight.data += torch.randn_like(l.weight.data) * sd * noise
            sd = torch.std(l.bias.data)
            l.bias.data += torch.randn_like(l.bias.data) * sd * noise

    return new_mlp

def get_dataset(name, train=True):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    
    if name == 'CIFAR10':
        dataset = datasets.CIFAR10(root='./data', train=train, download=True, transform=transform)
        input_dim = 3 * 32 * 32
        output_dim = 10
    elif name == 'CIFAR100':
        dataset = datasets.CIFAR100(root='./data', train=train, download=True, transform=transform)
        input_dim = 3 * 32 * 32
        output_dim = 100
    elif name == 'MNIST':
        dataset = datasets.MNIST(root='./data', train=train, download=True, transform=transform)
        input_dim = 28 * 28
        output_dim = 10
    elif name == 'FashionMNIST':
        dataset = datasets.FashionMNIST(root='./data', train=train, download=True, transform=transform)
        input_dim = 28 * 28
        output_dim = 10
    else:
        raise ValueError(f"Unknown dataset: {name}")
    
    return dataset, input_dim, output_dim

def train(model, train_loader, optimizer, criterion, device):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

def test(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data = data.view(data.size(0), -1)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, accuracy

    
def main(config):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    
    train_dataset, input_dim, output_dim = get_dataset(config['dataset'], train=True)
    test_dataset, _, _ = get_dataset(config['dataset'], train=False)
    
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)
    
    model = ConfigurableMLP(input_dim, config['hidden_dims'], output_dim, config['activation'], config['norm_layer']).to(device)
    
    if config['optimizer'] == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    elif config['optimizer'] == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=config['lr'])
    else:
        raise ValueError(f"Unknown optimizer: {config['optimizer']}")
    
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(config['epochs']):
        train(model, train_loader, optimizer, criterion, device)
        test_loss, accuracy = test(model, test_loader, criterion, device)
        print(f'Epoch: {epoch+1}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
        
    return model, train_loader, test_loader, criterion, device


def get_default_config():
    return {
        'dataset': 'CIFAR100',
        'hidden_dims': [512,]*5,
        'activation': 'tanh',
        'norm_layer': 'none',
        'lr': 0.001,
        'epochs': 1,
        'optimizer': 'adam',
        'batch_size': 512
    }


# config = get_default_config()
# config['epochs'] = 10
model, train_loader, test_loader, criterion, device = main(config)
model2 = double_mlp_size(model, config, noise=0.1)
test_loss, accuracy = test(model2.to(device), test_loader, criterion, device)
print(f'Hyper cloned model Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

# if __name__ == '__main__':
    # Use default configuration
    # config = get_default_config()

    # You can modify the configuration here
    # config['dataset'] = 'CIFAR10'
    # config['hidden_dims'] = [256, 128, 64]
    # config['activation'] = 'tanh'
    # config['norm_layer'] = 'batchnorm'
    # config['lr'] = 0.0005
    # config['epochs'] = 20
    # config['optimizer'] = 'sgd'
    # config['batch_size'] = 32

    # main(config)

Files already downloaded and verified
Files already downloaded and verified
Epoch: 1, Test loss: 0.0075, Accuracy: 13.19%
Epoch: 2, Test loss: 0.0072, Accuracy: 16.94%
Epoch: 3, Test loss: 0.0070, Accuracy: 18.62%
Epoch: 4, Test loss: 0.0068, Accuracy: 19.43%
Epoch: 5, Test loss: 0.0067, Accuracy: 21.16%
Epoch: 6, Test loss: 0.0067, Accuracy: 20.46%
Epoch: 7, Test loss: 0.0067, Accuracy: 22.00%
Epoch: 8, Test loss: 0.0066, Accuracy: 22.16%
Epoch: 9, Test loss: 0.0066, Accuracy: 22.27%
Epoch: 10, Test loss: 0.0066, Accuracy: 23.13%
Files already downloaded and verified
Hyper cloned model Test loss: 0.0066, Accuracy: 23.05%


In [None]:
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        data = data.view(data.size(0), -1)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    
def test_models(model1, model2, criterion, test_loader, device):
    model.eval()
    error = 0
    correct1, correct2 = 0, 0
    loss1, loss2 = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            data = data.view(data.size(0), -1)
            o1 = model(data)
            o2 = model2(data)
            error += torch.mean((o1-o2)**2)
            loss1 += criterion(o1, target).item()
            loss2 += criterion(o2, target).item()
            pred = o1.argmax(dim=1, keepdim=True)
            correct1 += pred.eq(target.view_as(pred)).sum().item()            
            pred = o2.argmax(dim=1, keepdim=True)
            correct2 += pred.eq(target.view_as(pred)).sum().item()

    loss1 /= len(test_loader.dataset)
    loss2 /= len(test_loader.dataset)
    acc1 = 100. * correct1 / len(test_loader.dataset)
    acc2 = 100. * correct2 / len(test_loader.dataset)
    error /= len(test_loader)
    return loss1,acc1, loss2,acc2,error 


config = {
    'dataset': 'CIFAR100',
    'hidden_dims': [256,]*5,
    'activation': 'tanh',
    'norm_layer': 'rmsnorm',
    'lr': 0.001,
    'epochs': 3,
    'gram_epochs': 3,
    'freeze_others': True,
    'doubling_noise': 0.0,
    'optimizer': 'adam',
    'batch_size': 512
}
model, train_loader, test_loader = main(config)
model2 = double_mlp_size(model, config, noise=config['doubling_noise']).to(device)
opt1 = optim.Adam(model.parameters(), lr=config['lr'])
opt2 = optim.Adam(model2.parameters(), lr=config['lr'])

epoch = 0
loss1,acc1, loss2,acc2,err  = test_models(model, model2, criterion, test_loader, device)
print(f"Epoch {epoch}, model discrepancy={err:.8f}, loss1={loss1:.2f},acc1={acc1:.1f}, loss2={loss2:.2f},acc2={acc2:.1f}")
while epoch < 3:
    epoch += 1
    train(model, train_loader, opt1, criterion, device)
    train(model2, train_loader, opt2, criterion, device)
    
    loss1,acc1, loss2,acc2,err  = test_models(model, model2, criterion, test_loader, device)
    
    print(f"Epoch {epoch}, model discrepancy={err:.8f}, loss1={loss1:.2f},acc1={acc1:.1f}, loss2={loss2:.2f},acc2={acc2:.1f}")

In [395]:

cos = nn.CosineSimilarity(dim=0)
for l in model2.layers:
    if isinstance(l,nn.Linear):
        if l.weight.shape[0]!=l.weight.shape[1]:
            continue
        w = l.weight.detach().cpu()
        d = w.shape[0]//2
        print(cos(w[:d,:d].flatten(),w[d:,d:].flatten()),cos(w[:d,:d].flatten(),w[:d,d:].flatten()))


tensor(1.0000) tensor(1.0000)
tensor(1.0000) tensor(1.0000)
tensor(1.0000) tensor(1.0000)
tensor(1.0000) tensor(1.0000)


In [373]:
for epoch in range(5):
    train(model, train_loader, optimizer, criterion, device)
    test_loss, accuracy = test(model, test_loader, criterion, device)
    print(f'Epoch: {epoch+1}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

training model
Epoch: 1, Test loss: 0.0033, Accuracy: 42.13%
training model
Epoch: 2, Test loss: 0.0033, Accuracy: 42.13%
training model
Epoch: 3, Test loss: 0.0033, Accuracy: 42.13%
training model
Epoch: 4, Test loss: 0.0033, Accuracy: 42.13%
training model
Epoch: 5, Test loss: 0.0033, Accuracy: 42.13%


In [191]:
def get_hidden_layers(model, x, layer_name):
    hidden_layers = []
    
    def hook(module, input, output):
        hidden_layers.append((module.name, output.detach().cpu()))
    
    hooks = []
    for layer in model.layers:
        print(layer.name)
        hooks.append(layer.register_forward_hook(hook))
    
    # Forward pass
    model(x)
    
    # Remove hooks
    for h in hooks:
        h.remove()
    
    return hidden_layers

for batch in train_loader:
    data, _ = batch
    data = data.view(data.size(0), -1).to('cuda')
    break

hidden = get_hidden_layers(model, data, layer_name='linear')
first_hidden = model.layers[0](data).cpu()

hidden2 = get_hidden_layers(model2, data, layer_name='linear')
first_hidden2 = model2.layers[0](data).cpu()
hidden[0][1]-first_hidden, hidden2[0][1]-first_hidden2




linear_in
linear_0
act_0
linear_1
act_1
linear_2
act_2
linear_out
linear_in
linear_0
act_0
linear_1
act_1
linear_2
act_2
linear_out


(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<SubBackward0>),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], grad_fn=<SubBackward0>))

In [204]:
def compute_gradients(model, batch, criterion, device):
    model.train()  # Set the model to training mode
    model.zero_grad()  # Zero out any existing gradients

    # Unpack the batch
    inputs, targets = batch
    inputs, targets = inputs.to(device), targets.to(device)
    
    # Reshape the input if necessary (e.g., for MNIST or CIFAR)
    inputs = inputs.view(inputs.size(0), -1)

    # Forward pass
    outputs = model(inputs)
    
    # Compute the loss
    loss = criterion(outputs, targets)
    
    # Backward pass
    loss.backward()

    # Now the gradients are stored in the .grad attribute of each parameter
    gradients = []
    for name, param in model.named_parameters():
        if param.grad is not None:
            gradients.append((name, param.grad.clone().cpu()))
    
    return gradients

gradients = compute_gradients(model, batch, criterion, device)
gradients2 = compute_gradients(model2, batch, criterion, device)

In [216]:
l = 3
name, gradient = gradients[l]
name, gradient2 = gradients2[l]
name, gradient.shape, gradient2.shape

('layers.1.bias', torch.Size([500]), torch.Size([1000]))

Files already downloaded and verified
Files already downloaded and verified
M1 - Epoch: 1, Test loss: 0.0074, Accuracy: 14.82%
M1 - Epoch: 2, Test loss: 0.0071, Accuracy: 17.26%
M1 - Epoch: 3, Test loss: 0.0068, Accuracy: 19.41%
M1 - Epoch: 4, Test loss: 0.0067, Accuracy: 19.70%
M1 - Epoch: 5, Test loss: 0.0066, Accuracy: 21.55%
M1 - Epoch: 6, Test loss: 0.0065, Accuracy: 22.43%
M1 - Epoch: 7, Test loss: 0.0065, Accuracy: 22.38%
Epoch 1, Train Loss: 4.018192, Test Gram Loss: 1.355158
M2 - Final Test loss: 0.0125, Accuracy: 0.45%
Training after kernel transfer (M3 = M2 from scratch)
M2 - Epoch: 1, Test loss: 0.0074, Accuracy: 14.22%
M2 - Epoch: 2, Test loss: 0.0072, Accuracy: 15.91%
M2 - Epoch: 3, Test loss: 0.0070, Accuracy: 17.86%
M2 - Epoch: 4, Test loss: 0.0068, Accuracy: 19.75%
M2 - Epoch: 5, Test loss: 0.0067, Accuracy: 20.42%
M2 - Epoch: 6, Test loss: 0.0066, Accuracy: 21.30%
M2 - Epoch: 7, Test loss: 0.0065, Accuracy: 21.99%
M3 - Epoch: 1, Test loss: 0.0075, Accuracy: 13.98%
M3 

In [293]:
i = 0
for data, label in test_loader:
    i += 1
    data = data.view(data.size(0), -1)
    data = data.to(device)
    if i==11:
        break
out = M1(data)
out2 = M2(data)
G = out @ out.t()
G2 = out2 @ out2.t()
(G - G2).norm(), G.norm()

(tensor(21680.8848, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>),
 tensor(60140.7812, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>))

M1 - Epoch: 6, Test loss: 0.0032, Accuracy: 43.91%
M1 - Epoch: 7, Test loss: 0.0030, Accuracy: 46.60%
M1 - Epoch: 8, Test loss: 0.0029, Accuracy: 48.73%
M1 - Epoch: 9, Test loss: 0.0028, Accuracy: 50.22%
M1 - Epoch: 10, Test loss: 0.0028, Accuracy: 50.80%


In [303]:
# Train M1 using the original training method
optimizer_M2 = optim.Adam(M2.parameters(), lr=config['lr'])
criterion = nn.CrossEntropyLoss()

for epoch in range(config['epochs']):
    train(M2, train_loader, optimizer_M2, criterion, device)
    test_loss, accuracy = test(M2, test_loader, criterion, device)
    print(f'M2 - Epoch: {config["epochs"]+epoch+1}, Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

M2 - Epoch: 6, Test loss: 0.0132, Accuracy: 22.18%
M2 - Epoch: 7, Test loss: 0.0132, Accuracy: 22.32%
M2 - Epoch: 8, Test loss: 0.0132, Accuracy: 22.44%
M2 - Epoch: 9, Test loss: 0.0132, Accuracy: 22.07%
M2 - Epoch: 10, Test loss: 0.0132, Accuracy: 22.47%
