In [None]:
# %% Deep learning - Section 11.110
#    Distribution of weights pre- and post-learning

# This code pertains a deep learning course provided by Mike X. Cohen on Udemy:
#   > https://www.udemy.com/course/deeplearning_x
# The "base" code in this repository is adapted (with very minor modifications)
# from code developed by the course instructor (Mike X. Cohen), while the
# "exercises" and the "code challenges" contain more original solutions and
# creative input from my side. If you are interested in DL (and if you are
# reading this statement, chances are that you are), go check out the course, it
# is singularly good.


In [None]:
# %% Libraries and modules
import numpy               as np
import matplotlib.pyplot   as plt
import torch
import torch.nn            as nn
import seaborn             as sns
import copy
import torch.nn.functional as F
import pandas              as pd
import scipy.stats         as stats
import time

from torch.utils.data                 import DataLoader,TensorDataset
from sklearn.model_selection          import train_test_split
from google.colab                     import files
from torchsummary                     import summary
from IPython                          import display
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('svg')


In [None]:
# %% Reminder

# The whole purpose of deep learning is to find weights minimising the loss
# function (i.e., the difference between what the model predicts and what the
# reality of the data is).
# This is why it can be useful to analyse or visualise the distribution of the
# weights before and after training/model fitting


In [None]:
# %% Data

# Load data
data = np.loadtxt(open('sample_data/mnist_train_small.csv','rb'),delimiter=',')

# Split labels from data
labels = data[:,0]
data   = data[:,1:]

# Normalise data (original range is (0,255))
data_norm = data / np.max(data)


In [None]:
# %% Create train and test datasets

# Convert to tensor (float and integers)
data_tensor   = torch.tensor(data_norm).float()
labels_tensor = torch.tensor(labels).long()

# Split data with scikitlearn (10% test data)
train_data,test_data,train_labels,test_labels = train_test_split(data_tensor,labels_tensor,test_size=0.1)

# Convert to PyTorch datasets
train_data = TensorDataset(train_data,train_labels)
test_data  = TensorDataset(test_data,test_labels)

# Convert into DataLoader objects
batch_size   = 32
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True,drop_last=True)
test_loader  = DataLoader(test_data,batch_size=test_data.tensors[0].shape[0])


In [None]:
# %% Function to generate the model

def gen_model(drop_rate):

    class mnist_FFN(nn.Module):
        def __init__(self,dropout_rate):
            super().__init__()

            # Architecture
            self.input  = nn.Linear(784,64)
            self.fc1    = nn.Linear( 64,32)
            self.fc2    = nn.Linear( 32,32)
            self.output = nn.Linear( 32,10)

            # Dropout
            self.dr = dropout_rate

        # Forward propagation
        def forward(self,x):

            x = F.relu(self.input(x))
            x = F.dropout(x,p=self.dr,training=self.training)
            x = F.relu(self.fc1(x))
            x = F.dropout(x,p=self.dr,training=self.training)
            x = F.relu(self.fc2(x))
            x = F.dropout(x,p=self.dr,training=self.training)

            x = self.output(x)

            return x

    # Create model instance
    ANN = mnist_FFN(drop_rate)

    # Loss function
    loss_fun = nn.CrossEntropyLoss()

    # Optimizer (SGD to slow down learning for illustration purpose)
    optimizer = torch.optim.SGD(ANN.parameters(),lr=0.01)

    return ANN,loss_fun,optimizer


In [None]:
# %% Explore the model

# Template
drop_rate   = 0
example_net = gen_model(drop_rate)[0]

# Summary
print('Model summary:')
print(example_net)
print()

# Explore one layer
print('Summary of input layer:')
print(vars(example_net.input))
print()

# Explore a weight matrix
print('Input layer weights:')
print(example_net.input.weight.shape)
print(example_net.input.weight)
print()

# Extract the weights and make an histogram
w = example_net.input.weight.detach().flatten()
plt.hist(w,40)
plt.xlabel('Weight value')
plt.ylabel('Count')
plt.title('Distribution of initialized input layer weights')

plt.savefig('figure29_weights_distribution_pre_post_learning.png')

plt.show()

files.download('figure29_weights_distribution_pre_post_learning.png')


In [None]:
# %% Function to create an histogram of all weights across all layers

