In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
import random

# Basic set-up

In [4]:
# Load Iris dataset
iris = load_iris()
X = iris.data
y = iris.target

# Split the dataset into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert data to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.int64)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.int64)


# Define a basic neural network

In [37]:
# # regular model 
def set_seed(seed_value):
    """Set seed for reproducibility"""
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)  # if you are using CUDA
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed_value)
    random.seed(seed_value)

# Set a seed value
seed = 42
set_seed(seed)

# # # Define the neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, 10)  # 4 input features, 10 neurons in the hidden layer
        self.fc2 = nn.Linear(10, 3)  # 10 neurons in the hidden layer, 3 output classes

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def __getitem__(self, idx):
        if idx == 0:
            return self.fc1
        elif idx == 1:
            return self.fc2
        else:
            raise IndexError("Index out of range")

    def get_first_layer(self):
        return self.fc1

# define the neural network 
net = Net()

# Define a generic training routine

In [7]:
def train_nn(net):
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.1)
    # Train the network
    num_epochs = 1000
    for epoch in range(num_epochs):
        optimizer.zero_grad()  # zero the gradients
        outputs = net(X_train_tensor)  # forward pass
        loss = criterion(outputs, y_train_tensor)  # compute the loss
        loss.backward()  # backward pass
        optimizer.step()  # update weights
        if (epoch+1) % 10 == 0:
            # print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
            pass
    # Test the network
    with torch.no_grad():
        outputs = net(X_test_tensor)
        predicted = torch.argmax(outputs, 1)
        accuracy = (predicted == y_test_tensor).sum().item() / y_test_tensor.size(0)
        print(f'Accuracy on the test set: {100 * accuracy:.2f}%')
        

In [41]:
# for benchmarking purposes, lets look at the original accuracy
train_nn(net)

Accuracy on the test set: 100.00%


# Low Rank Approximation of the weight matrix directly

Extract the first layer weights, find a low tank approximation for them,  $$W^{(1)} \approx (U_r\Sigma_r V_r^T)$$. Check that they look similar. 

In [42]:
# Extract weights of the first layer
weights_fc1 = net.fc1.weight.data

# Perform SVD on the weights
U, S, V = torch.svd(weights_fc1)

# Choose the desired rank for the low-rank approximation
k = 2  # Example: Choose top 5 singular vectors/values

# Form low-rank approximation
U_k = U[:, :k]
S_k = torch.diag(S[:k])
V_k = V[:, :k]
low_rank_approximation = torch.mm(U_k, torch.mm(S_k, V_k.t()))

# Replace weights of the first layer with the low-rank approximation
net.fc1.weight.data = nn.Parameter(low_rank_approximation)

print("Original weights of the first layer:")
print(weights_fc1)
print("\nLow-rank approximation of the first layer:")
print(low_rank_approximation)


Original weights of the first layer:
tensor([[-4.1648e-02, -4.6252e-01,  1.7128e+00,  1.5548e+00],
        [-4.0338e-02, -1.7054e-01,  7.0191e-01,  6.2048e-01],
        [ 1.1626e-01, -6.6152e-01,  9.7242e-01,  1.0223e+00],
        [ 1.6039e-01,  2.3432e-01, -1.0579e+00, -8.7162e-01],
        [ 2.7400e-01,  2.0377e-01, -1.2229e+00, -9.7948e-01],
        [-7.3428e-02,  4.9181e-02, -1.7385e-01, -1.7795e-01],
        [-2.8788e-01,  6.2951e-01, -1.1089e+00, -1.1594e+00],
        [ 1.2932e-01, -1.9027e-01,  1.5801e-01,  1.9960e-01],
        [-2.2499e-02, -1.1339e-01,  4.5061e-01,  4.0080e-01],
        [ 2.7131e-04, -7.5118e-04,  1.5493e-03,  1.6636e-03]])

