In [1]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
from scipy.special import softmax
import torch
import argparse
from logger import create_logger
import os


from utils import load_checkpoint, load_pretrained, NativeScalerWithGradNormCount
from config import get_config
from data import build_loader
from models import build_model
from lr_scheduler import build_scheduler
from optimizer import build_optimizer

from main import train_one_epoch, validate, throughput

from config import get_only_config
import json
import copy
import math
import time

import datetime
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.utils import accuracy, AverageMeter

  from .autonotebook import tqdm as notebook_tqdm


Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.
To use FusedLAMB or FusedAdam, please install apex.


In [32]:
import sys
import importlib
importlib.reload(sys.modules['main'])
from main import validate

In [2]:
config_path = 'configs/swin/swin_tiny_patch4_window7_224_resisc45.yaml'
config = get_only_config(config_path)

=> merge config from configs/swin/swin_tiny_patch4_window7_224_resisc45.yaml


In [29]:
config.defrost()
config.OUTPUT = "/afs/ece.cmu.edu/usr/bmarimut/Private/output"
# config.MODEL.PRETRAINED = "/afs/ece.cmu.edu/usr/ashwinve/Public/ckpt_epoch_29_6.pth"
config.MODEL.PRETRAINED = "/afs/ece.cmu.edu/usr/ashwinve/Public/golden_resisc45.pth"
config.MODEL.RESUME = "/afs/ece.cmu.edu/usr/ashwinve/Public/golden_resisc45.pth"
config.DATA.CACHE_MODE = 'no'
config.DATA.DATA_PATH = './data/RESISC45/'
config.DATA.ZIP_MODE = True
config.PRINT_FREQ = 120
config.DATA.BATCH_SIZE = 4
config.freeze()
os.makedirs(config.OUTPUT, exist_ok=True)
logger = create_logger(output_dir=config.OUTPUT, name=f"{config.MODEL.NAME}")

In [4]:
model = build_model(config)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)

_IncompatibleKeys(missing_keys=['layers.0.blocks.0.attn.lora_k.weight', 'layers.0.blocks.0.attn.lora_v.weight', 'layers.0.blocks.0.attn.lora_rpb.weight', 'layers.1.blocks.0.attn.lora_k.weight', 'layers.1.blocks.0.attn.lora_v.weight', 'layers.1.blocks.0.attn.lora_rpb.weight', 'layers.2.blocks.0.attn.lora_k.weight', 'layers.2.blocks.0.attn.lora_v.weight', 'layers.2.blocks.0.attn.lora_rpb.weight', 'layers.2.blocks.2.attn.lora_k.weight', 'layers.2.blocks.2.attn.lora_v.weight', 'layers.2.blocks.2.attn.lora_rpb.weight', 'layers.2.blocks.4.attn.lora_k.weight', 'layers.2.blocks.4.attn.lora_v.weight', 'layers.2.blocks.4.attn.lora_rpb.weight', 'layers.3.blocks.0.attn.lora_k.weight', 'layers.3.blocks.0.attn.lora_v.weight', 'layers.3.blocks.0.attn.lora_rpb.weight'], unexpected_keys=[])

In [6]:
# load_pretrained(config, model, logger)

In [7]:
# # Freeze specific layers for downstream task training
# model_named_params = list(model.named_parameters())
# num_params = len(model_named_params)
# for param_iter in range(num_params):
#     param_name, param = model_named_params[param_iter]
#     if "lora" not in param_name:
#         param.requires_grad = False

In [8]:
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"number of params: {n_parameters}")
if hasattr(model, 'flops'):
    flops = model.flops()
    logger.info(f"number of GFLOPs: {flops / 1e9}")

model.cuda()
model_without_ddp = model
super_model = model

