In [1]:
import copy
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from time import time

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Custom linear layer
class CustomLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(CustomLinear, self).__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))
    def forward(self, x):
        return x @ self.weight.t() + self.bias

# Custom linear layer with SVD
class CustomLinearSVD(nn.Module):
    def __init__(self, original_layer, rank=None):
        super(CustomLinearSVD, self).__init__()
        U, S, V = torch.svd(original_layer.weight.data)

        if rank is not None and rank < len(S):
            U = U[:, :rank]
            S = S[:rank]
            V = V[:, :rank]
        self.U = nn.Parameter(U,requires_grad=False)
        self.S = nn.Parameter(torch.diag(S),requires_grad=False)
        self.V = nn.Parameter(V)
        self.precomputed_SU = nn.Parameter(self.S.t() @ self.U.t(),requires_grad=False)
        self.bias = nn.Parameter(original_layer.bias.data)
    def forward(self, x):

        x= x @ self.V
        x= x @ self.precomputed_SU
        return x  + self.bias

def replace_linear_with_svd(model, layer_num, rank=None):
    new_model=copy.deepcopy(model)
    for l_n in layer_num:
        layer_name = f'layer{l_n}'
        original_layer = getattr(new_model, layer_name)
        svd_layer = CustomLinearSVD(original_layer, rank=rank)
        setattr(new_model, layer_name, svd_layer)
    return new_model

# Model definition

In [3]:
class Net(nn.Module):
    def __init__(self, linear_layer=CustomLinear):
        super(Net, self).__init__()
        self.layer1 = linear_layer(784, 8192)
        self.layer2 = linear_layer(8192, 8192)
        self.layer3 = linear_layer(8192, 10)
    def forward(self, x):
        x = x.view(x.size(0), -1) # Flatten the tensor
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return F.log_softmax(self.layer3(x), dim=1)
    def freeze_it(self):
        for param in self.parameters():
            param.requires_grad = False


In [4]:

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=63, shuffle=True)
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform)
testloader = DataLoader(testset, batch_size=63, shuffle=False)


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Function to train the model
def evaluate_model(model, dataloader, device):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy
def evaluate_model_and_time(model, dataloader, device):
    correct = 0
    total = 0
    time_records=[]
    model.eval()
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            start_time=time()
            outputs = model(images)
            time_records.append(time() - start_time)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy,np.mean(time_records)
# Updated train model function
def train_model(model, trainloader, testloader, device, epochs=3,lr=0.01):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    for epoch in range(epochs):
        model.train()
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

        # Evaluate on training data
        train_accuracy = evaluate_model(model, trainloader, device)
        # Evaluate on test data
        test_accuracy = evaluate_model(model, testloader, device)

        print(f'Epoch {epoch+1}/{epochs} - Train Accuracy: {train_accuracy:.2f}% - Test Accuracy: {test_accuracy:.2f}%')

# Initialize and train the original model
model_original = Net(CustomLinear).to(device)
train_model(model_original, trainloader,testloader,device)


Epoch 1/3 - Train Accuracy: 93.08% - Test Accuracy: 92.44%
Epoch 2/3 - Train Accuracy: 93.78% - Test Accuracy: 92.73%
Epoch 3/3 - Train Accuracy: 94.95% - Test Accuracy: 94.00%


In [10]:

accuracy_original,time_original = evaluate_model_and_time(model_original.to('cpu'), testloader,'cpu')
model_original.freeze_it()
model_svd = replace_linear_with_svd(model_original, layer_num=[2], rank=512).to(device) # You can adjust the rank
print(f"Original Model Accuracy: {accuracy_original:.3f}%, Time: {time_original:.8f} seconds")
accuracy_svd ,time_svd= evaluate_model_and_time(model_svd.to('cpu'), testloader,'cpu')
print(f"SVD Model (not tuned) Accuracy: {accuracy_svd:.3f}%, Time: {time_svd:.8f} seconds")
train_model(model_svd, trainloader,testloader,device,2,0.0001)
accuracy_svd ,time_svd= evaluate_model_and_time(model_svd.to('cpu'), testloader,'cpu')
print(f"SVD Model (Tuned) Accuracy: {accuracy_svd:.3f}%, Time: {time_svd:.8f} seconds")


Original Model Accuracy: 94.000%, Time: 0.03497588 seconds
SVD Model (not tuned) Accuracy: 69.000%, Time: 0.01164589 seconds
Epoch 1/2 - Train Accuracy: 97.25% - Test Accuracy: 96.04%
Epoch 2/2 - Train Accuracy: 98.38% - Test Accuracy: 96.66%
SVD Model (Tuned) Accuracy: 96.660%, Time: 0.01784200 seconds


In [7]:
# for name, param in model_svd.named_parameters():
#     print(name,param.requires_grad)
#     print()