In [11]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import torchvision
from torchvision import datasets, transforms

import os
import os.path
from collections import OrderedDict

#import matplotlib.pyplot as plt
import numpy as np

import time
from copy import deepcopy

In [12]:
## Define AlexNet model
def compute_conv_output_size(Lin,kernel_size,stride=1,padding=0,dilation=1):
    return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1))
class AlexNet(nn.Module):
    def __init__(self,taskcla):
        self.fbegin = 15
        self.klen = 3
        super(AlexNet, self).__init__()
        self.act=OrderedDict()
        self.map =[]
        self.ksize=[]
        self.in_channel =[]
        self.map.append(32)
        self.conv1 = nn.Conv2d(3, 64, 4, bias=False)
        self.bn1 = nn.BatchNorm2d(64, track_running_stats=False)
        s=compute_conv_output_size(32,4)
        s=s//2
        self.ksize.append(4)
        self.in_channel.append(3)
        self.map.append(s)
        self.conv2 = nn.Conv2d(64, 128, 3, bias=False)
        self.bn2 = nn.BatchNorm2d(128, track_running_stats=False)
        s=compute_conv_output_size(s,3)
        s=s//2
        self.ksize.append(3)
        self.in_channel.append(64)
        self.map.append(s)
        self.conv3 = nn.Conv2d(128, 256, 2, bias=False)
        self.bn3 = nn.BatchNorm2d(256, track_running_stats=False)
        s=compute_conv_output_size(s,2)
        s=s//2
        self.smid=s
        self.ksize.append(2)
        self.in_channel.append(128)
        self.map.append(256*self.smid*self.smid)
        self.maxpool=torch.nn.MaxPool2d(2)
        self.relu=torch.nn.ReLU()
        self.drop1=torch.nn.Dropout(0.2)
        self.drop2=torch.nn.Dropout(0.5)

        self.fc1 = nn.Linear(256*self.smid*self.smid,2048, bias=False)
        self.bn4 = nn.BatchNorm1d(2048, track_running_stats=False)
        self.fc2 = nn.Linear(2048,2048, bias=False)
        self.bn5 = nn.BatchNorm1d(2048, track_running_stats=False)
        self.map.extend([2048])

        self.taskcla = taskcla
        self.fc3=torch.nn.ModuleList()
        for t,n in self.taskcla:
            self.fc3.append(torch.nn.Linear(2048,n,bias=False))

    def forward(self, x):
        bsz = deepcopy(x.size(0))
        self.act['conv1']=x
        x = self.conv1(x)
        x = self.maxpool(self.drop1(self.relu(self.bn1(x))))

        self.act['conv2']=x
        x = self.conv2(x)
        x = self.maxpool(self.drop1(self.relu(self.bn2(x))))


        self.act['conv3']=x
        x = self.conv3(x)
        x = self.maxpool(self.drop2(self.relu(self.bn3(x))))


        x=x.view(bsz,-1)
        self.act['fc1']=x
        x = self.fc1(x)
        x = self.drop2(self.relu(self.bn4(x)))


        self.act['fc2']=x
        x = self.fc2(x)
        x = self.drop2(self.relu(self.bn5(x)))


        y=[]
        for t,i in self.taskcla:
            y.append(self.fc3[t](x))

        return y

