# Aleksey Senkin - Low-Rank approximation

# Exploring the rank of trained Neural Networks

In this notebook, you're going to explore trained neural networks, and study the rank of its matrices.

**Reminder**: The rank is the number of independent columns of the matrix. If a matrix $A \in \mathbb{R}^{n\times m}$  has rank $k$, then $A$ can be approximated by

$$A \approx B \cdot C$$

where $B \in \mathbb{R}^{n\times k}$ and $C \in \mathbb{R}^{k\times m}$.

You can find the rank of matrix $A$ by performing Gaussian elimination and counting the number of pivots. This can be done in few lines of `numpy` code.

**References**:
- https://arxiv.org/pdf/1804.08838
- https://arxiv.org/pdf/2209.13569
- https://arxiv.org/pdf/2012.13255

Note: The references above are not needed to complete this notebook, but reading them might give you additional insights.

## Important

1. For all the training done, make sure to plot things like the loss values and accuracy on each epoch.

    - You can either use tensorboard or just make a static matplotlib plot.
    
2. Don't add biases to the layers in the network, not important for this notebook.
3. No need to use Dropout or BatchNorm on the network.
4. Remember to use GPUs during the training.
5. Always test your hypothesis on both training and testing sets, you might get a surprising result sometimes.

## Task 1: Downloading MNIST and Dataloaders

Download the MNIST dataset and split into training and testing, and create dataloaders.

Link: https://pytorch.org/vision/stable/generated/torchvision.datasets.MNIST.html

In [None]:
import numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [None]:
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 14939030.84it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 430412.39it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1224891.66it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1368375.86it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw






In [None]:
batch_size = 32

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([32, 1, 28, 28])
Shape of y: torch.Size([32]) torch.int64


## Task 2: Train a neural network

Build a simple Multi-layered Perceptron with ReLU activations, and train it on MNIST until achieving 95% accuracy or higher.


In [None]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
def validate_on_test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
def validate_on_train(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Train Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")

In [None]:
# Define model
class NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512, bias=False),
            nn.ReLU(),
            nn.Linear(512, 512, bias=False),
            nn.ReLU(),
            nn.Linear(512, 10, bias=False)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NN().to(device)
print(model)

In [None]:
epochs = 3

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    validate_on_train(train_dataloader, model, loss_fn)
    validate_on_test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
Train Error: 
 Accuracy: 96.5%, Avg loss: 0.111752
Test Error: 
 Accuracy: 96.1%, Avg loss: 0.124072 

Epoch 2
-------------------------------
Train Error: 
 Accuracy: 96.4%, Avg loss: 0.115141
Test Error: 
 Accuracy: 95.7%, Avg loss: 0.144408 

Epoch 3
-------------------------------
Train Error: 
 Accuracy: 97.4%, Avg loss: 0.083799
Test Error: 
 Accuracy: 96.5%, Avg loss: 0.136577 

Done!


## Task 3: Analyze the rank of the matrices in this network

Perform experiments and answer the following questions:
- What's the average rank of the matrices on all layers?
- How does the rank increase as we go to deeper layers?
- Try the same MLP, but change the activation function to others ($\tanh, \sigma, \dots$). Do the answers change?

In [None]:
ranks = []

for i, parameter in enumerate(model.parameters()):
    print(f"Shape of {i + 1} layer: {parameter.shape}")
    rank = torch.linalg.matrix_rank(parameter)
    print(f"Rank of {i + 1} layer: {rank}")
    ranks.append(rank)

Shape of 1 layer: torch.Size([512, 784])
Rank of 1 layer: 512
Shape of 2 layer: torch.Size([512, 512])
Rank of 2 layer: 511
Shape of 3 layer: torch.Size([10, 512])
Rank of 3 layer: 10


### The matrices rank is almost equal to *min(matrix dimensions)*

### Change activation function to tanh

In [None]:
# Define model
class NN_tanh(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_tanh_stack = nn.Sequential(
            nn.Linear(28*28, 512, bias=False),
            nn.Tanh(),
            nn.Linear(512, 512, bias=False),
            nn.Tanh(),
            nn.Linear(512, 10, bias=False)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_tanh_stack(x)
        return logits

model_tanh = NN_tanh().to(device)
print(model_tanh)

NN_tanh(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_tanh_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=False)
    (1): Tanh()
    (2): Linear(in_features=512, out_features=512, bias=False)
    (3): Tanh()
    (4): Linear(in_features=512, out_features=10, bias=False)
  )
)


