In [1]:
import os
import sys
import wandb
import argparse
import numpy as np


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "")))
import torch
import torchvision.transforms as T
import torchvision

from dataloaders.dataloader_cifar10 import get_cifar10
from dataloaders.dataloader_cifar100 import get_cifar100
from utils.eval_metrics import linear_evaluation, get_t_SNE_plot
from models.linear_classifer import LinearClassifier
from models.ssl import  SimSiam, Siamese, Encoder, Predictor

from trainers.train_simsiam import train_simsiam
from trainers.train_infomax import train_infomax
from trainers.train_barlow import train_barlow

from trainers.train_PFR import train_PFR_simsiam
from trainers.train_PFR_contrastive import train_PFR_contrastive_simsiam
from trainers.train_contrastive import train_contrastive_simsiam
from trainers.train_ering import train_ering_simsiam

from torchsummary import summary
import random
from utils.lr_schedulers import LinearWarmupCosineAnnealingLR, SimSiamScheduler
from utils.eval_metrics import Knn_Validation_cont
from copy import deepcopy
from loss import invariance_loss,CovarianceLoss,ErrorCovarianceLoss
import torch.nn as nn
import time
import torch.nn.functional as F

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = torchvision.transforms.functional.gaussian_blur(x,kernel_size=[3,3],sigma=sigma)#kernel size and sigma are open problems but right now seems ok!
        return x


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [30]:
class Args():
    normalization = 'group'
    weight_standard = True
    same_lr = False
    pretrain_batch_size = 512
    pretrain_warmup_epochs = 10
    pretrain_warmup_lr = 3e-3
    pretrain_base_lr = 0.03
    pretrain_momentum = 0.9
    pretrain_weight_decay = 5e-4
    min_lr = 0.00
    lambdap = 1.0
    appr = 'barlow_PFR'
    knn_report_freq = 10
    cuda_device = 5
    num_workers = 8
    contrastive_ratio = 0.001
    dataset = 'cifar100'
    class_split = [25,25,25,25]
    epochs = [500,500,500,500]
    cov_loss_weight = 1.0
    sim_loss_weight = 250.0
    info_loss = 'invariance'
    lambda_norm = 1.0
    subspace_rate = 0.99
    lambda_param = 5e-3
    bsize = 32
    msize = 150
    proj_hidden = 2048
    proj_out = 2048 #infomax
    pred_hidden = 512
    pred_out = 2048



In [31]:
args = Args()

In [15]:
if args.dataset == "cifar10":
        get_dataloaders = get_cifar10
        num_classes=10
elif args.dataset == "cifar100":
    get_dataloaders = get_cifar100
    num_classes=100
assert sum(args.class_split) == num_classes
assert len(args.class_split) == len(args.epochs)

In [19]:
num_worker = args.num_workers
#device
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)

cuda:5


In [20]:
#wandb init
wandb.init(project="CSSL",  entity="yavuz-team",
            mode="disabled",
            config=args,
            name= str(args.dataset) + '-algo' + str(args.appr) + "-e" + str(args.epochs) + "-b" 
            + str(args.pretrain_batch_size) + "-lr" + str(args.pretrain_base_lr)+"-CS"+str(args.class_split))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [21]:
if 'infomax' in args.appr or 'barlow' in args.appr:
    transform = T.Compose([
            T.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)]), p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur()], p=0.5), 
            T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])

    transform_prime = T.Compose([
            T.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)]), p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur()], p=0.5), 
            T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])

In [22]:
#Dataloaders
print("Creating Dataloaders..")
#Class Based
train_data_loaders, train_data_loaders_knn, test_data_loaders, _, train_data_loaders_linear, train_data_loaders_pure  = get_dataloaders(transform, transform_prime, \
                                    classes=args.class_split, valid_rate = 0.00, batch_size=args.pretrain_batch_size, seed = 0, num_worker= num_worker)