Low-rank approximation of the first layer:
tensor([[-8.5978e-02, -5.1320e-01,  1.7142e+00,  1.5338e+00],
        [-5.5634e-02, -1.8698e-01,  7.0093e-01,  6.1522e-01],
        [ 2.0243e-01, -5.7831e-01,  9.9132e-01,  1.0339e+00],
        [ 1.4468e-01,  2.1212e-01, -1.0514e+00, -8.8720e-01],
        [ 2.4014e-01,  1.6710e-01, -1.2247e+00, -9

# Find the Optimal Rank to Choose

As discussed, various choices of r lead to different levels of approximation, we find $r$ that minimizes the below, where $\phi$ is any accuracy measure, in particular the AUC,
$$\min_{r\in{[1, \frac{m\times n}{m+n}}]} {r} \quad \text{s.t.} \quad \phi (y, \hat{y}) - \phi (y, \hat{y'}) < \delta
$$


In [43]:
from OptimizeRank import OptimizeRank
import tensorly

In [44]:
opt_rank = OptimizeRank(X_train_tensor,net,y_train_tensor)
optimal_rank, accuracy_diff, space_saved  = opt_rank.optimize_rank_binary_search()

Max rank: 2


  self.tensor = torch.tensor(tensor).float()


In [45]:
net.fc1.weight.data = nn.Parameter(low_rank_approximation)
train_nn(net) # the accuracy should be the same or close to the benchmark

Accuracy on the test set: 100.00%


# Allow the neural network to "learn" a lower rank weight structure

In [46]:
class NetLite(nn.Module):
    def __init__(self, rank):
        super(NetLite, self).__init__()
        # For the matrix factorization approach we were doing an SVD of W = USV'
        # instead of that, we can pass in low rank matrices, u, s, v that mimics that low rank property
        # but instead of calculating the values we let back propagation "learn" it for us
        self.u1 = nn.Linear(4, rank) 
        self.core1 = nn.Linear(rank, rank)
        self.v1 = nn.Linear(rank, 10)
        self.fc2 = nn.Linear(10, 3)  # 10 neurons in the hidden layer, 3 output classes

    def forward(self, x):
        # the next 3 are just matrix multiplies without any activation layer
        # we want this to mimic a matrix multiplication as closely as possible
        x = self.u1(x)
        x = self.core1(x)
        x = self.v1(x)

        # now it is business as usual from here on
        x = torch.relu(x)
        x = self.fc2(x)
        return x

    def __getitem__(self, idx):
        if idx == 0:
            return self.fc1
        elif idx == 1:
            return self.fc2
        else:
            raise IndexError("Index out of range")
            
    def return_decomposed_layers(self):
        return self.u1, self.core1, self.v1

In [47]:
# choose a rank you suspect is a good one, based on the optimal rank using W directly
netLite = NetLite(rank=optimal_rank)

In [29]:
from flopth import flopth
# Using a random package off the internet to get FLOPs, RESTART KERNEL IF YOU SEE KeyError: "attribute 'flops' already exists"
dummy_inputs = torch.rand(X_train.shape)
flops, params = flopth(net, inputs=(dummy_inputs,))
print(f"Original FLOPS : {flops}, #params : {params}")
flops, params = flopth(netLite, inputs=(dummy_inputs,))
print(f"Lite FLOPS : {flops}, #params : {params}")

Original FLOPS : 8.4K, #params : 83.0
Lite FLOPS : 5.4K, #params : 60.0


In [48]:
# W_dash is what the neural network proposes as a rough equivalent of W, they can be very different as shown below.
# This is because the neural network learns the best low rank matrix triplet that minimizes the loss, not the one that approximates W as closely as possible. 

u1, core1, v1 = netLite.return_decomposed_layers()
W_dash = torch.mm(torch.mm(u1.weight.T, core1.weight), v1.weight.T)
print("Learned Low Rank Weights of the First Layer")
print(np.round(W_dash.T.detach().numpy()*10,2))
W = net.get_first_layer().weight
print("Original Weights of the First Layer")
print(np.round(W.detach().numpy(),2))

Learned Low Rank Weights of the First Layer
[[ 0.45  0.49 -0.14  0.54]
 [ 0.68  0.74 -0.21  0.82]
 [-0.57 -0.61  0.17 -0.68]
 [ 0.67  0.73 -0.21  0.81]
 [ 0.14  0.16 -0.04  0.17]
 [ 0.57  0.62 -0.17  0.68]
 [ 0.1   0.11 -0.03  0.13]
 [ 0.37  0.4  -0.11  0.45]
 [-0.11 -0.12  0.03 -0.13]
 [ 0.59  0.65 -0.18  0.71]]
Original Weights of the First Layer
[[-0.08 -0.51  1.73  1.55]
 [-0.05 -0.19  0.71  0.62]
 [ 0.18 -0.61  0.99  1.03]
 [ 0.15  0.21 -1.06 -0.89]
 [ 0.24  0.17 -1.24 -1.  ]
 [-0.03  0.09 -0.17 -0.17]
 [-0.26  0.68 -1.11 -1.17]
 [ 0.12 -0.2   0.15  0.21]
 [-0.03 -0.12  0.46  0.4 ]
 [ 0.   -0.    0.    0.  ]]
