In [1]:
import numpy as np
import os
from retinanet import model
from retinanet import coco_eval
from retinanet.dataloader import CocoDataset_inOrder,rehearsal_DataSet, collater, Resizer, AspectRatioBasedSampler, Augmenter, Normalizer
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import collections
import torch
root_path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/'
method = 'w_distillation'
data_split = '15+1'
start_round = 1
batch_size = 1

checkpoint_epoch = 50

def checkDir(path):
    """check whether directory exists or not.If not, then create it 
    """
    if not os.path.isdir(path):
        os.mkdir(path)
def get_checkpoint_path(method, now_round, epoch, data_split ="None"):
    global root_path
    checkDir(os.path.join(root_path, 'model', method, 'round{}'.format(now_round)))
    checkDir(os.path.join(root_path, 'model', method, 'round{}'.format(now_round), data_split))
    path = os.path.join(root_path, 'model', method, 'round{}'.format(now_round), data_split,'voc_retinanet_{}_checkpoint.pt'.format(epoch))
    return path


def readCheckpoint(method, now_round, epoch, data_split, retinanet, optimizer = None, scheduler = None):
    print('readcheckpoint at Round{} Epoch{}'.format(now_round, epoch))
    prev_checkpoint = torch.load(get_checkpoint_path(method, now_round, epoch, data_split))
    retinanet.load_state_dict(prev_checkpoint['model_state_dict'])
    if optimizer != None:
        optimizer.load_state_dict(prev_checkpoint['optimizer_state_dict'])
    if scheduler != None:
        scheduler.load_state_dict(prev_checkpoint['scheduler_state_dict'])
    
dataset_train = CocoDataset_inOrder(os.path.join(root_path, 'DataSet', 'VOC2012'), set_name='TrainVoc2012', dataset = 'voc',
                                    transform=transforms.Compose([Normalizer(), Resizer()]),
                                   data_split=data_split, start_round=start_round)
retinanet = model.resnet50(num_classes=dataset_train.num_classes(), pretrained=True)
retinanet.cuda()

readCheckpoint(method, start_round, checkpoint_epoch,data_split, retinanet)