_, train_data_loaders_knn_all, test_data_loaders_all, _, train_data_loaders_linear_all, train_data_loaders_pure_all = get_dataloaders(transform, transform_prime, \
                                        classes=[num_classes], valid_rate = 0.00, batch_size=args.pretrain_batch_size, seed = 0, num_worker= num_worker)


Creating Dataloaders..
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [32]:
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)
if 'infomax' in args.appr or 'barlow' in args.appr:
    proj_hidden = args.proj_hidden
    proj_out = args.proj_out
    encoder = Encoder(hidden_dim=proj_hidden, output_dim=proj_out, normalization = args.normalization, weight_standard = args.weight_standard,appr_name =args.appr)
    model = Siamese(encoder)
    model.to(device) #automatically detects from model

cuda:5


In [26]:
#load model here
file_name = 'checkpoints/checkpoint_cifar100-algocassle_barlow-e[500, 500, 500, 500]-b256-lr0.06-CS[25, 25, 25, 25]acc_60.42999999999999.pth.tar'
dict = torch.load(file_name)

In [29]:
dict.keys()

dict_keys(['arch', 'lr', 'state_dict', 'optimizer', 'loss', 'encoder', 'classifier'])

In [34]:
model.temporal_projector = nn.Sequential(
            nn.Linear(args.proj_out, args.proj_hidden, bias=False),
            nn.BatchNorm1d(args.proj_hidden),
            nn.ReLU(),
            nn.Linear(args.proj_hidden, args.proj_out),
        ).to(device)

In [35]:
model.load_state_dict(dict['state_dict'])

<All keys matched successfully>

In [36]:
def collect_activations(net,train_data_loader, device, orth_set): #distributed

    layer_names = ['0.weight', '3.0.conv1.weight', '3.0.conv2.weight', '3.1.conv1.weight', \
    '3.1.conv2.weight', '4.0.conv1.weight', '4.0.conv2.weight', '4.0.shortcut.0.weight', \
        '4.1.conv1.weight', '4.1.conv2.weight', '5.0.conv1.weight', '5.0.conv2.weight', \
            '5.0.shortcut.0.weight', '5.1.conv1.weight', '5.1.conv2.weight', '6.0.conv1.weight', \
                '6.0.conv2.weight', '6.0.shortcut.0.weight', '6.1.conv1.weight', '6.1.conv2.weight']
                
    stride_list = [1, 1,1,1,1, 2,1,2,1,1, 2,1,2,1,1, 2,1,2,1,1]
    padding_list = []
    kernel_list = []

    net.eval()
    activation = {}
    for key in layer_names:
        activation[key] = []
        
    for batch_index, (x, _) in enumerate(train_data_loader):
        #print(x.shape)
        if batch_index > 15: break
        _ = net(x.to(device))
        act_list = [x.to(device), 
            net[3][0].act['conv_0'], net[3][0].act['conv_1'], net[3][1].act['conv_0'], net[3][1].act['conv_1'],
            net[4][0].act['conv_0'], net[4][0].act['conv_1'], net[4][0].act['conv_0'], net[4][1].act['conv_0'], net[4][1].act['conv_1'],
            net[5][0].act['conv_0'], net[5][0].act['conv_1'], net[5][0].act['conv_0'], net[5][1].act['conv_0'], net[5][1].act['conv_1'],
            net[6][0].act['conv_0'], net[6][0].act['conv_1'], net[6][0].act['conv_0'], net[6][1].act['conv_0'], net[6][1].act['conv_1']]
        for j, key in enumerate(layer_names):
            activation[key].append(act_list[j].detach().cpu())
    for name in activation.keys():
        activation[name] = torch.cat(activation[name],dim=0)
        if "shortcut" not in name:
            activation[name] = F.pad(activation[name], (1, 1, 1, 1), "constant", 0)
            kernel_list.append(3)
        else:
            kernel_list.append(1)
    
    
    device = 'cpu'
    for i in range(len(stride_list)):
        layer_name = layer_names[i]
        print(i)
       
        st = stride_list[i]
        #pad = padding_list[i]
        kernel = kernel_list[i]

        act = activation[layer_name]

        unfolder = torch.nn.Unfold(kernel, dilation=1, padding=0, stride= st)

        mat = unfolder(act)
        mat = mat.permute(0,2,1)
        mat = mat.reshape(-1, mat.shape[2])
        
        mat = mat.T.to(device)
        ratio = 1
        if orth_set[layer_name] is not None:
            U = orth_set[layer_name].to(device)
            projected = U @ U.T @ mat
            remaining = mat - projected
            rem_norm = torch.norm(remaining)
            orj_norm = torch.norm(mat)
            ratio = (rem_norm / orj_norm).cpu()
            mat = remaining
        activation[layer_name] = mat.cpu()
    return activation
 

