# Import libraries

In [None]:
# Import PyTorch
import torch
from torch import nn
from torch import autograd

# Import torchvision
import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor

# Import matplotlib
import matplotlib.pyplot as plt

# For data
from sklearn.model_selection import train_test_split

# Check versions
print(torch.__version__)
print(torchvision.__version__)

# Setup device-agnostic code

In [None]:
device = cuda if torch.cuda.is_available() else 'cpu'
device

# Utils

In [None]:
from timeit import default_timer as timer 
def print_train_time(start: float, end: float, device: torch.device = None):
    """Prints difference between start and end time.

    Args:
        start (float): Start time of computation (preferred in timeit format). 
        end (float): End time of computation.
        device ([type], optional): Device that compute is running on. Defaults to None.

    Returns:
        float: time between start and end in seconds (higher is longer).
    """
    total_time = end - start
    print(f"Train time on {device}: {total_time:.3f} seconds")
    return total_time

In [None]:
# Calculate accuracy (a classification metric)
def accuracy_fn(y_true, y_pred):
    """Calculates accuracy between truth labels and predictions.

    Args:
        y_true (torch.Tensor): Truth labels for predictions.
        y_pred (torch.Tensor): Predictions to be compared to predictions.

    Returns:
        [torch.float]: Accuracy value between y_true and y_pred, e.g. 78.45
    """
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

# Engine

In [None]:
def train_step(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               accuracy_fn,
               device: torch.device = device):
    train_loss, train_acc = 0, 0
    model.to(device)
    for batch, (X, y) in enumerate(data_loader):
        # Send data to the target device
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        y_pred = model(X)

        # 2. Calculate loss
        loss = loss_fn(y_pred, y)
        train_loss += loss

        # Update accuracy
        train_acc += accuracy_fn(y_true=y,
                                 y_pred=y_pred.argmax(dim=1))
        
        # 3. Optimizer zero grad
        optimizer.zero_grad()

        # 4. Loss backward
        loss.backward()

        # 5. Optimizer step
        optimizer.step()

    # Calculate loss and accuracy per epoch and print out what's happening
    train_loss /= len(data_loader)
    train_acc /= len(data_loader)
    print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}%")

def test_step(data_loader: torch.utils.data.DataLoader,
              model: torch.nn.Module,
              loss_fn: torch.nn.Module,
              accuracy_fn,
              device: torch.device = device):
    test_loss, test_acc = 0, 0
    model.to(device)
    model.eval() # put model in eval mode
    # Turn on inference context manager
    with torch.inference_mode(): 
        for X, y in data_loader:
            # Send data to GPU
            X, y = X.to(device), y.to(device)
            
            # 1. Forward pass
            test_pred = model(X)
            
            # 2. Calculate loss and update accuracy
            test_loss += loss_fn(test_pred, y)
            test_acc += accuracy_fn(y_true=y,
                y_pred=test_pred.argmax(dim=1))
        # Adjust metrics and print out
        test_loss /= len(data_loader)
        test_acc /= len(data_loader)
        print(f"Test loss: {test_loss:.5f} | Test accuracy: {test_acc:.2f}%\n")

In [None]:
from tqdm.auto import tqdm

def train_test_BP(model: torch.nn.Module,
             train_dataloader: torch.utils.data.DataLoader,
             test_dataloader: torch.utils.data.DataLoader,
             loss_fn: torch.nn.Module,
             optimizer: torch.optim.Optimizer,
             accuracy_fn,
             lr: int,
             epochs: int,
             device: torch.device = device):
    # Measure time
    optimizer = torch.optim.SGD(params=model.parameters(), lr = lr)
    time_start = timer()

    # Train and test model
    for epoch in tqdm(range(epochs)):
        print(f"Epoch: {epoch}\n--------")
        train_step(model=model,
                   data_loader=train_dataloader,
                   loss_fn=loss_fn,
                   optimizer=optimizer,
                   accuracy_fn=accuracy_fn,
                   device=device)
        test_step(model=model,
                  data_loader=test_dataloader,
                  loss_fn=loss_fn,
                  accuracy_fn=accuracy_fn,
                  device=device)

    time_end = timer()
    total_train_time_model_2 = print_train_time(start=time_start,
                                                end=time_end,
                                                device=device)

