In [1]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from scipy.special import softmax
import torch
import argparse
from logger import create_logger
import os


from utils import load_checkpoint, load_pretrained
from config import get_config
from data import build_loader
from models import build_model

from main import train_one_epoch, validate, throughput

from config import get_only_config
import json
import copy
import math
import time

import datetime
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.utils import accuracy, AverageMeter

  from .autonotebook import tqdm as notebook_tqdm


Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.
To use FusedLAMB or FusedAdam, please install apex.


In [2]:
config_path = 'configs/swin/swin_tiny_patch4_window7_224_resisc45.yaml'
config = get_only_config(config_path)

=> merge config from configs/swin/swin_tiny_patch4_window7_224_resisc45.yaml


In [3]:
config.defrost()
config.OUTPUT = "/afs/ece.cmu.edu/usr/bmarimut/Private/output"
# config.MODEL.PRETRAINED = "/afs/ece.cmu.edu/usr/ashwinve/Public/ckpt_epoch_29_6.pth"
config.MODEL.PRETRAINED = "/afs/ece.cmu.edu/usr/ashwinve/Public/golden_resisc45.pth"
config.MODEL.RESUME = "/afs/ece.cmu.edu/usr/ashwinve/Public/golden_resisc45.pth"
config.DATA.CACHE_MODE = 'no'
config.DATA.DATA_PATH = './data/RESISC45/'
config.DATA.ZIP_MODE = True
config.PRINT_FREQ = 120
config.DATA.BATCH_SIZE = 8
config.freeze()
os.makedirs(config.OUTPUT, exist_ok=True)
logger = create_logger(output_dir=config.OUTPUT, name=f"{config.MODEL.NAME}")

In [4]:
lora_selector = 4
lora_layer = 0
keep_qkv = True
model = build_model(config, lora_selector, lora_layer, keep_qkv)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)

_IncompatibleKeys(missing_keys=['layers.0.blocks.0.attn.qkv_wUSprime', 'layers.0.blocks.0.attn.qkv_wVprime', 'layers.0.blocks.0.attn.qkv_b', 'layers.0.blocks.1.attn.qkv_wUSprime', 'layers.0.blocks.1.attn.qkv_wVprime', 'layers.0.blocks.1.attn.qkv_b', 'layers.1.blocks.0.attn.qkv_wUSprime', 'layers.1.blocks.0.attn.qkv_wVprime', 'layers.1.blocks.0.attn.qkv_b', 'layers.1.blocks.1.attn.qkv_wUSprime', 'layers.1.blocks.1.attn.qkv_wVprime', 'layers.1.blocks.1.attn.qkv_b', 'layers.2.blocks.0.attn.qkv_wUSprime', 'layers.2.blocks.0.attn.qkv_wVprime', 'layers.2.blocks.0.attn.qkv_b', 'layers.2.blocks.1.attn.qkv_wUSprime', 'layers.2.blocks.1.attn.qkv_wVprime', 'layers.2.blocks.1.attn.qkv_b', 'layers.2.blocks.2.attn.qkv_wUSprime', 'layers.2.blocks.2.attn.qkv_wVprime', 'layers.2.blocks.2.attn.qkv_b', 'layers.2.blocks.3.attn.qkv_wUSprime', 'layers.2.blocks.3.attn.qkv_wVprime', 'layers.2.blocks.3.attn.qkv_b', 'layers.2.blocks.4.attn.qkv_wUSprime', 'layers.2.blocks.4.attn.qkv_wVprime', 'layers.2.blocks.4.

In [6]:
model.init_qkv_low_rank_weights()

  new_sd['qkv_wVprime'] = torch.tensor([q_wVprime, k_wVprime, v_wVprime])


In [None]:
NUM_CONFIGS = 5

for i in range(NUM_CONFIGS):

    config_name = "config"+str(i)+"_chkpt.pth"

    lora_selector = i
    lora_layer = 0
    keep_qkv = True

    my_model = build_model(config, lora_selector, lora_layer, keep_qkv)

    checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
    my_model.load_state_dict(checkpoint['model'], strict=False)

    my_model.init_qkv_low_rank_weights()

    new_sd = copy.deepcopy(my_model.state_dict())

    del_keys = [key for key in new_sd.keys() if ("qkv.weight" in key or "qkv.bias" in key)]
    for key in del_keys:
        del new_sd[key]

    torch.save(new_sd, config_name)

In [7]:
# load_pretrained(config, model, logger)

In [8]:
model.cuda()
model_without_ddp = model
super_model = model
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)

In [None]:
# # Freeze specific layers for downstream task training
lora_w_name_pattern = ['q_b', 'k_b', 'v_b', 'prime']

n_parameters_orig = 0
n_parameters_lora = 0
model_named_params = list(model.named_parameters())
num_params = len(model_named_params)
for param_iter in range(num_params):
    param_name, param = model_named_params[param_iter]
    # if "lora" not in param_name:
    if any(substring in param_name for substring in lora_w_name_pattern):
        n_parameters_lora += param.numel()
    elif "qkv.weight" in param_name or "qkv.bias" in param_name:
        n_parameters_orig += param.numel()
    else:
        n_parameters_orig += param.numel()
        n_parameters_lora += param.numel()


In [9]:
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)