In [37]:
from sklearn.utils.extmath import randomized_svd

def expand_orth_set(activations, orth_set, eps = 0.95):
    #print(activations.keys())
    for key in activations.keys():
        if orth_set[key] == None:
            projected = torch.zeros(1)
        else:
            projected = orth_set[key]  @ orth_set[key].T @ activations[key] 

        remaining = activations[key] - projected
        ratio = torch.norm(projected)**2 / torch.norm(activations[key])**2
        eps_new = eps - ratio
        #tot = torch.norm(remaining)**2

        #find svds of remaining
        remaining = remaining / (remaining.shape[1]) @ remaining.T

        U, S, V = torch.svd(remaining.cpu())
        #if key == '3.1.conv2.weight':
        #    return S

        tot = torch.linalg.norm(S, ord = 1)#because already square
        
        
        #S = torch.sqrt(S)
        #U, S, V = randomized_svd(remaining.cpu().numpy(),n_components=remaining.shape[0])
        U = torch.tensor(U)
        S = torch.tensor(S)
        #U = U.cuda()
        #find how many singular vectors will be used
        for i in range(len(S)):
            hand = torch.linalg.norm(S[0:i+1], ord = 1)
                #print(eps_new)
            
            if  hand / tot > eps_new:
                break
            
        print(U[:,0:i+1].shape)
        if orth_set[key] == None:
            orth_set[key] = U[:,0:i+1]
        else:
            orth_set[key] = torch.cat((orth_set[key], U[:,0:i+1]),dim=1)

In [38]:
def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return model
    else:
       # look for children from children... to the last child!
       for child in children:
            try:
                flatt_children.extend(get_children(child))
            except TypeError:
                flatt_children.append(get_children(child))
    return flatt_children

In [19]:
# all_modules = get_children(model)
# for m in all_modules:
#     if isinstance(m, nn.GroupNorm):
#         m.track_running_stats = False
# for key,p in model.encoder.backbone.named_parameters():
#     if 'bias' in key or 'bn' in key or 'shortcut.1' in key or '1.weight' == key:
#         print(key)
#         p.requires_grad = False

In [39]:
def update_memory(memory, dataloader, size):
    indices = np.random.choice(len(dataloader.dataset), size=size, replace=False)
    x, _ =  dataloader.dataset[indices]
    memory = torch.cat((memory, x), dim=0)
    return memory

