In [47]:
import torch 

d = 4
W_rank = 2
k = 4
W = torch.randn(d, W_rank) @ torch.randn(W_rank, k)
print(W)
U, S, V = torch.svd(W)
print(U)
print(S)
print(V)


tensor([[-8.5756e-01, -4.5152e-01,  1.5634e+00, -4.4968e-01],
        [-5.2195e-03,  7.3769e-01, -5.9260e-01, -1.6914e-01],
        [-1.9742e-01, -5.3052e+00,  4.5895e+00,  1.0654e+00],
        [ 2.1769e-01,  2.1175e+00, -2.0255e+00, -3.3596e-01]])
tensor([[-0.1728, -0.9720,  0.0669,  0.1441],
        [ 0.1217, -0.0677, -0.9797,  0.1444],
        [-0.9022,  0.2103, -0.0722,  0.3696],
        [ 0.3760,  0.0796,  0.1748,  0.9065]])
tensor([7.8612e+00, 1.3568e+00, 1.1492e-07, 3.4908e-08])
tensor([[ 0.0518,  0.5968,  0.2014, -0.7749],
        [ 0.7315, -0.4113, -0.3976, -0.3712],
        [-0.6671, -0.4981, -0.2514, -0.4936],
        [-0.1311,  0.4760, -0.8592,  0.1345]])


In [48]:
import numpy as np
W_rank = np.linalg.matrix_rank(W)
print(f'Rank of W: {W_rank}')

Rank of W: 2


In [49]:
W = W@W.T


In [50]:
W.shape

torch.Size([4, 4])

In [51]:
eigenvalues, eigenvectors = torch.linalg.eig(W)

print("Eigenvalues:", eigenvalues)
print("Eigenvectors:", eigenvectors)

Eigenvalues: tensor([ 6.1798e+01+0.j,  1.8408e+00+0.j,  7.5576e-07+0.j, -5.5081e-07+0.j])
Eigenvectors: tensor([[-0.1728+0.j, -0.9720+0.j, -0.1181+0.j,  0.1587+0.j],
        [ 0.1217+0.j, -0.0677+0.j, -0.4259+0.j, -0.2312+0.j],
        [-0.9022+0.j,  0.2103+0.j, -0.3745+0.j,  0.3161+0.j],
        [ 0.3760+0.j,  0.0796+0.j, -0.8151+0.j,  0.9063+0.j]])


In [52]:
import torch
import numpy as np 
import torch.nn.functional as F
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets 
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"This model runs on {device}")

class_names = [str(i) for i in range(10)]  # MNIST has digits 0-9

class ComplexModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(28 * 28, 512)  # Input layer
        self.dropout1 = nn.Dropout(p=0.5)      # Dropout layer for regularization
        self.layer2 = nn.Linear(512, 256)       # Second layer
        self.dropout2 = nn.Dropout(p=0.5)
        self.layer3 = nn.Linear(256, 128)       # Third layer
        self.dropout3 = nn.Dropout(p=0.5)
        self.layer4 = nn.Linear(128, 10)        # Output layer
        self.relu = nn.ReLU()                   # Activation function

    def forward(self, X):
        X = X.view(-1, 28 * 28)  # Flatten the input
        res = self.layer1(X)
        res = self.relu(res)
        res = self.dropout1(res)  # Apply dropout
        res = self.layer2(res)
        res = self.relu(res)
        res = self.dropout2(res)
        res = self.layer3(res)
        res = self.relu(res)
        res = self.dropout3(res)
        res = self.layer4(res)  # Output logits
        return res 

def Prepare_Data(batch_size_train=64, batch_size_test=1024):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # Normalize with MNIST mean and std
    ])
    mnist_trainset = datasets.MNIST(
        root='../data', train=True, download=True, transform=transform
    )
    train_loader = DataLoader(mnist_trainset, batch_size=batch_size_train, shuffle=True)
    mnist_testset = datasets.MNIST(
        root='../data', train=False, download=True, transform=transform
    )
    test_loader = DataLoader(mnist_testset, batch_size=batch_size_test, shuffle=False)
    return train_loader, test_loader

model = ComplexModel().to(device=device)
print(model)
 
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Lower learning rate