In [None]:
epochs = 3

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_tanh.parameters(), lr=1e-3)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model_tanh, loss_fn, optimizer)
    validate_on_train(train_dataloader, model_tanh, loss_fn)
    validate_on_test(test_dataloader, model_tanh, loss_fn)
print("Done!")

Epoch 1
-------------------------------
Train Error: 
 Accuracy: 94.2%, Avg loss: 0.186010
Test Error: 
 Accuracy: 94.1%, Avg loss: 0.191365 

Epoch 2
-------------------------------
Train Error: 
 Accuracy: 96.5%, Avg loss: 0.109385
Test Error: 
 Accuracy: 96.2%, Avg loss: 0.125641 

Epoch 3
-------------------------------
Train Error: 
 Accuracy: 97.5%, Avg loss: 0.078888
Test Error: 
 Accuracy: 96.9%, Avg loss: 0.105623 

Done!


In [None]:
ranks = []

for i, parameter in enumerate(model_tanh.parameters()):
    print(f"Shape of {i + 1} layer: {parameter.shape}")
    rank = torch.linalg.matrix_rank(parameter)
    print(f"Rank of {i + 1} layer: {rank}")
    ranks.append(rank)

Shape of 1 layer: torch.Size([512, 784])
Rank of 1 layer: 512
Shape of 2 layer: torch.Size([512, 512])
Rank of 2 layer: 512
Shape of 3 layer: torch.Size([10, 512])
Rank of 3 layer: 10


### Weight matrices ranks don't change and are still equal to *min(matrix dimensions)*

## Task 4: Overfit by scaling the MLP

1. Create a bigger network and train it on MNIST, to the point of overfitting.
2. Now check the rank of the matrices in the network, and answer the same questions.

In [None]:
# Define model
class NN_scaled(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 1024, bias=False),
            nn.ReLU(),
            nn.Linear(1024, 1024, bias=False),
            nn.ReLU(),
            nn.Linear(1024, 512, bias=False),
            nn.ReLU(),
            nn.Linear(512, 256, bias=False),
            nn.ReLU(),
            nn.Linear(256, 10, bias=False)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model_scaled = NN_scaled().to(device)
print(model_scaled)

NN_scaled(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=False)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=1024, bias=False)
    (3): ReLU()
    (4): Linear(in_features=1024, out_features=512, bias=False)
    (5): ReLU()
    (6): Linear(in_features=512, out_features=256, bias=False)
    (7): ReLU()
    (8): Linear(in_features=256, out_features=10, bias=False)
  )
)


In [None]:
epochs = 50

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_scaled.parameters(), lr=1e-4)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model_scaled, loss_fn, optimizer)
    validate_on_train(train_dataloader, model_scaled, loss_fn)
    validate_on_test(test_dataloader, model_scaled, loss_fn)
print("Done!")

Epoch 1
-------------------------------
Train Error: 
 Accuracy: 94.5%, Avg loss: 0.178778
Test Error: 
 Accuracy: 94.8%, Avg loss: 0.176825 

Epoch 2
-------------------------------
Train Error: 
 Accuracy: 96.7%, Avg loss: 0.109358
Test Error: 
 Accuracy: 96.3%, Avg loss: 0.118930 

Epoch 3
-------------------------------
Train Error: 
 Accuracy: 98.0%, Avg loss: 0.066444
Test Error: 
 Accuracy: 97.3%, Avg loss: 0.091504 

Epoch 4
-------------------------------
Train Error: 
 Accuracy: 98.6%, Avg loss: 0.044982
Test Error: 
 Accuracy: 97.5%, Avg loss: 0.085121 

Epoch 5
-------------------------------
Train Error: 
 Accuracy: 99.0%, Avg loss: 0.032451
Test Error: 
 Accuracy: 97.7%, Avg loss: 0.086183 

Epoch 6
-------------------------------
Train Error: 
 Accuracy: 98.8%, Avg loss: 0.038224
Test Error: 
 Accuracy: 97.5%, Avg loss: 0.102648 

Epoch 7
-------------------------------
Train Error: 
 Accuracy: 99.4%, Avg loss: 0.020740
Test Error: 
 Accuracy: 97.8%, Avg loss: 0.088526 