def train_LRD_replay_infomax(model, train_data_loaders, knn_train_data_loaders, train_data_loaders_pure, test_data_loaders, device, args, transform, transform_prime):#just for 2 tasks
    
    memory = torch.Tensor()

    epoch_counter = 0
    old_model = None
    criterion = nn.CosineSimilarity(dim=1)
    Q = None

    for task_id, loader in enumerate(train_data_loaders):
        # Optimizer and Scheduler
        model.task_id = task_id
        init_lr = args.pretrain_base_lr*args.pretrain_batch_size/256.
        if task_id != 0 and args.same_lr != True:
            init_lr = init_lr / 10

        project_dim = args.proj_out
        covarince_loss = CovarianceLoss(project_dim,device=device)

            
        optimizer = torch.optim.SGD(model.parameters(), lr=init_lr, momentum=args.pretrain_momentum, weight_decay= args.pretrain_weight_decay)
        scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=args.pretrain_warmup_epochs , max_epochs=args.epochs[task_id],warmup_start_lr=args.pretrain_warmup_lr,eta_min=args.min_lr) #eta_min=2e-4 is removed scheduler + values ref: infomax paper

        loss_ = []
        for epoch in range(args.epochs[task_id]):
            start = time.time()
            model.train()
            epoch_loss_task = []
            epoch_loss_kd = []
            epoch_loss_norm = []
            epoch_loss_norm_old = []
            for x1, x2, y in loader:
                x1,x2 = x1.to(device), x2.to(device)
                f1 = model.encoder.backbone(x1).squeeze() # NxC
                f2 = model.encoder.backbone(x2).squeeze() # NxC

                if task_id > 0:
                    indices = np.random.choice(len(memory), size=min(args.bsize, len(memory)), replace=False)
                    x = memory[indices].to(device)
                    x1_old, x2_old = transform(x), transform_prime(x)
                    f1_old = model.encoder.backbone(x1_old).squeeze() # NxC
                    f2_old = model.encoder.backbone(x2_old).squeeze() # NxC

                if Q != None:#let's do projection
                    f1_projected = f1 @ Q @ Q.T  
                    f2_projected = f2 @ Q @ Q.T

                    f1 = f1 - f1_projected
                    f2 = f2 - f2_projected

                    norm_loss_1 = torch.norm(f1_projected,dim =1) / (torch.norm(f1,dim =1) + 0.0000001) 
                    norm_loss_1 = torch.mean(norm_loss_1)

                    norm_loss_2 = torch.norm(f2_projected,dim =1) / (torch.norm(f2,dim =1) + 0.0000001) 
                    norm_loss_2 = torch.mean(norm_loss_2)

                    loss_norm = (norm_loss_1 + norm_loss_2) / 2

                    f1_projected_old = f1_old @ Q @ Q.T  
                    f2_projected_old = f2_old @ Q @ Q.T

                    f1_rem_old = f1_old - f1_projected_old
                    f2_rem_old = f2_old - f2_projected_old

                    norm_loss_1 = torch.norm(f1_rem_old,dim =1) / (torch.norm(f1_projected_old,dim =1) + 0.0000001) 
                    norm_loss_1 = torch.mean(norm_loss_1)

                    norm_loss_2 = torch.norm(f2_rem_old,dim =1) / (torch.norm(f2_projected_old,dim =1) + 0.0000001) 
                    norm_loss_2 = torch.mean(norm_loss_2)

                    loss_norm_old = (norm_loss_1 + norm_loss_2) / 2
                else:
                    loss_norm = torch.tensor(0)
                    loss_norm_old = torch.tensor(0)


                z1 = model.encoder.projector(f1) # NxC
                z2 = model.encoder.projector(f2) # NxC

                z1 = F.normalize(z1, p=2)
                z2 = F.normalize(z2, p=2)
                cov_loss =  covarince_loss(z1, z2)
                sim_loss =  invariance_loss(z1, z2)
                

                loss_task = (args.sim_loss_weight * sim_loss) + (args.cov_loss_weight * cov_loss) 

                if task_id != 0: #do Distillation
                    f1Old = oldModel(x1).squeeze().detach()
                    f2Old = oldModel(x2).squeeze().detach()

                    lossKD = (-(criterion(f1_projected, f1Old).mean() * 0.5
                                            + criterion(f2_projected, f2Old).mean() * 0.5) )
                else:
                    lossKD = torch.tensor(0)
                


                epoch_loss_task.append(loss_task.item())
                epoch_loss_kd.append(lossKD.item())
                epoch_loss_norm.append(loss_norm.item())
                epoch_loss_norm_old.append(loss_norm_old.item())
                
                optimizer.zero_grad()
                loss = loss_task +  args.lambdap * lossKD + args.lambda_norm * loss_norm +  args.lambda_norm * loss_norm_old
                loss.backward()
            
                optimizer.step() 
                    
            epoch_counter += 1
            scheduler.step()
            loss_.append(np.mean(epoch_loss_task))
            end = time.time()
            print('epoch end')
            if (epoch+1) % args.knn_report_freq == 0:
                knn_acc, task_acc_arr = Knn_Validation_cont(model, knn_train_data_loaders[:task_id+1], test_data_loaders[:task_id+1], device=device, K=200, sigma=0.5) 
                wandb.log({" Global Knn Accuracy ": knn_acc, " Epoch ": epoch_counter})
                for i, acc in enumerate(task_acc_arr):
                    wandb.log({" Knn Accuracy Task-"+str(i): acc, " Epoch ": epoch_counter})
                    print(f" Knn Accuracy Task- {str(i)} : {acc},  Epoch : {epoch_counter}")
                print(f'Task {task_id:2d} | Epoch {epoch:3d} | Time:  {end-start:.1f}s  | Loss: {np.mean(epoch_loss_task):.4f} | KDLoss: {np.mean(epoch_loss_kd):.4f} | Norm_Loss: {np.mean(epoch_loss_norm):.4f}  | Norm_Loss Old:  {np.mean(epoch_loss_norm_old):.4f}   | Knn:  {knn_acc*100:.2f}')
                print(task_acc_arr)
            else:
                print(f'Task {task_id:2d} | Epoch {epoch:3d} | Time:  {end-start:.1f}s  | Loss: {np.mean(epoch_loss_task):.4f} | KDLoss: {np.mean(epoch_loss_kd):.4f} | Norm_Loss: {np.mean(epoch_loss_norm):.4f} | Norm_Loss Old:  {np.mean(epoch_loss_norm_old):.4f} ')
        
            wandb.log({" Average Training Loss ": np.mean(epoch_loss_task), " Epoch ": epoch_counter, " Average KD Loss ": np.mean(epoch_loss_kd) , " Average Norm Loss ": np.mean(epoch_loss_norm) , " Average Norm Loss Old": np.mean(epoch_loss_norm_old) })  
            wandb.log({" lr ": optimizer.param_groups[0]['lr'], " Epoch ": epoch_counter})
            
        with torch.no_grad():
            oldModel = deepcopy(model.encoder.backbone.eval())  # save t-1 model
        oldModel.to(device)
        oldModel.eval()
        for param in oldModel.parameters(): #Freeze old model
            param.requires_grad = False

        Q = extract_subspace(model, knn_train_data_loaders[task_id], rate= args.subspace_rate,device = device, Q_prev = Q)
        Q = Q.to(device)

        memory = update_memory(memory, train_data_loaders_pure[task_id], args.msize)

        #file_name = './checkpoints/checkpoint_' + str(args.dataset) + '-algo' + str(args.appr) + "-e" + str(args.epochs) + "-b" + str(args.pretrain_batch_size) + "-lr" + str(args.pretrain_base_lr)
        #+"-CS"+str(args.class_split) + 'task_' + str(task_id) + 'lambdap_' + str(args.lambdap) + 'lambda_norm_' + str(args.lambda_norm) + 'same_lr_' + str(args.same_lr) + 'norm_' + str(normalization) + 'ws_' + str(args.weight_standard) + '.pth.tar' 
        
        file_name = './checkpoints/checkpoint_' + str(args.dataset) + '-algo' + str(args.appr) + "-e" + str(args.epochs) + "-b" + str(args.pretrain_batch_size) + "-lr" + str(args.pretrain_base_lr) + "-CS" + str(args.class_split) + '_task_' + str(task_id) + '_lambdap_' + str(args.lambdap) + '_lambda_norm_' + str(args.lambda_norm) + '_same_lr_' + str(args.same_lr) + '_norm_' + str(args.normalization) + '_ws_' + str(args.weight_standard) + '.pth.tar'

        # save your encoder network
        torch.save({
                        'state_dict': model.state_dict(),
                        'optimizer' : optimizer.state_dict(),
                        'encoder': model.encoder.backbone.state_dict(),
                    }, file_name)


    return model, loss_, optimizer

