In [3]:
import loralib as lora
import copy
import os
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import timm
import torch
import torch.nn as nn

from scipy.spatial import distance
import seaborn as sns

In [4]:
def disable_module(module):
    for p in module.parameters():
        p.requires_grad = False
        
def enable_module(module):
    for p in module.parameters():
        p.requires_grad = True


def check_tunable_params(model, verbose=True):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    
    for name, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            if(verbose):
                print(name)
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.5f}"
    )

    return trainable_params, all_param

def create_mapping(model, vector):
    mapping = {}
    i = 0

    for name_p,p in model.named_parameters():
        if '.attn.' in name_p or 'attention' in name_p:
            mapping[name_p] = vector[i]
            i += 1
        else:
            p.requires_grad = False
            
    return mapping

def sort_dict(dict, descending=False):
    sorted_dict = dict(sorted(dict.items(), key=lambda item: item[1], reverse=descending))
    
    return sorted_dict

def get_modules_from_vector(vector, model):
    trainable_blocks = []
    frozen_blocks = []
    
    trainable_blocks = np.where(np.array(vector) == 1)
    frozen_blocks = np.where(np.array(vector) == 0)
    
    return trainable_blocks, frozen_blocks

def get_model_for_bitfit(model):
    trainable_components = ['bias', 'pooler.dense.bias', 'head'] 

    # Disale all the gradients
    for param in model.parameters():
        param.requires_grad = False 
      
    vector = []

    for name, param in model.named_parameters():
        for component in trainable_components:
            if component in name:
                vector.append(1)
                param.requires_grad = True
                break
    
    return vector

def enable_from_vector(vector, model):
    print("Vector: ", vector)
    
    disable_module(model)
    
    for idx, block in enumerate(model.blocks): 
    
        if(vector[idx] == 1):
            print("Enabling attention in Block {}".format(idx))
            enable_module(block.attn)
        else:
            #print("Disabling attention in Block {}".format(idx))
            disable_module(block.attn)

def create_best_worst_vectors(df, k=10):
    best_df = df.sort_values(by=['Test Acc@1'], ascending=False).head(k).reset_index(drop=True)
    worst_df = df.sort_values(by=['Test Acc@1'], ascending=True).head(k).reset_index(drop=True)

    best_vector = np.array([0]*12)

    for i in range(len(best_df)):
        vector_path = best_df['Vector Path'][i]
        vector = np.load(vector_path)
        best_vector += vector

    worst_vector = np.array([0]*12)

    for i in range(len(worst_df)):
        vector_path = worst_df['Vector Path'][i]
        vector = np.load(vector_path)
        worst_vector += vector

    return best_vector, worst_vector

def tune_blocks_random(model, mask, segment):

    vector = []

    for idx, block in enumerate(model.blocks):

        if(mask is None):
            bit = int(np.random.random(1)[0] > 0.5)
        else:
            bit = mask[idx]

        if(bit == 1):
            print("Enabling {} in Block {}".format(segment, idx))
            if(segment == 'attention'):
                enable_module(block.attn)
            elif(segment == 'layernorm'):
                enable_module(block.norm1)
                enable_module(block.norm2)

            vector.append(1)
        else:
            print("Disabling {} in Block {}".format(segment, idx))
            if(segment == 'attention'):
                disable_module(block.attn)
            elif(segment == 'layernorm'):
                disable_module(block.norm1)
                disable_module(block.norm2)
            
            vector.append(0)
    
    if(mask is not None):
        assert (mask == vector)
        
    return vector

In [9]:
def create_lora_model(model, lora_r: int = 8, lora_alpha: int = 8, lora_dropout: float = 0., tune_k=False, block_mask=None):

    lora_model = copy.deepcopy(model)
    
    tune_list = [True, True, True] if tune_k else [True, False, True]

    block_mask = [1]*len(model.blocks) if block_mask is None else block_mask    #Apply LoRA to all attention layers if mask is not given.

    for idx, block in enumerate(lora_model.blocks):
        if(block_mask[idx] == 1):
            in_d = block.attn.qkv.in_features
            out_d = block.attn.qkv.out_features
            block.attn.qkv = lora.MergedLinear(in_d, out_d, r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=tune_list)

    lora_model.load_state_dict(model.state_dict(),strict=False)
    lora.mark_only_lora_as_trainable(lora_model)
    
    
    return lora_model

In [15]:
model = timm.create_model('vit_base_patch16_224', pretrained=True)

In [12]:
block_mask = np.load('/home/co-dutt1/rds/hpc-work/Layer-Masking/Experiment_Vectors/random_vector_1.npy')
print("Block Mask: ", block_mask)
lora_model = create_lora_model(model, block_mask=block_mask)
check_tunable_params(lora_model, True)

Block Mask:  [1 1 1 1 0 0 1 0 0 1 0 1]
blocks.0.attn.qkv.lora_A
blocks.0.attn.qkv.lora_B
blocks.1.attn.qkv.lora_A
blocks.1.attn.qkv.lora_B
blocks.2.attn.qkv.lora_A
blocks.2.attn.qkv.lora_B
blocks.3.attn.qkv.lora_A
blocks.3.attn.qkv.lora_B
blocks.6.attn.qkv.lora_A
blocks.6.attn.qkv.lora_B
blocks.9.attn.qkv.lora_A
blocks.9.attn.qkv.lora_B
blocks.11.attn.qkv.lora_A
blocks.11.attn.qkv.lora_B
trainable params: 172032 || all params: 86739688 || trainable%: 0.19833


(172032, 86739688)