In [13]:
class LeNet(nn.Module):
    def __init__(self,taskcla):
        self.fbegin = 4
        self.klen = 2
        super(LeNet, self).__init__()
        self.act=OrderedDict()
        self.map =[]
        self.ksize=[]
        self.in_channel =[]

        self.map.append(32)
        self.conv1 = nn.Conv2d(3, 20, 5, bias=False, padding=2)

        s=compute_conv_output_size(32,5,1,2)
        s=compute_conv_output_size(s,3,2,1)
        self.ksize.append(5)
        self.in_channel.append(3)
        self.map.append(s)
        self.conv2 = nn.Conv2d(20, 50, 5, bias=False, padding=2)

        s=compute_conv_output_size(s,5,1,2)
        s=compute_conv_output_size(s,3,2,1)
        self.ksize.append(5)
        self.in_channel.append(20)
        self.smid=s
        self.map.append(50*self.smid*self.smid)
        self.maxpool=torch.nn.MaxPool2d(3,2,padding=1)
        self.relu=torch.nn.ReLU()
        self.drop1=torch.nn.Dropout(0)
        self.drop2=torch.nn.Dropout(0)
        self.lrn = torch.nn.LocalResponseNorm(4,0.001/9.0,0.75,1)

        self.fc1 = nn.Linear(50*self.smid*self.smid,800, bias=False)
        self.fc2 = nn.Linear(800,500, bias=False)
        self.map.extend([800])

        self.taskcla = taskcla
        self.fc3=torch.nn.ModuleList()
        for t,n in self.taskcla:
            self.fc3.append(torch.nn.Linear(500,n,bias=False))

    def forward(self, x):
        bsz = deepcopy(x.size(0))
        self.act['conv1']=x
        x = self.conv1(x)
        x = self.maxpool(self.drop1(self.lrn(self.relu(x))))

        self.act['conv2']=x
        x = self.conv2(x)
        x = self.maxpool(self.drop1(self.lrn (self.relu(x))))

        x=x.reshape(bsz,-1)
        self.act['fc1']=x
        x = self.fc1(x)
        x = self.drop2(self.relu(x))

        self.act['fc2']=x
        x = self.fc2(x)
        x = self.drop2(self.relu(x))

        y=[]
        for t,i in self.taskcla:
            y.append(self.fc3[t](x))

        return y


In [14]:
class VGG16(nn.Module):
    def __init__(self, taskcla):
        super(VGG16, self).__init__()
        self.ksize = []
        self.in_channel = []
        self.map = []

        self.map.append(32)
        self.act=OrderedDict()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
        self.ksize.append(3)
        self.in_channel.append(3)
        s=compute_conv_output_size(32,3,padding=1)
        self.map.append(s)


        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.ksize.append(3)
        self.in_channel.append(64)
        s=compute_conv_output_size(s,3,padding=1)
        s=s//2
        self.map.append(s)

        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())

        self.ksize.append(3)
        self.in_channel.append(64)
        s=compute_conv_output_size(s,3,padding=1)
        self.map.append(s)

        self.layer4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.ksize.append(3)
        self.in_channel.append(128)
        s=compute_conv_output_size(s,3,padding=1)
        s=s//2
        self.map.append(s)

        self.layer5 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
        self.ksize.append(3)
        self.in_channel.append(128)
        s=compute_conv_output_size(s,3,padding=1)
        self.map.append(s)

        self.layer6 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
        self.ksize.append(3)
        self.in_channel.append(256)
        s=compute_conv_output_size(s,3,padding=1)
        self.map.append(s)


        self.layer7 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.ksize.append(3)
        self.in_channel.append(256)
        s=compute_conv_output_size(s,3,padding=1)
        s=s//2
        self.map.append(s)

        self.layer8 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU())
        self.ksize.append(3)
        self.in_channel.append(256)
        s=compute_conv_output_size(s,3,padding=1)


        self.klen = len(self.ksize)
        self.fbegin = 35

        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512*s*s, 4096),
            nn.ReLU())

        self.map.append(512*s*s)

        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())

        self.map.append(4096)
        self.taskcla = taskcla
        self.fc2=torch.nn.ModuleList()
        for t,n in self.taskcla:
            self.fc2.append(torch.nn.Linear(4096,n,bias=False))

    def forward(self, x):
        #print(x.shape)
        self.act['conv1']=x
        out = self.layer1(x)
        #print(out.shape)
        self.act['conv2']=out
        out = self.layer2(out)
        #print(out.shape)
        self.act['conv3']=out
        out = self.layer3(out)
        #print(out.shape)
        self.act['conv4']=out
        out = self.layer4(out)
        #print(out.shape)
        self.act['conv5']=out
        out = self.layer5(out)
        #print(out.shape)
        self.act['conv6']=out
        out = self.layer6(out)
        #print(out.shape)
        self.act['conv7']=out
        out = self.layer7(out)
        #print(out.shape)
        self.act['conv8']=out
        out = self.layer8(out)
        #print(out.shape)


        #out = self.layer1(out)

        #print(self.map)

        out = out.reshape(out.size(0), -1)
        self.act['fc']=out
        out = self.fc(out)
        self.act['fc1']=out
        out = self.fc1(out)
        y=[]
        for t,i in self.taskcla:
            y.append(self.fc2[t](out))

        return y