def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()            
        outputs = model(data)           
        loss = criterion(outputs, target)   
        loss.backward()                
        optimizer.step()                 
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        if (batch_idx + 1) % 100 == 0:
            print(f'Epoch [{epoch}/100], Step [{batch_idx + 1}/{len(train_loader)}], '
                  f'Loss: {running_loss / (batch_idx + 1):.4f}, '
                  f'Accuracy: {100 * correct / total:.2f}%')
    epoch_time = time.time() - start_time
    print(f'Epoch [{epoch}/100] completed in {epoch_time:.2f} seconds. '
          f'Average Loss: {running_loss / len(train_loader):.4f}, '
          f'Accuracy: {100 * correct / total:.2f}%')

def test(model, device, test_loader, criterion, class_names):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    class_correct = list(0. for _ in range(len(class_names)))
    class_total = list(0. for _ in range(len(class_names)))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            loss = criterion(outputs, target)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            c = (predicted == target).squeeze()
            for i in range(len(target)):
                label = target[i]
                class_correct[label] += (predicted[i] == label).item()
                class_total[label] += 1
            correct += (predicted == target).sum().item()
            total += target.size(0)
    average_loss = test_loss / len(test_loader)
    accuracy = 100 * correct / total
    print(f'Test Loss: {average_loss:.4f}, Test Accuracy: {accuracy:.2f}%\n')
    for i in range(len(class_names)):
        if class_total[i] > 0:
            print(f'Class: {class_names[i]:15s} - Correct: {int(class_correct[i])}/{int(class_total[i])} '
                  f'({100 * class_correct[i] / class_total[i]:.2f}%)')
        else:
            print(f'Class: {class_names[i]:15s} - No samples.')
    print("\n")
    return accuracy

if __name__ == "__main__":
    train_loader, test_loader = Prepare_Data()
    best_accuracy = 0.0
    num_epochs = 50
    for epoch in range(1, num_epochs + 1):
        train(model, device, train_loader, optimizer, criterion, epoch)
        accuracy = test(model, device, test_loader, criterion, class_names)
        # if accuracy > best_accuracy:
        #     best_accuracy = accuracy
        #     torch.save(model.state_dict(), 'best_model.pth')
        #     print(f'Best model saved with accuracy: {best_accuracy:.2f}%\n')
    print(f'Training completed. Best Test Accuracy: {best_accuracy:.2f}%')


This model runs on cuda
ComplexModel(
  (layer1): Linear(in_features=784, out_features=512, bias=True)
  (dropout1): Dropout(p=0.5, inplace=False)
  (layer2): Linear(in_features=512, out_features=256, bias=True)
  (dropout2): Dropout(p=0.5, inplace=False)
  (layer3): Linear(in_features=256, out_features=128, bias=True)
  (dropout3): Dropout(p=0.5, inplace=False)
  (layer4): Linear(in_features=128, out_features=10, bias=True)
  (relu): ReLU()
)
Epoch [1/100], Step [100/938], Loss: 1.1655, Accuracy: 59.67%
Epoch [1/100], Step [200/938], Loss: 0.8305, Accuracy: 72.30%
Epoch [1/100], Step [300/938], Loss: 0.6883, Accuracy: 77.51%
Epoch [1/100], Step [400/938], Loss: 0.6050, Accuracy: 80.60%
Epoch [1/100], Step [500/938], Loss: 0.5494, Accuracy: 82.66%
Epoch [1/100], Step [600/938], Loss: 0.5091, Accuracy: 84.11%
Epoch [1/100], Step [700/938], Loss: 0.4790, Accuracy: 85.15%
Epoch [1/100], Step [800/938], Loss: 0.4545, Accuracy: 86.01%
Epoch [1/100], Step [900/938], Loss: 0.4315, Accuracy: 8

In [60]:
total_parameters_original= sum(p.numel() for p in model.parameters() if p.requires_grad)
total_parameters_original

569882

In [54]:
import torch.nn.utils.parametrize as parametrize
class LoRA_Scratch(nn.Module):
    def __init__(self, Layer_dims, rank, alpha= 0.5):
        
        super().__init__()
        feature_in, feature_out = Layer_dims.weight.shape

        self.A = nn.Parameter(torch.zeros(feature_in, rank).to(device=device))
        self.A = self.A.to(device)

        self.B = nn.Parameter(torch.zeros(rank, feature_out).to(device=device))
        self.A = self.A.to(device)

        self.scale = alpha/rank
        self.LoRA = True
    
    def forward(self, X):
        
        if self.LoRA:
            return  X + (self.A @ self.B) * self.scale
        else:
            return X