In [None]:
def train_step_PEPITA(model: torch.nn.Module,
                      data_loader: torch.utils.data.DataLoader,
                      loss_fn: torch.nn.Module,
                      accuracy_fn,
                      lr: int,
                      device: torch.device = device):
    train_loss, train_acc = 0, 0
    model.to(device)
    for batch, (X, y) in enumerate(data_loader):
        with torch.no_grad():
                # Send data to the target device
                X, y = X.to(device), y.to(device)

                # 1. Forward pass
                y_pred = model(X)

                # 2. Calculate loss
                loss = loss_fn(y_pred, y)
                train_loss += loss

                # Update accuracy
                train_acc += accuracy_fn(y_true=y,
                                            y_pred=y_pred.argmax(dim=1))
                
                target = F.one_hot(y, num_classes)
                #print(target)
                model.modulated_forward(X = X, 
                                        Y = y_pred,
                                        target = target,
                                        lr = lr)

    # Calculate loss and accuracy per epoch and print out what's happening
    train_loss /= len(data_loader)
    train_acc /= len(data_loader)
    print(f"Train loss: {train_loss:.5f} | Train accuracy: {train_acc:.2f}%")

In [None]:
import torch.nn.functional as F

def train_test_PEPITA(model: torch.nn.Module,
                 train_dataloader: torch.utils.data.DataLoader,
                 test_dataloader: torch.utils.data.DataLoader,
                 loss_fn: torch.nn.Module,
                 accuracy_fn,
                 lr: int,
                 epochs: int,
                 device: torch.device = device):
    # Measure time
    time_start = timer()

    model = model.to(device)

    num_classes = len(train_dataloader.dataset.classes)

    for epoch in tqdm(range(epochs)):
        print(f"Epoch: {epoch}\n--------")
        train_step_PEPITA(model=model,
                        data_loader=train_dataloader,
                        loss_fn=loss_fn,
                        accuracy_fn=accuracy_fn,
                        lr = lr,
                        device=device)
        test_step(model=model,
                  data_loader=test_dataloader,
                  loss_fn=loss_fn,
                  accuracy_fn=accuracy_fn,
                  device=device)

    time_end = timer()
    total_train_time_model_2 = print_train_time(start=time_start,
                                                end=time_end,
                                                device=device)

# Datas

In [None]:
from torchvision import datasets 

def get_train_test_data(dataset_class, root):
    train_data = dataset_class(
    root=root,
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=None
    )

    test_data = dataset_class(
    root=root,
    train=False,
    download=True,
    transform=ToTensor(),
    target_transform=None
    )
    return train_data, test_data

In [None]:
def plot_some_images(train_data, class_names):
    # Plot some images
    torch.manual_seed(42)
    fig = plt.figure(figsize=(9, 9))
    rows, cols = 4, 4
    for i in range(1, rows * cols + 1):
        random_idx = torch.randint(0, len(train_data), size=[1]).item()
        img, label = train_data[random_idx]
        fig.add_subplot(rows, cols, i)
        plt.imshow(img.squeeze(), cmap="gray")
        plt.title(class_names[label])
        plt.axis(False)

In [None]:
# Let's Get a dataset: FashionMNIST
from torchvision import datasets 
from torch.utils.data import DataLoader