In [29]:
class Args():
    normalization = 'group'
    weight_standard = True
    same_lr = False
    pretrain_batch_size = 512
    pretrain_warmup_epochs = 10
    pretrain_warmup_lr = 3e-3
    pretrain_base_lr = 0.10
    pretrain_momentum = 0.9
    pretrain_weight_decay = 5e-4
    min_lr = 0.00
    lambdap = 100.0
    appr = 'LRD_infomax'
    knn_report_freq = 1
    cuda_device = 2
    num_workers = 8
    contrastive_ratio = 0.001
    dataset = 'cifar10'
    class_split = [5,5]
    epochs = [50,500]
    cov_loss_weight = 1.0
    sim_loss_weight = 250.0
    info_loss = 'invariance'
    lambda_norm = 1.0
    subspace_rate = 0.98
    lambda_param = 5e-3
    bsize = 32
    msize = 150
    proj_hidden = 2048
    proj_out = 64 #infomax
    pred_hidden = 512
    pred_out = 2048
    m_size = 100

In [30]:
args = Args()

In [31]:
args.lambda_norm

1.0

In [34]:
model, loss, optimizer = train_LRD_replay_infomax(model, train_data_loaders, train_data_loaders_knn,train_data_loaders_pure , test_data_loaders, device, args,transform,transform_prime)

