In [1]:
import torch
import torch.nn as nn
from IPython.display import Image
# !pip install torchview
import torchvision
from torchview import draw_graph

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [2]:
model_parameters={}
model_parameters['resnet18'] = ([64,128,256,512],[2,2,2,2],1,False)
model_parameters['resnet34'] = ([64,128,256,512],[3,4,6,3],1,False)
model_parameters['resnet50'] = ([64,128,256,512],[3,4,6,3],4,True)
model_parameters['resnet101'] = ([64,128,256,512],[3,4,23,3],4,True)
model_parameters['resnet152'] = ([64,128,256,512],[3,8,36,3],4,True)


class Bottleneck(nn.Module):

    def __init__(self,in_channels,intermediate_channels,expansion,is_Bottleneck,stride):
        
        """
        Creates a Bottleneck with conv 1x1->3x3->1x1 layers.
        
        Note:
          1. Addition of feature maps occur at just before the final ReLU with the input feature maps
          2. if input size is different from output, select projected mapping or else identity mapping.
          3. if is_Bottleneck=False (3x3->3x3) are used else (1x1->3x3->1x1). Bottleneck is required for resnet-50/101/152
        Args:
            in_channels (int) : input channels to the Bottleneck
            intermediate_channels (int) : number of channels to 3x3 conv 
            expansion (int) : factor by which the input #channels are increased
            stride (int) : stride applied in the 3x3 conv. 2 for first Bottleneck of the block and 1 for remaining

        Attributes:
            Layer consisting of conv->batchnorm->relu

        """

        super(Bottleneck,self).__init__()

        self.expansion = expansion
        self.in_channels = in_channels
        self.intermediate_channels = intermediate_channels
        self.is_Bottleneck = is_Bottleneck
        
        # i.e. if dim(x) == dim(F) => Identity function
        if self.in_channels==self.intermediate_channels*self.expansion:
            self.identity = True
        else:
            self.identity = False
            projection_layer = []
            projection_layer.append(nn.Conv2d(in_channels=self.in_channels, out_channels=self.intermediate_channels*self.expansion, kernel_size=1, stride=stride, padding=0, bias=False ))
            projection_layer.append(nn.BatchNorm2d(self.intermediate_channels*self.expansion))
            # Only conv->BN and no ReLU
            # projection_layer.append(nn.ReLU())
            self.projection = nn.Sequential(*projection_layer)

        # commonly used relu
        self.relu = nn.ReLU()

        # is_Bottleneck = True for all ResNet 50+
        if self.is_Bottleneck:
            # bottleneck
            # 1x1
            self.conv1_1x1 = nn.Conv2d(in_channels=self.in_channels, out_channels=self.intermediate_channels, kernel_size=1, stride=1, padding=0, bias=False )
            self.batchnorm1 = nn.BatchNorm2d(self.intermediate_channels)
            
            # 3x3
            self.conv2_3x3 = nn.Conv2d(in_channels=self.intermediate_channels, out_channels=self.intermediate_channels, kernel_size=3, stride=stride, padding=1, bias=False )
            self.batchnorm2 = nn.BatchNorm2d(self.intermediate_channels)
            
            # 1x1
            self.conv3_1x1 = nn.Conv2d(in_channels=self.intermediate_channels, out_channels=self.intermediate_channels*self.expansion, kernel_size=1, stride=1, padding=0, bias=False )
            self.batchnorm3 = nn.BatchNorm2d( self.intermediate_channels*self.expansion )
        
        else:
            # basicblock
            # 3x3
            self.conv1_3x3 = nn.Conv2d(in_channels=self.in_channels, out_channels=self.intermediate_channels, kernel_size=3, stride=stride, padding=1, bias=False )
            self.batchnorm1 = nn.BatchNorm2d(self.intermediate_channels)
            
            # 3x3
            self.conv2_3x3 = nn.Conv2d(in_channels=self.intermediate_channels, out_channels=self.intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False )
            self.batchnorm2 = nn.BatchNorm2d(self.intermediate_channels)

    def forward(self,x):
        # input stored to be added before the final relu
        in_x = x

        if self.is_Bottleneck:
            # conv1x1->BN->relu
            x = self.relu(self.batchnorm1(self.conv1_1x1(x)))
            
            # conv3x3->BN->relu
            x = self.relu(self.batchnorm2(self.conv2_3x3(x)))
            
            # conv1x1->BN
            x = self.batchnorm3(self.conv3_1x1(x))
        
        else:
            # conv3x3->BN->relu
            x = self.relu(self.batchnorm1(self.conv1_3x3(x)))

            # conv3x3->BN
            x = self.batchnorm2(self.conv2_3x3(x))


        # identity or projected mapping
        if self.identity:
            x += in_x
        else:
            x += self.projection(in_x)

        # final relu
        x = self.relu(x)
        
        return x


