# *Challenge 2*: *Discovering Complexity in Neural Networks*

Advanced Topics in Machine Learning -- Fall 2023, UniTS

<a target="_blank" href="https://colab.research.google.com/github/ganselmif/adv-ml-units/blob/main/notebooks/AdvML_Challenge_2.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>

In this Notebook we will perform an automatic analysis of complexity of a CNN over the Cifar10 dataset, in order to understand how non-linear is our feature representation (I choose this challenge as second project instead the the Challenge 2)

### Main Idea
We'll analyze a naive CNN architecture: Convolution -> BN -> ReLU -> MaxPool -> ... -> Linear -> Softmax

We will boost a bit this cute model with pruning technique.

Now, we want to set a trainable non-linearity $R_{\beta}(x) = ReLU(x) - \beta \cdot ReLU(-x)$, if we take $\beta = (1 - \alpha)$ we obtain

$R_{\alpha}(x) = ReLU(x) - (1-\alpha)\cdot ReLU(-x)$

If we observe that $x = ReLU(x) - ReLU(-x)$ we finally obtain:

$$R_{\alpha}(x) = x + \alpha \cdot ReLU(-x)$$

Where $\alpha \in [0, 1]$

*Note*: you can also interpret the action of $R_\alpha(x)$ as a residual connection!

*Note*: we'll use a different $\alpha$ for each non linearity

There could be different way to model $\alpha$, Professor Anselmi suggested to simply penalize the final loss with the L1 Norm of the alphas, instead I'll use $\alpha_t = f((f(\gamma)-0.5)\nu_t)$, with $\gamma$ trainable, $f(x) = 1/(1+e^{-x})$, and $\nu_t$ increasing in time, so in the end the model will automatically select if the $\alpha$ is relevant or not.

You can find the same updating strategy in this paper: https://appliednetsci.springeropen.com/articles/10.1007/s41109-023-00542-x

In [47]:
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

In [48]:
class GammaAlphaCompute(nn.Module):

    """
    Computes and tracks the alpha value based on a learnable gamma parameter.

    This module is responsible for computing the alpha value used in the CustomNonLinearity.
    It maintains a history of alpha and gamma values during training.

    Attributes:
        gamma (nn.Parameter): Learnable parameter used to compute alpha.
        nu (torch.Tensor): A buffer that increases during training.
        name (str): Identifier for the instance.
        alpha_history (List[float]): History of computed alpha values during training.
        gamma_history (List[float]): History of gamma values during training.
    """

    def __init__(self, name: str):
        """
        Initializes the GammaAlphaCompute module.

        Args:
            name (str): Identifier for this instance.
        """
        super().__init__()
        self.gamma = nn.Parameter(torch.tensor(0.0))
        self.register_buffer('nu', torch.tensor(0.0))
        self.name = name
        self.alpha_history: List[float] = []
        self.gamma_history: List[float] = []

    def compute_alpha(self) -> torch.Tensor:
        """
        Computes the alpha value based on the current gamma and nu.

        If in training mode, this method also updates nu and records the
        current alpha and gamma values in their respective histories.

        Returns:
            torch.Tensor: The computed alpha value.
        """
        if self.training:
            self.nu = self.nu + 0.01  # Increase nu during training

        f_gamma = torch.sigmoid(self.gamma)
        alpha = torch.sigmoid((f_gamma - 0.5) * self.nu)

        if self.training:
            self.alpha_history.append(alpha.item())
            self.gamma_history.append(self.gamma.item())

        return alpha

