# The following blocks of codes are for pruning purpose.
### The difference between this and the original pruning experiment is that the training data is all randomly generated

In [1]:
import os
from copy import deepcopy

import torch
import torch.nn as nn
from torchvision import transforms
import cv2
import numpy as np
import matplotlib.cm as cm
from src.utils.plotting import make_matching_figure
from pathlib import Path
import torch_pruning as tp
from src.loftr.backbone import build_backbone
from src.loftr.backbone.resnet_fpn import BasicBlock

from einops.einops import rearrange

from src.loftr.utils.position_encoding import PositionEncodingSine
from src.loftr.loftr_module import LocalFeatureTransformer, FinePreprocess
from src.loftr.utils.coarse_matching import CoarseMatching
from src.loftr.utils.fine_matching import FineMatching
from src.loftr import LoFTR, default_cfg

In [2]:
# The default config uses dual-softmax.
# The outdoor and indoor models share the same config.
# You can change the default values like thr and coarse_match_type.
_default_cfg = deepcopy(default_cfg)
_default_cfg['coarse']['temp_bug_fix'] = True  # set to False when using the old ckpt
original_backbone = build_backbone(_default_cfg)

In [3]:
from collections import OrderedDict
backbone_weight = OrderedDict()
remain_weight = torch.load("./weights/indoor_ds_new.ckpt")['state_dict']
for k in list(remain_weight.keys()):
    if k.startswith('matcher.backbone.'):
        backbone_weight[k.replace('matcher.backbone.', '', 1)] = remain_weight.pop(k)
original_backbone.load_state_dict(backbone_weight)
new_backbone = deepcopy(original_backbone)

In [4]:
def prune_model(model):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 1, 480, 640) )
    def prune_conv(conv, amount=0.2):
        #weight = conv.weight.detach().cpu().numpy()
        #out_channels = weight.shape[0]
        #L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
        #num_pruned = int(out_channels * pruned_prob)
        #pruning_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
        strategy = tp.strategy.L1Strategy()
        pruning_index = strategy(conv.weight, amount=amount)
        plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
        plan.exec()
    
    block_prune_probs = [0.05, 0.05, 0.1, 0.1, 0.1, 0.1]
    blk_id = 0
    for m in model.modules():
        if isinstance( m, BasicBlock ):
            prune_conv( m.conv1, block_prune_probs[blk_id] )
            prune_conv( m.conv2, block_prune_probs[blk_id] )
            blk_id+=1
    return model

transform = transforms.Compose([
    transforms.RandomCrop((480,640)),
    transforms.RandomHorizontalFlip(p=0.5)
])
    
def get_random_img(img_list):
    img = cv2.imread(img_list[np.random.randint(0,len(img_list))], cv2.IMREAD_GRAYSCALE)
    #img = cv2.resize(img, (640, 480))
    img = torch.from_numpy(img)[None][None] / 255.     #return an image tensor
    img = transform(img)
    return img

In [5]:
#prune model
#prune_model(new_backbone)
new_backbone = torch.load('./temp_backbone/untrain.pth')

In [6]:
#create the retrain img list
img_list = []
for path in Path('/home/cvte-vm/Datasets/ScanNet/scannet_test_1500').rglob('*.jpg'):
    img_list.append(str(path))
for path in Path('/home/cvte-vm/Datasets/Megadepth/megadepth_test_1500').rglob('*.jpg'):
    img_list.append(str(path))
    
#knowledge distillation
#freeze original model
for name, param in original_backbone.named_parameters():                
    param.requires_grad = False
original_backbone = original_backbone.cuda()
original_backbone.eval()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(new_backbone.parameters(), lr=0.02, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95, last_epoch=-1, verbose=False)

num_epoch = 4
num_step = 300
batch_size = 8

best_model_wts = deepcopy(new_backbone.state_dict())
def iterative_pruning(prune_time, criterion, optimizer, scheduler, original_backbone, new_backbone):
    for time in range(prune_time):
        print()
        print('**************************')
        print(f"starting the {time} prune")
        #prune model
        #prune_model(new_backbone)
        new_backbone = new_backbone.cuda()
        new_backbone.train()
        epoch_loss_old = 100
        num_param = sum(p.numel() for p in new_backbone.parameters())
        print(f"total parameters for backbone is now {num_param} after pruning, now start retrain")
        #start retrain
        for i in range(num_epoch):
            print(f"now learning rate becomes {optimizer.param_groups[0]['lr']}")
            running_loss = 0.0
            for j in range(num_step):    
                #first prepare data batch
                img = get_random_img(img_list)             #first img in a batch
                #for bs in range(batch_size-1):
                    #img1 = get_random_img(img_list)
                    #img = torch.cat([img, img1], dim=0)    #concatenate in batch dimension, now img is a batch
                img = torch.randint(0,256,(8,1,480,640))/255.0
                #print(img*255)
                img = img.cuda()
                #finding loss
                optimizer.zero_grad()
                (layer2_label, layer4_label) = original_backbone(img) #soft label from teacher
                (layer2_student, layer4_student) = new_backbone(img)  #student prediction
                loss1 = criterion(layer2_student, layer2_label)
                loss2 = criterion(layer4_student, layer4_label)
                total_loss = loss1+loss2
                total_loss.backward()
                optimizer.step()        
                if j%10 == 0:
                    print('step'+str(j)+' loss is {:.4f} '.format(total_loss))

                #calculating loss to check training result
                running_loss += total_loss.item() * batch_size
            epoch_loss = running_loss/(num_step*batch_size)
            print('*******epoch loss is {:.4f} '.format(epoch_loss))

            if epoch_loss < epoch_loss_old:    #save if loss gets smaller
                epoch_loss_old = epoch_loss
                best_model_wts = deepcopy(new_backbone.state_dict())
            if epoch_loss < 0.07:             #good enough, start next prune
                torch.save(new_backbone, '/home/cvte-vm/Deep_Feature_Extract/LoFTR/temp_backbone/backbones'+str(time)+'.pth')
                torch.save(new_backbone.state_dict(), '/home/cvte-vm/Deep_Feature_Extract/LoFTR/temp_backbone/backbones'+str(time)+'.dict')
                for g in optimizer.param_groups:
                    g['lr'] = 0.0003*(2**time)
                break                         
            lr_scheduler.step()               #decay the learning rate for next epoch
        
        for g in optimizer.param_groups:      #reset learning rate for next prune
            g['lr'] = 0.0003*(2**time)
        if epoch_loss > 0.1:                 #if the loss cannot be optimized anymore, then stop pruning
            print(f"can only prune {time+1} time, cannot continue")
            new_backbone.load_state_dict(best_model_wts)
            return new_backbone
        else:                                  #when 0.07<loss<0.10, you still save the model
            torch.save(new_backbone, '/home/cvte-vm/Deep_Feature_Extract/LoFTR/temp_backbone/backbones'+str(time)+'.pth')
            torch.save(new_backbone.state_dict(), '/home/cvte-vm/Deep_Feature_Extract/LoFTR/temp_backbone/backbones'+str(time)+'.dict')
    new_backbone.load_state_dict(best_model_wts)
    return new_backbone