def data(dataset_class):
    root = 'data' #'/Users/alexcolagrande/Desktop/Python/BBpropTorch/PepitA_WORK_IN_PROGRESS/data'
    train_data, test_data = get_train_test_data(dataset_class=dataset_class, root=root)

    class_names = train_data.classes
    class_to_idx = train_data.class_to_idx
    # How many samples are there? 
    print(f"train X: {len(train_data.data)}, train y: {len(train_data.targets)}, test X: {len(test_data.data)}, test y: {len(test_data.targets)}")

    plot_some_images(train_data=train_data, class_names=class_names)

    ## Let's prepare the dataloader
    # Setup the batch size
    BATCH_SIZE = 32

    # Turn datasets into iterables (batches)
    train_dataloader = DataLoader(dataset=train_data,
                                batch_size=BATCH_SIZE,
                                shuffle = True)

    test_dataloader = DataLoader(dataset=test_data,
                                batch_size=BATCH_SIZE,
                                shuffle=False) # Better to not shuffle the test so every time we evaluate the model the batches are the same and not shuffled again and again

    train_dataloader, test_dataloader
    # Let's check out what we have created
    print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}...")
    print(f"Length of test_dataloader: {len(test_dataloader)} batches of {BATCH_SIZE}...")
    # Check out what's inside the trainig dataloader
    train_features_batch, train_labels_batch = next(iter(train_dataloader))
    print(f"Shape of the training batch: {train_features_batch.size()}, Shape of the training label batch: {train_labels_batch.size()}")
    n = train_features_batch.size()[1:].numel()     # input dimension
    m = len(class_names)                            # output dimension
    return train_dataloader, test_dataloader, class_names, class_to_idx, train_data, test_data, n, m

# Models

In [None]:
from torch import linalg as LA

In [None]:
class SUPER_PEP(nn.Module):
    """ 
        SUPER_PEP is an architecture implementing a modulated pass to be trained with PEPITA while being highly customizable:
       -layers_dim: list containing the dimensions of ALL the layers i.e. [input_dim, hidden_dim1, hidden_dim2, ..., output_dim]
       -bias: bool, do you want to use bias?
       -F_std: float, the standard deviation of the F matrix
       -vision: bool, do you want to flatten the input?
       -F_norm: bool, do you want to normalize the "projection" matrix F?
       -layer_norm: bool, do you want to normalize the output of every layer during the forward?
       -error_norm: bool, do you want to normalize the error E before propagating it in the modulated pass?
       -delta_norm: bool, do you want to normalize the delta_W before updating the weights?
       -F_decay: float, the decay of the F matrix
    """
    def __init__(self,
                 layers_dim: list, 
                 bias: bool=False,
                 F_std: float=1,
                 vision: bool=False,
                 F_norm: bool=False,
                 layer_norm: bool= False,
                 error_norm: bool=False,
                 delta_norm: bool=False,
                 F_decay: float=1): 
        super().__init__()
        self.vision = vision
        self.F_norm = F_norm
        self.layer_norm = layer_norm
        self.error_norm = error_norm
        self.delta_norm = delta_norm
        self.F_decay = F_decay
        
        if self.vision:
            self.flatten = nn.Flatten()
                
        self.nb_layers = len(layers_dim)
        self.layers = nn.ModuleList()
        for i in range(0, self.nb_layers - 1):
            self.layers.append(nn.Linear(in_features=layers_dim[i], out_features=layers_dim[i+1], bias=bias))

        self.F_T = F_std * torch.randn(size=(layers_dim[-1], layers_dim[0]))  
        print(f"F_norm: {LA.norm(self.F_T)}")
        if F_norm: 
            # Really important: with our inizialization (N(0,1)) the expected l_2 norm squared of F is input_dim * output_dim 
            self.F_T /= LA.norm(self.F_T) 
        self.activations = []
        # Register hooks
        for i in range(self.nb_layers-1):
            self.layers[i].register_forward_hook(lambda module, input, output: self.save_activation(output))

    def save_activation(self, output):
        self.activations.append(output)

    def forward(self, x):
        self.activations = []
        if self.vision:
            x = self.flatten(x)
        for i in range(self.nb_layers-1):
            x = self.layers[i](x)
            if self.layer_norm:
                x /= LA.matrix_norm(x)
        return x
    
    def modulated_forward(self, X, Y, target, lr):
        self.F_T *= self.F_decay
        target = target.float()
        E = Y - target 
        if self.error_norm:
            E = E / LA.matrix_norm(E)
        if self.vision:
            X = self.flatten(X)
    
        #print("E", E.shape, "\nF", self.F.shape, "\nX", X.shape)
        X_mod = X - torch.mm(E, self.F_T)

        H = self.activations # Forward activations
        
        modulated_forward = self.forward(X_mod)
        H_mod = self.activations # Modulated activations
        #print("H_1", H[0].shape, "H_1^mod", H_mod[0].shape, "\nH_2", H[1].shape, "H_2^mod", H_mod[1].shape, "\nE", E.shape)

        delta_W = torch.mm((H[0]-H_mod[0]).T,X_mod)
        if self.delta_norm:
            delta_W /= LA.matrix_norm(delta_W)
        self.layers[0].weight -= lr * delta_W 

        for i in range(1, self.nb_layers-2): # I iterate over the "number of functions - 1" and number of functions = numbers of layers
            delta_W = torch.mm((H[i]-H_mod[i]).T,H_mod[i])

            if self.delta_norm:
                delta_W /= LA.matrix_norm(delta_W)

            #print(f"i: {i}, delta_W = {delta_W.shape}, W: {self.layers[i].weight.shape}, H[{i}]: {H[i].shape}, H_mod[{i}]: {H_mod[i].shape}")
            self.layers[i].weight -= lr * delta_W 
        
        delta_W_L = torch.mm(E.T,H_mod[-2])
        if self.delta_norm:
            delta_W_L /= LA.matrix_norm(delta_W_L)

        self.layers[-1].weight -= lr * delta_W_L