def LoRA_convertor(layer, rank=1, lora_alpha=1):
    return LoRA_Scratch(
        layer, rank=rank, alpha=lora_alpha
    )

parametrize.register_parametrization(
    model.layer1, "weight", LoRA_convertor(model.layer1)
)
parametrize.register_parametrization(
    model.layer2, "weight", LoRA_convertor(model.layer2)
)
parametrize.register_parametrization(
    model.layer3, "weight", LoRA_convertor(model.layer3)
)


ParametrizedLinear(
  in_features=256, out_features=128, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRA_Scratch()
    )
  )
)

In [55]:
model.layer1

ParametrizedLinear(
  in_features=784, out_features=512, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): LoRA_Scratch()
    )
  )
)

In [57]:
import torch.nn.utils.parametrize as parametrize
class LoRA_Scratch(nn.Module):
    def __init__(self, Layer_dims, rank, alpha= 0.5):
        
        super().__init__()
        feature_in, feature_out = Layer_dims.weight.shape

        self.A = nn.Parameter(torch.zeros(feature_in, rank).to(device=device))
        self.A = self.A.to(device)

        self.B = nn.Parameter(torch.zeros(rank, feature_out).to(device=device))
        self.A = self.A.to(device)

        self.scale = alpha/rank
        self.LoRA = True
    
    def forward(self, X):
        
        if self.LoRA:
            return  X + (self.A @ self.B) * self.scale
        else:
            return X

def LoRA_convertor(layer, rank=1, lora_alpha=1):
    return LoRA_Scratch(
        layer, rank=rank, alpha=lora_alpha
    )

parametrize.register_parametrization(
    model.layer1, "weight", LoRA_convertor(model.layer1)
)
parametrize.register_parametrization(
    model.layer2, "weight", LoRA_convertor(model.layer2)
)
parametrize.register_parametrization(
    model.layer3, "weight", LoRA_convertor(model.layer3)
)


def Enable_lora(enable= True):
    for Layer in [model.layer1, model.layer2, model.layer3]:
        Layer.parametrizations['weight'][0].LoRA = True

def Disable_lora(enable= True):
    for Layer in [model.layer1, model.layer2, model.layer3]:
        Layer.parametrizations['weight'][0].LoRA = False

In [67]:
import torch.nn.utils.parametrize as parametrize
class LoRA_Scratch(nn.Module):
    def __init__(self, Layer_dims, rank, alpha= 0.5):
        
        super().__init__()
        feature_in, feature_out = Layer_dims.weight.shape

        self.A = nn.Parameter(torch.zeros(feature_in, rank).to(device=device))
        self.A = self.A.to(device)

        self.B = nn.Parameter(torch.zeros(rank, feature_out).to(device=device))
        self.A = self.A.to(device)

        self.scale = alpha/rank
        self.LoRA = True
    
    def forward(self, X):
        
        if self.LoRA:
            return  X + (self.A @ self.B) * self.scale
        else:
            return X

def LoRA_convertor(layer, rank=1, lora_alpha=1):
    return LoRA_Scratch(
        layer, rank=rank, alpha=lora_alpha
    )

parametrize.register_parametrization(
    model.layer1, "weight", LoRA_convertor(model.layer1)
)
parametrize.register_parametrization(
    model.layer2, "weight", LoRA_convertor(model.layer2)
)
parametrize.register_parametrization(
    model.layer3, "weight", LoRA_convertor(model.layer3)
)


def Enable_lora(enable= True):
    for Layer in [model.layer1, model.layer2, model.layer3]:
        Layer.parametrizations['weight'][0].LoRA = True

def Disable_lora(enable= True):
    for Layer in [model.layer1, model.layer2, model.layer3]:
        Layer.parametrizations['weight'][0].LoRA = False

total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.layer1, model.layer2, model.layer3]):
    total_parameters_lora += layer.parametrizations["weight"][0].A.nelement() + layer.parametrizations["weight"][0].B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].A.shape} + Lora_B: {layer.parametrizations["weight"][0].B.shape}'
    )