In [None]:
# n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"number of params: {n_parameters_orig}")
logger.info(f"number of LoRA params: {n_parameters_lora}")
flops = 0
flops_lora = 0
if hasattr(model, 'flops'):
    flops = model.flops()
    logger.info(f"number of GFLOPs: {flops / 1e9}")
    flops_lora = model.flops_lora()
    logger.info(f"Lora number of GFLOPs: {flops_lora / 1e9}")

model.cuda()
model_without_ddp = model

### Lora rank 150

In [None]:
print("Param savings = ", (n_parameters_orig - n_parameters_lora)/n_parameters_orig*100.0, " % ")
print("FLOPs savings = ", (flops - flops_lora)/flops*100.0, " % ")

### Lora 300

In [None]:
print("Param savings = ", (n_parameters_orig - n_parameters_lora)/n_parameters_orig*100.0, " % ")
print("FLOPs savings = ", (flops - flops_lora)/flops*100.0, " % ")

In [None]:
super_model = model

In [None]:
# model

In [None]:
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)

In [None]:
LORA_SELECTOR = 0
LORA_RANK_DICT = {
    'layers.0.blocks.0.attn': [9,   12, 13, 49],
    'layers.0.blocks.1.attn': [28,  35, 38, 49],
    'layers.1.blocks.0.attn': [24,  32, 39, 49],
    'layers.1.blocks.1.attn': [19,  22, 23, 49],
    'layers.2.blocks.0.attn': [16,  18, 19, 49],
    'layers.2.blocks.1.attn': [20,  22, 22, 49],
    'layers.2.blocks.2.attn': [22,  26, 30, 49],
    'layers.2.blocks.3.attn': [22,  25, 26, 49],
    'layers.2.blocks.4.attn': [24,  24, 25, 49],
    'layers.2.blocks.5.attn': [23,  24, 24, 49],
    'layers.3.blocks.0.attn': [20,  21, 22, 49],
    'layers.3.blocks.1.attn': [16,  16, 17, 49]
    }

In [None]:
LORA_SELECTOR = 2
LORA_RANK_DICT = {
    'layers.0.blocks.0.attn': [300, 325, 350, 384],
    'layers.0.blocks.1.attn': [300, 325, 350, 384],
    'layers.1.blocks.0.attn': [300, 325, 350, 384],
    'layers.1.blocks.1.attn': [300, 325, 350, 384],
    'layers.2.blocks.0.attn': [300, 325, 350, 384],
    'layers.2.blocks.1.attn': [300, 325, 350, 384],
    'layers.2.blocks.2.attn': [300, 325, 350, 384],
    'layers.2.blocks.3.attn': [300, 325, 350, 384],
    'layers.2.blocks.4.attn': [300, 325, 350, 384],
    'layers.2.blocks.5.attn': [300, 325, 350, 384],
    'layers.3.blocks.0.attn': [300, 325, 350, 384],
    'layers.3.blocks.1.attn': [300, 325, 350, 384]
    }


In [None]:
# attn_md = get_attn(2,5)
# print(get_attn(2,5).input)

In [None]:
# print(list(list(model.children())[2][0].children())[0][1])
# print(get_attn(2,5))

In [None]:
# import importlib
# # import models
# import sys
# importlib.reload(sys.modules['models'])
# from models import build_model

In [None]:
def get_block(my_model, layer_id, block_id):
    # print(list(list(model.children())[2][layer_id].children())[0][block_id])
    return list(list(my_model.children())[2][layer_id].children())[0][block_id]

In [None]:
def get_attn(my_model, layer_id, block_id):
    block = get_block(my_model, layer_id, block_id)
    return list(block.children())[1]

In [None]:
# print(get_attn(super_model, 0, 0))