class CustomNonLinearity(nn.Module):
    """
    Implements a custom non-linearity function using the computed alpha value.

    This module applies a non-linear transformation to the input using an alpha
    value computed by a GammaAlphaCompute instance.

    Attributes:
        ga_compute (GammaAlphaCompute): Instance to compute the alpha value.
    """

    def __init__(self, name: str):
        """
        Initializes the CustomNonLinearity module.

        Args:
            name (str): Identifier for this instance, passed to GammaAlphaCompute.
        """
        super().__init__()
        self.ga_compute = GammaAlphaCompute(name)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Applies the custom non-linearity to the input tensor.

        The non-linearity is defined as: f(x) = x + alpha * ReLU(-x)

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: The result of applying the non-linearity to the input.
        """
        alpha = self.ga_compute.compute_alpha()
        return x + alpha * F.relu(-x)

In [49]:
class DynamicPruningCNN(nn.Module):
    """
    A Convolutional Neural Network with dynamic pruning capabilities.

    This CNN can dynamically remove layers during training based on their importance,
    as determined by the alpha values of CustomNonLinearity layers.

    Attributes:
        layers (nn.ModuleList): List of CNN layers.
        input_channels (int): Number of input channels.
        hidden_channels (int): Number of hidden channels in CNN layers.
        num_classes (int): Number of output classes.
        output_size (int): Size of the flattened output before the final FC layer.
        fc (nn.Linear): Final fully connected layer.
        softmax (nn.Softmax): Softmax activation for output.
        gamma_history (dict): History of gamma values for each layer.
        alpha_history (dict): History of alpha values for each layer.
    """

    def __init__(self, initial_layers: int = 3, input_channels: int = 3,
                 hidden_channels: int = 64, num_classes: int = 10):
        """
        Initializes the DynamicPruningCNN.

        Args:
            initial_layers (int): Number of initial CNN layers.
            input_channels (int): Number of input channels.
            hidden_channels (int): Number of hidden channels in CNN layers.
            num_classes (int): Number of output classes.
        """
        super(DynamicPruningCNN, self).__init__()
        self.layers = nn.ModuleList()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.num_classes = num_classes

        for i in range(initial_layers):
            in_channels = input_channels if i == 0 else hidden_channels
            self.layers.append(self._create_cnn_block(in_channels, hidden_channels, f"CNN{i+1}"))

        self.output_size = self._get_output_size(initial_layers)
        self.fc = nn.Linear(self.output_size, num_classes)
        self.softmax = nn.Softmax(dim=1)

        self.gamma_history = {f"CNN{i+1}": [] for i in range(initial_layers)}
        self.alpha_history = {f"CNN{i+1}": [] for i in range(initial_layers)}

    def _create_cnn_block(self, in_channels: int, out_channels: int, name: str) -> nn.Sequential:
        """
        Creates a CNN block consisting of Conv2d, BatchNorm2d, and CustomNonLinearity.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            name (str): Name for the CustomNonLinearity layer.

        Returns:
            nn.Sequential: A CNN block.
        """
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            CustomNonLinearity(name)
        )

    def _get_output_size(self, num_layers: int) -> int:
        """
        Calculates the output size after applying all CNN layers and max pooling.

        Args:
            num_layers (int): Number of CNN layers.

        Returns:
            int: The size of the flattened output.
        """
        x = torch.randn(1, self.input_channels, 32, 32)  # CIFAR-10 input size
        for _ in range(num_layers):
            x = F.max_pool2d(x, 2)
        return x.numel() * self.hidden_channels // x.shape[1]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the network.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after passing through the network.
        """
        for i, layer in enumerate(self.layers):
            x = layer(x)
            x = F.max_pool2d(x, 2)

            layer_name = f"CNN{i+1}"
            self.gamma_history[layer_name].append(layer[-1].ga_compute.gamma.item())
            self.alpha_history[layer_name].append(layer[-1].ga_compute.compute_alpha().item())

        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return self.softmax(x)

    def prune_and_rebuild(self, alpha_threshold: float = 0.01) -> Tuple['DynamicPruningCNN', Optional[int]]:
        """
        Prunes the network by removing a layer with alpha value below the threshold.

        If a layer is pruned, a new model is created with the remaining layers,
        and the weights are transferred from the old model to the new one.

        Args:
            alpha_threshold (float): Threshold for pruning decision.

        Returns:
            Tuple[DynamicPruningCNN, Optional[int]]: A tuple containing the new (or same) model
            and the index of the pruned layer (or None if no pruning occurred).
        """
        pruned_layer_index = next((i for i, layer in enumerate(self.layers)
                                   if layer[-1].ga_compute.compute_alpha() < alpha_threshold), None)

        if pruned_layer_index is not None:
            new_model = DynamicPruningCNN(len(self.layers) - 1,
                                          self.input_channels,
                                          self.hidden_channels,
                                          self.num_classes)

            # This part copies weights and history from the old model to the new submodel
            with torch.no_grad():
                for i in range(pruned_layer_index):
                    new_model.layers[i].load_state_dict(self.layers[i].state_dict())
                    new_model.gamma_history[f"CNN{i+1}"] = self.gamma_history[f"CNN{i+1}"]
                    new_model.alpha_history[f"CNN{i+1}"] = self.alpha_history[f"CNN{i+1}"]

                for i in range(pruned_layer_index, len(new_model.layers)):
                    new_model.layers[i].load_state_dict(self.layers[i+1].state_dict())
                    new_model.gamma_history[f"CNN{i+1}"] = self.gamma_history[f"CNN{i+2}"]
                    new_model.alpha_history[f"CNN{i+1}"] = self.alpha_history[f"CNN{i+2}"]

                # Preserve the history of the pruned layer
                pruned_layer_name = f"CNN{pruned_layer_index+1}"
                new_model.gamma_history[pruned_layer_name] = self.gamma_history[pruned_layer_name]
                new_model.alpha_history[pruned_layer_name] = self.alpha_history[pruned_layer_name]

                # Adjust the FC layer
                if new_model.output_size != self.output_size:
                    new_fc = nn.Linear(new_model.output_size, self.num_classes)
                    new_fc.weight.data[:, :self.output_size] = self.fc.weight.data
                    new_fc.bias.data = self.fc.bias.data
                    new_model.fc = new_fc
                else:
                    new_model.fc.load_state_dict(self.fc.state_dict())

            # Freeze all layers except those after the pruned layer
            for i, layer in enumerate(new_model.layers):
                for param in layer.parameters():
                    param.requires_grad = (i >= pruned_layer_index)

            return new_model, pruned_layer_index

        return self, None

    def unfreeze_all_layers(self) -> None:
        """
        Unfreezes all layers in the network, making them trainable.
        """
        for layer in self.layers:
            for param in layer.parameters():
                param.requires_grad = True

