In [1]:
import numpy as np
import os
from retinanet import model
from retinanet import coco_eval
from retinanet.dataloader import CocoDataset_inOrder,rehearsal_DataSet, collater, Resizer, AspectRatioBasedSampler, Augmenter, Normalizer
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import collections
import torch
root_path = '/home/deeplab307/Documents/Anaconda/Shiang/CL/'
method = 'w_distillation'
data_split = '15+1'
start_round = 1
batch_size = 1

checkpoint_epoch = 50

def checkDir(path):
    """check whether directory exists or not.If not, then create it 
    """
    if not os.path.isdir(path):
        os.mkdir(path)
def get_checkpoint_path(method, now_round, epoch, data_split ="None"):
    global root_path
    checkDir(os.path.join(root_path, 'model', method, 'round{}'.format(now_round)))
    checkDir(os.path.join(root_path, 'model', method, 'round{}'.format(now_round), data_split))
    path = os.path.join(root_path, 'model', method, 'round{}'.format(now_round), data_split,'voc_retinanet_{}_checkpoint.pt'.format(epoch))
    return path


def readCheckpoint(method, now_round, epoch, data_split, retinanet, optimizer = None, scheduler = None):
    print('readcheckpoint at Round{} Epoch{}'.format(now_round, epoch))
    prev_checkpoint = torch.load(get_checkpoint_path(method, now_round, epoch, data_split))
    retinanet.load_state_dict(prev_checkpoint['model_state_dict'])
    if optimizer != None:
        optimizer.load_state_dict(prev_checkpoint['optimizer_state_dict'])
    if scheduler != None:
        scheduler.load_state_dict(prev_checkpoint['scheduler_state_dict'])
    
dataset_train = CocoDataset_inOrder(os.path.join(root_path, 'DataSet', 'VOC2012'), set_name='TrainVoc2012', dataset = 'voc',
                                    transform=transforms.Compose([Normalizer(), Resizer()]),
                                   data_split=data_split, start_round=start_round)
retinanet = model.resnet50(num_classes=dataset_train.num_classes(), pretrained=True)
retinanet.cuda()

readCheckpoint(method, start_round, checkpoint_epoch,data_split, retinanet)

loading annotations into memory...
Done (t=0.09s)
creating index...
index created!
loading annotations into memory...
Done (t=0.05s)
creating index...
index created!
{'id': [[1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 15, 16, 17, 18, 19], [6], [20], [14], [12], [13]], 'name': [['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person'], ['train'], ['sheep'], ['sofa'], ['pottedplant'], ['tvmonitor']]}
dataloader class_num = 15
readcheckpoint at Round1 Epoch50


In [18]:
model = retinanet
num_data = len(dataset_train)

i = 0
precision_matrices = {}
for n, p in model.named_parameters():
    precision_matrices[n] = p.clone().detach().fill_(0) 
for idx, data in enumerate(dataset_train):
    with torch.cuda.device(0):
        if torch.cuda.is_available():
            features, regression, classification = model([data['img'].permute(2, 0, 1).cuda().float().unsqueeze(dim=0), data['annot'].cuda().unsqueeze(dim=0)])
        else:
            print('not have gpu')
            break
        classification = torch.norm(classification)
        regression = torch.norm(regression)
        ratio = float(classification / regression) 
        regression = regression * ratio
        
        output = classification + regression
        output.backward()
        
        for n, p in model.named_parameters():                      
            precision_matrices[n].data += p.grad.abs() / num_data ## difference with EWC      

In [21]:
import torch.nn as nn
import pickle
import os
class MAS(object):
    def __init__(self, model: nn.Module, dataloader):
        self.model = model
        self.dataloader = dataloader
    def load_importance(path):
        with open(os.path.join(path, "MAS.pickle")) as f:
            self.precision_matrices = pickle.load(f)
    def calculate_importance(self):
        print('Computing MAS')
        self.model.distill_feature = True
        precision_matrices = {}
        for n, p in self.params.items():
            precision_matrices[n] = p.clone().detach().fill_(0)
        self.model.train()
        self.model.freeze_bn()
        num_data = len(self.dataloader)
        for idx, data in enumerate(self.dataloader):
            with torch.cuda.device(0):
                self.model.zero_grad()

                features, regression, classification = self.model([data['img'].permute(2, 0, 1).cuda().float().unsqueeze(dim=0))
                                                          
                classification = torch.norm(classification)
                regression = torch.norm(regression)
                regression *= float(classification / regression) 
                output = classification + regression
                output.backward()
                                          
                for n, p in self.model.named_parameters():                      
                    precision_matrices[n].data += p.grad.abs() / num_data

        self.model.distill_feature = False
        self.precision_matrices = precision_matrices

    def penalty(self, model: nn.Module):
        loss = 0
                                                                       
        for n, p in model.named_parameters():
            if "classificationModel.output" not in name:
                _loss = self.precision_matrices[n] * (p - self.old_params[n]) ** 2
                loss += _loss.sum()
#             else:
#                 _loss = self.precision_matrices[n] * (p[] - self.old_params[n]) ** 2                                                           
#                 self.output.weight.data[i*20 + j,:,:,:] = old_output.weight.data[i*4 + part_idx,:,:,:]
#                 self.output.bias.data[i*20 + j] = old_output.bias.data[i*4 + part_idx]
                
        return loss

In [23]:
mas = MAS(retinanet, dataset_train)

Computing MAS


In [25]:
mas.penalty

<bound method MAS.penalty of <__main__.MAS object at 0x7f510c767550>>