epoch end
 Knn Accuracy Task- 0 : 0.5298,  Epoch : 1
Task  0 | Epoch   0 | Time:  49.0s  | Loss: 13.2724 | KDLoss: 0.0000 | Norm_Loss: 0.0000  | Norm_Loss Old:  0.0000   | Knn:  52.98
[0.5298]
epoch end
 Knn Accuracy Task- 0 : 0.5468,  Epoch : 2
Task  0 | Epoch   1 | Time:  29.8s  | Loss: 13.9099 | KDLoss: 0.0000 | Norm_Loss: 0.0000  | Norm_Loss Old:  0.0000   | Knn:  54.68
[0.5468]
epoch end
 Knn Accuracy Task- 0 : 0.5546,  Epoch : 3
Task  0 | Epoch   2 | Time:  30.4s  | Loss: 12.6586 | KDLoss: 0.0000 | Norm_Loss: 0.0000  | Norm_Loss Old:  0.0000   | Knn:  55.46
[0.5546]
epoch end
 Knn Accuracy Task- 0 : 0.576,  Epoch : 4
Task  0 | Epoch   3 | Time:  30.8s  | Loss: 12.5329 | KDLoss: 0.0000 | Norm_Loss: 0.0000  | Norm_Loss Old:  0.0000   | Knn:  57.60
[0.576]
epoch end
 Knn Accuracy Task- 0 : 0.593,  Epoch : 5
Task  0 | Epoch   4 | Time:  31.1s  | Loss: 12.3795 | KDLoss: 0.0000 | Norm_Loss: 0.0000  | Norm_Loss Old:  0.0000   | Knn:  59.30
[0.593]
epoch end
 Knn Accuracy Task- 0 : 0.619

KeyboardInterrupt: 

In [27]:
if 'infomax' in args.appr or 'barlow' in args.appr:
    proj_hidden = args.proj_hidden
    proj_out = args.proj_out
    encoder = Encoder(hidden_dim=proj_hidden, output_dim=proj_out, normalization = args.normalization, weight_standard = args.weight_standard,appr_name=args.appr)
    old_model = Siamese(encoder)
    old_model.to(device) #automatically detects from model

In [28]:
old_model.load_state_dict(dict['state_dict'])

<All keys matched successfully>

In [40]:
Q = extract_subspace(old_model, train_data_loaders_knn[0], rate= args.subspace_rate,device = device,Q_prev = None)
old_model = None

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:7 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)