In [15]:
def get_model(model):
    return deepcopy(model.state_dict())

def set_model_(model,state_dict):
    model.load_state_dict(deepcopy(state_dict))
    return

def adjust_learning_rate(optimizer, epoch, args):
    for param_group in optimizer.param_groups:
        if (epoch ==1):
            param_group['lr']=args.lr
        else:
            param_group['lr'] /= args.lr_factor

def train(args, model, device, x,y, optimizer,criterion, task_id):
    x = x.to(device)
    y = y.to(device)
    print(device)
    model.train()
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    # Loop batches
    for i in range(0,len(r),args.batch_size_train):
        if i+args.batch_size_train<=len(r): b=r[i:i+args.batch_size_train]
        else: b=r[i:]
        data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output[task_id], target)
        loss.backward()
        optimizer.step()

def train_projected(args,model,device,x,y,optimizer,criterion,feature_mat,task_id):
    x = x.to(device)
    y = y.to(device)
    model.train()
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    # Loop batches
    for i in range(0,len(r),args.batch_size_train):
        if i+args.batch_size_train<=len(r): b=r[i:i+args.batch_size_train]
        else: b=r[i:]
        data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output[task_id], target)
        loss.backward()
        # Gradient Projections
        kk = 0
        for k, (m,params) in enumerate(model.named_parameters()):
            if k<model.fbegin and len(params.size())!=1:
                sz =  params.grad.data.size(0)
                params.grad.data = params.grad.data - torch.mm(params.grad.data.view(sz,-1),\
                                                        feature_mat[kk]).view(params.size())
                kk +=1
            elif (k<model.fbegin and len(params.size())==1) and task_id !=0 :
                params.grad.data.fill_(0)


        optimizer.step()

def test(args, model, device, x, y, criterion, task_id):
    x = x.to(device)
    y = y.to(device)
    model.eval()
    total_loss = 0
    total_num = 0
    correct = 0
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    with torch.no_grad():
        # Loop batches
        for i in range(0,len(r),args.batch_size_test):
            if i+args.batch_size_test<=len(r): b=r[i:i+args.batch_size_test]
            else: b=r[i:]
            data = x[b]
            data, target = data.to(device), y[b].to(device)
            output = model(data)
            loss = criterion(output[task_id], target)
            pred = output[task_id].argmax(dim=1, keepdim=True)

            correct    += pred.eq(target.view_as(pred)).sum().item()
            total_loss += loss.data.cpu().numpy().item()*len(b)
            total_num  += len(b)

    acc = 100. * correct / total_num
    final_loss = total_loss / total_num
    return final_loss, acc

In [16]:
def get_representation_matrix (net, device, x, y=None):
    x = x.to(device)
    # Collect activations by forward pass
    r=np.arange(x.size(0))
    np.random.shuffle(r)
    r=torch.LongTensor(r).to(device)
    b=r[0:125] # Take 125 random samples
    example_data = x[b]
    example_data = example_data.to(device)
    example_out  = net(example_data)

    #batch_list=[2*12,100,100,125,125,125,125,125,125,125,125,125,125,125,125,125]
    batch_list = 15*[20]
    mat_list=[]
    act_key=list(net.act.keys())

    #print("MAP: ")
    #print(net.map)
    #print("ksize")
    #print(net.ksize)
    #print("in-c")
    #print(net.in_channel)

    #print("ACT keys:")
    #print(act_key)

    #print(len(net.map))
    #print(len(act_key))

    for i in range(len(net.map)):
        bsz=batch_list[i]
        k=0
        if i<net.klen:
            ksz= net.ksize[i]
            s=compute_conv_output_size(net.map[i],net.ksize[i])

            mat = np.zeros((net.ksize[i]*net.ksize[i]*net.in_channel[i],s*s*bsz))
            act = net.act[act_key[i]].detach().cpu().numpy()
            #print(f"s is {s}")
            #print(f"ksz is {ksz}")
            ##print(f"shape of mat is {mat.shape}")
            #print(f"shape of act is {act.shape}")
            #print(act.shape)
            #print(net.in_channel[i])
            #print(f"done with {i}")
            for kk in range(bsz):
                for ii in range(s):
                    for jj in range(s):
                        mat[:,k]=act[kk,:,ii:ksz+ii,jj:ksz+jj].reshape(-1)
                        k +=1
            #print(mat.shape)
            mat_list.append(mat)
        else:
            act = net.act[act_key[i]].detach().cpu().numpy()
            activation = act[0:bsz].transpose()
            mat_list.append(activation)

    print('-'*30)
    print('Representation Matrix')
    print('-'*30)
    for i in range(len(mat_list)):
        print ('Layer {} : {}'.format(i+1,mat_list[i].shape))
    print('-'*30)
    return mat_list


