# The following blocks of codes are for pruning purpose

In [None]:
import os
if not os.path.isdir("./temp_backbone"):
    os.mkdir("./temp_backbone")
from copy import deepcopy

import torch
import torch.nn as nn
from torchvision import transforms
import cv2
import numpy as np
import matplotlib.cm as cm
from src.utils.plotting import make_matching_figure
from pathlib import Path
import torch_pruning as tp
from src.loftr.backbone import build_backbone
from src.loftr.backbone.resnet_fpn import BasicBlock

from einops.einops import rearrange

from src.loftr.utils.position_encoding import PositionEncodingSine
from src.loftr.loftr_module import LocalFeatureTransformer, FinePreprocess
from src.loftr.utils.coarse_matching import CoarseMatching
from src.loftr.utils.fine_matching import FineMatching
from src.loftr import LoFTR, default_cfg

In [None]:
# The default config uses dual-softmax.
# The outdoor and indoor models share the same config.
# You can change the default values like thr and coarse_match_type.
_default_cfg = deepcopy(default_cfg)
_default_cfg['coarse']['temp_bug_fix'] = True  # set to False when using the old ckpt
original_backbone = build_backbone(_default_cfg)

In [None]:
from collections import OrderedDict
backbone_weight = OrderedDict()
remain_weight = torch.load("./weights/indoor_ds_new.ckpt")['state_dict']
for k in list(remain_weight.keys()):
    if k.startswith('matcher.backbone.'):
        backbone_weight[k.replace('matcher.backbone.', '', 1)] = remain_weight.pop(k)
original_backbone.load_state_dict(backbone_weight)
new_backbone = deepcopy(original_backbone)

In [None]:
def prune_model(model):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 1, 480, 640) )
    def prune_conv(conv, amount=0.2):
        #weight = conv.weight.detach().cpu().numpy()
        #out_channels = weight.shape[0]
        #L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
        #num_pruned = int(out_channels * pruned_prob)
        #pruning_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
        strategy = tp.strategy.L1Strategy()
        pruning_index = strategy(conv.weight, amount=amount)
        plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
        plan.exec()
    
    block_prune_probs = [0.05, 0.05, 0.1, 0.1, 0.1, 0.1]
    blk_id = 0
    for m in model.modules():
        if isinstance( m, BasicBlock ):
            prune_conv( m.conv1, block_prune_probs[blk_id] )
            prune_conv( m.conv2, block_prune_probs[blk_id] )
            blk_id+=1
    return model

transform = transforms.Compose([
    transforms.RandomCrop((480,640)),
    transforms.RandomHorizontalFlip(p=0.5)
])
    
def get_random_img(img_list):
    img = cv2.imread(img_list[np.random.randint(0,len(img_list))], cv2.IMREAD_GRAYSCALE)
    #img = cv2.resize(img, (640, 480))
    img = torch.from_numpy(img)[None][None] / 255.     #return an image tensor
    img = transform(img)
    return img

In [None]:
#prune model
#prune_model(new_backbone)
#new_backbone = torch.load('./temp_backbone/seventh_prune.pth')

In [None]:
#you will need to first download the scannet 1500 and megadepth 1500 dataset
#then change the directory below to the datasets location
#create the retrain img list
img_list = []
for path in Path('/your_location_to/ScanNet/scannet_test_1500').rglob('*.jpg'):
    img_list.append(str(path))
for path in Path('/your_location_to/Megadepth/megadepth_test_1500').rglob('*.jpg'):
    img_list.append(str(path))
    
#knowledge distillation
#freeze original model
for name, param in original_backbone.named_parameters():                
    param.requires_grad = False
original_backbone = original_backbone.cuda()
original_backbone.eval()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(new_backbone.parameters(), lr=0.0003, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95, last_epoch=-1, verbose=False)

num_epoch = 4
num_step = 300
batch_size = 8