**Train Error:**
 Accuracy: 100%, Avg loss: 0.000372

**Test Error:**
 Accuracy: 98.4%, Avg loss: 0.104507  

 The model has learnt all the training data

In [None]:
ranks = []

for i, parameter in enumerate(model_scaled.parameters()):
    print(f"Shape of {i + 1} layer: {parameter.shape}")
    rank = torch.linalg.matrix_rank(parameter)
    print(f"Rank of {i + 1} layer: {rank}")
    ranks.append(rank)

Shape of 1 layer: torch.Size([1024, 784])
Rank of 1 layer: 784
Shape of 2 layer: torch.Size([1024, 1024])
Rank of 2 layer: 1022
Shape of 3 layer: torch.Size([512, 1024])
Rank of 3 layer: 512
Shape of 4 layer: torch.Size([256, 512])
Rank of 4 layer: 256
Shape of 5 layer: torch.Size([10, 256])
Rank of 5 layer: 10


Despite having 100% accuracy on train set, matrices ranks didn't get lower.
This could be connected with lack of parameters needed to get obvious overfitting. Furthermore, the dataset used is not really designed for fully connected networks, so I suppose this structure doesn't get all the needed information from data - that's why it's quite difficult to overfit it properly.

Also, the elements of these weight matrices are floating point numbers so it's clear, that the probability of two rows being linearly dependent is low.

## Task 5: Approximate low-rank

From some of the references given at the beginning, you can realize that trained neural networks have intrinsically low dimensionality (meaning low-rank matrices).

In this task, take the overparametrized network already trained from the TASK4 and try to approximate each layer's matrix with a product of two other low-rank matrices?

This means, if a layer has a matrix $A \in\mathbb{R}^{n\times m}$, then try to find two matrices $B \in \mathbb{R}^{n\times r}$ and $C \in \mathbb{R}^{r\times m}$ so that $\lvert {A - B\cdot C}\rvert $ is minimized, where $\lvert x\rvert$ means the Frobenius norm. You can use a different norm, if you think it makes sense. In order to learn $B$ and $C$, you can do gradient descent-like algorithms, where you alternate between updating $B$ and $C$ on each optimization step.

**Ablate**:
Try different values for $r$ and analyze how good your approximation is (for e.g, by taking average Frobenius norm across all layers) as you increase $r$. Make a plot with that.

Conclude what is the effective rank $r$: the smallest rank such that the approximation of that rank is good enough (meaning the Frobenius norm is smaller than some threshold chosen by you).

In [None]:
layers = []

for i, parameter in enumerate(model_scaled.parameters()):
    print(f"Shape of {i + 1} layer: {parameter.shape}")
    layers.append(parameter.data)

layers = layers[:-1] # not to approximate the last layer with already low rank = 10

Shape of 1 layer: torch.Size([1024, 784])
Shape of 2 layer: torch.Size([1024, 1024])
Shape of 3 layer: torch.Size([512, 1024])
Shape of 4 layer: torch.Size([256, 512])
Shape of 5 layer: torch.Size([10, 256])


In [None]:
class LowRankApprox(nn.Module):
    def __init__(self, n, m, rank):
        super().__init__()
        self.B = nn.Linear(n, rank, bias=False)
        self.C = nn.Linear(rank, m, bias=False)

    def forward(self, x):
        return self.C(self.B(x))