# Bottleneck(64*4,64,4,stride=1)

def test_Bottleneck():
    x = torch.randn(1,64,112,112)
    model = Bottleneck(64,64,4,True,2)
    print(model(x).shape)
    del model

test_Bottleneck()

torch.Size([1, 256, 56, 56])


In [3]:
class ResNet(nn.Module):

    def __init__(self, resnet_variant,in_channels,num_classes):
        """
        Creates the ResNet architecture based on the provided variant. 18/34/50/101 etc.
        Based on the input parameters, define the channels list, repeatition list along with expansion factor(4) and stride(3/1)
        using _make_blocks method, create a sequence of multiple Bottlenecks
        Average Pool at the end before the FC layer 

        Args:
            resnet_variant (list) : eg. [[64,128,256,512],[3,4,6,3],4,True]
            in_channels (int) : image channels (3)
            num_classes (int) : output #classes 

        Attributes:
            Layer consisting of conv->batchnorm->relu

        """
        super(ResNet,self).__init__()
        self.channels_list = resnet_variant[0]
        self.repeatition_list = resnet_variant[1]
        self.expansion = resnet_variant[2]
        self.is_Bottleneck = resnet_variant[3]

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False )
        self.batchnorm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()

        self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

        self.block1 = self._make_blocks( 64 , self.channels_list[0], self.repeatition_list[0], self.expansion, self.is_Bottleneck, stride=1 )
        self.block2 = self._make_blocks( self.channels_list[0]*self.expansion , self.channels_list[1], self.repeatition_list[1], self.expansion, self.is_Bottleneck, stride=2 )
        self.block3 = self._make_blocks( self.channels_list[1]*self.expansion , self.channels_list[2], self.repeatition_list[2], self.expansion, self.is_Bottleneck, stride=2 )
        self.block4 = self._make_blocks( self.channels_list[2]*self.expansion , self.channels_list[3], self.repeatition_list[3], self.expansion, self.is_Bottleneck, stride=2 )

        self.average_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear( self.channels_list[3]*self.expansion , num_classes)



    def forward(self,x):
        x = self.relu(self.batchnorm1(self.conv1(x)))
        x = self.maxpool(x)
        
        x = self.block1(x)
        
        x = self.block2(x)
        
        x = self.block3(x)
        
        x = self.block4(x)
        
        x = self.average_pool(x)

        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
        
        return x

    def _make_blocks(self,in_channels,intermediate_channels,num_repeat, expansion, is_Bottleneck, stride):
        
        """
        Args:
            in_channels : #channels of the Bottleneck input
            intermediate_channels : #channels of the 3x3 in the Bottleneck
            num_repeat : #Bottlenecks in the block
            expansion : factor by which intermediate_channels are multiplied to create the output channels
            is_Bottleneck : status if Bottleneck in required
            stride : stride to be used in the first Bottleneck conv 3x3

        Attributes:
            Sequence of Bottleneck layers

        """
        layers = [] 

        layers.append(Bottleneck(in_channels,intermediate_channels,expansion,is_Bottleneck,stride=stride))
        for num in range(1,num_repeat):
            layers.append(Bottleneck(intermediate_channels*expansion,intermediate_channels,expansion,is_Bottleneck,stride=1))

        return nn.Sequential(*layers)


def test_ResNet(params):
    model = ResNet( params , in_channels=3, num_classes=1000)
    x = torch.randn(1,3,224,224)
    output = model(x)
    print(output.shape)
    return model