## Direct PEPITA

In [None]:
class Dpepita(nn.Module):
    """ 
        Dpepita is a variant of PEPITA where the signal is backpropagated directed at every layer and there is no forward propagation of the modulated inputs while being higly customizable:
       -layers_dim: list containing the dimensions of ALL the layers i.e. [input_dim, hidden_dim1, hidden_dim2, ..., output_dim]
       -bias: bool, do you want to use bias?
       -F_std: float, the standard deviation of the F matrix
       -vision: bool, do you want to flatten the input?
       -layer_norm: bool, do you want to normalize the output of every layer during the forward?
       -error_norm: bool, do you want to normalize the error E before propagating it in the modulated pass?
       -delta_norm: bool, do you want to normalize the delta_W before updating the weights?
    """
    def __init__(self,
                 layers_dim: list, 
                 bias: bool=False,
                 F_std: float=1,
                 vision: bool=False,
                 layer_norm: bool= False,
                 error_norm: bool=False,
                 delta_norm: bool=False): 
        super().__init__()
        self.vision = vision
        self.layer_norm = layer_norm
        self.error_norm = error_norm
        self.delta_norm = delta_norm
        
        if self.vision:
            self.flatten = nn.Flatten()
                
        self.nb_layers = len(layers_dim)
        self.layers = nn.ModuleList()
        for i in range(0, self.nb_layers - 1):
            self.layers.append(nn.Linear(in_features=layers_dim[i], out_features=layers_dim[i+1], bias=bias))

        self.F_T = []
        for i in range(self.nb_layers - 1):
            self.F_T.append(F_std * torch.randn(size=(layers_dim[-1], layers_dim[i]))) 

        self.activations = []
        # Register hooks
        for i in range(self.nb_layers-1):
            self.layers[i].register_forward_hook(lambda module, input, output: self.save_activation(output))

    def save_activation(self, output):
        self.activations.append(output)

    def forward(self, x):
        self.activations = []
        if self.vision:
            x = self.flatten(x)
        for i in range(self.nb_layers-1):
            x = self.layers[i](x)
            if self.layer_norm:
                x /= torch.linalg_matrix_norm(x)
        return x
    
    def modulated_forward(self, X, Y, target, lr):
        target = target.float()
        E = Y - target 
        if self.error_norm:
            E = E / torch.norm(E)
        if self.vision:
            X = self.flatten(X)
    
        #print("E", E.shape, "\nF", self.F.shape, "\nX", X.shape)
        X_mod = X - torch.mm(E, self.F_T)

        H = self.activations # Forward activations
        
        #modulated_forward = self.forward(X_mod)
        #H_mod = self.activations # Modulated activations
        #print("H_1", H[0].shape, "H_1^mod", H_mod[0].shape, "\nH_2", H[1].shape, "H_2^mod", H_mod[1].shape, "\nE", E.shape)
        h_mod = H[i]

        delta_W = torch.mm((H[0]-H_mod[0]).T,X_mod)

        for i in range(1, self.nb_layers-2): # I iterate over the "number of functions - 1" and number of functions = numbers of layers
            delta_W = torch.mm((H[i]-H_mod[i]).T,H_mod[i])

            if self.delta_norm:
                delta_W /= torch.linalg_matrix_norm(delta_W)

            #print(f"i: {i}, delta_W = {delta_W.shape}, W: {self.layers[i].weight.shape}, H[{i}]: {H[i].shape}, H_mod[{i}]: {H_mod[i].shape}")
            self.layers[i].weight -= lr * delta_W 
        
        for i in range(self.nb_layers-2):
            delta_W = torch.mm((H[i]-H_mod[i]).T, )
            
        
        delta_W_L = torch.mm(E.T,H_mod[-2])
        if self.delta_norm:
            delta_W_L /= torch.linalg_matrix_norm(delta_W_L)

        self.layers[-1].weight -= lr * delta_W_L