In [None]:
def get_approximated_layers(layers, ranks, lr, epochs):
    approximated_layers = []

    for i, layer in enumerate(layers):
        if i != 0:
            print()
        # print(f"Approximating for {i+1} layer")

        n, m = layer.shape
        rank = ranks[i]
        input = torch.from_numpy(np.eye(n)).float().to(device)
        label = layer.to(device)

        model = LowRankApprox(n, m, rank).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        loss_fn = nn.MSELoss()

        for epoch in range(epochs):
            output = model(input)
            loss = loss_fn(output, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # if epoch % 9000 == 0:
            #     print(f'Epoch {epoch}: Loss = {loss.item():.6f}')


        approximated_layers.append((model.B.weight.detach().T, model.C.weight.detach().T))

    return approximated_layers

In [None]:
def evaluate_approximated_layers(approximated_layers, layers):
    for i, layer in enumerate(approximated_layers):
        original_matrix = layers[i].cpu().numpy()
        approximation = layer[0].cpu().numpy() @ layer[1].cpu().numpy()

        n, m = layer[0].shape[0], layer[1].shape[1]
        rank = layer[0].shape[1]

        print(f"Layer {i+1}:")
        # print(f"Shape of A: {original_matrix.shape}")
        # print(f"Shape of B: {n, rank}")
        # print(f"Shape of C: {rank, m}\n")

        # --- Evaluation ---
        compression_ratio = (n * m)  / (n * rank + rank * m )
        print(f'Compression Ratio: {compression_ratio:.3f}')

        error = np.linalg.norm(original_matrix - approximation)
        print(f'Approximation Error (Frobenius Norm): {error:.5f}\n')

        # abs_error = np.abs(original_matrix - approximation)
        # relative_error = abs_error / (np.abs(original_matrix) + 1e-8)
        # percent_error = 100 * np.mean(relative_error)
        # print(f'Approximation Error (Avg Percent): {percent_error:.3f}%\n')

### Try some combinations of ranks

In [None]:
ranks = [[40, 40, 40, 20],
         [100, 100, 100, 50],
         [150, 150, 150, 80],
         [200, 200, 200, 100]]
lr = 1e-4
epochs = 10000

In [None]:
appr_layers = []

for i in range(4):
    print(f"Training for ranks = {ranks[i]}")
    approximated_layers = get_approximated_layers(layers, ranks[i], lr, epochs)
    appr_layers.append(approximated_layers)
    evaluate_approximated_layers(approximated_layers, layers)

Training for ranks = [40, 40, 40, 20]



Layer 1:
Compression Ratio: 11.101
Approximation Error (Frobenius Norm): 22.99904

Layer 2:
Compression Ratio: 12.800
Approximation Error (Frobenius Norm): 22.48495

Layer 3:
Compression Ratio: 8.533
Approximation Error (Frobenius Norm): 13.42279

Layer 4:
Compression Ratio: 8.533
Approximation Error (Frobenius Norm): 8.92651

Training for ranks = [100, 100, 100, 50]



Layer 1:
Compression Ratio: 4.440
Approximation Error (Frobenius Norm): 18.84841

Layer 2:
Compression Ratio: 5.120
Approximation Error (Frobenius Norm): 19.29996

Layer 3:
Compression Ratio: 3.413
Approximation Error (Frobenius Norm): 11.17345

Layer 4:
Compression Ratio: 3.413
Approximation Error (Frobenius Norm): 7.50189

Training for ranks = [150, 150, 150, 80]



Layer 1:
Compression Ratio: 2.960
Approximation Error (Frobenius Norm): 16.50859

Layer 2:
Compression Ratio: 3.413
Approximation Error (Frobenius Norm): 17.35723

Layer 3:
Compression Ratio: 2.276
Approximation Err

### Bigger hidden ranks expectedly give lower Frobenius norm values, though compression ratio also becomes lower

So the chosen effective ranks are [150, 150, 150, 80]

## Task 6: Learning with low-rank factorization

Once you found the effective rank $r$, take the same architecture from the previous task, and now replace each layer $A \in \mathbb{R}^{n\times m}$ by a layer that applies $B\cdot C$ with $B\in \mathbb{R}^{n\times r}$ and $C \in \mathbb{R}^{r\times m}$.

**Question**: How much memory do you save? (you can just count the number of parameters of the original network and compare to that of the new network).

Initialize these values with standard initialization, and train this network.

**Question**: How does the learning change? Does it converge faster or slower? What about accuracy on both training and testing sets?

**Question**: Now try doing inference, how much improvement do you see?

In [None]:
for layer in appr_layers[2]:
    print(layer[0].shape, layer[1].shape)

torch.Size([1024, 150]) torch.Size([150, 784])
torch.Size([1024, 150]) torch.Size([150, 1024])
torch.Size([512, 150]) torch.Size([150, 1024])
torch.Size([256, 80]) torch.Size([80, 512])


In [None]:
# Define model
class NN_approximated(nn.Module):
    def __init__(self):
        super().__init__()
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.B1 = nn.Linear(150, 1024, bias=False)
        self.C1 = nn.Linear(784, 150, bias=False)

        self.B2 = nn.Linear(150, 1024, bias=False)
        self.C2 = nn.Linear(1024, 150, bias=False)

        self.B3 = nn.Linear(150, 512, bias=False)
        self.C3 = nn.Linear(1024, 150, bias=False)

        self.B4 = nn.Linear(80, 256, bias=False)
        self.C4 = nn.Linear(512, 80, bias=False)

        self.A5 = nn.Linear(256, 10, bias=False)

    def forward(self, x):
        x = self.flatten(x)
        a1 = self.activation(self.B1(self.C1(x)))

        a2 = self.activation(self.B2(self.C2(a1)))
        a3 = self.activation(self.B3(self.C3(a2)))
        a4 = self.activation(self.B4(self.C4(a3)))

        logits = self.A5(a4)
        return logits

model_approximated = NN_approximated().to(device)
print(model_approximated)

NN_approximated(
  (activation): ReLU()
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (B1): Linear(in_features=150, out_features=1024, bias=False)
  (C1): Linear(in_features=784, out_features=150, bias=False)
  (B2): Linear(in_features=150, out_features=1024, bias=False)
  (C2): Linear(in_features=1024, out_features=150, bias=False)
  (B3): Linear(in_features=150, out_features=512, bias=False)
  (C3): Linear(in_features=1024, out_features=150, bias=False)
  (B4): Linear(in_features=80, out_features=256, bias=False)
  (C4): Linear(in_features=512, out_features=80, bias=False)
  (A5): Linear(in_features=256, out_features=10, bias=False)
)


### Compare low-rank and basic models' performances

In [None]:
# low-rank

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_approximated.parameters(), lr=0.5e-3)