# The non-LoRA parameters count must match the original network
# assert total_parameters_non_lora == total_parameters_original, 'not matched'
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Layer 1: W: torch.Size([512, 784]) + B: torch.Size([512]) + Lora_A: torch.Size([512, 1]) + Lora_B: torch.Size([1, 784])
Layer 2: W: torch.Size([256, 512]) + B: torch.Size([256]) + Lora_A: torch.Size([256, 1]) + Lora_B: torch.Size([1, 512])
Layer 3: W: torch.Size([128, 256]) + B: torch.Size([128]) + Lora_A: torch.Size([128, 1]) + Lora_B: torch.Size([1, 256])
Total number of parameters (original): 566,144
Total number of parameters (original + LoRA): 568,592
Parameters introduced by LoRA: 2,448
Parameters incremment: 0.432%


In [65]:
total_parameters_non_lora

566144

In [66]:
total_parameters_original

569882

In [76]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torch.nn.utils.parametrize as parametrize
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"This model runs on {device}")

class LoRA_Scratch(nn.Module):
    def __init__(self, Layer_dims, rank, alpha=0.5):
        super().__init__()
        feature_in, feature_out = Layer_dims.weight.shape

        self.A = nn.Parameter(torch.zeros(feature_in, rank).to(device=device))
        self.B = nn.Parameter(torch.zeros(rank, feature_out).to(device=device))
        
        self.scale = alpha / rank
        self.LoRA = True
    
    def forward(self, X):
        if self.LoRA:
            return X + (self.A @ self.B) * self.scale
        else:
            return X

def LoRA_convertor(layer, rank=1, lora_alpha=1):
    return LoRA_Scratch(
        layer, rank=rank, alpha=lora_alpha
    )

# Example model structure
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(28*28, 100)
        self.layer2 = nn.Linear(100, 100)
        self.layer3 = nn.Linear(100, 10)

    def forward(self, X):
        X = X.view(-1, 28*28)
        X = torch.relu(self.layer1(X))
        X = torch.relu(self.layer2(X))
        X = self.layer3(X)
        return X

model = Model().to(device=device)

# Register LoRA parameters
parametrize.register_parametrization(
    model.layer1, "weight", LoRA_convertor(model.layer1)
)
parametrize.register_parametrization(
    model.layer2, "weight", LoRA_convertor(model.layer2)
)
parametrize.register_parametrization(
    model.layer3, "weight", LoRA_convertor(model.layer3)
)

def Enable_lora(enable=True):
    for Layer in [model.layer1, model.layer2, model.layer3]:
        Layer.parametrizations['weight'][0].LoRA = enable

def Disable_lora(enable=True):
    for Layer in [model.layer1, model.layer2, model.layer3]:
        Layer.parametrizations['weight'][0].LoRA = not enable

# Count parameters
total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([model.layer1, model.layer2, model.layer3]):
    total_parameters_lora += layer.parametrizations["weight"][0].A.nelement() + layer.parametrizations["weight"][0].B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + (layer.bias.nelement() if layer.bias is not None else 0)
    print(
        f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape if layer.bias is not None else "None"} + '
        f'Lora_A: {layer.parametrizations["weight"][0].A.shape} + Lora_B: {layer.parametrizations["weight"][0].B.shape}'
    )

print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_increment = (total_parameters_lora / total_parameters_non_lora) * 100 if total_parameters_non_lora > 0 else 0
print(f'Parameters increment: {parameters_increment:.3f}%')

for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset, keeping only the digit 9
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize with MNIST mean and std
])

mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

# Create a dataloader for the training
train_loader = DataLoader(mnist_trainset, batch_size=10, shuffle=True)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()            
        outputs = model(data)           
        loss = criterion(outputs, target)   
        loss.backward()                
        optimizer.step()                 
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        if (batch_idx + 1) % 10 == 0:  # Print every 10 batches
            print(f'Epoch [{epoch}], Step [{batch_idx+1}/{len(train_loader)}], '
                  f'Loss: {running_loss / (batch_idx+1):.4f}, '
                  f'Accuracy: {100 * correct / total:.2f}%')
    epoch_time = time.time() - start_time
    print(f'Epoch [{epoch}] completed in {epoch_time:.2f} seconds. '
          f'Average Loss: {running_loss / len(train_loader):.4f}, '
          f'Accuracy: {100 * correct / total:.2f}%')

