In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# CONFIG

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import pandas as pd
import math
from math import sqrt
from math import ceil
import numpy as np
from scipy.stats import gmean
from tqdm import tqdm
import sys
import torchvision.transforms as transforms


In [3]:
class SinAct(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.sin(x)

class CustomBatchNorm1d(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.bn = nn.BatchNorm1d(d, affine=False)

    def forward(self, x):
        return self.bn(x)

import torch
import torch.nn as nn
from math import sqrt

class CustomNormalization(nn.Module):
    def __init__(self, norm_type, mean_reduction, force_factor=None):
        """
        Initializes the CustomNormalization layer.

        Args:
            norm_type (str): Type of normalization to apply. 'bn' for batch normalization,
                             'ln' for layer normalization, 'id' for identity (no normalization).
            mean_reduction (bool): If True, subtracts the mean before normalization.
            force_factor (float, optional): A custom scaling factor. If None, it's computed based on dimensions.
        """
        super().__init__()
        self.mean_reduction = mean_reduction
        self.norm_type = norm_type
        self.force_factor = force_factor

        # Determines the dimension along which normalization is applied.
        if norm_type == 'bn':
            self.dim = 0  # Normalize across the batch size (columns).
        elif norm_type == 'ln':
            self.dim = 1  # Normalize across the feature dimension (rows).
        elif norm_type == 'id':
            self.dim = -1  # No normalization.
        else:
            raise ValueError("No such normalization.")

    def forward(self, X):
        """
        Applies the normalization to the input tensor.

        Args:
            X (Tensor): The input tensor to normalize.

        Returns:
            Tensor: The normalized tensor.
        """
        # If 'id', return the input as is (no normalization).
        if self.dim == -1:
            return X

        # If mean_reduction is True, subtracts the mean from the tensor along the specified dimension.
        if self.mean_reduction:
            X = X - X.mean(dim=self.dim, keepdim=True)

        # Computes the norm of the tensor along the specified dimension.
        norm = X.norm(dim=self.dim, keepdim=True)

        # Determines the scaling factor: the square root of the dimension size.
        # For batch normalization ('bn'), it's the batch size (n).
        # For layer normalization ('ln'), it's the feature dimension size (d).
        factor = sqrt(X.shape[self.dim])

        # If a custom force_factor is provided, it overrides the computed factor.
        if self.force_factor is not None:
            factor = self.force_factor

        # Normalizes the tensor by dividing each element by (norm / factor).
        X = X / (norm / factor)
        return X


class GainedActivation(nn.Module):
    def __init__(self, activation, gain):
        super().__init__()
        self.activation = activation()
        self.gain = nn.Parameter(torch.tensor([gain], requires_grad=True))

    def forward(self, x):
        return self.activation(self.gain * x)



###############################


class MLPWithBatchNorm(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers, hidden_dim, norm_type, mean_reduction, activation, save_hidden, exponent, order='norm_act', force_factor=None, bias=False):
        """
        Initializes the MLPWithBatchNorm class.

        Args:
            input_dim (int): Dimension of the input features.
            output_dim (int): Dimension of the output.
            num_layers (int): Number of layers in the MLP.
            hidden_dim (int): Dimension of the hidden layers.
            norm_type (str): Type of normalization ('torch_bn' for PyTorch BatchNorm1d or other types for custom normalization).
            mean_reduction (bool): If True, normalization includes mean reduction.
            activation (callable): Activation function to be used in the network.
            save_hidden (bool): If True, saves the output of each layer.
            exponent (float): Exponent factor for layer gain adjustment.
            order (str): The order of applying normalization and activation. Either 'act_norm' or 'norm_act'.
            force_factor (float, optional): Force factor for custom normalization.
            bias (bool): If True, adds a learnable bias to the layers.
        """
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.hiddens = {}  # Dictionary to store outputs of each layer if save_hidden is True.
        self.initialized = False  # Flag to check if the model's parameters have been initialized.
        self.exponent = exponent
        self.save_hidden = save_hidden
        self.order = order

        # Validate the order of normalization and activation.
        if self.order not in ['act_norm', 'norm_act']:
            raise ValueError("Unknown order")

        # Initializing the layers of the MLP.
        self.layers = nn.ModuleDict()

        # Input layer
        self.layers[f'fc_0'] = nn.Linear(input_dim, hidden_dim, bias=bias)
        # Normalization layer
        if norm_type == 'torch_bn':
            self.layers[f'norm_0'] = nn.BatchNorm1d(hidden_dim)
        else:
            self.layers[f'norm_0'] = CustomNormalization(norm_type, mean_reduction, force_factor=force_factor)
        # Activation layer
        self.layers[f'act_0'] = activation()

        # Hidden layers
        for l in range(1, num_layers):
            self.layers[f'fc_{l}'] = nn.Linear(hidden_dim, hidden_dim, bias=bias)
            if norm_type == 'torch_bn':
                self.layers[f'norm_{l}'] = nn.BatchNorm1d(hidden_dim)
            else:
                self.layers[f'norm_{l}'] = CustomNormalization(norm_type, mean_reduction, force_factor=force_factor)
            self.layers[f'act_{l}'] = activation()

        # Output layer
        self.layers[f'fc_{num_layers}'] = nn.Linear(hidden_dim, output_dim, bias=bias)

    def forward(self, x):
        """
        Forward pass of the MLP.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor after passing through the MLP.
        """
        assert self.initialized, "Model parameters not initialized."

        # Flatten the input tensor if necessary.
        x = x.view(-1, self.input_dim)

        # Pass input through each layer.
        for l in range(self.num_layers):
            # Calculate layer gain based on the exponent.
            layer_gain = ((l + 1) ** self.exponent)

            # Apply linear transformation.
            x = self.layers[f'fc_{l}'](x)
            if self.save_hidden:
                self.hiddens[f'fc_{l}'] = x.clone().detach()

            # Apply normalization and activation in the specified order.
            if self.order == 'norm_act':
                x = self.layers[f'norm_{l}'](x)
                if self.save_hidden:
                    self.hiddens[f'norm_{l}'] = x.clone().detach()
                x = self.layers[f'act_{l}'](x * layer_gain)
                if self.save_hidden:
                    self.hiddens[f'act_{l}'] = x.clone().detach()
            elif self.order == 'act_norm':
                x = self.layers[f'act_{l}'](x * layer_gain)
                if self.save_hidden:
                    self.hiddens[f'act_{l}'] = x.clone().detach()
                x = self.layers[f'norm_{l}'](x)
                if self.save_hidden:
                    self.hiddens[f'norm_{l}'] = x.clone().detach()

        # Final layer to produce output.
        x = self.layers[f'fc_{self.num_layers}'](x)
        if self.save_hidden:
            self.hiddens[f'fc_{self.num_layers}'] = x.clone().detach()
        return x

    def set_save_hidden(self, state):
        """
        Enables or disables saving of hidden layer outputs.

        Args:
            state (bool): If True, enables saving hidden layer outputs.
        """
        self.save_hidden = state
        if state:
            self.hiddens.clear()

    def reset_parameters(self, init_type, gain=1.0):
        """
        Resets the parameters of the network according to the specified initialization type.

        Args:
            init_type (str): Type of initialization ('xavier_normal' or 'orthogonal').
            gain (float): Gain factor for initialization.
        """
        for name, p in self.named_modules():
            if isinstance(p, nn.Linear):
                # Xavier normal initialization.
                if init_type == 'xavier_normal':
                    nn.init.xavier_normal_(p.weight, gain=gain)
                # Orthogonal initialization.
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(p.weight)
                else:
                    raise ValueError("No such initialization scheme.")
        self.initialized = True



In [4]:
def generate_matrix_close_to_isometry(d):
    X = torch.rand(d, d)
    eps = torch.rand(d)
    eps /= eps.sum()
    eps -= 1/d

    eigs = (1 + eps).sqrt()
    U, S, Vt = torch.linalg.svd(X)
    return U @ eigs.diag() @ Vt

def generate_matrix_far_from_isometry(d, eps):
    X = torch.rand(d, d)
    eigs = torch.from_numpy(np.asarray([d - (d-1) * eps] + [eps] * (d-1))).float().sqrt()
    U, S, Vt = torch.linalg.svd(X)
    return U @ eigs.diag() @ Vt

def cosine(x, y):
    return torch.dot(x, y) / (x.norm().abs() * y.norm().abs())

def cosine_similarity(x, y):
    return 1 - cosine(x, y).abs()

def isometry_gap(X):
    G = X @ X.t()
    G = G.detach().cpu()
    eigs = torch.linalg.eigvalsh(G)
    return -torch.log(eigs).mean() + torch.log(torch.mean(eigs))

def ortho_gap(X):
    n, d = X.shape
    I_n = torch.eye(n).to(X.device)
    Y1 = (X@X.T) / (X.norm('fro')**2)
    Y2 = I_n / (I_n.norm('fro')**2)
    return (Y1 - Y2).norm('fro')

def isometry_gap2(X):
    G = X @ X.t()
    G = G.detach().cpu()
    eigs = torch.linalg.eigvalsh(G).numpy()
    return -np.log(gmean(eigs) / eigs.mean())


def get_measurements(model, inputs, labels, criterion, epoch, device):
    # Do one forward pass and one backward pass without updating anything
    model.train()
    inputs, labels = inputs.to(device), labels.to(device)
    model.set_save_hidden(True)
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    model.zero_grad(set_to_none=True)
    loss.backward()

    measurements = []
    for l in tqdm.tqdm(range(0, model.num_layers+1)):
        w = model.layers[f'fc_{l}'].weight
        w_grad = model.layers[f'fc_{l}'].weight.grad
        w_grad_fro = torch.linalg.matrix_norm(w_grad, ord='fro').item()
        w_ig = isometry_gap(w).item()
        fc_ig = isometry_gap(model.hiddens[f'fc_{l}']).item()
        if l < model.num_layers:
            act_ig = isometry_gap(model.hiddens[f'act_{l}']).item()
            norm_ig = isometry_gap(model.hiddens[f'norm_{l}']).item()
        else:
            # Final layer
            act_ig = np.nan
            norm_ig = np.nan
        measurements.append({
            'layer': l,
            'epoch': epoch,
            'weight_isogap': w_ig,
            'fc_isogap': fc_ig,
            'act_isogap': act_ig,
            'norm_isogap': norm_ig,
            'grad_fro_norm': w_grad_fro,
        })

    # Sanity cleaning gradients
    model.zero_grad(set_to_none=True)
    model.set_save_hidden(False)
    return measurements


def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0.0
    total = 0.0

    for i, (inputs, labels) in enumerate(loader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        train_accuracy = correct / total

        running_loss += loss.item()

        if math.isnan(running_loss):
            print("Train loss is nan", flush=True, file=sys.stderr)
            exit(1)

    train_accuracy = correct / total
    train_loss = running_loss / len(loader)
    return train_loss, train_accuracy

def test_one_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0
    correct = 0.0
    total = 0.0

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            test_accuracy = correct / total

            if math.isnan(running_loss):
                print("Test loss is nan", flush=True, file=sys.stderr)
                exit(1)

    test_accuracy = correct / total
    test_loss = running_loss / len(loader)
    return test_loss, test_accuracy


In [5]:
def dataset_to_tensors(dataset, indices=None, device='cuda'):
    if indices is None:
        indices = range(len(dataset))  # all
    xy_train = [dataset[i] for i in indices]
    x = torch.stack([e[0] for e in xy_train]).to(device)
    y = torch.stack([torch.tensor(e[1]) for e in xy_train]).to(device)
    return x, y


class TensorDataLoader:
    """Combination of torch's DataLoader and TensorDataset for efficient batch sampling
    and adaptive augmentation on GPU."""

    def __init__(
        self,
        x,
        y,
        batch_size=500,
        shuffle=False,
    ):
        assert x.size(0) == y.size(0), 'Size mismatch'
        self.x = x
        self.y = y
        self.device = x.device
        self.n_data = y.size(0)
        self.batch_size = batch_size
        self.n_batches = ceil(self.n_data / self.batch_size)
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            permutation = torch.randperm(self.n_data, device=self.device)
            self.x = self.x[permutation]
            self.y = self.y[permutation]
        self.i_batch = 0
        return self

    def __next__(self):
        if self.i_batch >= self.n_batches:
            raise StopIteration

        start = self.i_batch * self.batch_size
        end = start + self.batch_size
        x, y = self.x[start:end], self.y[start:end]
        self.i_batch += 1
        return (x, y)

    def __len__(self):
        return self.n_batches

    def attach(self):
        self._detach = False
        return self

    def detach(self):
        self._detach = True
        return self

    @property
    def dataset(self):
        return DatasetDummy(self.n_data)


class DatasetDummy:
    def __init__(self, N):
        self.N = N

    def __len__(self):
        return int(self.N)


In [6]:
# Constants
DS_INPUT_SIZES = {
    'CIFAR10': 3 * 32 * 32,
    'CIFAR100': 3 * 32 * 32,
    'MNIST': 28 * 28,
    'FashionMNIST': 28 * 28,
}
DS_NUM_CLASSES = {
    'CIFAR10': 10,
    'CIFAR100': 100,
    'MNIST': 10,
    'FashionMNIST': 10,
}

DS_TRANSFORMS = {
    'CIFAR10': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))]),
    'CIFAR100': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))]),
    'MNIST': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307), (0.3081))]),
    'FashionMNIST': transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2860), (0.3530))]),
}