In [50]:
def evaluate(model: nn.Module, dataloader: DataLoader, criterion: nn.Module) -> Tuple[float, float]:
    """
    Evaluates the model on the given dataloader.

    Args:
        model (nn.Module): The model to evaluate.
        dataloader (DataLoader): The dataloader containing the evaluation data.
        criterion (nn.Module): The loss function.

    Returns:
        Tuple[float, float]: A tuple containing the average loss and accuracy.
    """
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return running_loss / len(dataloader), 100 * correct / total

def train_with_dynamic_pruning(model: DynamicPruningCNN,
                               trainloader: DataLoader,
                               valloader: DataLoader,
                               criterion: nn.Module,
                               optimizer: optim.Optimizer,
                               num_epochs: int,
                               alpha_threshold: float = 0.01) -> Tuple[DynamicPruningCNN, List[float], List[float], List[float], List[float], List[int], List[int]]:
    """
    Trains the model with dynamic pruning.

    Args:
        model (DynamicPruningCNN): The model to train.
        trainloader (DataLoader): The dataloader for training data.
        valloader (DataLoader): The dataloader for validation data.
        criterion (nn.Module): The loss function.
        optimizer (optim.Optimizer): The optimizer.
        num_epochs (int): The number of epochs to train.
        alpha_threshold (float): The threshold for pruning decision.

    Returns:
        Tuple[DynamicPruningCNN, List[float], List[float], List[float], List[float], List[int], List[int]]:
        A tuple containing the trained model, training losses, training accuracies,
        validation losses, validation accuracies, pruned epochs, and pruned layers.
    """
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    pruned_epochs, pruned_layers = [], []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_correct = 0
        total_samples = 0

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("Model parameters:")
        for name, param in model.named_parameters():
            if 'gamma' in name:
                print(f"{name}: {param.item():.4f}")

        pbar = tqdm(enumerate(trainloader), total=len(trainloader), desc=f'Training: ')
        for i, (inputs, labels) in pbar:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            running_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)

            pbar.set_postfix({
                'loss': f'{running_loss / (i + 1):.4f}',
                'accuracy': f'{100 * running_correct / total_samples:.2f}%'
            })

        train_loss = running_loss / len(trainloader)
        train_acc = 100 * running_correct / total_samples
        train_losses.append(train_loss)
        train_accs.append(train_acc)

        # Check for pruning
        new_model, pruned_layer = model.prune_and_rebuild(alpha_threshold)
        if pruned_layer is not None:
            model = new_model
            optimizer = optim.Adam(model.parameters(), lr=0.001)
            pruned_epochs.append(epoch)
            pruned_layers.append(pruned_layer)
            print(f"Pruned layer {pruned_layer} at epoch {epoch+1}")

        # Validation
        val_loss, val_acc = evaluate(model, valloader, criterion)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f'Epoch {epoch+1} summary:')
        print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%')

        # Unfreeze all layers after one epoch of pruning
        if pruned_epochs and epoch == pruned_epochs[-1] + 1:
            model.unfreeze_all_layers()
            print("Unfreezing all layers")

    return model, train_losses, train_accs, val_losses, val_accs, pruned_epochs, pruned_layers