In [None]:
# print(get_block(super_model, 0, 0))

In [None]:
learning_rate = 1e-02

In [None]:
LAYER_DEPTHS = [2, 2, 6, 2]
NUM_LAYERS = len(LAYER_DEPTHS)
NUM_HEADS = [ 3, 6, 12, 24 ]
H = 224
W = 224
B = 1
L = 224 * 224
window_size = 7


# i = 0
# C = 96 * 2**i
# num_heads = NUM_HEADS[i]
# dim = C

In [None]:
all_lora_attns = []
for layer_id in range(NUM_LAYERS):
    layer_attns = []
    C = 96 * 2**layer_id
    num_heads = NUM_HEADS[layer_id]
    dim = C
    for block_id in range(LAYER_DEPTHS[layer_id]):
        if layer_id>=3:
            param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
            lora_attn = LORA_WindowAttention3(dim, LORA_RANK_DICT[param_name][LORA_SELECTOR],
            to_2tuple(window_size), num_heads, param_name)
            lora_attn.load_pretrained_weights(super_model)
            lora_attn.init_low_rank_approx_weights()
            lora_attn.cuda()
            layer_attns.append(lora_attn)
        else:
            layer_attns.append(None)
    all_lora_attns.append(layer_attns)
    

In [None]:
tmp = LORA_WindowAttention2(dim, LORA_RANK_DICT[param_name][LORA_SELECTOR],
            to_2tuple(window_size), num_heads, param_name)

In [None]:
tmp.load_state_dict(torch.load("layers.3.blocks.0.attn.pth")['state_dict'])
tmp.eval()

In [None]:
tmp.state_dict()

### Save LoRA Students to disk

In [None]:
for layer_id in range(NUM_LAYERS):
    for block_id in range(LAYER_DEPTHS[layer_id]):
        if block_id%2 == 0:
            param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
            torch.save({
                'state_dict': all_lora_attns[layer_id][block_id].state_dict()},
                param_name+".pth"
            )

In [None]:
all_lora_attns = []
for layer_id in range(NUM_LAYERS):
    layer_attns = []
    C = 96 * 2**layer_id
    num_heads = NUM_HEADS[layer_id]
    dim = C
    for block_id in range(LAYER_DEPTHS[layer_id]):
        if block_id%2 == 0 and layer_id>=2:
            param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
            lora_attn = LORA_WindowAttention2(dim, LORA_RANK_DICT[param_name][LORA_SELECTOR],
            to_2tuple(window_size), num_heads, param_name)
            lora_attn.load_state_dict(torch.load(param_name+".pth")['state_dict'])
            lora_attn.cuda()
            layer_attns.append(lora_attn)
        else:
            layer_attns.append(None)
    all_lora_attns.append(layer_attns)

In [None]:
def load_lora_weights(super_model):
    new_sm_sd = copy.deepcopy(super_model.state_dict())
    for layer_id in range(NUM_LAYERS):
        for block_id in range(LAYER_DEPTHS[layer_id]):
            if block_id%2 == 0 and layer_id>=2:
                param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
                # for name in params_names_list:
                # print(new_sd)
                new_sm_sd[param_name+'.lora_k.weight'] = all_lora_attns[layer_id][block_id].state_dict()['lora_k.weight']
                new_sm_sd[param_name+'.lora_v.weight'] = all_lora_attns[layer_id][block_id].state_dict()['lora_v.weight']
                new_sm_sd[param_name+'.lora_rpb.weight'] = all_lora_attns[layer_id][block_id].state_dict()['lora_rpb.weight']
                
                super_model.load_state_dict(new_sm_sd)
    return super_model

In [None]:
# param_name+'.lora_k'

In [None]:
super_model = load_lora_weights(super_model)

In [None]:
# super_model.state_dict()['layers.3.blocks.0.attn.lora_k']

In [None]:
# all_lora_attns[0][0].state_dict()
# [p[1] for p in all_lora_attns[0][0].named_parameters() if "lora" in p[0]]

In [None]:
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# # create a loss function
# criterion = torch.nn.MSELoss()

### Student teacher training

