In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

In [2]:
_ = torch.manual_seed(69)

In [3]:

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])



trainset = datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=10, shuffle=True)


testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=10, shuffle=True)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:06<00:00, 28080913.93it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [4]:
class DenseFNN(nn.Module):
    def __init__(self, hidden_size_1=2048, hidden_size_2=1024):
        super(DenseFNN,self).__init__()
        self.linear1 = nn.Linear(32 * 32 * 3, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 32 * 32 * 3)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

dense_network = DenseFNN().to(device)


In [5]:

loss_fn = nn.CrossEntropyLoss()

def train(train_loader, model, epochs=5,iterations_limit=None):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    model.train()
    total_iterations = 0
    for epoch in range(epochs):
        print(f'<---------epoch:{epoch+1}----------->')
        total_loss = 0
        for batch in tqdm(train_loader):
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            output = model(inputs.view(-1, 32 * 32 * 3))
            loss = loss_fn(output, targets)
            total_loss += loss.item()
          
            loss.backward()
            optimizer.step()
            total_iterations += 1
            if iterations_limit is not None and total_iterations >= iterations_limit:
                print("Iteration limit reached. Stopping training.")
                return model
        avg_loss = total_loss / len(train_loader)
        print(f'Average loss for epoch:{epoch+1} is {avg_loss}')
            
    return model

pretrained_model=train(train_loader, dense_network, epochs=15)

<---------epoch:1----------->


100%|██████████| 5000/5000 [00:21<00:00, 229.67it/s]


Average loss for epoch:1 is 1.8011643437981606
<---------epoch:2----------->


100%|██████████| 5000/5000 [00:21<00:00, 234.79it/s]


Average loss for epoch:2 is 1.6511358682513237
<---------epoch:3----------->


100%|██████████| 5000/5000 [00:21<00:00, 232.83it/s]


Average loss for epoch:3 is 1.5737762178540229
<---------epoch:4----------->


100%|██████████| 5000/5000 [00:21<00:00, 231.49it/s]


Average loss for epoch:4 is 1.5113011913120746
<---------epoch:5----------->


100%|██████████| 5000/5000 [00:21<00:00, 231.20it/s]


Average loss for epoch:5 is 1.4610699408233165
<---------epoch:6----------->


100%|██████████| 5000/5000 [00:21<00:00, 234.00it/s]


Average loss for epoch:6 is 1.409328603476286
<---------epoch:7----------->


100%|██████████| 5000/5000 [00:21<00:00, 232.83it/s]


Average loss for epoch:7 is 1.3735657937765122
<---------epoch:8----------->


100%|██████████| 5000/5000 [00:21<00:00, 234.08it/s]


Average loss for epoch:8 is 1.341443283957243
<---------epoch:9----------->


100%|██████████| 5000/5000 [00:21<00:00, 233.65it/s]


Average loss for epoch:9 is 1.3003583974719048
<---------epoch:10----------->


100%|██████████| 5000/5000 [00:21<00:00, 234.65it/s]


Average loss for epoch:10 is 1.2723090587079524
<---------epoch:11----------->


100%|██████████| 5000/5000 [00:21<00:00, 234.27it/s]


Average loss for epoch:11 is 1.2414974037885667
<---------epoch:12----------->


100%|██████████| 5000/5000 [00:21<00:00, 234.08it/s]


Average loss for epoch:12 is 1.2149927342146636
<---------epoch:13----------->


100%|██████████| 5000/5000 [00:21<00:00, 233.97it/s]


Average loss for epoch:13 is 1.1920241847425699
<---------epoch:14----------->


100%|██████████| 5000/5000 [00:21<00:00, 233.53it/s]


Average loss for epoch:14 is 1.1654484893083572
<---------epoch:15----------->


100%|██████████| 5000/5000 [00:21<00:00, 234.75it/s]

Average loss for epoch:15 is 1.1439058827638626





In [6]:
original_weights = {}
for name, param in pretrained_model.named_parameters():
    original_weights[name] = param.clone().detach()

In [7]:
classes = ['plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [29]:
# def test(model):
#     correct = 0
#     total = 0

#     wrong_counts = [0 for i in range(10)]

#     with torch.no_grad():
#         for batch in tqdm(test_loader, desc='Testing'):
#             inputs, targets = batch
#             inputs = inputs.to(device)
#             targets = targets.to(device)
            
#             output = model(inputs.view(-1, 32 * 32 * 3))
#             for idx, i in enumerate(output):
#                 if torch.argmax(i) == targets[idx]:
#                     correct +=1
#                 else:
#                     wrong_counts[targets[idx]] +=1
#                 total +=1
#     for i in range(len(wrong_counts)):
#         print(f'wrong counts for the animal {classes[i]}: {wrong_counts[i]}')

# test(pretrained_model)

In [8]:
    
def test(model):
    # prepare to count predictions for each class
    correct_pred = {classname: 0 for classname in classes}
    total_pred = {classname: 0 for classname in classes}

    # again no gradients needed
    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Testing'):
            images, labels = batch
            images,labels = images.to(device),labels.to(device)
            outputs = model(images)
            _, predictions = torch.max(outputs, 1)
            for label, prediction in zip(labels, predictions):
                if label == prediction:
                    correct_pred[classes[label]] += 1
                total_pred[classes[label]] += 1


    # print accuracy for each class
    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count) / total_pred[classname]
        print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
        