if __name__ == "__main__":
    num_epochs = 2
    for epoch in range(1, num_epochs + 1):
        Enable_lora(True)  # Enable LoRA during training
        train(model, device, train_loader, optimizer, criterion, epoch)

print('Training completed.')


This model runs on cuda
Layer 1: W: torch.Size([100, 784]) + B: torch.Size([100]) + Lora_A: torch.Size([100, 1]) + Lora_B: torch.Size([1, 784])
Layer 2: W: torch.Size([100, 100]) + B: torch.Size([100]) + Lora_A: torch.Size([100, 1]) + Lora_B: torch.Size([1, 100])
Layer 3: W: torch.Size([10, 100]) + B: torch.Size([10]) + Lora_A: torch.Size([10, 1]) + Lora_B: torch.Size([1, 100])
Total number of parameters (original): 89,610
Total number of parameters (original + LoRA): 90,804
Parameters introduced by LoRA: 1,194
Parameters increment: 1.332%
Freezing non-LoRA parameter layer1.bias
Freezing non-LoRA parameter layer1.parametrizations.weight.original
Freezing non-LoRA parameter layer1.parametrizations.weight.0.A
Freezing non-LoRA parameter layer1.parametrizations.weight.0.B
Freezing non-LoRA parameter layer2.bias
Freezing non-LoRA parameter layer2.parametrizations.weight.original
Freezing non-LoRA parameter layer2.parametrizations.weight.0.A
Freezing non-LoRA parameter layer2.parametrizatio

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [77]:
def Enable_lora(enable=True):
    for Layer in [model.layer1, model.layer2, model.layer3]:
        Layer.parametrizations['weight'][0].LoRA = enable
        for param in [Layer.parametrizations['weight'][0].A, Layer.parametrizations['weight'][0].B]:
            param.requires_grad = enable  # Ensure LoRA parameters require grad

# Train the network with LoRA only on the digit 9
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()            
        
        # Check that data requires grad
        assert data.requires_grad is False, "Data should not require gradients"
        
        outputs = model(data)           
        loss = criterion(outputs, target)   
        
        # Check that loss requires grad
        assert loss.requires_grad is True, "Loss should require gradients"
        
        loss.backward()                
        optimizer.step()                 
        running_loss += loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        if (batch_idx + 1) % 10 == 0:  # Print every 10 batches
            print(f'Epoch [{epoch}], Step [{batch_idx+1}/{len(train_loader)}], '
                  f'Loss: {running_loss / (batch_idx+1):.4f}, '
                  f'Accuracy: {100 * correct / total:.2f}%')
    
    epoch_time = time.time() - start_time
    print(f'Epoch [{epoch}] completed in {epoch_time:.2f} seconds. '
          f'Average Loss: {running_loss / len(train_loader):.4f}, '
          f'Accuracy: {100 * correct / total:.2f}%')

if __name__ == "__main__":
    num_epochs = 2
    for epoch in range(1, num_epochs + 1):
        Enable_lora(True)  # Enable LoRA during training
        train(model, device, train_loader, optimizer, criterion, epoch)

print('Training completed.')