In [None]:
device = 'cpu'
Q = Q.to(device)
model.eval()
model.to(device)
for x1, x2, y in train_data_loaders[1]:
    x1,x2 = x1.to(device), x2.to(device)
    f1 = model.encoder.backbone(x1).squeeze() # NxC
    f2 = model.encoder.backbone(x2).squeeze() # NxC

    if Q != None:#let's do projection
        f1_projected = f1 @ Q @ Q.T  
        f2_projected = f2 @ Q @ Q.T

        f1 = f1 - f1_projected
        f2 = f2 - f2_projected

        norm_loss_1 = torch.norm(f1_projected,dim =1) / (torch.norm(f1,dim =1) + 0.0000001) 
        norm_loss_1 = torch.mean(norm_loss_1)

        norm_loss_2 = torch.norm(f2_projected,dim =1) / (torch.norm(f2,dim =1) + 0.0000001) 
        norm_loss_2 = torch.mean(norm_loss_2)

        loss_norm = (norm_loss_1 + norm_loss_2) / 2
        print(loss_norm)

tensor(0.9881, grad_fn=<DivBackward0>)


KeyboardInterrupt: 

In [None]:
knn_acc, task_acc_arr = Knn_Validation_cont(model, train_data_loaders_knn[:1], test_data_loaders[:task_id+1], device=device, K=200, sigma=0.5) 

In [37]:
torch.cuda.empty_cache()

In [26]:
#Test Linear classification acc
print("Starting Classifier Training..")
lin_epoch = 100
if args.dataset == 'cifar10':
    classifier = LinearClassifier(num_classes = 10).to(device)
elif args.dataset == 'cifar100':
    classifier = LinearClassifier(num_classes = 100).to(device)

lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.1, momentum=0.9) # Infomax: no weight decay, epoch 100, cosine scheduler
lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=2e-4) #scheduler + values ref: infomax paper
test_loss, test_acc1, test_acc5, classifier = linear_evaluation(model, train_data_loaders_knn_all[0],
                                                                    test_data_loaders_all[0],lin_optimizer, classifier, 
                                                                    lin_scheduler, epochs=lin_epoch, device=device) 


Starting Classifier Training..


Lin.Train Epoch: [1] Loss: 1.9388 : 100%|██████████| 98/98 [00:09<00:00, 10.29it/s]
Lin.Test Epoch: [1] Loss: 1.7908 ACC@1: 63.40% ACC@5: 80.53% : 100%|██████████| 20/20 [00:02<00:00,  6.98it/s]
Lin.Train Epoch: [2] Loss: 1.6584 : 100%|██████████| 98/98 [00:09<00:00, 10.25it/s]
Lin.Test Epoch: [2] Loss: 1.6033 ACC@1: 69.04% ACC@5: 85.03% : 100%|██████████| 20/20 [00:02<00:00,  7.09it/s]
Lin.Train Epoch: [3] Loss: 1.4899 : 100%|██████████| 98/98 [00:09<00:00, 10.60it/s]
Lin.Test Epoch: [3] Loss: 1.4646 ACC@1: 71.50% ACC@5: 87.95% : 100%|██████████| 20/20 [00:02<00:00,  6.97it/s]
Lin.Train Epoch: [4] Loss: 1.3616 : 100%|██████████| 98/98 [00:09<00:00, 10.46it/s]
Lin.Test Epoch: [4] Loss: 1.3554 ACC@1: 72.52% ACC@5: 90.94% : 100%|██████████| 20/20 [00:02<00:00,  7.05it/s]
Lin.Train Epoch: [5] Loss: 1.2603 : 100%|██████████| 98/98 [00:09<00:00, 10.20it/s]
Lin.Test Epoch: [5] Loss: 1.2686 ACC@1: 73.70% ACC@5: 92.30% : 100%|██████████| 20/20 [00:02<00:00,  6.93it/s]
Lin.Train Epoch: [6] Loss

KeyboardInterrupt: 