architecture = 'resnet50'
model = test_ResNet(model_parameters[architecture])


torch.Size([1, 1000])


In [4]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

_ = torch.manual_seed(0)

transform = transforms.Compose([transforms.ToTensor() ,

transforms.Normalize((0.1307,),(0.3081,))

])




mnist_trainset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=False)


# model = SimplifiedVGG16(num_classes=10).to(device)


num_classes = 10
num_epochs = 1
batch_size = 10
learning_rate = 0.005

# model =.to(device)

model = ResNet( model_parameters['resnet50'] , in_channels=1, num_classes=10).to(device)


# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  


# Train the model
total_step = len(train_loader)

total_step = len(train_loader)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # Move tensors to the configured device
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            
    # Validation
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs
    
        print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total)) 


Epoch [1/1], Step [6000/6000], Loss: 0.5794
Accuracy of the network on the 5000 validation images: 94.26 %


# lets make a copy of original weights


In [5]:
original_weights = {}

for name , param in model.named_parameters():

    original_weights[name] = param.clone().detach()


print(original_weights)

{'conv1.weight': tensor([[[[ 1.5971e-01,  2.3322e-01,  2.1765e-01,  ..., -1.7124e-03,
            1.7205e-01,  3.2256e-02],
          [ 8.3770e-02,  1.9675e-01,  2.7954e-01,  ..., -8.1308e-02,
           -8.2088e-02, -6.3201e-02],
          [-8.3370e-02,  1.8852e-01,  2.5028e-01,  ..., -1.5003e-01,
           -1.6160e-01, -4.5111e-02],
          ...,
          [-9.3456e-03, -4.2822e-02,  8.6994e-02,  ...,  2.5766e-01,
            2.1900e-01,  1.1620e-01],
          [-1.3293e-01, -1.7470e-01,  6.7129e-02,  ...,  3.4651e-01,
            1.8532e-01,  1.5070e-01],
          [-1.1327e-01, -3.2232e-02, -1.1343e-03,  ..., -7.6857e-03,
            1.4115e-02,  1.3369e-01]]],


        [[[-8.3469e-02,  1.5050e-02,  1.5197e-01,  ...,  1.1746e-01,
            2.3503e-01,  9.4007e-02],
          [ 4.1660e-02,  7.9188e-02,  1.0423e-01,  ...,  1.4117e-01,
            1.9122e-01,  2.2379e-01],
          [ 1.5624e-01,  2.8769e-01,  2.1155e-01,  ...,  1.5953e-01,
            1.1956e-01,  2.8370e-01],
 

# checking its performance on each number in MNIST

In [6]:
def test():

    correct = 0
    total = 0

    wrong_counts = [0 for i in range(10)]

    with torch.no_grad():

        for data in tqdm(test_loader , desc = 'Testing'):

            x,y = data

            x = x.to(device)
            y = y.to(device)

            output = model(x)

            for idx , i in enumerate(output):

                if torch.argmax(i) == y[idx]:
                    
                    correct += 1
                else:
                    wrong_counts[y[idx]] +=1

                total+=1
        
        print(f'Accuracy : {round(correct/total , 3)}')

        for i in range(len(wrong_counts)):

            print(f'wrong counts for the digit {i} : {wrong_counts[i]}')


test()

Testing: 100%|██████████| 1000/1000 [00:06<00:00, 163.37it/s]

Accuracy : 0.943
wrong counts for the digit 0 : 21
wrong counts for the digit 1 : 10
wrong counts for the digit 2 : 73
wrong counts for the digit 3 : 143
wrong counts for the digit 4 : 34
wrong counts for the digit 5 : 25
wrong counts for the digit 6 : 36
wrong counts for the digit 7 : 69
wrong counts for the digit 8 : 70
wrong counts for the digit 9 : 93





In [7]:
total_params = 0
for name, module in model.named_modules():
    if isinstance(module, (nn.Conv2d, nn.Linear)):
        weights_params = module.weight.nelement()
        bias_params = module.bias.nelement() if module.bias is not None else 0
        total_params += weights_params + bias_params
        print(f'Layer {name}: W: {module.weight.shape} + B: {module.bias.shape if module.bias is not None else "No bias"}')
        print(f'Parameters: {weights_params + bias_params}')

print(f'Total number of parameters: {total_params:,}')

Layer conv1: W: torch.Size([64, 1, 7, 7]) + B: No bias
Parameters: 3136
Layer block1.0.projection.0: W: torch.Size([256, 64, 1, 1]) + B: No bias
Parameters: 16384
Layer block1.0.conv1_1x1: W: torch.Size([64, 64, 1, 1]) + B: No bias
Parameters: 4096
Layer block1.0.conv2_3x3: W: torch.Size([64, 64, 3, 3]) + B: No bias
Parameters: 36864
Layer block1.0.conv3_1x1: W: torch.Size([256, 64, 1, 1]) + B: No bias
Parameters: 16384
Layer block1.1.conv1_1x1: W: torch.Size([64, 256, 1, 1]) + B: No bias
Parameters: 16384
Layer block1.1.conv2_3x3: W: torch.Size([64, 64, 3, 3]) + B: No bias
Parameters: 36864
Layer block1.1.conv3_1x1: W: torch.Size([256, 64, 1, 1]) + B: No bias
Parameters: 16384
Layer block1.2.conv1_1x1: W: torch.Size([64, 256, 1, 1]) + B: No bias
Parameters: 16384
Layer block1.2.conv2_3x3: W: torch.Size([64, 64, 3, 3]) + B: No bias
Parameters: 36864
Layer block1.2.conv3_1x1: W: torch.Size([256, 64, 1, 1]) + B: No bias
Parameters: 16384
Layer block2.0.projection.0: W: torch.Size([512, 2

In [8]:
from torchsummary import summary

summary(model, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 14, 14]           3,136
       BatchNorm2d-2           [-1, 64, 14, 14]             128
              ReLU-3           [-1, 64, 14, 14]               0
         MaxPool2d-4             [-1, 64, 7, 7]               0
            Conv2d-5             [-1, 64, 7, 7]           4,096
       BatchNorm2d-6             [-1, 64, 7, 7]             128
              ReLU-7             [-1, 64, 7, 7]               0
            Conv2d-8             [-1, 64, 7, 7]          36,864
       BatchNorm2d-9             [-1, 64, 7, 7]             128
             ReLU-10             [-1, 64, 7, 7]               0
           Conv2d-11            [-1, 256, 7, 7]          16,384
      BatchNorm2d-12            [-1, 256, 7, 7]             512
           Conv2d-13            [-1, 256, 7, 7]          16,384
      BatchNorm2d-14            [-1, 25

# Now lets define a LORA class

In [9]:
class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        # Section 4.1 of the paper: 
        #   We use a random Gaussian initialization for A and zero for B, so ∆W = BA is zero at the beginning of training
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        
        # Section 4.1 of the paper: 
        #   We then scale ∆Wx by α/r , where α is a constant in r. 
        #   When optimizing with Adam, tuning α is roughly the same as tuning the learning rate if we scale the initialization appropriately. 
        #   As a result, we simply set α to the first r we try and do not tune it. 
        #   This scaling helps to reduce the need to retune hyperparameters when we vary r.
        self.scale = alpha / rank
        self.enabled = True

    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

# Now lets define function for conv layer parameterization and linear layer parametrization


In [17]:
import torch.nn.utils.parametrize as parametrize

def conv_layer_parameterization(layer , device , rank =1 , lora_alpha = 1):

    features_in = layer.weight.shape[1] * layer.weight.shape[2] * layer.weight.shape[3]

    features_out = layer.weight.shape[0]

    return LoRAParametrization(features_in , features_out , rank = rank , alpha=lora_alpha , device = device)

def linear_layer_parameterization(layer , device , rank = 1, lora_alpha=1):

    features_in , features_out = layer.weight.shape

    return LoRAParametrization(

        features_in,features_out,rank=rank,alpha=lora_alpha,device=device
    )



def add_lora_to_model(model, device, rank=1, lora_alpha=1):
    # Add LoRA to conv1
    if isinstance(model.conv1, nn.Conv2d):
        parametrize.register_parametrization(
            model.conv1, "weight", conv_layer_parameterization(model.conv1, device, rank, lora_alpha)
        )

    # Add LoRA to blocks
    for block_name in ['block1', 'block2', 'block3', 'block4']:
        block = getattr(model, block_name)
        for sub_block in block:
            for layer_name in ['conv1_1x1', 'conv2_3x3', 'conv3_1x1']:
                layer = getattr(sub_block, layer_name)
                if isinstance(layer, nn.Conv2d):
                    parametrize.register_parametrization(
                        layer, "weight", conv_layer_parameterization(layer, device, rank, lora_alpha)
                    )
            
            # Add LoRA to projection layer if it exists
            if hasattr(sub_block, 'projection') and isinstance(sub_block.projection[0], nn.Conv2d):
                parametrize.register_parametrization(
                    sub_block.projection[0], "weight", conv_layer_parameterization(sub_block.projection[0], device, rank, lora_alpha)
                )

    # Add LoRA to fc1
    if isinstance(model.fc1, nn.Linear):
        parametrize.register_parametrization(
            model.fc1, "weight", linear_layer_parameterization(model.fc1, device, rank, lora_alpha)
        )

def enable_disable_lora(model, enabled=True):
    for module in model.modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)) and hasattr(module, 'parametrizations'):
            module.parametrizations["weight"][0].enabled = enabled



In [18]:
def count_parameters(model):
    total_params_original = 0
    total_params_lora = 0
    
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            weights_params = module.weight.nelement()
            bias_params = module.bias.nelement() if module.bias is not None else 0
            total_params_original += weights_params + bias_params
            
            if hasattr(module, 'parametrizations'):
                lora_params = module.parametrizations["weight"][0].lora_A.nelement() + \
                              module.parametrizations["weight"][0].lora_B.nelement()
                total_params_lora += lora_params
                
                print(f'Layer {name}:')
                print(f'  W: {module.weight.shape}')
                print(f'  B: {module.bias.shape if module.bias is not None else "No bias"}')
                print(f'  Lora_A: {module.parametrizations["weight"][0].lora_A.shape}')
                print(f'  Lora_B: {module.parametrizations["weight"][0].lora_B.shape}')
                print(f'  Parameters: {weights_params + bias_params + lora_params}')
            else:
                print(f'Layer {name}:')
                print(f'  W: {module.weight.shape}')
                print(f'  B: {module.bias.shape if module.bias is not None else "No bias"}')
                print(f'  Parameters: {weights_params + bias_params}')
    
    print(f'\nTotal number of parameters (original): {total_params_original:,}')
    print(f'Total number of parameters (original + LoRA): {total_params_original + total_params_lora:,}')
    print(f'Parameters introduced by LoRA: {total_params_lora:,}')
    parameters_increment = (total_params_lora / total_params_original) * 100
    print(f'Parameters increment: {parameters_increment:.3f}%')

In [19]:
print("Parameters before adding LoRA:")
count_parameters(model)

# Add LoRA to the model
add_lora_to_model(model, device, rank=4, lora_alpha=1)

# Count parameters after adding LoRA
print("\nParameters after adding LoRA:")
count_parameters(model)

# Example of enabling/disabling LoRA
enable_disable_lora(model, enabled=True)

Parameters before adding LoRA:
Layer conv1:
  W: torch.Size([64, 1, 7, 7])
  B: No bias
  Parameters: 3136
Layer block1.0.projection.0:
  W: torch.Size([256, 64, 1, 1])
  B: No bias
  Parameters: 16384
Layer block1.0.conv1_1x1:
  W: torch.Size([64, 64, 1, 1])
  B: No bias
  Parameters: 4096
Layer block1.0.conv2_3x3:
  W: torch.Size([64, 64, 3, 3])
  B: No bias
  Parameters: 36864
Layer block1.0.conv3_1x1:
  W: torch.Size([256, 64, 1, 1])
  B: No bias
  Parameters: 16384
Layer block1.1.conv1_1x1:
  W: torch.Size([64, 256, 1, 1])
  B: No bias
  Parameters: 16384
Layer block1.1.conv2_3x3:
  W: torch.Size([64, 64, 3, 3])
  B: No bias
  Parameters: 36864
Layer block1.1.conv3_1x1:
  W: torch.Size([256, 64, 1, 1])
  B: No bias
  Parameters: 16384
Layer block1.2.conv1_1x1:
  W: torch.Size([64, 256, 1, 1])
  B: No bias
  Parameters: 16384
Layer block1.2.conv2_3x3:
  W: torch.Size([64, 64, 3, 3])
  B: No bias
  Parameters: 36864
Layer block1.2.conv3_1x1:
  W: torch.Size([256, 64, 1, 1])
  B: No 

In [20]:
def train(train_loader , model , epochs=1 , total_iterations_limit = None):

        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  


        # Train the model
        total_step = len(train_loader)

        total_step = len(train_loader)

        for epoch in range(num_epochs):
            for i, (images, labels) in enumerate(train_loader):  
                # Move tensors to the configured device
                images = images.to(device)
                labels = labels.to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                        .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
                    
            # Validation
            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in test_loader:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    del images, labels, outputs
            
                print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total)) 

# Now lets fine tune it on 6,3,8

In [21]:
for name, param in model.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create a dataloader for the training
include_indices = (mnist_trainset.targets == 6) | (mnist_trainset.targets == 3) | (mnist_trainset.targets == 8)

# Apply the mask to both data and targets
mnist_trainset.data = mnist_trainset.data[include_indices]
mnist_trainset.targets = mnist_trainset.targets[include_indices]
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Train the network with LoRA only on the digit 9 and only for 100 batches (hoping that it would improve the performance on the digit 9)
train(train_loader, model, epochs=1, total_iterations_limit=100)

Freezing non-LoRA parameter conv1.parametrizations.weight.original
Freezing non-LoRA parameter batchnorm1.weight
Freezing non-LoRA parameter batchnorm1.bias
Freezing non-LoRA parameter block1.0.projection.0.parametrizations.weight.original
Freezing non-LoRA parameter block1.0.projection.1.weight
Freezing non-LoRA parameter block1.0.projection.1.bias
Freezing non-LoRA parameter block1.0.conv1_1x1.parametrizations.weight.original
Freezing non-LoRA parameter block1.0.batchnorm1.weight
Freezing non-LoRA parameter block1.0.batchnorm1.bias
Freezing non-LoRA parameter block1.0.conv2_3x3.parametrizations.weight.original
Freezing non-LoRA parameter block1.0.batchnorm2.weight
Freezing non-LoRA parameter block1.0.batchnorm2.bias
Freezing non-LoRA parameter block1.0.conv3_1x1.parametrizations.weight.original
Freezing non-LoRA parameter block1.0.batchnorm3.weight
Freezing non-LoRA parameter block1.0.batchnorm3.bias
Freezing non-LoRA parameter block1.1.conv1_1x1.parametrizations.weight.original
Free

In [22]:
enable_disable_lora(model,enabled=True)
test()

Testing: 100%|██████████| 1000/1000 [00:09<00:00, 106.60it/s]

Accuracy : 0.619
wrong counts for the digit 0 : 99
wrong counts for the digit 1 : 845
wrong counts for the digit 2 : 881
wrong counts for the digit 3 : 4
wrong counts for the digit 4 : 537
wrong counts for the digit 5 : 305
wrong counts for the digit 6 : 7
wrong counts for the digit 7 : 567
wrong counts for the digit 8 : 25
wrong counts for the digit 9 : 542





In [23]:
enable_disable_lora(model,enabled=False)
test()

Testing: 100%|██████████| 1000/1000 [00:06<00:00, 151.48it/s]

Accuracy : 0.943
wrong counts for the digit 0 : 21
wrong counts for the digit 1 : 10
wrong counts for the digit 2 : 73
wrong counts for the digit 3 : 143
wrong counts for the digit 4 : 34
wrong counts for the digit 5 : 25
wrong counts for the digit 6 : 36
wrong counts for the digit 7 : 69
wrong counts for the digit 8 : 70
wrong counts for the digit 9 : 93





# Thats it folks