In [51]:
def plot_training_metrics(train_losses: List[float], train_accs: List[float],
                          val_losses: List[float], val_accs: List[float]) -> None:
    """
    Plots training and validation metrics.

    Args:
        train_losses (List[float]): Training losses.
        train_accs (List[float]): Training accuracies.
        val_losses (List[float]): Validation losses.
        val_accs (List[float]): Validation accuracies.
    """
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss over epochs')

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train')
    plt.plot(val_accs, label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy over epochs')

    plt.tight_layout()
    plt.savefig('training_validation_metrics.png')
    plt.close()

def plot_gamma_alpha_history(model: DynamicPruningCNN, pruned_epochs: List[int], pruned_layers: List[int]) -> None:
    """
    Plots the history of gamma and alpha values.

    Args:
        model (DynamicPruningCNN): The trained model.
        pruned_epochs (List[int]): List of epochs where pruning occurred.
        pruned_layers (List[int]): List of layers that were pruned.
    """
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 16))

    for layer_name, gamma_history in model.gamma_history.items():
        ax1.plot(gamma_history, label=layer_name)

    for layer_name, alpha_history in model.alpha_history.items():
        ax2.plot(alpha_history, label=layer_name)

    for epoch, layer in zip(pruned_epochs, pruned_layers):
        ax1.axvline(x=epoch, color='r', linestyle='--', label=f'Pruned layer {layer}' if layer == pruned_layers[0] else "")
        ax2.axvline(x=epoch, color='r', linestyle='--', label=f'Pruned layer {layer}' if layer == pruned_layers[0] else "")

    ax1.set_title('Gamma History')
    ax1.set_xlabel('Training Steps')
    ax1.set_ylabel('Gamma Value')
    ax1.legend()

    ax2.set_title('Alpha History')
    ax2.set_xlabel('Training Steps')
    ax2.set_ylabel('Alpha Value')
    ax2.legend()

    plt.tight_layout()
    plt.savefig('gamma_alpha_history.png')
    plt.close()

In [52]:
torch.autograd.set_detect_anomaly(False) # My friend for bug fixing

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7ca1334c7dc0>

In [53]:
transform = transforms.Compose([
    transforms.ToTensor()
])

fullset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# Split the dataset
train_size = int(0.8 * len(fullset))
val_size = len(fullset) - train_size
trainset, valset = random_split(fullset, [train_size, val_size])

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

# Model initialization and training
model = DynamicPruningCNN(initial_layers=5) # as we'll see, 5 is too much, so maybe we can simplify
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model, train_losses, train_accs, val_losses, val_accs, pruned_epochs, pruned_layers = train_with_dynamic_pruning(
    model, trainloader, valloader, criterion, optimizer, num_epochs=20, alpha_threshold=0.01
)

Files already downloaded and verified
Files already downloaded and verified

Epoch 1/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.0000
layers.1.2.ga_compute.gamma: 0.0000
layers.2.2.ga_compute.gamma: 0.0000
layers.3.2.ga_compute.gamma: 0.0000
layers.4.2.ga_compute.gamma: 0.0000


Training: 100%|██████████| 625/625 [00:44<00:00, 14.20it/s, loss=1.9523, accuracy=52.37%]


Epoch 1 summary:
Train Loss: 1.9523, Train Accuracy: 52.37%
Val Loss: 1.9251, Val Accuracy: 53.83%