best_model_wts = deepcopy(new_backbone.state_dict())
def iterative_pruning(prune_time, criterion, optimizer, scheduler, original_backbone, new_backbone):
    for time in range(prune_time):
        print()
        print('**************************')
        print(f"starting the {time} prune")
        #prune model
        prune_model(new_backbone)
        new_backbone = new_backbone.cuda()
        new_backbone.train()
        epoch_loss_old = 100
        num_param = sum(p.numel() for p in new_backbone.parameters())
        print(f"total parameters for backbone is now {num_param} after pruning, now start retrain")
        #start retrain
        for i in range(num_epoch):
            print(f"now learning rate becomes {optimizer.param_groups[0]['lr']}")
            running_loss = 0.0
            for j in range(num_step):    
                #first prepare data batch
                img = get_random_img(img_list)             #first img in a batch
                for bs in range(batch_size-1):
                    img1 = get_random_img(img_list)
                    img = torch.cat([img, img1], dim=0)    #concatenate in batch dimension, now img is a batch
                img = img.cuda()
                #finding loss
                optimizer.zero_grad()
                (layer2_label, layer4_label) = original_backbone(img) #soft label from teacher
                (layer2_student, layer4_student) = new_backbone(img)  #student prediction
                loss1 = criterion(layer2_student, layer2_label)
                loss2 = criterion(layer4_student, layer4_label)
                total_loss = loss1+loss2
                total_loss.backward()
                optimizer.step()        
                if j%10 == 0:
                    print('step'+str(j)+' loss is {:.4f} '.format(total_loss))

                #calculating loss to check training result
                running_loss += total_loss.item() * batch_size
            epoch_loss = running_loss/(num_step*batch_size)
            print('*******epoch loss is {:.4f} '.format(epoch_loss))

            if epoch_loss < epoch_loss_old:    #save if loss gets smaller
                epoch_loss_old = epoch_loss
                best_model_wts = deepcopy(new_backbone.state_dict())
            if epoch_loss < 0.07:             #good enough, start next prune
                torch.save(new_backbone, './temp_backbone/backbones'+str(time)+'.pth')
                torch.save(new_backbone.state_dict(), './temp_backbone/backbones'+str(time)+'.dict')
                for g in optimizer.param_groups:
                    g['lr'] = 0.0003*(2**time)
                break                         
            lr_scheduler.step()               #decay the learning rate for next epoch
        
        for g in optimizer.param_groups:      #reset learning rate for next prune
            g['lr'] = 0.0003*(2**time)
        if epoch_loss > 0.1:                 #if the loss cannot be optimized anymore, then stop pruning
            print(f"can only prune {time+1} time, cannot continue")
            new_backbone.load_state_dict(best_model_wts)
            return new_backbone
        else:                                  #when 0.07<loss<0.10, you still save the model
            torch.save(new_backbone, './temp_backbone/backbones'+str(time)+'.pth')
            torch.save(new_backbone.state_dict(), './temp_backbone/backbones'+str(time)+'.dict')
    new_backbone.load_state_dict(best_model_wts)
    return new_backbone

In [None]:
new_backbone = iterative_pruning(10, criterion, optimizer, lr_scheduler, original_backbone, new_backbone)

# The following blocks of code is only for visualization, you don't need to run it

In [None]:
for m in original_backbone.modules():
    #print(m)
    if isinstance( m, BasicBlock ):
        print(m)
        print('............................')

In [None]:
#original
pytorch_total_params = sum(p.numel() for p in original_backbone.parameters())
pytorch_total_params

In [None]:
#new
pytorch_total_params = sum(p.numel() for p in new_backbone.parameters())
pytorch_total_params
#4882176 for first prune
#4111846 for second prune
#3528448 for third prune
#3258448 for fourth prune
#3103365 for fifth prune
#2772528 for sixth prune
#2520887 for seventh prune
#2338495 for eighth prune
#2193625 for ninth prune
#2080770 for tenth prune

In [None]:
pytorch_total_params1 = sum(p.numel() for p in matcher.parameters())
pytorch_total_params1

In [None]:
#original weight
for name, param in original_backbone.named_parameters():
    print(name)
    print(param)
    print('..................')

In [None]:
#new weight
for name, param in new_backbone.named_parameters():
    print(name)
    print(param)
    print('..................')

In [None]:
#original
original_backbone

In [None]:
new_backbone