In [10]:
# @torch.no_grad()
def teacher_validate(config, data_loader, model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()

    # attn_loss_meters = []
    # for layer_id in range(NUM_LAYERS):
    #     block_loss = []
    #     for block_id in range(LAYER_DEPTHS[layer_id]):
    #         if block_id%2 == 0:
    #             atn_loss_meter = AverageMeter()
    #             block_loss.append(atn_loss_meter)
    #         else:
    #             block_loss.append(None)
    #     attn_loss_meters.append(block_loss)

    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=False)
        target = target.cuda(non_blocking=False)

        # compute output
        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
                output = model(images)

        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        
        # Train student
        # for layer_id in range(NUM_LAYERS):
        #     for block_id in range(LAYER_DEPTHS[layer_id]):
        #         if block_id%2 == 0:
        #             lora_md = all_lora_attns[layer_id][block_id]
        #             teacher_attn = get_attn(super_model, layer_id, block_id)
        #             lora_md.forward(teacher_attn.input)
        #             lora_md.do_backward(teacher_attn.output, idx % config.PRINT_FREQ == 0)
        #             lora_md.do_lr_step()
        #             attn_loss_meters[layer_id][block_id].update(lora_md.loss.item(), 1)
                    

        # Train student
        # for layer_id in range(NUM_LAYERS):
        #     for block_id in range(LAYER_DEPTHS[layer_id]):
        #         if block_id%2 == 0:
        #             lora_md = all_lora_attns[layer_id][block_id]
        #             teacher_attn = get_attn(super_model, layer_id, block_id)
        #             param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"

        #             # with open(param_name+'.ndarray',mode='ba+') as f:
        #             #     teacher_attn.input.cpu().numpy().tofile(f)
        #             #     teacher_attn.output.cpu().numpy().tofile(f)
        #             lora_md.forward(teacher_attn.input)
        #             lora_md.do_backward(teacher_attn.output, idx % config.PRINT_FREQ == 0)
        #             lora_md.do_lr_step()



        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f'Test: [{idx}/{len(data_loader)}]\t'
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                f'Mem {memory_used:.0f}MB')
            # for layer_id in range(NUM_LAYERS):
            #     for block_id in range(LAYER_DEPTHS[layer_id]):
            #         if block_id%2 == 0:
            #             param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
            #             print(param_name, ": Loss : ", attn_loss_meters[layer_id][block_id].val, " : ", attn_loss_meters[layer_id][block_id].avg, 
            #                 " LR : ", all_lora_attns[layer_id][block_id].lr_scheduler.get_last_lr())
        
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg

In [None]:
logger.info("Start teacher-student training")
start_time = time.time()
# for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
for epoch in range(0, 10):
    # if not config.TEST.SEQUENTIAL:
    # data_loader_train.sampler.set_epoch(epoch)

    # train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler=lr_scheduler,
    #                 loss_scaler=None)
    print(" Epoch : ", epoch)
    acc1, acc5, loss = teacher_validate(config, data_loader_train, super_model)
    for layer_id in range(NUM_LAYERS):
        for block_id in range(LAYER_DEPTHS[layer_id]):
            if block_id%2 == 0:
                avg_loss = all_lora_attns[layer_id][block_id].avg_loss / len(data_loader_train)
                # all_lora_attns[layer_id][block_id].do_lr_step()
                # all_lora_attns[layer_id][block_id].do_lr_step(avg_loss)
                all_lora_attns[layer_id][block_id].avg_loss = 0

    # acc1, acc5, loss = validate(config, data_loader_val, model)
    # logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
    # max_accuracy = max(max_accuracy, acc1)
    # logger.info(f'Max accuracy: {max_accuracy:.2f}%')

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str))

In [None]:
# acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

### Teacher inference

### Only Layer 3

#### Only layer 3 Low rank: 350

In [None]:
logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

#### Low rank: 384

In [None]:
logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

Low rank: 768

In [None]:
logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

In [None]:
super_model.state_dict()['layers.' + str(2) + ".blocks." + str(0) + ".attn.q_wUSprime"].shape

### Layer 2 and  3

#### Low rank: 350

In [None]:
logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

#### Low Rank 325

In [None]:
logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

### All Layers 

#### Low rank 325

In [None]:
logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

#### Low rank 300

In [None]:
logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

#### Low rank 150

In [None]:
logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

### BMM

In [11]:
# logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, model)

 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
[32m[2022-12-07 23:27:43 swin_tiny_patch4_window7_224_resisc45][0m[33m(<ipython-input-10-d2eb1b709886> 77)[0m: INFO Test: [0/788]	Time 0.307 (0.307)	Loss 0.0000 (0.0000)	Acc@1 100.000 (100.000)	Acc@5 100.000 (100.000)	Mem 314MB
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn fwd44
 Doing attn 

KeyboardInterrupt: 