Epoch 2/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.2872
layers.1.2.ga_compute.gamma: 0.3652
layers.2.2.ga_compute.gamma: 0.2950
layers.3.2.ga_compute.gamma: 0.3371
layers.4.2.ga_compute.gamma: -0.6518


Training: 100%|██████████| 625/625 [00:43<00:00, 14.24it/s, loss=1.8104, accuracy=65.94%]


Pruned layer 4 at epoch 2
Epoch 2 summary:
Train Loss: 1.8104, Train Accuracy: 65.94%
Val Loss: 2.3028, Val Accuracy: 10.69%

Epoch 3/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.4141
layers.1.2.ga_compute.gamma: 0.5302
layers.2.2.ga_compute.gamma: 0.4808
layers.3.2.ga_compute.gamma: 0.4204


Training: 100%|██████████| 625/625 [00:24<00:00, 25.32it/s, loss=1.8966, accuracy=60.91%]


Epoch 3 summary:
Train Loss: 1.8966, Train Accuracy: 60.91%
Val Loss: 1.8328, Val Accuracy: 65.42%
Unfreezing all layers

Epoch 4/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.4141
layers.1.2.ga_compute.gamma: 0.5302
layers.2.2.ga_compute.gamma: 0.4808
layers.3.2.ga_compute.gamma: 0.4204


Training: 100%|██████████| 625/625 [00:41<00:00, 14.89it/s, loss=1.8032, accuracy=67.00%]


Epoch 4 summary:
Train Loss: 1.8032, Train Accuracy: 67.00%
Val Loss: 1.8479, Val Accuracy: 61.69%

Epoch 5/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.4991
layers.1.2.ga_compute.gamma: 0.6294
layers.2.2.ga_compute.gamma: 0.5973
layers.3.2.ga_compute.gamma: 0.2110


Training: 100%|██████████| 625/625 [00:42<00:00, 14.58it/s, loss=1.7661, accuracy=70.28%]


Epoch 5 summary:
Train Loss: 1.7661, Train Accuracy: 70.28%
Val Loss: 1.7966, Val Accuracy: 66.82%

Epoch 6/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5278
layers.1.2.ga_compute.gamma: 0.6546
layers.2.2.ga_compute.gamma: 0.6277
layers.3.2.ga_compute.gamma: 0.0487


Training: 100%|██████████| 625/625 [00:42<00:00, 14.76it/s, loss=1.7450, accuracy=72.25%]


Epoch 6 summary:
Train Loss: 1.7450, Train Accuracy: 72.25%
Val Loss: 1.8053, Val Accuracy: 66.00%

Epoch 7/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5366
layers.1.2.ga_compute.gamma: 0.6611
layers.2.2.ga_compute.gamma: 0.6355
layers.3.2.ga_compute.gamma: 0.0135


Training: 100%|██████████| 625/625 [00:42<00:00, 14.71it/s, loss=1.7271, accuracy=73.93%]


Epoch 7 summary:
Train Loss: 1.7271, Train Accuracy: 73.93%
Val Loss: 1.7746, Val Accuracy: 68.92%

Epoch 8/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5392
layers.1.2.ga_compute.gamma: 0.6625
layers.2.2.ga_compute.gamma: 0.6372
layers.3.2.ga_compute.gamma: 0.0163


Training: 100%|██████████| 625/625 [00:42<00:00, 14.57it/s, loss=1.7126, accuracy=75.34%]


Epoch 8 summary:
Train Loss: 1.7126, Train Accuracy: 75.34%
Val Loss: 1.7789, Val Accuracy: 68.45%

Epoch 9/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5400
layers.1.2.ga_compute.gamma: 0.6628
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: -0.0008


Training: 100%|██████████| 625/625 [00:42<00:00, 14.63it/s, loss=1.7009, accuracy=76.58%]


Epoch 9 summary:
Train Loss: 1.7009, Train Accuracy: 76.58%
Val Loss: 1.7559, Val Accuracy: 70.65%

Epoch 10/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5403
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: 0.0029


Training: 100%|██████████| 625/625 [00:42<00:00, 14.69it/s, loss=1.6910, accuracy=77.53%]


Epoch 10 summary:
Train Loss: 1.6910, Train Accuracy: 77.53%
Val Loss: 1.7540, Val Accuracy: 70.81%

