# Environment

In [None]:
from tensorflow import keras
from glob import glob
from scipy import misc
from sklearn.model_selection import KFold
import torch
from torch import nn
from os.path import isfile
import numpy as np
import matplotlib.pyplot as plt
from skimage import transform
from scipy import linalg
import scipy.io as sio
import time
import copy

#from skimage.measure import compare_ssim as ssim # Old Version
from skimage.metrics import structural_similarity as ssim # New Version

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from tqdm import tqdm
from tqdm import notebook as tq

device_id=0
torch.cuda.set_device(device_id)
device='cuda:'+str(device_id)


# Data Preparation

In [None]:
# MNIST Dataset:

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train=x_train.reshape(x_train.shape[0],-1)
x_test=x_test.reshape(x_test.shape[0],-1)


# Permuted MNIST
num_task=20
step=100
init=100

seeds=np.arange(init,num_task*step+init,step)
n_class=len(np.unique(y_train))
train_y=np.zeros((len(y_train),n_class))
for i in range(train_y.shape[0]):
    train_y[i,y_train[i]]=1
test_y=np.zeros((len(y_test),n_class))
for i in range(test_y.shape[0]):
    test_y[i,y_test[i]]=1
    
data=[[x_train,train_y]]
data_test=[[x_test,test_y]]

pixels=np.arange(x_train.shape[1])

pixel_set=[pixels]
for task in range(num_task-1):
    np.random.seed(seeds[task])
    np.random.shuffle(pixels)
    pixel_set.append(pixels)
    
    
    x=x_train[:,pixels]    
    data.append([x,train_y])
    
    x=x_test[:,pixels]    
    data_test.append([x,test_y])
# Parameters:

NUM_K_FOLDS = 1 # if > 1, parition the dataset into K held-out partitions
NUM_FOLDS_VALIDATION = 1 # Train and test on this many partitions; must be smaller than NUM_K_FOLDS
# Use NUM_K_FOLDS = 1 to train and test on the entire dataset set.

# Partition dataset into 



X_train, X_test = [], []
Y_train, Y_test = [], []
# NUM_K_FOLDS partitons:
if NUM_K_FOLDS >= 2:
    kfold = KFold(n_splits=NUM_K_FOLDS, shuffle=True)
    for task in range(num_task):
        X_currTask_train=data[task][0]
        Y_currTask_train=data[task][1]
        X_currTask_test=data_test[task][0]
        Y_currTask_test=data_test[task][1]
        for folds, (train_index, test_index) in enumerate(kfold.split(Y_currTask)):
            X_train.append(X_currTask_train[train_index]) # Need to revise for reassignment
            X_test.append(X_currTask_test[test_index]) # Need to revise for reassignment
            Y_train.append(Y_currTask_train[train_index]) # Need to revise for reassignment
            Y_test.append(Y_currTask_test[test_index]) # Need to revise for reassignment
# Full dataset:
else:
    for task in range(num_task):
        X_currTask_train=data[task][0]
        Y_currTask_train=data[task][1]
        X_currTask_test=data_test[task][0]
        Y_currTask_test=data_test[task][1]
        
        X_train.append(X_currTask_train) # Need to revise for reassignment
        X_test.append(X_currTask_test) # Need to revise for reassignment
        Y_train.append(Y_currTask_train) # Need to revise for reassignment
        Y_test.append(Y_currTask_test) # Need to revise for reassignment
X_train = np.asarray(X_train)
X_train = np.reshape(X_train,(num_task,NUM_K_FOLDS,X_train.shape[1],X_train.shape[2])) # tasks X folds
X_test = np.asarray(X_test)
X_test = np.reshape(X_test,(num_task,NUM_K_FOLDS,X_test.shape[1],X_test.shape[2])) # tasks X folds
Y_train = np.asarray(Y_train)
Y_train = np.reshape(Y_train,(num_task,NUM_K_FOLDS,Y_train.shape[1],Y_train.shape[2])) # tasks X folds
Y_test = np.asarray(Y_test)
Y_test = np.reshape(Y_test,(num_task,NUM_K_FOLDS,Y_test.shape[1],Y_test.shape[2])) # tasks X folds
print(X_train.shape)

# Hyperparameters

In [None]:
# Parameters:
num_task=20
n_class=10 # Ouput dimension

# Selection of r1, r2:
r1_byfolds = [11]
r2_byfolds = [1]

batch_size=128
# Epochs by task (task1, task2, task3 ....):
#train_epochs=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
train_epochs=[500,500,500,500,500,500,500,500,500,500,500,500,500,500,500,500,500,500,500,500]

lr_base = 1e-5
lr_cont = 1e-4
lr_drop_epoch = []
lr_drop_factor = 0.5