epochs = 15
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model_approximated, loss_fn, optimizer)
    validate_on_train(train_dataloader, model_approximated, loss_fn)
    validate_on_test(test_dataloader, model_approximated, loss_fn)
print("Done!")

In [None]:
# basic

model_scaled = NN_scaled().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_scaled.parameters(), lr=0.5e-3)

epochs = 15
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model_scaled, loss_fn, optimizer)
    validate_on_train(train_dataloader, model_scaled, loss_fn)
    validate_on_test(test_dataloader, model_scaled, loss_fn)
print("Done!")

### With *epochs = 15, lr = 0.5e-3* we get the next results

Model with low-rank approximation *(ranks = [150, 150, 150, 80])*:
- Time for training: 4m 0s
- Train loss: 0.02393
- Train accuracy: 99.3%
- Test loss: 0.133316
- Test accuracy: 97.8%

Basic model:
- Time for training: 5m 17s
- Train loss: 0.007569
- Train accuracy: 99.8%
- Test loss: 0.094549
- Test accuracy: 98.3%

### Compare the size of two models (total number of parameters)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The low-rank model has {count_parameters(model_approximated):,} trainable parameters')
print(f'The basic model has {count_parameters(model_scaled):,} trainable parameters')

The low-rank model has 872,800 trainable parameters
The basic model has 2,509,312 trainable parameters


### As we can see, the new model has 3 times less parameters than the basic one

## Task 7: Final conclusions

Based on all the previous experiments, report your conclusions and try to give an explanation to the behaviours you observed.

Can you think of other ways of using the low-rank factorizations? What about SVD? Provide an explanation.

### Some conclusions about low-rank factorizations:

- Overfitting leads to lower ranks of weight matrices in a Neural Network. This is not a surprise, because overfitted models usually tend to find more complex dependencies, than the real ones existing between train sample and target. This is the result of the bigger amount of parameters and complex structure of the model; and in this case lower rank of weight matrices may point at the excessive amount of parameters.
- This method fastens training and inference processes of the model while maintaining the overall quality of a similar model with a bigger amount of parameters. The choice of hidden rank affects the model quality.
- The main application of low-rank factorization is reducing the number of parameters in a neural network and, hence, reducing it's computational cost (both memory consumption and time of processing).
- Speaking about SVD, it is an optimal (in terms of Frobenius norm) way to get a low-rank approximation of a given matrix.

## BONUS Task: LoRA

Propose ideas by which low-rank could improve fine-tuning and training? Which disadvantages does it have?

Read about LoRA (given in one of the references at the begining of the notebook).

Now, take MNIST, and remove some digit from the dataset (keep the same labels, just remove the datapoints of a specific label).

Train a simple MLP on this modified dataset.
Fine-tune in the datapoints of the chosen digit, by using LoRA.

Report the memory and time overheads.