Epoch 11/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: 0.0004


Training: 100%|██████████| 625/625 [00:42<00:00, 14.73it/s, loss=1.6796, accuracy=78.63%]


Epoch 11 summary:
Train Loss: 1.6796, Train Accuracy: 78.63%
Val Loss: 1.7709, Val Accuracy: 68.98%

Epoch 12/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: -0.0089


Training: 100%|██████████| 625/625 [00:42<00:00, 14.69it/s, loss=1.6714, accuracy=79.42%]


Epoch 12 summary:
Train Loss: 1.6714, Train Accuracy: 79.42%
Val Loss: 1.7746, Val Accuracy: 68.55%

Epoch 13/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: 0.0020


Training: 100%|██████████| 625/625 [00:42<00:00, 14.79it/s, loss=1.6658, accuracy=79.95%]


Epoch 13 summary:
Train Loss: 1.6658, Train Accuracy: 79.95%
Val Loss: 1.7455, Val Accuracy: 71.69%

Epoch 14/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: 0.0022


Training: 100%|██████████| 625/625 [00:42<00:00, 14.78it/s, loss=1.6594, accuracy=80.68%]


Epoch 14 summary:
Train Loss: 1.6594, Train Accuracy: 80.68%
Val Loss: 1.7428, Val Accuracy: 71.84%

Epoch 15/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: -0.0024


Training: 100%|██████████| 625/625 [00:42<00:00, 14.80it/s, loss=1.6479, accuracy=81.90%]


Epoch 15 summary:
Train Loss: 1.6479, Train Accuracy: 81.90%
Val Loss: 1.7367, Val Accuracy: 72.63%

Epoch 16/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: -0.0067


Training: 100%|██████████| 625/625 [00:42<00:00, 14.73it/s, loss=1.6158, accuracy=85.08%]


Epoch 16 summary:
Train Loss: 1.6158, Train Accuracy: 85.08%
Val Loss: 1.7224, Val Accuracy: 73.93%

Epoch 17/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: -0.0112


Training: 100%|██████████| 625/625 [00:42<00:00, 14.75it/s, loss=1.6008, accuracy=86.56%]


Epoch 17 summary:
Train Loss: 1.6008, Train Accuracy: 86.56%
Val Loss: 1.7019, Val Accuracy: 76.15%

Epoch 18/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: -0.0144


Training: 100%|██████████| 625/625 [00:42<00:00, 14.69it/s, loss=1.5907, accuracy=87.55%]


Epoch 18 summary:
Train Loss: 1.5907, Train Accuracy: 87.55%
Val Loss: 1.7158, Val Accuracy: 74.57%

Epoch 19/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: -0.0080


Training: 100%|██████████| 625/625 [00:42<00:00, 14.72it/s, loss=1.5811, accuracy=88.50%]


Epoch 19 summary:
Train Loss: 1.5811, Train Accuracy: 88.50%
Val Loss: 1.7122, Val Accuracy: 74.81%

Epoch 20/20
Model parameters:
layers.0.2.ga_compute.gamma: 0.5404
layers.1.2.ga_compute.gamma: 0.6629
layers.2.2.ga_compute.gamma: 0.6376
layers.3.2.ga_compute.gamma: -0.0083


Training: 100%|██████████| 625/625 [00:42<00:00, 14.81it/s, loss=1.5746, accuracy=89.22%]


Epoch 20 summary:
Train Loss: 1.5746, Train Accuracy: 89.22%
Val Loss: 1.6974, Val Accuracy: 76.46%


In [54]:
# Final test evaluation
test_loss, test_acc = evaluate(model, testloader, criterion)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%')

# Plot training and validation metrics
plot_training_metrics(train_losses, train_accs, val_losses, val_accs)
print("Training/validation metrics plot saved as 'training_validation_metrics.png'")

# Plot gamma and alpha history
plot_gamma_alpha_history(model, pruned_epochs, pruned_layers)
print("Gamma and Alpha history plot saved as 'gamma_alpha_history.png'")

Test Loss: 1.6984, Test Accuracy: 76.32%
Training/validation metrics plot saved as 'training_validation_metrics.png'
Gamma and Alpha history plot saved as 'gamma_alpha_history.png'