Epoch [1], Step [10/595], Loss: 2.1827, Accuracy: 59.00%
Epoch [1], Step [20/595], Loss: 2.1810, Accuracy: 58.00%
Epoch [1], Step [30/595], Loss: 2.1796, Accuracy: 57.67%
Epoch [1], Step [40/595], Loss: 2.1789, Accuracy: 58.25%
Epoch [1], Step [50/595], Loss: 2.1784, Accuracy: 58.00%
Epoch [1], Step [60/595], Loss: 2.1787, Accuracy: 57.50%
Epoch [1], Step [70/595], Loss: 2.1780, Accuracy: 57.86%
Epoch [1], Step [80/595], Loss: 2.1787, Accuracy: 57.62%
Epoch [1], Step [90/595], Loss: 2.1780, Accuracy: 58.22%
Epoch [1], Step [100/595], Loss: 2.1781, Accuracy: 58.40%
Epoch [1], Step [110/595], Loss: 2.1783, Accuracy: 57.64%
Epoch [1], Step [120/595], Loss: 2.1784, Accuracy: 57.58%
Epoch [1], Step [130/595], Loss: 2.1786, Accuracy: 57.23%
Epoch [1], Step [140/595], Loss: 2.1791, Accuracy: 56.79%
Epoch [1], Step [150/595], Loss: 2.1789, Accuracy: 57.07%
Epoch [1], Step [160/595], Loss: 2.1785, Accuracy: 57.12%
Epoch [1], Step [170/595], Loss: 2.1782, Accuracy: 57.18%
Epoch [1], Step [180/59

In [79]:
class LoRA_Scratch(nn.Module):
    def __init__(self, layer, rank, alpha=0.5):
        super().__init__()
        feature_in, feature_out = layer.weight.shape

        self.A = nn.Parameter(torch.zeros(feature_in, rank).to(device))
        self.B = nn.Parameter(torch.zeros(rank, feature_out).to(device))
        
        self.scale = alpha / rank
        self.LoRA = True

    def forward(self, X):
        if self.LoRA:
            return X + (self.A @ self.B) * self.scale
        else:
            return X

def LoRA_convertor(layer, rank=1, lora_alpha=1):
    return LoRA_Scratch(layer, rank=rank, alpha=lora_alpha)

# Register LoRA to the layers
parametrize.register_parametrization(model.layer1, "weight", LoRA_convertor(model.layer1))
parametrize.register_parametrization(model.layer2, "weight", LoRA_convertor(model.layer2))
parametrize.register_parametrization(model.layer3, "weight", LoRA_convertor(model.layer3))

def Enable_lora(enable=True):
    for Layer in [model.layer1, model.layer2, model.layer3]:
        Layer.parametrizations['weight'][0].LoRA = enable
        for param in [Layer.parametrizations['weight'][0].A, Layer.parametrizations['weight'][0].B]:
            param.requires_grad = enable

# Load the MNIST dataset, keeping only the digit 9
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize with MNIST mean and std
])
mnist_trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Define training function
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()            
        
        outputs = model(data)
        loss = criterion(outputs, target)   
        loss.backward()                
        optimizer.step()                 
        running_loss += loss.item()
        
        # Count correct predictions
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        if (batch_idx + 1) % 10 == 0:  # Print every 10 batches
            print(f'Epoch [{epoch}], Step [{batch_idx + 1}/{len(train_loader)}], '
                  f'Loss: {running_loss / (batch_idx + 1):.4f}, '
                  f'Accuracy: {100 * correct / total:.2f}%, '
                  f'Correct Predictions: {correct}/{total}')
    
    epoch_time = time.time() - start_time
    print(f'Epoch [{epoch}] completed in {epoch_time:.2f} seconds. '
          f'Average Loss: {running_loss / len(train_loader):.4f}, '
          f'Accuracy: {100 * correct / total:.2f}%')

# Set up optimizer and loss criterion

# Training loop
if __name__ == "__main__":
    num_epochs = 2
    for epoch in range(1, num_epochs + 1):
        Enable_lora(True)  # Enable LoRA during training
        train(model, device, train_loader, optimizer, criterion, epoch)

print('Training completed.')

Epoch [1], Step [10/595], Loss: 2.1782, Accuracy: 60.00%, Correct Predictions: 60/100
Epoch [1], Step [20/595], Loss: 2.1799, Accuracy: 55.00%, Correct Predictions: 110/200
Epoch [1], Step [30/595], Loss: 2.1808, Accuracy: 55.67%, Correct Predictions: 167/300
Epoch [1], Step [40/595], Loss: 2.1801, Accuracy: 58.00%, Correct Predictions: 232/400
Epoch [1], Step [50/595], Loss: 2.1792, Accuracy: 56.80%, Correct Predictions: 284/500
Epoch [1], Step [60/595], Loss: 2.1786, Accuracy: 56.67%, Correct Predictions: 340/600
Epoch [1], Step [70/595], Loss: 2.1790, Accuracy: 56.00%, Correct Predictions: 392/700
Epoch [1], Step [80/595], Loss: 2.1790, Accuracy: 56.50%, Correct Predictions: 452/800
Epoch [1], Step [90/595], Loss: 2.1785, Accuracy: 57.89%, Correct Predictions: 521/900
Epoch [1], Step [100/595], Loss: 2.1787, Accuracy: 57.70%, Correct Predictions: 577/1000
Epoch [1], Step [110/595], Loss: 2.1781, Accuracy: 57.64%, Correct Predictions: 634/1100
Epoch [1], Step [120/595], Loss: 2.1775,