ACTIVATIONS = {
    'identity': nn.Identity,
    'sin': SinAct,
    'tanh': nn.Tanh,
    'selu': nn.SELU,
    'relu': nn.ReLU,
    'leaky_relu': nn.LeakyReLU
}

GAINS = {
    'identity': 1,
    'sin': 1,
    'tanh': 5/3,
    'selu': 3/4,
    'relu': np.sqrt(2),
    'leaky_relu': nn.init.calculate_gain('leaky_relu')
}

# Setup configuration

In [7]:
# Configuration setup
config = {
    'dataset': 'MNIST',
    'num_layers': None, #TBD
    'hidden_dim': 100,
    'batch_size': 100,
    'init_type': None, # TBD
    'norm_type': 'torch_bn',
    'activation': None, # TBD
    'learning_rate': 0.001,
    'order': 'norm_act',
    'bias': True,
    'mean_reduction': None, # Not needed if we use default pytorch BN
    'force_factor': None, # Not needed if we use default pytorch BN
    'gain_exponent': -0.4,
    'num_epochs': 1000
}


# Load dataset

In [8]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'

ds = getattr(torchvision.datasets, config['dataset'])
transform = torchvision.transforms.ToTensor()  # Replace with actual transform
trainset = ds(root='../Data', train=True, download=True, transform=transform)
testset = ds(root='../Data', train=False, download=True, transform=transform)

