In [27]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

_ = torch.manual_seed(0)

transform = transforms.Compose([transforms.ToTensor() ,

transforms.Normalize((0.1307,),(0.3081,))

])




mnist_trainset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=False)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class SimplifiedVGG16(nn.Module):
    def __init__(self, num_classes=10):
        super(SimplifiedVGG16, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(256 * 3 * 3, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = SimplifiedVGG16(num_classes=10).to(device)


num_classes = 10
num_epochs = 10
batch_size = 10
learning_rate = 0.005

model =SimplifiedVGG16(num_classes).to(device)


# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  


# Train the model
total_step = len(train_loader)

total_step = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            
    # Validation
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs
    
        print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total)) 

Epoch [1/10], Step [6000/6000], Loss: 0.0976
Accuracy of the network on the 5000 validation images: 97.75 %
Epoch [2/10], Step [6000/6000], Loss: 0.0016
Accuracy of the network on the 5000 validation images: 98.0 %
Epoch [3/10], Step [6000/6000], Loss: 0.0080
Accuracy of the network on the 5000 validation images: 97.55 %
Epoch [4/10], Step [6000/6000], Loss: 0.0281
Accuracy of the network on the 5000 validation images: 97.56 %
Epoch [5/10], Step [6000/6000], Loss: 0.0004
Accuracy of the network on the 5000 validation images: 98.61 %
Epoch [6/10], Step [6000/6000], Loss: 0.0869
Accuracy of the network on the 5000 validation images: 98.19 %
Epoch [7/10], Step [6000/6000], Loss: 0.0249
Accuracy of the network on the 5000 validation images: 97.71 %
Epoch [8/10], Step [6000/6000], Loss: 0.0294
Accuracy of the network on the 5000 validation images: 98.69 %
Epoch [9/10], Step [6000/6000], Loss: 0.2552
Accuracy of the network on the 5000 validation images: 96.76 %
Epoch [10/10], Step [6000/600

In [None]:
original_weights = {}

for name , param in model.named_parameters():

    original_weights[name] = param.clone().detach()

print(original_weights)

# now test this model on each numbers in mnist


In [29]:
def test():

    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():

        for data in tqdm(test_loader , desc='Testing'):

                x,y = data

                x = x.to(device)
                y = y.to(device)
                output = model(x)

                for idx , i in enumerate(output):

                    if torch.argmax(i) == y[idx]:
                        correct += 1
                    else:
                        wrong_counts[y[idx]] +=1
                    total+=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')

test()


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 579.31it/s]

Accuracy: 0.986
wrong counts for the digit 0: 7
wrong counts for the digit 1: 10
wrong counts for the digit 2: 7
wrong counts for the digit 3: 5
wrong counts for the digit 4: 31
wrong counts for the digit 5: 11
wrong counts for the digit 6: 14
wrong counts for the digit 7: 30
wrong counts for the digit 8: 13
wrong counts for the digit 9: 13





# Calculating Total number of parameters


In [30]:
total_params = 0
for name, module in model.named_modules():
    if isinstance(module, (nn.Conv2d, nn.Linear)):
        weights_params = module.weight.nelement()
        bias_params = module.bias.nelement() if module.bias is not None else 0
        total_params += weights_params + bias_params
        print(f'Layer {name}: W: {module.weight.shape} + B: {module.bias.shape if module.bias is not None else "No bias"}')
        print(f'Parameters: {weights_params + bias_params}')

print(f'Total number of parameters: {total_params:,}')

Layer features.0: W: torch.Size([64, 1, 3, 3]) + B: torch.Size([64])
Parameters: 640
Layer features.3: W: torch.Size([128, 64, 3, 3]) + B: torch.Size([128])
Parameters: 73856
Layer features.6: W: torch.Size([256, 128, 3, 3]) + B: torch.Size([256])
Parameters: 295168
Layer features.8: W: torch.Size([256, 256, 3, 3]) + B: torch.Size([256])
Parameters: 590080
Layer classifier.0: W: torch.Size([1024, 2304]) + B: torch.Size([1024])
Parameters: 2360320
Layer classifier.2: W: torch.Size([10, 1024]) + B: torch.Size([10])
Parameters: 10250
Total number of parameters: 3,330,314


In [31]:
!pip install torchsummary



In [32]:
from torchsummary import summary

summary(model, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 28, 28]             640
              ReLU-2           [-1, 64, 28, 28]               0
         MaxPool2d-3           [-1, 64, 14, 14]               0
            Conv2d-4          [-1, 128, 14, 14]          73,856
              ReLU-5          [-1, 128, 14, 14]               0
         MaxPool2d-6            [-1, 128, 7, 7]               0
            Conv2d-7            [-1, 256, 7, 7]         295,168
              ReLU-8            [-1, 256, 7, 7]               0
            Conv2d-9            [-1, 256, 7, 7]         590,080
             ReLU-10            [-1, 256, 7, 7]               0
        MaxPool2d-11            [-1, 256, 3, 3]               0
           Linear-12                 [-1, 1024]       2,360,320
             ReLU-13                 [-1, 1024]               0
           Linear-14                   

In [33]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Section 4.1 of the paper: 
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # Section 4.1 of the paper: 
        #   We then scale ∆Wx by α/r , where α is a constant in r. 
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. 
        #   As a result, we simply set α to the first r we try and do not tune it. 
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