In [7]:
new_backbone = iterative_pruning(1, criterion, optimizer, lr_scheduler, original_backbone, new_backbone)


**************************
starting the 0 prune
total parameters for backbone is now 2520887 after pruning, now start retrain
now learning rate becomes 0.02
step0 loss is 0.7674 
step10 loss is 0.1481 
step20 loss is 0.0814 
step30 loss is 0.0668 
step40 loss is 0.0604 
step50 loss is 0.0558 
step60 loss is 0.0525 
step70 loss is 0.0501 
step80 loss is 0.0476 
step90 loss is 0.0451 
step100 loss is 0.0428 
step110 loss is 0.0411 
step120 loss is 0.0391 
step130 loss is 0.0368 
step140 loss is 0.0353 
step150 loss is 0.0333 
step160 loss is 0.0320 
step170 loss is 0.0303 
step180 loss is 0.0289 
step190 loss is 0.0272 
step200 loss is 0.0260 
step210 loss is 0.0253 
step220 loss is 0.0235 
step230 loss is 0.0226 
step240 loss is 0.0215 
step250 loss is 0.0208 
step260 loss is 0.0199 
step270 loss is 0.0189 
step280 loss is 0.0187 
step290 loss is 0.0179 
*******epoch loss is 0.0615 


In [8]:
torch.save(new_backbone, '/home/cvte-vm/Deep_Feature_Extract/LoFTR/temp_backbone/random2.pth')
torch.save(new_backbone.state_dict(), '/home/cvte-vm/Deep_Feature_Extract/LoFTR/temp_backbone/random2.dict')

# The following blocks of code is only for visualization, you don't need to run it

In [None]:
for m in original_backbone.modules():
    #print(m)
    if isinstance( m, BasicBlock ):
        print(m)
        print('............................')

In [None]:
#original
pytorch_total_params = sum(p.numel() for p in original_backbone.parameters())
pytorch_total_params

In [11]:
#new
pytorch_total_params = sum(p.numel() for p in new_backbone.parameters())
pytorch_total_params
#4882176 for first prune
#4111846 for second prune
#3528448 for third prune
#3258448 for fourth prune
#3103365 for fifth prune
#2772528 for sixth prune
#2520887 for seventh prune
#2338495 for eighth prune
#2193625 for ninth prune
#2080770 for tenth prune

2520887

In [None]:
pytorch_total_params1 = sum(p.numel() for p in matcher.parameters())
pytorch_total_params1

In [None]:
#original weight
for name, param in original_backbone.named_parameters():
    print(name)
    print(param)
    print('..................')

In [None]:
#new weight
for name, param in new_backbone.named_parameters():
    print(name)
    print(param)
    print('..................')

In [None]:
#original
original_backbone

In [None]:
new_backbone

In [12]:
torch.rand(1,2,3)

tensor([[[0.1381, 0.0192, 0.1413],
         [0.9916, 0.0400, 0.6787]]])

In [4]:
random_backbone = torch.load('./temp_backbone/random.pth')
random_backbone

ResNetFPN_8_2(
  (conv1): Conv2d(1, 73, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(73, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(96, 73, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(73, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(96, 73, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2

In [5]:
seventh_backbone = torch.load('./temp_backbone/seventh_prune.pth')
seventh_backbone

ResNetFPN_8_2(
  (conv1): Conv2d(1, 73, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(73, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(96, 73, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(73, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(96, 73, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2

In [6]:
untrain_backbone = torch.load('./temp_backbone/untrain.pth')
untrain_backbone

ResNetFPN_8_2(
  (conv1): Conv2d(1, 73, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(73, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(96, 73, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(73, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(96, 73, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm2

In [5]:
(torch.randint(0,255,(8,1,480,640))/255.0).shape

torch.Size([8, 1, 480, 640])

In [8]:
torch.randint(0,2,(4,4))

tensor([[1, 1, 1, 1],
        [1, 0, 1, 1],
        [1, 0, 0, 1],
        [0, 0, 1, 1]])

In [3]:
2**6/127

0.5039370078740157