## Target PEPITA

In [None]:
class TARGET_PEP(nn.Module):
    """ 
        TARGET_PEP is what PEPITA wants to be while being highly customizable:
       -layers_dim: list containing the dimensions of ALL the layers i.e. [input_dim, hidden_dim1, hidden_dim2, ..., output_dim]
       -bias: bool, do you want to use bias?
       -F_std: float, the standard deviation of the F matrix
       -vision: bool, do you want to flatten the input?
       -F_norm: bool, do you want to normalize the "projection" matrix F?
       -layer_norm: bool, do you want to normalize the output of every layer during the forward?
       -error_norm: bool, do you want to normalize the error E before propagating it in the modulated pass?
       -delta_norm: bool, do you want to normalize the delta_W before updating the weights?
       -F_decay: float, the decay of the F matrix
    """
    def __init__(self,
                 layers_dim: list, 
                 bias: bool=False,
                 F_std: float=1,
                 vision: bool=False,
                 F_norm: bool=False,
                 layer_norm: bool= False,
                 error_norm: bool=False,
                 delta_norm: bool=False,
                 F_decay: float=1): 
        super().__init__()
        self.vision = vision
        self.F_norm = F_norm
        self.layer_norm = layer_norm
        self.error_norm = error_norm
        self.delta_norm = delta_norm
        self.F_decay = F_decay
        
        if self.vision:
            self.flatten = nn.Flatten()
                
        self.nb_layers = len(layers_dim)
        self.layers = nn.ModuleList()
        for i in range(0, self.nb_layers - 1):
            self.layers.append(nn.Linear(in_features=layers_dim[i], out_features=layers_dim[i+1], bias=bias))

        ############################################# WE CHANGE JUST THIS PART ###############################################
        self.F_T = self.layers[0].weight.clone()
        for i  in range(1, self.nb_layers - 1):
            #print(f"F_T: {self.F_T.shape}, W: {self.layers[i].weight.shape}")
            self.F_T = torch.mm(self.layers[i].weight.clone(), self.F_T)
        self.F_T = self.F_T
        print(f"F_norm: {LA.norm(self.F_T)}")
        if F_norm: 
            # Really important: with our inizialization (N(0,1)) the expected l_2 norm squared of F is input_dim * output_dim 
            self.F_T /= LA.norm(self.F_T) 
        ######################################################################################################################
        self.activations = []
        # Register hooks
        for i in range(self.nb_layers-1):
            self.layers[i].register_forward_hook(lambda module, input, output: self.save_activation(output))

    def save_activation(self, output):
        self.activations.append(output)

    def forward(self, x):
        self.activations = []
        if self.vision:
            x = self.flatten(x)
        for i in range(self.nb_layers-1):
            x = self.layers[i](x)
            if self.layer_norm:
                x /= LA.matrix_norm(x)
        return x
    
    def modulated_forward(self, X, Y, target, lr):
        self.F_T *= self.F_decay
        target = target.float()
        E = Y - target 
        if self.error_norm:
            E = E / LA.matrix_norm(E)
        if self.vision:
            X = self.flatten(X)
    
        #print("E", E.shape, "\nF", self.F.shape, "\nX", X.shape)
        X_mod = X - torch.mm(E, self.F_T)

        H = self.activations # Forward activations
        
        modulated_forward = self.forward(X_mod)
        H_mod = self.activations # Modulated activations
        #print("H_1", H[0].shape, "H_1^mod", H_mod[0].shape, "\nH_2", H[1].shape, "H_2^mod", H_mod[1].shape, "\nE", E.shape)

        delta_W = torch.mm((H[0]-H_mod[0]).T,X_mod)
        if self.delta_norm:
            delta_W /= LA.matrix_norm(delta_W)
        self.layers[0].weight -= lr * delta_W 

        for i in range(1, self.nb_layers-2): # I iterate over the "number of functions - 1" and number of functions = numbers of layers
            delta_W = torch.mm((H[i]-H_mod[i]).T,H_mod[i])

            if self.delta_norm:
                delta_W /= LA.matrix_norm(delta_W)

            #print(f"i: {i}, delta_W = {delta_W.shape}, W: {self.layers[i].weight.shape}, H[{i}]: {H[i].shape}, H_mod[{i}]: {H_mod[i].shape}")
            self.layers[i].weight -= lr * delta_W 
        
        delta_W_L = torch.mm(E.T,H_mod[-2])
        if self.delta_norm:
            delta_W_L /= LA.matrix_norm(delta_W_L)

        self.layers[-1].weight -= lr * delta_W_L