test(pretrained_model)

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 265.06it/s]

Accuracy for class: plane is 79.0 %
Accuracy for class: car   is 66.5 %
Accuracy for class: bird  is 22.6 %
Accuracy for class: cat   is 54.5 %
Accuracy for class: deer  is 41.0 %
Accuracy for class: dog   is 3.4 %
Accuracy for class: frog  is 59.7 %
Accuracy for class: horse is 52.6 %
Accuracy for class: ship  is 56.7 %
Accuracy for class: truck is 46.4 %





In [31]:
# sample_test_loader =torch.utils.data.DataLoader(testset, batch_size=2, shuffle=True)
# for batch in sample_test_loader:
#     inputs,targets = batch
#     inputs = inputs.to(device)
#     targets = targets.to(device)
#     break
# output= lora_ftnd_model(inputs.view(-1, 32 * 32 * 3))    
# print(output.shape)
# torch.max(output,dim=1)
# for idx,i in enumerate(output):
#     print(f'idx:{idx} ,and i: {i}')
    

In [9]:
class LoRAParametrization(nn.Module):
    def __init__(self, input_features, out_features, rank, alpha=1, device=device):
        super().__init__()
        
        self.lora_A = nn.Parameter(torch.zeros((rank,out_features)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((input_features, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        

        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights



In [10]:
import torch.nn.utils.parametrize as parametrize

def linear_layer_parameterization(layer, device, rank, lora_alpha=1):
    
    input_features, out_features = layer.weight.shape
    return LoRAParametrization(
        input_features, out_features, rank=rank, alpha=lora_alpha, device=device
    )



In [11]:
def parameterize_model(model,rank):
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            parametrize.register_parametrization(
            module, "weight", linear_layer_parameterization(module, device,rank=rank)
        )
    
    return model

In [12]:
lora_parameterized_model = parameterize_model(pretrained_model,rank=1)
print(lora_parameterized_model)

DenseFNN(
  (linear1): ParametrizedLinear(
    in_features=3072, out_features=2048, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (linear2): ParametrizedLinear(
    in_features=2048, out_features=1024, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (linear3): ParametrizedLinear(
    in_features=1024, out_features=10, bias=True
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (relu): ReLU()
)


In [36]:
# parametrize.register_parametrization(
#     pretrained_model.linear1, "weight", linear_layer_parameterization(pretrained_model.linear1, device,rank=2)
# )
# parametrize.register_parametrization(
#     pretrained_model.linear2, "weight", linear_layer_parameterization(pretrained_model.linear2, device,rank=2)
# )
# parametrize.register_parametrization(
#     pretrained_model.linear3, "weight", linear_layer_parameterization(pretrained_model.linear3, device,rank=2)
# )

In [15]:
def lora_switch(model,enable=True):
    for layer in [model.linear1, model.linear2, model.linear3]:
        layer.parametrizations["weight"][0].enabled = enable

In [16]:
# Freeze the non-Lora parameters
for name, param in lora_parameterized_model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False


trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
specific_indices=[idx for idx,target in enumerate(trainset.targets)if target == 5]

filtered_trainset= torch.utils.data.Subset(trainset, specific_indices)

filtered_trainloader = torch.utils.data.DataLoader(filtered_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
lora_ftnd_model = train(filtered_trainloader, lora_parameterized_model, epochs=1,iterations_limit=20)

Freezing non-LoRA parameter linear1.bias
Freezing non-LoRA parameter linear1.parametrizations.weight.original
Freezing non-LoRA parameter linear2.bias
Freezing non-LoRA parameter linear2.parametrizations.weight.original
Freezing non-LoRA parameter linear3.bias
Freezing non-LoRA parameter linear3.parametrizations.weight.original
Files already downloaded and verified
<---------epoch:1----------->


  4%|▍         | 19/500 [00:00<00:03, 132.30it/s]

Iteration limit reached. Stopping training.





In [18]:
# Test with LoRA enabled
lora_switch(lora_ftnd_model,enabled=True)
test(lora_ftnd_model)

Testing: 100%|██████████| 1000/1000 [00:04<00:00, 203.27it/s]

Accuracy for class: plane is 74.9 %
Accuracy for class: car   is 66.1 %
Accuracy for class: bird  is 6.6 %
Accuracy for class: cat   is 4.3 %
Accuracy for class: deer  is 32.0 %
Accuracy for class: dog   is 75.4 %
Accuracy for class: frog  is 55.8 %
Accuracy for class: horse is 37.2 %
Accuracy for class: ship  is 12.6 %
Accuracy for class: truck is 44.7 %





In [21]:
lora_switch(lora_ftnd_model,enable=False)
test(lora_ftnd_model)

Testing: 100%|██████████| 1000/1000 [00:03<00:00, 256.00it/s]

Accuracy for class: plane is 79.0 %
Accuracy for class: car   is 66.5 %
Accuracy for class: bird  is 22.6 %
Accuracy for class: cat   is 54.5 %
Accuracy for class: deer  is 41.0 %
Accuracy for class: dog   is 3.4 %
Accuracy for class: frog  is 59.7 %
Accuracy for class: horse is 52.6 %
Accuracy for class: ship  is 56.7 %
Accuracy for class: truck is 46.4 %