def update_SGP (args, model, mat_list, threshold, task_id, feature_list=[], importance_list=[]):
    print ('Threshold: ', threshold)
    print(len(mat_list))
    if not feature_list:
        # After First Task
        for i in range(len(mat_list)):
            activation = mat_list[i]

            U,S,Vh = np.linalg.svd(activation, full_matrices=False)
            # criteria (Eq-1)
            sval_total = (S**2).sum()
            sval_ratio = (S**2)/sval_total

            r = np.sum(np.cumsum(sval_ratio)<threshold[i]) #+1
            # update GPM
            feature_list.append(U[:,0:r])
            # update importance (Eq-2)
            importance = ((args.scale_coff+1)*S[0:r])/(args.scale_coff*S[0:r] + max(S[0:r]))
            importance_list.append(importance)
    else:
        for i in range(len(mat_list)):
            activation = mat_list[i]
            U1,S1,Vh1=np.linalg.svd(activation, full_matrices=False)
            sval_total = (S1**2).sum()
            # Projected Representation (Eq-4)
            act_proj = np.dot(np.dot(feature_list[i],feature_list[i].transpose()),activation)
            r_old = feature_list[i].shape[1] # old GPM bases
            Uc,Sc,Vhc = np.linalg.svd(act_proj, full_matrices=False)
            importance_new_on_old = np.dot(np.dot(feature_list[i].transpose(),Uc[:,0:r_old])**2, Sc[0:r_old]**2) ## r_old no of elm s**2 fmt
            importance_new_on_old = np.sqrt(importance_new_on_old)

            act_hat = activation - act_proj
            U,S,Vh = np.linalg.svd(act_hat, full_matrices=False)
            # criteria (Eq-5)
            sval_hat = (S**2).sum()
            sval_ratio = (S**2)/sval_total
            accumulated_sval = (sval_total-sval_hat)/sval_total

            r = 0
            for ii in range (sval_ratio.shape[0]):
                if accumulated_sval < threshold[i]:
                    accumulated_sval += sval_ratio[ii]
                    r += 1
                else:
                    break
            if r == 0:
                print ('Skip Updating GPM for layer: {}'.format(i+1))
                # update importances
                importance = importance_new_on_old
                importance = ((args.scale_coff+1)*importance)/(args.scale_coff*importance + max(importance))
                importance [0:r_old] = np.clip(importance [0:r_old]+importance_list[i][0:r_old], 0, 1)
                importance_list[i] = importance # update importance
                continue
            # update GPM
            Ui=np.hstack((feature_list[i],U[:,0:r]))
            # update importance
            importance = np.hstack((importance_new_on_old,S[0:r]))
            importance = ((args.scale_coff+1)*importance)/(args.scale_coff*importance + max(importance))
            importance [0:r_old] = np.clip(importance [0:r_old]+importance_list[i][0:r_old], 0, 1)

            if Ui.shape[1] > Ui.shape[0] :
                feature_list[i]=Ui[:,0:Ui.shape[0]]
                importance_list[i] = importance[0:Ui.shape[0]]
            else:
                feature_list[i]=Ui
                importance_list[i] = importance

    print('-'*40)
    print('Gradient Constraints Summary')
    print('-'*40)
    for i in range(len(feature_list)):
        print ('Layer {} : {}/{}'.format(i+1,feature_list[i].shape[1], feature_list[i].shape[0]))
    print('-'*40)
    return feature_list, importance_list