[32m[2022-12-04 23:01:52 swin_tiny_patch4_window7_224_resisc45][0m[33m(<ipython-input-8-ae1c34bdd2f3> 2)[0m: INFO number of params: 52624074
[32m[2022-12-04 23:01:52 swin_tiny_patch4_window7_224_resisc45][0m[33m(<ipython-input-8-ae1c34bdd2f3> 5)[0m: INFO number of GFLOPs: 4.423600896


In [9]:
# model

In [10]:
dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)

In [11]:
LORA_SELECTOR = 0
LORA_RANK_DICT = {
    'layers.0.blocks.0.attn': [9,   12, 13, 49],
    'layers.0.blocks.1.attn': [28,  35, 38, 49],
    'layers.1.blocks.0.attn': [24,  32, 39, 49],
    'layers.1.blocks.1.attn': [19,  22, 23, 49],
    'layers.2.blocks.0.attn': [16,  18, 19, 49],
    'layers.2.blocks.1.attn': [20,  22, 22, 49],
    'layers.2.blocks.2.attn': [22,  26, 30, 49],
    'layers.2.blocks.3.attn': [22,  25, 26, 49],
    'layers.2.blocks.4.attn': [24,  24, 25, 49],
    'layers.2.blocks.5.attn': [23,  24, 24, 49],
    'layers.3.blocks.0.attn': [20,  21, 22, 49],
    'layers.3.blocks.1.attn': [16,  16, 17, 49]
    }

In [12]:
# attn_md = get_attn(2,5)
# print(get_attn(2,5).input)

In [13]:
# print(list(list(model.children())[2][0].children())[0][1])
# print(get_attn(2,5))

In [14]:
# import importlib
# # import models
# import sys
# importlib.reload(sys.modules['models'])
# from models import build_model

In [15]:
learning_rate = 1e-02

In [16]:
class LORA_WindowAttention2(torch.nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, lora_rank, window_size, num_heads, param_name, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = torch.nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = torch.nn.Dropout(attn_drop)
        self.proj = torch.nn.Linear(dim, dim)
        self.proj_drop = torch.nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = torch.nn.Softmax(dim=-1)
        
        # Defining LORA parameters
        # self.lora_k = torch.nn.Parameter(torch.FloatTensor(self.window_size[0] * self.window_size[1], lora_rank), requires_grad=True)
        # self.lora_v = torch.nn.Parameter(torch.FloatTensor(self.window_size[0] * self.window_size[1], lora_rank), requires_grad=True)
        # self.lora_rpb = torch.nn.Parameter(torch.FloatTensor(self.window_size[0] * self.window_size[1], lora_rank), requires_grad=True)
        
        self.lora_k = torch.nn.Linear(head_dim * self.window_size[0] * self.window_size[1], head_dim * lora_rank, bias=False)
        self.lora_v = torch.nn.Linear(head_dim * self.window_size[0] * self.window_size[1], head_dim * lora_rank, bias=False)
        self.lora_rpb = torch.nn.Linear(self.window_size[0] * self.window_size[1] * self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1] * lora_rank, bias=False)

        # self.lora_k.weight.data.fill_(0.01)
        # self.lora_v.weight.data.fill_(0.01)
        # self.lora_rpb.weight.data.fill_(0.01)

        self.lora_rank = lora_rank
        # torch.nn.init.ones_(self.lora_k)
        # torch.nn.init.ones_(self.lora_v)
        # torch.nn.init.ones_(self.lora_rpb)
        # torch.nn.init.xavier_normal_(self.lora_k)
        # torch.nn.init.xavier_normal_(self.lora_v)
        # torch.nn.init.xavier_normal_(self.lora_rpb)

        # optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=0.1)
        # self.optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate)
        lora_params = [p[1] for p in self.named_parameters() if "lora" in p[0]]
        # print(lora_params)
        self.optimizer = torch.optim.Adam(lora_params, lr=learning_rate)
        # self.optimizer = torch.optim.AdamW(filter(lambda p[1]: "lora" in p[0], self.named_parameters()), lr=learning_rate)

        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', patience=1700)
        # self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=1500, gamma=0.1)
        # create a loss function
        self.criterion = torch.nn.MSELoss()
        # self.criterion = torch.nn.L1Loss()

        self.param_name = param_name
        self.avg_loss = 0

        
    
    def do_backward(self, target, print_log):
        self.optimizer.zero_grad()
        # self.loss = self.criterion(self.output, target)
        self.loss = torch.sqrt(self.criterion(self.output*100, target*100))
        self.loss.backward()
        self.avg_loss += self.loss.item()
        self.optimizer.step()


    def do_lr_step(self):
        self.lr_scheduler.step(self.loss)

    
    def load_pretrained_weights(self, super_model):
        new_sd = copy.deepcopy(self.state_dict())
        # for name in params_names_list:
        # print(new_sd)
        new_sd['qkv.weight'] = super_model.state_dict()[self.param_name+'.qkv.weight']
        new_sd['qkv.bias'] = super_model.state_dict()[self.param_name+'.qkv.bias']
        
        new_sd['relative_position_bias_table'] = super_model.state_dict()[self.param_name+'.relative_position_bias_table']
        
        self.load_state_dict(new_sd)

        # model_named_params = list(self.named_parameters())
        # num_params = len(model_named_params)
        # for param_iter in range(num_params):
        #     param_name, param = model_named_params[param_iter]
        #     if "lora" not in param_name:
        #         param.requires_grad = False


    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        # attn = (q @ k.transpose(-2, -1))
        # k_lora = k.transpose(-2, -1) @ self.lora_k

        flattened_k = k.transpose(-2, -1).reshape(B_, self.num_heads, -1)
        k_lora = self.lora_k(flattened_k).reshape(B_, self.num_heads, C // self.num_heads, self.lora_rank)

        attn = (q @ k_lora)

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        
        # attn = attn + relative_position_bias.unsqueeze(0)
        
        rpb = relative_position_bias.unsqueeze(0)

        # rpb_lora = rpb @ self.lora_rpb
        flattened_rpb = rpb.reshape(self.num_heads, -1)
        
        # rpb_lora: num_heads * Window_size * window_size * lora_rank
        rpb_lora = self.lora_rpb(flattened_rpb).view(self.num_heads, self.window_size[0] * self.window_size[1], self.lora_rank)

        attn = attn + rpb_lora

        if mask is not None:
            # print("lora_k: ", self.lora_k.shape)
            # print("lora_v: ", self.lora_v.shape)
            # print("lora_rpb: ", self.lora_rpb.shape)
            # print("attn: ", attn.shape)
            # print("mask: ", mask.shape)
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        # v_lora = (v.transpose(-2, -1) @ self.lora_v).transpose(-2, -1)

        flattened_v_lora = v.transpose(-2, -1).reshape(B_, self.num_heads, -1)
        v_lora = self.lora_v(flattened_v_lora).reshape(B_, self.num_heads, C // self.num_heads, self.lora_rank).transpose(-2, -1)

        x = (attn @ v_lora).transpose(1, 2).reshape(B_, N, C)
        
        x = self.proj(x)
        x = self.proj_drop(x)
        self.output = x
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

    def flops(self, N):
        # calculate flops for 1 window with token length of N
        flops = 0
        # qkv = self.qkv(x)
        flops += N * self.dim * 3 * self.dim
        
        # attn = (q @ k.transpose(-2, -1))
        # flops += self.num_heads * N * (self.dim // self.num_heads) * N
        #  x = (attn @ v)
        # flops += self.num_heads * N * N * (self.dim // self.num_heads)

        # TODO: Figure out FLOPS compute
        # # transform : k @ lora_k
        # flops += k.shape[0] * k.shape[1] * 2 * k.shape[2] * k.shape[3] * k_lora.shape[3]
        # # attn: q @ k_lora
        # flops += q.shape[0] * q.shape[1] * 2 * q.shape[2] * q.shape[3] * k_lora.shape[3]
        # # transform : v @ v_lora
        # flops += v.shape[0] * v.shape[1] * 2 * v.shape[2] * v.shape[3] * v_lora.shape[3]
        # # op : attn @ v_lora
        # flops += attn.shape[0] * attn.shape[1] * 2 * attn.shape[2] * attn.shape[3] * v_lora.shape[3]

        # x = self.proj(x)
        flops += N * self.dim * self.dim
        return flops

In [17]:
def get_block(my_model, layer_id, block_id):
    # print(list(list(model.children())[2][layer_id].children())[0][block_id])
    return list(list(my_model.children())[2][layer_id].children())[0][block_id]

In [18]:
def get_attn(my_model, layer_id, block_id):
    block = get_block(my_model, layer_id, block_id)
    return list(block.children())[1]

In [19]:
# print(get_attn(super_model, 0, 0))

In [20]:
# print(get_block(super_model, 0, 0))

In [21]:
LAYER_DEPTHS = [2, 2, 6, 2]
NUM_LAYERS = len(LAYER_DEPTHS)
NUM_HEADS = [ 3, 6, 12, 24 ]
H = 224
W = 224
B = 1
L = 224 * 224
window_size = 7


# i = 0
# C = 96 * 2**i
# num_heads = NUM_HEADS[i]
# dim = C

In [22]:
all_lora_attns = []
for layer_id in range(NUM_LAYERS):
    layer_attns = []
    C = 96 * 2**layer_id
    num_heads = NUM_HEADS[layer_id]
    dim = C
    for block_id in range(LAYER_DEPTHS[layer_id]):
        if block_id%2 == 0:
            param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
            lora_attn = LORA_WindowAttention2(dim, LORA_RANK_DICT[param_name][LORA_SELECTOR],
            to_2tuple(window_size), num_heads, param_name)
            lora_attn.cuda()
            lora_attn.load_pretrained_weights(super_model)
            layer_attns.append(lora_attn)
        else:
            layer_attns.append(None)
    all_lora_attns.append(layer_attns)
    

### Save LoRA Students to disk

In [199]:
for layer_id in range(NUM_LAYERS):
    for block_id in range(LAYER_DEPTHS[layer_id]):
        if block_id%2 == 0:
            param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
            torch.save({
                'state_dict': all_lora_attns[layer_id][block_id].state_dict()},
                param_name+".pth"
            )

In [22]:
all_lora_attns = []
for layer_id in range(NUM_LAYERS):
    layer_attns = []
    C = 96 * 2**layer_id
    num_heads = NUM_HEADS[layer_id]
    dim = C
    for block_id in range(LAYER_DEPTHS[layer_id]):
        if block_id%2 == 0:
            param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
            lora_attn = LORA_WindowAttention2(dim, LORA_RANK_DICT[param_name][LORA_SELECTOR],
            to_2tuple(window_size), num_heads, param_name)
            lora_attn.load_state_dict(torch.load(param_name+".pth")['state_dict'])
            lora_attn.cuda()
            layer_attns.append(lora_attn)
        else:
            layer_attns.append(None)
    all_lora_attns.append(layer_attns)

In [23]:
def load_lora_weights(super_model):
    new_sm_sd = copy.deepcopy(super_model.state_dict())
    for layer_id in range(NUM_LAYERS):
        for block_id in range(LAYER_DEPTHS[layer_id]):
            if block_id%2 == 0:
                param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
                # for name in params_names_list:
                # print(new_sd)
                new_sm_sd[param_name+'.lora_k.weight'] = all_lora_attns[layer_id][block_id].state_dict()['lora_k.weight']
                new_sm_sd[param_name+'.lora_v.weight'] = all_lora_attns[layer_id][block_id].state_dict()['lora_v.weight']
                new_sm_sd[param_name+'.lora_rpb.weight'] = all_lora_attns[layer_id][block_id].state_dict()['lora_rpb.weight']
                
                super_model.load_state_dict(new_sm_sd)
    return super_model

In [24]:
# param_name+'.lora_k'

In [25]:
super_model = load_lora_weights(super_model)

In [28]:
# super_model.state_dict()['layers.3.blocks.0.attn.lora_k']

In [29]:
# all_lora_attns[0][0].state_dict()
# [p[1] for p in all_lora_attns[0][0].named_parameters() if "lora" in p[0]]

In [30]:
# optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# # create a loss function
# criterion = torch.nn.MSELoss()

### Student teacher training

In [197]:
# @torch.no_grad()
def teacher_validate(config, data_loader, model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()

    attn_loss_meters = []
    for layer_id in range(NUM_LAYERS):
        block_loss = []
        for block_id in range(LAYER_DEPTHS[layer_id]):
            if block_id%2 == 0:
                atn_loss_meter = AverageMeter()
                block_loss.append(atn_loss_meter)
            else:
                block_loss.append(None)
        attn_loss_meters.append(block_loss)

    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=False)
        target = target.cuda(non_blocking=False)

        # compute output
        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
                output = model(images)

        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        
        # Train student
        for layer_id in range(NUM_LAYERS):
            for block_id in range(LAYER_DEPTHS[layer_id]):
                if block_id%2 == 0:
                    lora_md = all_lora_attns[layer_id][block_id]
                    teacher_attn = get_attn(super_model, layer_id, block_id)
                    lora_md.forward(teacher_attn.input)
                    lora_md.do_backward(teacher_attn.output, idx % config.PRINT_FREQ == 0)
                    # lora_md.avg_loss = acc1
                    lora_md.do_lr_step()
                    attn_loss_meters[layer_id][block_id].update(lora_md.loss.item(), 1)
                    

        # Train student
        # for layer_id in range(NUM_LAYERS):
        #     for block_id in range(LAYER_DEPTHS[layer_id]):
        #         if block_id%2 == 0:
        #             lora_md = all_lora_attns[layer_id][block_id]
        #             teacher_attn = get_attn(super_model, layer_id, block_id)
        #             param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"

        #             # with open(param_name+'.ndarray',mode='ba+') as f:
        #             #     teacher_attn.input.cpu().numpy().tofile(f)
        #             #     teacher_attn.output.cpu().numpy().tofile(f)
        #             lora_md.forward(teacher_attn.input)
        #             lora_md.do_backward(teacher_attn.output, idx % config.PRINT_FREQ == 0)
        #             lora_md.do_lr_step()



        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f'Test: [{idx}/{len(data_loader)}]\t'
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                f'Mem {memory_used:.0f}MB')
            for layer_id in range(NUM_LAYERS):
                for block_id in range(LAYER_DEPTHS[layer_id]):
                    if block_id%2 == 0:
                        param_name = 'layers.' + str(layer_id) + ".blocks." + str(block_id) + ".attn"
                        print(param_name, ": Loss : ", attn_loss_meters[layer_id][block_id].val, " : ", attn_loss_meters[layer_id][block_id].avg, 
                            " LR : ", all_lora_attns[layer_id][block_id].optimizer.param_groups[0]['lr'])
    # lr_scheduler.get_last_lr()
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg

### Training Loop

In [198]:
logger.info("Start teacher-student training")
start_time = time.time()
# for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
for epoch in range(0, 10):
    # if not config.TEST.SEQUENTIAL:
    # data_loader_train.sampler.set_epoch(epoch)

    # train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler=lr_scheduler,
    #                 loss_scaler=None)
    print(" Epoch : ", epoch)
    acc1, acc5, loss = teacher_validate(config, data_loader_train, super_model)
    # for layer_id in range(NUM_LAYERS):
    #     for block_id in range(LAYER_DEPTHS[layer_id]):
    #         if block_id%2 == 0:
    #             avg_loss = all_lora_attns[layer_id][block_id].avg_loss / len(data_loader_train)
    #             # all_lora_attns[layer_id][block_id].do_lr_step()
    #             # all_lora_attns[layer_id][block_id].do_lr_step(avg_loss)
    #             all_lora_attns[layer_id][block_id].avg_loss = 0

    # acc1, acc5, loss = validate(config, data_loader_val, model)
    # logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
    # max_accuracy = max(max_accuracy, acc1)
    # logger.info(f'Max accuracy: {max_accuracy:.2f}%')

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str))

[32m[2022-12-04 21:56:31 swin_tiny_patch4_window7_224_resisc45][0m[33m(<ipython-input-198-0dbba6bb9711> 1)[0m: INFO Start teacher-student training
 Epoch :  0
[32m[2022-12-04 21:56:32 swin_tiny_patch4_window7_224_resisc45][0m[33m(<ipython-input-197-2d8bca5164a9> 78)[0m: INFO Test: [0/3150]	Time 0.841 (0.841)	Loss 0.0740 (0.0740)	Acc@1 100.000 (100.000)	Acc@5 100.000 (100.000)	Mem 1573MB
layers.0.blocks.0.attn : Loss :  41.49988555908203  :  41.49988555908203  LR :  0.01
layers.1.blocks.0.attn : Loss :  30.26732635498047  :  30.26732635498047  LR :  0.01
layers.2.blocks.0.attn : Loss :  32.3546257019043  :  32.3546257019043  LR :  0.01
layers.2.blocks.2.attn : Loss :  33.35563659667969  :  33.35563659667969  LR :  0.01
layers.2.blocks.4.attn : Loss :  50.34062194824219  :  50.34062194824219  LR :  0.01
layers.3.blocks.0.attn : Loss :  39.98451614379883  :  39.98451614379883  LR :  0.01
[32m[2022-12-04 21:56:42 swin_tiny_patch4_window7_224_resisc45][0m[33m(<ipython-input-197-2

KeyboardInterrupt: 

### Student Teacher inference

In [None]:
# acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

In [None]:
logger.info("Start teacher-student inference")
start_time = time.time()
acc1, acc5, loss = teacher_validate(config, data_loader_val, super_model)

### Train Teacher

In [27]:
lora_params = [p[1] for p in super_model.named_parameters() if "lora" in p[0]]

optimizer = torch.optim.SGD(lora_params, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
                              lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)

loss_scaler = NativeScalerWithGradNormCount()

lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))

criterion = torch.nn.CrossEntropyLoss()

max_accuracy = 0.0

In [34]:
# optimizer = build_optimizer(config, model)

logger.info("Start training")
start_time = time.time()
for epoch in range(1, 10):
    if not config.TEST.SEQUENTIAL:
        data_loader_train.sampler.set_epoch(epoch)

    train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
                    loss_scaler, logger)
    

    acc1, acc5, loss = validate(config, data_loader_val, model, logger)
    logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
    max_accuracy = max(max_accuracy, acc1)
    logger.info(f'Max accuracy: {max_accuracy:.2f}%')

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time {}'.format(total_time_str))

[32m[2022-12-04 23:17:36 swin_tiny_patch4_window7_224_resisc45][0m[33m(<ipython-input-34-58531364c9af> 3)[0m: INFO Start training
[32m[2022-12-04 23:17:38 swin_tiny_patch4_window7_224_resisc45][0m[33m(main.py 282)[0m: INFO Train: [1/30][0/393]	eta 0:13:05 lr 0.0000997261	 wd 0.0000500000	time 1.9990 (1.9990)	loss 4.1720 (4.1720)	grad_norm 20412088.0000 (20412088.0000)	loss_scale 65536.0000 (65536.0000)	mem 7896MB
[32m[2022-12-04 23:19:49 swin_tiny_patch4_window7_224_resisc45][0m[33m(main.py 282)[0m: INFO Train: [1/30][120/393]	eta 0:05:01 lr 0.0000997261	 wd 0.0000500000	time 1.1269 (1.1034)	loss 8.2928 (6.7201)	grad_norm 25304230.0000 (31591084.0000)	loss_scale 65536.0000 (65536.0000)	mem 7899MB
[32m[2022-12-04 23:22:03 swin_tiny_patch4_window7_224_resisc45][0m[33m(main.py 282)[0m: INFO Train: [1/30][240/393]	eta 0:02:49 lr 0.0000997261	 wd 0.0000500000	time 1.0970 (1.1077)	loss 5.7804 (6.7508)	grad_norm 22607152.0000 (29714714.0000)	loss_scale 65536.0000 (65536.0000)	

KeyboardInterrupt: 

In [33]:
acc1, acc5, loss = validate(config, data_loader_val, model, logger)
logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
max_accuracy = max(max_accuracy, acc1)
logger.info(f'Max accuracy: {max_accuracy:.2f}%')

[32m[2022-12-04 23:16:43 swin_tiny_patch4_window7_224_resisc45][0m[33m(main.py 337)[0m: INFO Test: [0/99]	Time 0.925 (0.925)	Loss 3.4437 (3.4437)	Acc@1 4.688 (4.688)	Acc@5 46.875 (46.875)	Mem 7863MB
[32m[2022-12-04 23:17:20 swin_tiny_patch4_window7_224_resisc45][0m[33m(main.py 343)[0m: INFO  * Acc@1 3.079 Acc@5 14.302
[32m[2022-12-04 23:17:20 swin_tiny_patch4_window7_224_resisc45][0m[33m(<ipython-input-33-d2c43735fa0c> 2)[0m: INFO Accuracy of the network on the 6300 test images: 3.1%
[32m[2022-12-04 23:17:20 swin_tiny_patch4_window7_224_resisc45][0m[33m(<ipython-input-33-d2c43735fa0c> 4)[0m: INFO Max accuracy: 3.08%