# Layers def:
relu = nn.ReLU()
sigmoid=nn.Sigmoid()
softmax=torch.nn.Softmax(dim=1)
Lsoftmax = nn.LogSoftmax(dim=1)
mse_loss = nn.MSELoss(reduction='mean')


# Training and Testing

In [None]:
# Two factorized intermediate layers FCN(with selector):
# Use K-fold for parameter sweep if needed.
# fw1, fw2 are concatenated instead of added.

# Parameters:


w_task=[]

# Train test sequence:
accuracy = []
for fold in range(NUM_FOLDS_VALIDATION):
    fold_accuracy = []
    if NUM_K_FOLDS <= 1:
        print("Full Training Set(no partition):")
    else:
        print("Folds Number ", fold, "/",NUM_FOLDS_VALIDATION)
    print("Epochs:",train_epochs[fold], "Batch Size:", batch_size,"R1:",r1_byfolds[fold],"R2:",r2_byfolds[fold])
    
    # Reset model every folds:
    # Intialize Weights:
    for task in range(num_task):
        r1 = r1_byfolds[fold]
        r2 = r2_byfolds[fold]
        if task==0:
            r=r1
            lr = lr_base
        else:
            r=r2
            lr = lr_cont
            
        
        
        wR1 = torch.empty(x_train.shape[1], r).to(device)
        wL1 = torch.empty(256, r).to(device)
        s1 = torch.ones(r).to(device)
        b1 = torch.zeros(256).to(device)
        
        wR2 = torch.empty(256, r).to(device)
        wL2 = torch.empty(256, r).to(device)
        s2 = torch.ones(r).to(device)
        b2 = torch.zeros(256).to(device)
        
        w3=torch.empty(256, n_class).to(device)
        b3 = torch.zeros(n_class).to(device)
        
        # Initialization:
        nn.init.orthogonal_(wR1)
        nn.init.orthogonal_(wL1)
        nn.init.orthogonal_(wR2)
        nn.init.orthogonal_(wL2)
        nn.init.orthogonal_(w3)
    
        w_task.append([wR1,wL1,wR2,wL2,w3,s1,s2,b1,b2,b3])
        
    # Train-Test sequence:
    print("Task Progress:")
    # Train
    U_, S_, V_ = [],[],[]
    U_.append(w_task[0][0]); U_.append(w_task[0][2]) # U_[n], n-th interm layer.
    S_.append(w_task[0][5]); S_.append(w_task[0][6]) # s0
    V_.append(w_task[0][1]); V_.append(w_task[0][3]) # wL0
    S1_tasks = []
    S2_tasks = []
    b_tasks = []
    w3_tasks = []
    for task in tq.tqdm(range(num_task)):

        dwR1=torch.autograd.Variable(w_task[task][0],requires_grad=True)
        dwL1=torch.autograd.Variable(w_task[task][1],requires_grad=True)
        dwR2=torch.autograd.Variable(w_task[task][2],requires_grad=True)
        dwL2=torch.autograd.Variable(w_task[task][3],requires_grad=True)    
        w3=torch.autograd.Variable(w_task[task][4],requires_grad=True)
        # Selector
        if task == 0:
            ds1 = torch.autograd.Variable(w_task[task][5],requires_grad=True)
            ds2 = torch.autograd.Variable(w_task[task][6],requires_grad=True)
        # set previous task matrix as zero    
        else:
            #ds1 = torch.cat((S1_tasks[task-1].to(device), w_task[task][5]), dim = 0)
            #ds2 = torch.cat((S2_tasks[task-1].to(device), w_task[task][6]), dim = 0)    
            ds1 = torch.cat((torch.zeros(S1_tasks[task-1].shape).to(device), w_task[task][5]), dim = 0) 
            ds2 = torch.cat((torch.zeros(S2_tasks[task-1].shape).to(device), w_task[task][6]), dim = 0)      
            ds1=torch.autograd.Variable(ds1,requires_grad=True)
            ds2=torch.autograd.Variable(ds2,requires_grad=True)
        # Bias
        b1=torch.autograd.Variable(w_task[task][7],requires_grad=True)
        b2=torch.autograd.Variable(w_task[task][8],requires_grad=True)
        b3=torch.autograd.Variable(w_task[task][9],requires_grad=True)
        
        optimizer=torch.optim.Adam([dwR1,dwL1,dwR2,dwL2,w3,ds1,ds2,b1,b2,b3],lr=lr)
        #optimizer=torch.optim.SGD([dwR1,dwL1,dwR2,dwL2,w3,ds1,ds2,b1,b2,b3],lr=lr,momentum = 0.9)
        
        loss_epoch=[]
        
        epoch_correct, epoch_total = 0, 0
        epoch_accuracy = []
        times = 1
        for epoch in tq.tqdm(range(train_epochs[task])):
            train_size=len(X_train[task][fold])
            batch_no=np.int16(np.ceil(train_size/np.float64(batch_size)))
        
            epoch_idx=np.arange(train_size)
            np.random.shuffle(epoch_idx)
            
            # Reduce lr during training if needed:
            if epoch in lr_drop_epoch:
                new_lr = lr * (lr_drop_factor ** times)
                times = times + 1
                for param_group in optimizer.param_groups:
                    param_group['lr'] = new_lr
                print(param_group['lr'])
        
            loss_batch = 0
            for batch_idx in range(0, batch_no):
                x_batch=X_train[task][fold][epoch_idx[batch_idx*batch_size:np.min([(batch_idx+1)*batch_size,train_size])],:]
                y_batch=Y_train[task][fold][epoch_idx[batch_idx*batch_size:np.min([(batch_idx+1)*batch_size,train_size])]]
            
                x_batch_tensor=torch.FloatTensor(x_batch).to(device)
                y_batch_tensor=torch.FloatTensor(y_batch).to(device)
                
                if task == 0:
                    w1=torch.matmul(torch.matmul(dwR1,torch.diag(ds1)),dwL1.T)#+b1
                    w2=torch.matmul(torch.matmul(dwR2,torch.diag(ds2)),dwL2.T)#+b2
                else:
                    U_inc1 = torch.cat((U_[0].to(device), dwR1), dim = 1)
                    V_inc1 = torch.cat((V_[0].to(device), dwL1), dim = 1)
                    w1=torch.matmul(torch.matmul(U_inc1,torch.diag(ds1)),V_inc1.T)#+b1
                    
                    U_inc2 = torch.cat((U_[1].to(device), dwR2), dim = 1)
                    V_inc2 = torch.cat((V_[1].to(device), dwL2), dim = 1)
                    w2=torch.matmul(torch.matmul(U_inc2,torch.diag(ds2)),V_inc2.T)#+b2
                    
                h1=relu(torch.matmul(x_batch_tensor,w1)+b1)
                h2=relu(torch.matmul(h1,w2)+b2)
                
                y_hat=softmax(torch.matmul(h2,w3)+b3)
                
                loss=mse_loss(y_hat,y_batch_tensor)
                
                loss_batch = loss_batch + loss.item()
                
                
                optimizer.zero_grad()
                loss.backward()
            
                optimizer.step()
                
                # Prediction:
                prediction = np.argmax(y_hat.cpu().detach().numpy(),axis=1)
                y_label=np.argmax(y_batch_tensor.cpu().detach().numpy(),axis=1)
                epoch_correct += int(sum(y_label == prediction))
                epoch_total += int(len(y_label))
            
            epoch_accuracy.append(float(epoch_correct / epoch_total))
            epoch_correct, epoch_total = 0, 0
                
            loss_epoch.append(loss_batch)
            
        # Concatenate the weight matrix of every tasks:
        with torch.no_grad():
            if task == 0:
                U_[0] = dwR1.detach(); U_[1] = dwR2.detach()
                S_[0] = ds1.detach(); S_[1] = ds2.detach()
                V_[0] = dwL1.detach(); V_[1] = dwL2.detach()
            elif task >= 1:
                U_[0] = torch.cat((U_[0], dwR1.detach()), dim = 1); U_[1] = torch.cat((U_[1], dwR2.detach()), dim = 1)
                S_[0] = ds1; S_[1] = ds2
                V_[0] = torch.cat((V_[0], dwL1.detach()), dim = 1); V_[1] = torch.cat((V_[1], dwL2.detach()), dim = 1)
                
            # Save selector matrices of every tasks:
            S1_tasks.append(ds1); S2_tasks.append(ds2)
            b_tasks.append([b1.detach(), b2.detach(), b3.detach()])
            w3_tasks.append(w3.detach())
            
        # Plot loss:
        plt.figure()
        plt.plot(loss_epoch)
        plt.title(['Subset:',str(fold),' Task:',str(task)])
        plt.show()
        # Plot epoch accuracy:
        plt.figure()
        plt.plot(epoch_accuracy)
        plt.title(['Epochs accuracy:'])
        plt.show()
        
    with torch.no_grad():
        w_task[task][0]=dwR1.detach()
        w_task[task][1]=dwL1.detach()
        w_task[task][2]=dwR2.detach()
        w_task[task][3]=dwL2.detach()
        w_task[task][4]=w3.detach()
        w_task[task][5]=ds1.detach()
        w_task[task][6]=ds2.detach()
        w_task[task][7]=b1.detach()  
        w_task[task][8]=b2.detach()  
        w_task[task][9]=b3.detach()

    # Training accuracy:
    train_accuracy=[]
    for task in range(num_task):
        
        train_correct, train_total = 0, 0
        with torch.no_grad():
            train_size=len(X_train[task][fold])
            batch_no=np.int16(np.ceil(train_size/np.float64(batch_size)))
        
            epoch_idx=np.arange(train_size)
            np.random.shuffle(epoch_idx)
            
            for batch_idx in range(0, batch_no):
                x_batch=X_train[task][fold][epoch_idx[batch_idx*batch_size:np.min([(batch_idx+1)*batch_size,train_size])],:]
                y_batch=Y_train[task][fold][epoch_idx[batch_idx*batch_size:np.min([(batch_idx+1)*batch_size,train_size])]]
            
                x_batch_tensor=torch.FloatTensor(x_batch).to(device)
                y_batch_tensor=torch.FloatTensor(y_batch).to(device)
            

                u_1 = U_[0][:,0:r1+(task)*r2]
                s_1 = S1_tasks[task]
                v_1 = V_[0][:,0:r1+(task)*r2]
                u_2 = U_[1][:,0:r1+(task)*r2]
                s_2 = S2_tasks[task]
                v_2 = V_[1][:,0:r1+(task)*r2]
                    
                w1=torch.matmul(torch.matmul(u_1, torch.diag(s_1)),v_1.T)
                w2=torch.matmul(torch.matmul(u_2, torch.diag(s_2)),v_2.T)
                    
                h1=relu(torch.matmul(x_batch_tensor,w1) + b_tasks[task][0])
            
                h2=relu(torch.matmul(h1,w2) + b_tasks[task][1])
                
                # Prediction
                w3 = w3_tasks[task]
                y_hat=softmax(torch.matmul(h2,w3) + b_tasks[task][2]) # Multi-class
                
                prediction = np.argmax(y_hat.cpu().detach().numpy(),axis=1) # Multi-class
                y_label=np.argmax(y_batch_tensor.cpu().detach().numpy(),axis=1) # Multi-class
                train_correct += int(sum(y_label == prediction)) # Multi-class
                train_total += int(len(y_label)) # Multi-class
    
        train_accuracy.append(float(train_correct / train_total))
    print("Training Accuracy: ", train_accuracy, "Avg: ", np.mean(train_accuracy))
        
    # Test accuracy:
    task_accuracy=[]
    for task in range(num_task):
        
        task_correct, task_total = 0, 0
        with torch.no_grad():
            test_size=len(X_test[task][fold])
            batch_no=np.int16(np.ceil(test_size/np.float64(batch_size)))
        
            epoch_idx=np.arange(test_size)
            np.random.shuffle(epoch_idx)
            
            for batch_idx in range(0, batch_no):
                x_batch=X_test[task][fold][epoch_idx[batch_idx*batch_size:np.min([(batch_idx+1)*batch_size,test_size])],:]
                y_batch=Y_test[task][fold][epoch_idx[batch_idx*batch_size:np.min([(batch_idx+1)*batch_size,test_size])]]
            
                x_batch_tensor=torch.FloatTensor(x_batch).to(device)
                y_batch_tensor=torch.FloatTensor(y_batch).to(device)

                u_1 = U_[0][:,0:r1+(task)*r2]
                s_1 = S1_tasks[task]
                v_1 = V_[0][:,0:r1+(task)*r2]
                u_2 = U_[1][:,0:r1+(task)*r2]
                s_2 = S2_tasks[task]
                v_2 = V_[1][:,0:r1+(task)*r2]
                w1=torch.matmul(torch.matmul(u_1, torch.diag(s_1)),v_1.T)
                w2=torch.matmul(torch.matmul(u_2, torch.diag(s_2)),v_2.T)
                
                    
                h1=relu(torch.matmul(x_batch_tensor,w1) + b_tasks[task][0])
                
                h2=relu(torch.matmul(h1,w2) + b_tasks[task][1])
                
                # Prediction
                w3 = w3_tasks[task]
                y_hat=softmax(torch.matmul(h2,w3) + b_tasks[task][2])
                prediction = np.argmax(y_hat.cpu().detach().numpy(),axis=1)
                y_label=np.argmax(y_batch_tensor.cpu().detach().numpy(),axis=1)
                task_correct += int(sum(y_label == prediction))
                task_total += int(len(y_label))
        
        task_accuracy.append(float(task_correct / task_total))
        
    accuracy.append(task_accuracy)
    fold_accuracy.append(task_accuracy)
    print("Subset", fold, "Complete. Validation Accuracy: ", fold_accuracy, "Avg: ", np.mean(fold_accuracy))
    
    
disp_1 = [U_, S_, V_, S1_tasks, S2_tasks, b_tasks, w3_tasks]
del U_, S_, V_, S1_tasks, S2_tasks, b_tasks, w3_tasks

    