In [17]:
import sys
# import utils
from sklearn.utils import shuffle

cf100_dir = './data/'
file_dir = './data/binary_cifar100'

In [18]:
def get(seed=0,pc_valid=0.10):
    data={}
    taskcla=[]
    size=[3,32,32]

    if not os.path.isdir(file_dir):
        os.makedirs(file_dir)

        mean=[x/255 for x in [125.3,123.0,113.9]]
        std=[x/255 for x in [63.0,62.1,66.7]]

        # CIFAR100
        dat={}
        dat['train']=datasets.CIFAR100(cf100_dir,train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)]))
        dat['test']=datasets.CIFAR100(cf100_dir,train=False,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean,std)]))
        # dat['train'] = datasets.CIFAR100(cf100_dir,train=True,download=False,transform=transforms.Compose([transforms.ToTensor()]))
        # dat['test']  = datasets.CIFAR100(cf100_dir,train=False,download=False,transform=transforms.Compose([transforms.ToTensor()]))
        for n in range(10):
            data[n]={}
            data[n]['name']='cifar100'
            data[n]['ncla']=10
            data[n]['train']={'x': [],'y': []}
            data[n]['test']={'x': [],'y': []}
        for s in ['train','test']:
            loader=torch.utils.data.DataLoader(dat[s],batch_size=1,shuffle=False)
            for image,target in loader:
                n=target.numpy()[0]
                nn=(n//10)
                data[nn][s]['x'].append(image) # 255
                data[nn][s]['y'].append(n%10)

        # "Unify" and save
        for t in data.keys():
            for s in ['train','test']:
                data[t][s]['x']=torch.stack(data[t][s]['x']).view(-1,size[0],size[1],size[2])
                data[t][s]['y']=torch.LongTensor(np.array(data[t][s]['y'],dtype=int)).view(-1)
                torch.save(data[t][s]['x'], os.path.join(os.path.expanduser(file_dir),'data'+str(t)+s+'x.bin'))
                torch.save(data[t][s]['y'], os.path.join(os.path.expanduser(file_dir),'data'+str(t)+s+'y.bin'))

    # Load binary files
    data={}
    # ids=list(shuffle(np.arange(5),random_state=seed))
    ids=list(np.arange(10))
    print('Task order =',ids)
    for i in range(10):
        data[i] = dict.fromkeys(['name','ncla','train','test'])
        for s in ['train','test']:
            data[i][s]={'x':[],'y':[]}
            data[i][s]['x']=torch.load(os.path.join(os.path.expanduser(file_dir),'data'+str(ids[i])+s+'x.bin'))
            data[i][s]['y']=torch.load(os.path.join(os.path.expanduser(file_dir),'data'+str(ids[i])+s+'y.bin'))
        data[i]['ncla']=len(np.unique(data[i]['train']['y'].numpy()))
        if data[i]['ncla']==2:
            data[i]['name']='cifar10-'+str(ids[i])
        else:
            data[i]['name']='cifar100-'+str(ids[i])

    # Validation
    for t in data.keys():
        r=np.arange(data[t]['train']['x'].size(0))
        r=np.array(shuffle(r,random_state=seed),dtype=int)
        nvalid=int(pc_valid*len(r))
        ivalid=torch.LongTensor(r[:nvalid])
        itrain=torch.LongTensor(r[nvalid:])
        data[t]['valid']={}
        data[t]['valid']['x']=data[t]['train']['x'][ivalid].clone()
        data[t]['valid']['y']=data[t]['train']['y'][ivalid].clone()
        data[t]['train']['x']=data[t]['train']['x'][itrain].clone()
        data[t]['train']['y']=data[t]['train']['y'][itrain].clone()

    # Others
    n=0
    for t in data.keys():
        taskcla.append((t,data[t]['ncla']))
        n+=data[t]['ncla']
    data['ncla']=n

    return data,taskcla,size

In [19]:
def main(args):
    tstart=time.time()
    ## Device Setting
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print (device)
    ## setup seeds
    # os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


    data,taskcla,inputsize=get(pc_valid=args.pc_valid)




    acc_matrix=np.zeros((10,10))
    criterion = torch.nn.CrossEntropyLoss()

    task_id = 0
    task_list = []
    for k,ncla in taskcla:
        # specify threshold hyperparameter
        threshold = np.array([args.gpm_eps] * 15) + task_id * np.array([args.gpm_eps_inc] * 15)

        print('*'*100)
        print('Task {:2d} ({:s})'.format(k,data[k]['name']))
        print('*'*100)
        xtrain=data[k]['train']['x']
        ytrain=data[k]['train']['y']
        xvalid=data[k]['valid']['x']
        yvalid=data[k]['valid']['y']
        xtest =data[k]['test']['x']
        ytest =data[k]['test']['y']
        task_list.append(k)

        lr = args.lr
        best_loss=np.inf
        print ('-'*40)
        print ('Task ID :{} | Learning Rate : {}'.format(task_id, lr))
        print ('-'*40)


        if task_id==0:
            if args.md == "alex":
                model = AlexNet(taskcla).to(device)
            elif args.md == "le":
                model = LeNet(taskcla).to(device)
            elif args.md == "vgg":
                model = VGG16(taskcla).to(device)

            print ('Model parameters ---')
            for k_t, (m, param) in enumerate(model.named_parameters()):
                print (k_t,m,param.shape)
            print ('-'*40)

            best_model=get_model(model)
            feature_list =[]
            importance_list = []
            optimizer = optim.SGD(model.parameters(), lr=lr)

            for epoch in range(1, args.n_epochs+1):
                # Train
                clock0=time.time()
                train(args, model, device, xtrain, ytrain, optimizer, criterion, k)
                clock1=time.time()
                tr_loss,tr_acc = test(args, model, device, xtrain, ytrain,  criterion, k)
                print('Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% | time={:5.1f}ms |'.format(epoch,\
                                                            tr_loss,tr_acc, 1000*(clock1-clock0)),end='')
                # Validate
                valid_loss,valid_acc = test(args, model, device, xvalid, yvalid,  criterion, k)
                print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, valid_acc),end='')
                # Adapt lr
                if valid_loss<best_loss:
                    best_loss=valid_loss
                    best_model=get_model(model)
                    patience=args.lr_patience
                    print(' *',end='')
                else:
                    patience-=1
                    if patience<=0:
                        lr/=args.lr_factor
                        print(' lr={:.1e}'.format(lr),end='')
                        if lr<args.lr_min:
                            print()
                            break
                        patience=args.lr_patience
                        adjust_learning_rate(optimizer, epoch, args)
                print()
            set_model_(model,best_model)
            # Test
            print ('-'*40)
            test_loss, test_acc = test(args, model, device, xtest, ytest,  criterion, k)
            print('Test: loss={:.3f} , acc={:5.1f}%'.format(test_loss,test_acc))
            # Memory and Importance Update
            mat_list = get_representation_matrix(model, device, xtrain, ytrain)
            feature_list, importance_list = update_SGP(args, model, mat_list, threshold, task_id, feature_list, importance_list)

        else:
            optimizer = optim.SGD(model.parameters(), lr=args.lr)
            feature_mat = []
            # Projection Matrix Precomputation
            for i in range(len(model.act)):
                Uf=torch.Tensor(np.dot(feature_list[i],np.dot(np.diag(importance_list[i]),feature_list[i].transpose()))).to(device)
                # print('Layer {} - Projection Matrix shape: {}'.format(i+1,Uf.shape))
                Uf.requires_grad = False
                feature_mat.append(Uf)
            # print ('-'*40)
            for epoch in range(1, args.n_epochs+1):
                # Train
                clock0=time.time()
                train_projected(args, model,device,xtrain, ytrain,optimizer,criterion,feature_mat,k)
                clock1=time.time()
                tr_loss, tr_acc = test(args, model, device, xtrain, ytrain,criterion,k)
                print('Epoch {:3d} | Train: loss={:.3f}, acc={:5.1f}% | time={:5.1f}ms |'.format(epoch,\
                                                        tr_loss, tr_acc, 1000*(clock1-clock0)),end='')
                # Validate
                valid_loss,valid_acc = test(args, model, device, xvalid, yvalid, criterion,k)
                print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, valid_acc),end='')
                # Adapt lr
                if valid_loss<best_loss:
                    best_loss=valid_loss
                    best_model=get_model(model)
                    patience=args.lr_patience
                    print(' *',end='')
                else:
                    patience-=1
                    if patience<=0:
                        lr/=args.lr_factor
                        print(' lr={:.1e}'.format(lr),end='')
                        if lr<args.lr_min:
                            print()
                            break
                        patience=args.lr_patience
                        adjust_learning_rate(optimizer, epoch, args)
                print()
            set_model_(model,best_model)
            # Test
            test_loss, test_acc = test(args, model, device, xtest, ytest,  criterion,k)
            print('Test: loss={:.3f} , acc={:5.1f}%'.format(test_loss,test_acc))
            # Memory and Importance Update
            mat_list = get_representation_matrix (model, device, xtrain, ytrain)
            feature_list, importance_list = update_SGP (args, model, mat_list, threshold, task_id, feature_list, importance_list)

        # save accuracy
        jj = 0
        for ii in np.array(task_list)[0:task_id+1]:
            xtest =data[ii]['test']['x']
            ytest =data[ii]['test']['y']
            _, acc_matrix[task_id,jj] = test(args, model, device, xtest, ytest,criterion,ii)
            jj +=1
        print('Accuracies =')
        for i_a in range(task_id+1):
            print('\t',end='')
            for j_a in range(acc_matrix.shape[1]):
                print('{:5.1f}% '.format(acc_matrix[i_a,j_a]),end='')
            print()
        # update task id
        task_id +=1
    print('-'*50)

    # Simulation Results
    # print ('Task Order : {}'.format(np.array(task_list)))
    # print("Configs: seed: {} | lr: {} | gpm_eps: {} | gpm_eps_inc: {} | scale_coff: {}".format(args.seed,args.lr,args.gpm_eps,args.gpm_eps_inc,args.scale_coff))
    print ('Final Avg Accuracy: {:5.2f}%'.format(acc_matrix[-1].mean()))
    bwt=np.mean((acc_matrix[-1]-np.diag(acc_matrix))[:-1])
    print ('Backward transfer: {:5.2f}%'.format(bwt))
    print('[Elapsed time = {:.1f} ms]'.format((time.time()-tstart)*1000))
    print('-'*50)