trainloader = TensorDataLoader(*dataset_to_tensors(trainset, device=device), batch_size=config['batch_size'], shuffle=True)
testloader = TensorDataLoader(*dataset_to_tensors(testset, device=device), batch_size=config['batch_size'], shuffle=False)



# To run all night

In [None]:
import os

# Define the combinations of parameters : RUN ONLY ONE COMBINATION PER COLAB TO PARALLELIZE
combinations = [
        
    # ('orthogonal', 100, 'identity'),
    # ('orthogonal', 100, 'tanh'),
    # ('orthogonal', 100, 'sin'),

]

save_path = 'training/mnist_v2/'
checkpoint_path = 'training/checkpoints/'

# Ensure directories exists
os.makedirs(checkpoint_path, exist_ok=True)
os.makedirs(save_path, exist_ok=True)

for init_type, num_layers, activation in combinations:
    # Update the config dictionary with the current combination
    config['init_type'] = init_type
    config['num_layers'] = num_layers
    config['activation'] = activation

    # Instantiate the model with the updated configuration
    model = MLPWithBatchNorm(
        input_dim=1*28*28,  # Adjust for your dataset
        output_dim=10,      # Number of classes in your dataset
        num_layers=config['num_layers'],
        hidden_dim=config['hidden_dim'],
        norm_type=config['norm_type'],
        mean_reduction=config['mean_reduction'],
        activation=ACTIVATIONS[config['activation']],
        save_hidden=False,
        exponent=config['gain_exponent'],
        order=config['order'],
        force_factor=config['force_factor'],
        bias=config['bias']
    ).to(device)

    model.reset_parameters(config['init_type'])

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'])

    start_epoch = 1
    df = []
    checkpoint_file = os.path.join(checkpoint_path, f'checkpoint_d{num_layers}_{activation}_{init_type}.pt')

    # Load checkpoint if it exists
    if os.path.isfile(checkpoint_file):
        checkpoint = torch.load(checkpoint_file)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        df = checkpoint['df']

    progress_bar = tqdm(range(start_epoch, config['num_epochs'] + 1), desc="Training")
    for epoch in progress_bar:
        train_loss, train_acc = train_one_epoch(model, trainloader, optimizer, criterion, device)
        test_loss, test_acc = test_one_epoch(model, testloader, criterion, device)

        df.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'test_loss': test_loss,
            'train_acc': train_acc,
            'test_acc': test_acc,
        })
        progress_bar.set_description(f"Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")


        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'df': df,
        }, checkpoint_file)

    # Save the final results to CSV
    results_df = pd.DataFrame(df)
    save_name = f'mnist_d{num_layers}_{activation}_{init_type}.csv'
    results_df.to_csv(os.path.join(save_path, save_name))


Epoch: 12, Train Loss: 1.1155, Train Acc: 0.8422, Test Loss: 1.4470, Test Acc: 0.8723:   1%|          | 12/1000 [04:30<6:07:38, 22.33s/it]