def weights_histogram(network):

    # Initialise weight vector
    W = np.array([])

    # Concatenate each set of weights
    for layer in network.parameters():
        W = np.concatenate( (W,layer.detach().flatten().numpy()) )

    # Compute histogram (note: hard-coded range)
    hist_y,hist_x = np.histogram(W,bins=np.linspace(-.3,.3,101),density=True)
    hist_x = (hist_x[1:]+hist_x[:-1])/2

    return hist_x,hist_y


In [None]:
# %% Test the histogram function

hist_x,hist_y = weights_histogram(example_net)
plt.plot(hist_x,hist_y)


In [None]:
# %% Function to train the model

def train_model():

    # Parameters, model instance, inizialise vars
    num_epochs = 60
    drop_rate  = 0.25
    ANN,loss_fun,optimizer = gen_model(drop_rate)

    losses    = []
    train_acc = []
    test_acc  = []

    hist_x = None
    hist_y = np.zeros((num_epochs,100))

    # Loop over epochs
    for epoch_i in range(num_epochs):

        # Get the weights distribution
        hx,hy = weights_histogram(ANN)
        hist_y[epoch_i,:] = hy
        if hist_x is None:
            hist_x = hx

        # Loop over training batches
        batch_acc  = []
        batch_loss = []

        for X,y in train_loader:

            # Forward propagation and loss
            yHat = ANN(X)
            loss = loss_fun(yHat,y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Loss and accuracy from this batch
            batch_loss.append(loss.item())

            matches     = torch.argmax(yHat,axis=1) == y
            matches_num = matches.float()
            accuracy    = 100 * torch.mean(matches_num)
            batch_acc.append(accuracy)

        losses.append( np.mean(batch_loss) )
        train_acc.append( np.mean(batch_acc) )

        # Test accuracy
        ANN.eval()

        with torch.no_grad():
            X,y = next(iter(test_loader))
            yHat = ANN(X)
            test_acc.append( 100*torch.mean((torch.argmax(yHat,axis=1)==y).float()) )

        ANN.train()

    return train_acc,test_acc,losses,ANN,hist_x,hist_y


In [None]:
# %% Run the training

train_acc,test_acc,losses,ANN,hist_x,hist_y = train_model()


In [None]:
# %% Plotting

phi = ( 1 + np.sqrt(5) ) / 2
fig,ax = plt.subplots(1,2,figsize=(1.5*6*phi,6))

ax[0].plot(losses)
ax[0].set_xlabel('Epochs')
ax[0].set_ylabel('Loss')
ax[0].set_ylim([0,3])
ax[0].set_title('Model loss')

ax[1].plot(train_acc,label='Train accuracy')
ax[1].plot(test_acc,label='Test accuracy')
ax[1].set_xlabel('Epochs')
ax[1].set_ylabel('Accuracy (%)')
ax[1].set_ylim([10,100])
ax[1].set_title(f'Final model test accuracy: {test_acc[-1]:.2f}%')
ax[1].legend()

plt.savefig('figure33_weights_distribution_pre_post_learning.png')

plt.show()

files.download('figure33_weights_distribution_pre_post_learning.png')


In [None]:
# %% Plotting

num_epochs = hist_y.shape[0]

phi = ( 1 + np.sqrt(5) ) / 2
fig,ax = plt.subplots(1,2,figsize=(1.5*6*phi,6))

cmaps = plt.cm.plasma(np.linspace(.2,.9,hist_y.shape[0]))
for i in range(hist_y.shape[0]):
    ax[0].plot(hist_x,hist_y[i,:],color=cmaps[i])

ax[0].set_title('Histograms of weights\n(brigther is later in training)')
ax[0].set_xlabel('Weight value')
ax[0].set_ylabel('Density')

ax[1].imshow(hist_y,vmin=0,vmax=3,extent=[hist_x[0],hist_x[-1],0,num_epochs-1],aspect='auto',origin='lower',cmap='hot')
ax[1].set_xlabel('Weight value')
ax[1].set_ylabel('Training epoch')
ax[1].set_title("Image of weight histograms\n(note the similarity with Sauron's eye)")

plt.savefig('figure34_weights_distribution_pre_post_learning.png')

plt.show()

files.download('figure34_weights_distribution_pre_post_learning.png')


In [None]:
# %% Exercise 1
#    Separate the distributions for input, hidden, and output layers.
#    Are the learning-related changes similar across all layers?

# The distribution of the weights gets less and less uniform as one goes through
# the layers, and that is already evident even in the initialised weights. This
# trend seems however clearer for the input layer than for the later layers. The
# dropout regularisation (0 or 0.25) doesn't seem to change much.

# Modified model
def train_model():
    num_epochs = 60
    drop_rate  = 0
    ANN,loss_fun,optimizer = gen_model(drop_rate)

    losses    = []
    train_acc = []
    test_acc  = []

    # Define histogram bins
    bin_edges   = np.linspace(-0.7,0.7,101)
    bin_centers = (bin_edges[:-1]+bin_edges[1:])/2

    # Track histograms for each layer
    histograms_per_layer = {
        "Input layer (784→64)":   np.zeros((num_epochs,len(bin_centers))),
        "Hidden layer 1 (64→32)": np.zeros((num_epochs,len(bin_centers))),
        "Hidden layer 2 (32→32)": np.zeros((num_epochs,len(bin_centers))),
        "Output layer (32→10)":   np.zeros((num_epochs,len(bin_centers))),
    }

    for epoch in range(num_epochs):

        # Capture per-layer histograms
        for name,layer in zip(histograms_per_layer.keys(),[ANN.input,ANN.fc1,ANN.fc2,ANN.output]):
            w       = layer.weight.detach().flatten().numpy()
            hist, _ = np.histogram(w,bins=bin_edges,density=True)
            histograms_per_layer[name][epoch,:] = hist

        # Training loop
        batch_loss = []
        batch_acc  = []

        for X, y in train_loader:
            yHat = ANN(X)
            loss = loss_fun(yHat,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_loss.append(loss.item())
            accuracy = 100*torch.mean((torch.argmax(yHat,axis=1)==y).float())
            batch_acc.append(accuracy)

        losses.append(np.mean(batch_loss))
        train_acc.append(np.mean(batch_acc))

        ANN.eval()

        with torch.no_grad():
            X,y = next(iter(test_loader))
            yHat = ANN(X)
            test_acc.append(100*torch.mean((torch.argmax(yHat,axis=1)==y).float()))

        ANN.train()

    return train_acc,test_acc,losses,ANN,histograms_per_layer,bin_centers

# Define a plotting function
def plot_layer_weights(histograms_per_layer,bin_centers):

    num_layers = len(histograms_per_layer)

    phi = ( 1 + np.sqrt(5) ) / 2
    fig,axs = plt.subplots(2,2,figsize=(8*phi,8))

    axs = axs.flatten()

    for ax,(layer_name, hist_matrix) in zip(axs,histograms_per_layer.items()):
        num_epochs = hist_matrix.shape[0]
        cmaps      = plt.cm.plasma(np.linspace(.2,.9,num_epochs))

        for i in range(num_epochs):
            ax.plot(bin_centers,hist_matrix[i,:],color=cmaps[i],linewidth=1)

        ax.set_title(layer_name)
        ax.set_xlabel("Weight value")
        ax.set_ylabel("Density")
        ax.grid(True)

    # If fewer than 4 layers, hide unused subplot(s)
    for ax in axs[num_layers:]:
        ax.axis("off")

    plt.suptitle("Weight histograms per layer over epochs\n(brighter is later in training)",fontsize=14)
    plt.tight_layout()

    plt.savefig('figure35_weights_distribution_pre_post_learning_extra1.png')

    plt.show()

    files.download('figure35_weights_distribution_pre_post_learning_extra1.png')

# Train model and call plotting function
train_acc,test_acc,losses,ANN,hists,bin_centers = train_model()
plot_layer_weights(hists,bin_centers)


In [None]:
# %% Exercise 2
#    Re-run the code without data normalization.
#    Does the scale of the data affect the scale of the weights?

# No, it doesn't seem to matter too much for these data, which are most likely
# already distributed in a fairly normalised way


In [None]:
# %% Exercise 3
#    Test how dropout regularization affects the weight distributions.

# Again, given that the model was already performing quite well, the impact of
# a 0.25 dropout regularisation doesn't seem too strong