loading annotations into memory...
Done (t=0.09s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!
{'id': [[1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 15, 16, 17, 18, 19], [6], [20], [14], [12], [13]], 'name': [['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person'], ['train'], ['sheep'], ['sofa'], ['pottedplant'], ['tvmonitor']]}
dataloader class_num = 15
readcheckpoint at Round1 Epoch50


In [2]:
from retinanet import MAS

mas = MAS.MAS(retinanet, dataset_train)
mas.calculate_importance()

Computing MAS


In [5]:
b = {"123":1,"456":10}
print(b.keys())

dict_keys(['123', '456'])


In [None]:
import pickle 

with open('/home/deeplab307/Documents/Anaconda/Shiang/CL/model/w_distillation/round1/15+1/MAS.pickle') as f:
    precision pickle.load(f)

In [3]:
mas.precision_matrices

{'conv1.weight': tensor([[[[0.2647, 0.2653, 0.2660,  ..., 0.2661, 0.2652, 0.2647],
           [0.2662, 0.2671, 0.2679,  ..., 0.2673, 0.2659, 0.2650],
           [0.2675, 0.2690, 0.2702,  ..., 0.2697, 0.2677, 0.2664],
           ...,
           [0.2686, 0.2700, 0.2720,  ..., 0.2760, 0.2747, 0.2722],
           [0.2672, 0.2681, 0.2688,  ..., 0.2725, 0.2729, 0.2718],
           [0.2658, 0.2662, 0.2663,  ..., 0.2685, 0.2692, 0.2693]],
 
          [[0.2637, 0.2644, 0.2650,  ..., 0.2652, 0.2641, 0.2635],
           [0.2648, 0.2658, 0.2667,  ..., 0.2661, 0.2645, 0.2634],
           [0.2660, 0.2673, 0.2691,  ..., 0.2688, 0.2664, 0.2647],
           ...,
           [0.2662, 0.2674, 0.2699,  ..., 0.2748, 0.2733, 0.2704],
           [0.2650, 0.2658, 0.2669,  ..., 0.2709, 0.2711, 0.2699],
           [0.2639, 0.2642, 0.2646,  ..., 0.2667, 0.2672, 0.2672]],
 
          [[0.2530, 0.2537, 0.2541,  ..., 0.2550, 0.2548, 0.2545],
           [0.2539, 0.2545, 0.2552,  ..., 0.2553, 0.2545, 0.2540],
        

In [16]:
import torch.nn as nn
import pickle
import os
import torch
class MAS(object):
    def __init__(self, model: nn.Module, dataloader):
        self.model = model
        self.dataloader = dataloader
    def load_importance(self, path):
        pickle_name = "MAS.pickle"
        with open(os.path.join(path, pickle_name), "rb") as f:
            self.precision_matrices = pickle.load(f)
    def calculate_importance(self):
        print('Computing MAS')
        
        origin_status = self.model.distill_feature
        self.model.distill_feature = True
        precision_matrices = {}
        for n, p in self.model.named_parameters():
            if "prev_model" not in n:
                precision_matrices[n] = p.clone().detach().fill_(0)
        self.model.train()
        self.model.freeze_bn()
        num_data = len(self.dataloader)
        for idx, data in enumerate(self.dataloader):
            with torch.cuda.device(0):
                self.model.zero_grad()

                features, regression, classification = self.model(data['img'].permute(2, 0, 1).cuda().float().unsqueeze(dim=0))
                                                          
                classification = torch.norm(classification)
                regression = torch.norm(regression)
                #regression *= float(classification / regression) 
                
                #print("Classification:{:1.5f} | Regression:{:1.5f}".format(float(classification), float(regression)))
                output = classification + regression
                output = classification
                output.backward()
                                          
                for n, p in self.model.named_parameters():
                    if p.grad != None:
                        precision_matrices[n].data += p.grad.abs() / num_data

        self.model.distill_feature = origin_status
        self.precision_matrices = precision_matrices

    def penalty(self, model: nn.Module):
        assert model.prev_model != None
        loss = 0
        old_params = {n:p for n,p in model.prev_model.named_parameters()}
        used_names = [name for name in self.precision_matrices.keys()]
        for n, p in model.named_parameters():
            if "classificationModel.output" not in n and "prev_model" not in n and n in used_names:
                _loss = self.precision_matrices[n] * (p - old_params[n]) ** 2
                loss += _loss.sum()
#             else:
#                 _loss = self.precision_matrices[n] * (p[] - self.old_params[n]) ** 2                                                           
#                 self.output.weight.data[i*20 + j,:,:,:] = old_output.weight.data[i*4 + part_idx,:,:,:]
#                 self.output.bias.data[i*20 + j] = old_output.bias.data[i*4 + part_idx]     
        return loss

In [17]:
mas = MAS(retinanet, dataset_train)
mas.calculate_importance()

Computing MAS


In [20]:
mas.precision_matrices

{'conv1.weight': tensor([[[[177.0055, 178.1228, 178.2866,  ..., 178.4211, 177.5034, 176.6324],
           [177.4981, 178.5931, 178.7591,  ..., 178.5535, 177.5217, 176.7374],
           [178.5961, 180.1426, 180.7525,  ..., 180.4848, 178.8314, 177.6571],
           ...,
           [178.9204, 180.6105, 181.7114,  ..., 184.3036, 183.1820, 180.9999],
           [176.9341, 177.8875, 178.4373,  ..., 180.6802, 180.6986, 179.5204],
           [175.9596, 176.6727, 176.9409,  ..., 178.4650, 178.7221, 178.2795]],
 
          [[174.9468, 175.8566, 175.8772,  ..., 175.7624, 175.1313, 174.5938],
           [175.6991, 176.6791, 176.8339,  ..., 176.6178, 175.7208, 175.0639],
           [176.7031, 178.0967, 178.6270,  ..., 178.2943, 176.8773, 175.9861],
           ...,
           [177.6493, 178.8876, 179.6546,  ..., 181.7527, 180.7531, 178.9516],
           [175.9956, 176.8333, 177.2005,  ..., 179.1051, 179.0056, 177.9986],
           [175.3943, 176.1378, 176.4686,  ..., 177.7597, 177.8555, 177.2425]],


In [19]:
import pickle 

with open('/home/deeplab307/Documents/Anaconda/Shiang/CL/model/w_distillation/round1/15+1/MAS_logits.pickle', 'wb') as f:
    pickle.dump(mas.precision_matrices, f)

In [3]:
mas = MAS(retinanet, dataset_train)
mas.load_importance("/home/deeplab307/Documents/Anaconda/Shiang/CL/model/w_distillation/round1/15+1/")

In [4]:
retinanet16.prev_model = retinanet

In [5]:
mas.penalty(retinanet16)

tensor(0.0656, device='cuda:0', grad_fn=<AddBackward0>)

In [6]:
import random
random.sample(range(5000),5)

[3542, 180, 943, 4379, 2910]