In [34]:
import torch.nn.utils.parametrize as parametrize

def conv_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    features_in = layer.weight.shape[1] * layer.weight.shape[2] * layer.weight.shape[3]
    features_out = layer.weight.shape[0]
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )

def add_lora_to_model(model, device, rank=1, lora_alpha=1):
    # Add LoRA to convolutional layers
    for i, layer in enumerate(model.features):
        if isinstance(layer, nn.Conv2d):
            parametrize.register_parametrization(
                layer, "weight", conv_layer_parameterization(layer, device, rank, lora_alpha)
            )

    # Add LoRA to linear layers
    for i, layer in enumerate(model.classifier):
        if isinstance(layer, nn.Linear):
            parametrize.register_parametrization(
                layer, "weight", linear_layer_parameterization(layer, device, rank, lora_alpha)
            )

def enable_disable_lora(model, enabled=True):
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)) and hasattr(module, 'parametrizations'):
            module.parametrizations["weight"][0].enabled = enabled



In [35]:
def count_parameters(model):
    total_params_original = 0
    total_params_lora = 0
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            weights_params = module.weight.nelement()
            bias_params = module.bias.nelement() if module.bias is not None else 0
            total_params_original += weights_params + bias_params
            
            if hasattr(module, 'parametrizations'):
                lora_params = module.parametrizations["weight"][0].lora_A.nelement() + \
                              module.parametrizations["weight"][0].lora_B.nelement()
                total_params_lora += lora_params
                
                print(f'Layer {name}:')
                print(f'  W: {module.weight.shape}')
                print(f'  B: {module.bias.shape if module.bias is not None else "No bias"}')
                print(f'  Lora_A: {module.parametrizations["weight"][0].lora_A.shape}')
                print(f'  Lora_B: {module.parametrizations["weight"][0].lora_B.shape}')
                print(f'  Parameters: {weights_params + bias_params + lora_params}')
            else:
                print(f'Layer {name}:')
                print(f'  W: {module.weight.shape}')
                print(f'  B: {module.bias.shape if module.bias is not None else "No bias"}')
                print(f'  Parameters: {weights_params + bias_params}')
    
    print(f'\nTotal number of parameters (original): {total_params_original:,}')
    print(f'Total number of parameters (original + LoRA): {total_params_original + total_params_lora:,}')
    print(f'Parameters introduced by LoRA: {total_params_lora:,}')
    parameters_increment = (total_params_lora / total_params_original) * 100
    print(f'Parameters increment: {parameters_increment:.3f}%')

In [None]:
print("Parameters before adding LoRA:")
count_parameters(model)

# Add LoRA to the model
add_lora_to_model(model, device, rank=4, lora_alpha=1)

# Count parameters after adding LoRA
print("\nParameters after adding LoRA:")
count_parameters(model)

# Example of enabling/disabling LoRA
enable_disable_lora(model, enabled=True)

# Now lets finetune for 9  , 4 and 5

In [37]:
def train(train_loader , model , epochs=5 , total_iterations_limit = None):

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  


        # Train the model
        total_step = len(train_loader)

        total_step = len(train_loader)

        for epoch in range(num_epochs):
            for i, (images, labels) in enumerate(train_loader):  
                # Move tensors to the configured device
                images = images.to(device)
                labels = labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                        .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
                    
            # Validation
            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in test_loader:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    del images, labels, outputs
            
                print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total)) 

In [None]:
# for name, param in model.named_parameters():
#     if 'lora' not in name:
#         print(f'Freezing non-LoRA parameter {name}')
#         param.requires_grad = False


# mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# exclude_indices = mnist_trainset.targets == 9
# mnist_trainset.data = mnist_trainset.data[exclude_indices]
# mnist_trainset.targets = mnist_trainset.targets[exclude_indices]

# train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# train(train_loader, model, epochs=1, total_iterations_limit=100)

for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create a dataloader for the training
include_indices = (mnist_trainset.targets == 9) | (mnist_trainset.targets == 4) | (mnist_trainset.targets == 5)

# Apply the mask to both data and targets
mnist_trainset.data = mnist_trainset.data[include_indices]
mnist_trainset.targets = mnist_trainset.targets[include_indices]
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, model, epochs=1, total_iterations_limit=100)

In [41]:
enable_disable_lora(model,enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:02<00:00, 483.76it/s]

Accuracy: 0.972
wrong counts for the digit 0: 23
wrong counts for the digit 1: 9
wrong counts for the digit 2: 23
wrong counts for the digit 3: 32
wrong counts for the digit 4: 8
wrong counts for the digit 5: 3
wrong counts for the digit 6: 41
wrong counts for the digit 7: 87
wrong counts for the digit 8: 43
wrong counts for the digit 9: 8





In [42]:
enable_disable_lora(model,enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 558.99it/s]

Accuracy: 0.986
wrong counts for the digit 0: 7
wrong counts for the digit 1: 10
wrong counts for the digit 2: 7
wrong counts for the digit 3: 5
wrong counts for the digit 4: 31
wrong counts for the digit 5: 11
wrong counts for the digit 6: 14
wrong counts for the digit 7: 30
wrong counts for the digit 8: 13
wrong counts for the digit 9: 13