# Experiments

## Let's get MNIST

In [None]:
train_dataloader, test_dataloader, class_names, class_to_idx, train_data, test_data, n, num_classes = data(datasets.MNIST)
m = num_classes 

## Instantiate the model

In [None]:
torch.manual_seed(42)
#torch.cuda_manual_seed(42)
MYFIRSTsuper_pep = SUPER_PEP(layers_dim= [n] + [1024] * 1 + [m],
                             bias=False,
                             F_std=0.01,
                             vision=True,
                             F_norm=False,
                             layer_norm=False,
                             error_norm=False,
                             delta_norm=False,
                             F_decay=0.9)
MYFIRSTsuper_pep.to(device)

In [None]:
# Check that the forward works
torch.manual_seed(42)
dummy_x = torch.rand(size = (3, 1, 28, 28))
MYFIRSTsuper_pep(dummy_x)

### Initialize TARGET_PEP

In [None]:
torch.manual_seed(42)
#torch.cuda_manual_seed(42)
Tpep = TARGET_PEP(layers_dim= [n] + [1024] * 15 + [m],
                             bias=False,
                             F_std=1,
                             vision=True,
                             F_norm=False,
                             layer_norm=False,
                             error_norm=False,
                             delta_norm=False,
                             F_decay=1)
Tpep.to(device)

In [None]:
# Check that the forward works
torch.manual_seed(42)
dummy_x = torch.rand(size = (3, 1, 28, 28))
Tpep(dummy_x)

In [None]:
# Choose the model
model = Tpep

In [None]:
model.layers[0].weight.shape

In [None]:
model.parameters

## Common Hyperparameters for training

In [None]:
lr = 1e-3
epochs = 10

loss_fn = nn.CrossEntropyLoss()
accuracy_fn = accuracy_fn
optimizer = torch.optim.SGD(params=model.parameters(), lr = lr)   

## Train & Test with PEPITA

In [None]:
train_test_PEPITA(model = model,
                 train_dataloader = train_dataloader,
                 test_dataloader = test_dataloader,
                 loss_fn = loss_fn,
                 accuracy_fn = accuracy_fn,
                 lr = lr,
                 epochs = epochs,
                 device = device)

## Train & Test with BP

In [None]:
train_test_BP(model = model,
             train_dataloader = train_dataloader,
             test_dataloader = test_dataloader,
             loss_fn = loss_fn,
             optimizer = optimizer,
             accuracy_fn = accuracy_fn,
             lr = lr,
             epochs = epochs,
             device = device)

36.16 32.98 32.33 31.81 31.40 30.91