In [21]:

class arg_class():
    def __init__(self,batch_size_train=64, batch_size_test=64, n_epochs=1, seed=5, pc_valid=0.05, lr=0.05, momentum=0.9, lr_min=5e-05, lr_patience=6, lr_factor=2, scale_coff=10, gpm_eps=0.97, gpm_eps_inc=0.003,mod="alex"):
        self.batch_size_train=batch_size_train
        self.batch_size_test=batch_size_test
        self.n_epochs=n_epochs
        self.seed=seed
        self.pc_valid=pc_valid
        self.lr=lr
        self.momentum=momentum
        self.lr_min=lr_min
        self.lr_patience=lr_patience
        self.lr_factor=lr_factor
        self.scale_coff=scale_coff
        self.gpm_eps=gpm_eps
        self.gpm_eps_inc=gpm_eps_inc
        self.md = mod

args = arg_class(mod = "alex")

main(args)

cpu
Task order = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
****************************************************************************************************
Task  0 (cifar100-0)
****************************************************************************************************
----------------------------------------
Task ID :0 | Learning Rate : 0.05
----------------------------------------
Model parameters ---
0 conv1.weight torch.Size([64, 3, 4, 4])
1 bn1.weight torch.Size([64])
2 bn1.bias torch.Size([64])
3 conv2.weight torch.Size([128, 64, 3, 3])
4 bn2.weight torch.Size([128])
5 bn2.bias torch.Size([128])
6 conv3.weight torch.Size([256, 128, 2, 2])
7 bn3.weight torch.Size([256])
8 bn3.bias torch.Size([256])
9 fc1.weight torch.Size([2048, 1024])
10 bn4.weight torch.Size([2048])
11 bn4.bias torch.Size([2048])
12 fc2.weight torch.Size([2048, 2048])
13 bn5.weight torch.Size([2048])
14 bn5.bias torch.Size([2048])
15 fc3.0.weight torch.Size([10, 2048])
16 fc3.1.weight torch.Size([10, 2048])
17 