In [1]:
import os
import sys
import wandb
import argparse
import numpy as np


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "")))
import torch
import torchvision.transforms as T
import torchvision

from dataloaders.dataloader_cifar10 import get_cifar10
from dataloaders.dataloader_cifar100 import get_cifar100
from utils.eval_metrics import linear_evaluation, get_t_SNE_plot
from models.linear_classifer import LinearClassifier
from models.ssl import  SimSiam, Siamese, Encoder, Predictor

from trainers.train_simsiam import train_simsiam
from trainers.train_infomax import train_infomax
from trainers.train_barlow import train_barlow

from trainers.train_PFR import train_PFR_simsiam
from trainers.train_PFR_contrastive import train_PFR_contrastive_simsiam
from trainers.train_contrastive import train_contrastive_simsiam
from trainers.train_ering import train_ering_simsiam

from torchsummary import summary
import random
from utils.lr_schedulers import LinearWarmupCosineAnnealingLR, SimSiamScheduler
from utils.eval_metrics import Knn_Validation_cont
from copy import deepcopy
from loss import invariance_loss,CovarianceLoss,ErrorCovarianceLoss
import torch.nn as nn
import time
import torch.nn.functional as F
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC


os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = torchvision.transforms.functional.gaussian_blur(x,kernel_size=[3,3],sigma=sigma)#kernel size and sigma are open problems but right now seems ok!
        return x


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [21]:
class Args():
    normalization = 'batch'
    weight_standard = False
    same_lr = False
    pretrain_batch_size = 512
    pretrain_warmup_epochs = 10
    pretrain_warmup_lr = 3e-3
    pretrain_base_lr = 0.03
    pretrain_momentum = 0.9
    pretrain_weight_decay = 5e-4
    min_lr = 0.00
    lambdap = 1.0
    appr = 'barlow_PFR'
    knn_report_freq = 10
    cuda_device = 3
    num_workers = 8
    contrastive_ratio = 0.001
    dataset = 'cifar100'
    class_split = [20,20,20,20,20]
    epochs = [500,500,500,500,500]
    cov_loss_weight = 1.0
    sim_loss_weight = 250.0
    info_loss = 'invariance'
    lambda_norm = 1.0
    subspace_rate = 0.99
    lambda_param = 5e-3
    bsize = 32
    msize = 150
    proj_hidden = 2048
    proj_out = 2048 #infomax 64
    pred_hidden = 512
    pred_out = 2048



In [4]:
args = Args()

In [5]:
if args.dataset == "cifar10":
    get_dataloaders = get_cifar10
    num_classes=10
elif args.dataset == "cifar100":
    get_dataloaders = get_cifar100
    num_classes=100
assert sum(args.class_split) == num_classes
assert len(args.class_split) == len(args.epochs)

In [6]:
num_worker = args.num_workers
#device
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)

cuda:3


In [7]:
#wandb init
wandb.init(project="CSSL",  entity="yavuz-team",
            mode="disabled",
            config=args,
            name= str(args.dataset) + '-algo' + str(args.appr) + "-e" + str(args.epochs) + "-b" 
            + str(args.pretrain_batch_size) + "-lr" + str(args.pretrain_base_lr)+"-CS"+str(args.class_split))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [8]:
if 'infomax' in args.appr or 'barlow' in args.appr:
    transform = T.Compose([
            T.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)]), p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur()], p=0.5), 
            T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])

    transform_prime = T.Compose([
            T.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)]), p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur()], p=0.5), 
            T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])

In [15]:
batch_size = args.pretrain_batch_size
train_data_loaders, train_data_loaders_knn, test_data_loaders, _, train_data_loaders_linear, train_data_loaders_pure, train_data_loaders_generic = get_dataloaders(transform, transform_prime, \
                                        classes=args.class_split, valid_rate = 0.00, batch_size=batch_size, seed = 0, num_worker= num_worker)
_, train_data_loaders_knn_all, test_data_loaders_all, _, train_data_loaders_linear_all, _, _ = get_dataloaders(transform, transform_prime, \
                                        classes=[num_classes], valid_rate = 0.00, batch_size=batch_size, seed = 0, num_worker= num_worker)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [16]:

from tqdm import tqdm
def correct_top_k(outputs, targets, top_k=(1,5)):
    with torch.no_grad():
        prediction = torch.argsort(outputs, dim=-1, descending=True)
        result= []
        for k in top_k:
            correct_k = torch.sum((prediction[:, 0:k] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item() 
            result.append(correct_k)
        return result

def linear_test(net, data_loader, classifier, epoch, device, task_sep=False, intra_task=False):
    # evaluate model:
    net.eval() # for not update batchnorm
    linear_loss = 0.0
    num = 0
    total_loss, total_correct_1, total_correct_5, total_num, test_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
    with torch.no_grad():
        for data_tuple in test_bar:
            data, target = [t.to(device) for t in data_tuple]
            if task_sep:
                target = target // 25
            if intra_task:
                target = target % 25
            output = net(data)
            if classifier is not None:  #else net is already a classifier
                output = classifier(output) 
            linear_loss = F.cross_entropy(output, target)
            
            # Batchsize for loss and accuracy
            num = data.size(0)
            total_num += num 
            total_loss += linear_loss.item() * num 
            # Accumulating number of correct predictions 
            correct_top_1, correct_top_5 = correct_top_k(output, target, top_k=(1,5))    
            total_correct_1 += correct_top_1
            total_correct_5 += correct_top_5

            test_bar.set_description('Lin.Test Epoch: [{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}% '
                                     .format(epoch,  total_loss / total_num,
                                             total_correct_1 / total_num * 100, total_correct_5 / total_num * 100
                                             ))
        acc_1 = total_correct_1/total_num*100
        acc_5 = total_correct_5/total_num*100
        wandb.log({" Linear Layer Test Loss ": linear_loss / total_num, " Epoch ": epoch})
        wandb.log({" Linear Layer Test - Acc": acc_1, " Epoch ": epoch})
    return total_loss / total_num, acc_1 , acc_5 

def linear_train(net, data_loader, train_optimizer, classifier, scheduler, epoch, device, task_sep=False, intra_task=False):

    net.eval() # for not update batchnorm 
    total_num, train_bar = 0, tqdm(data_loader)
    linear_loss = 0.0
    total_correct_1, total_correct_5 = 0.0, 0.0
    for data_tuple in train_bar:
        # Forward prop of the model with single augmented batch
        pos_1, target = data_tuple
        pos_1 = pos_1.to(device)
        feature_1 = net(pos_1)
        # feature_1 = net.get_representation(pos_1) 

        # Batchsize
        batchsize_bc = feature_1.shape[0]
        features = feature_1
        targets = target.to(device)
        if task_sep:
            targets = targets // 25
        if intra_task:
            targets = targets % 25
        logits = classifier(features.detach()) 
        # Cross Entropy Loss 
        linear_loss_1 = F.cross_entropy(logits, targets)

        # Number of correct predictions
        linear_correct_1, linear_correct_5 = correct_top_k(logits, targets, top_k=(1, 5))
    

        # Backpropagation part
        train_optimizer.zero_grad()
        linear_loss_1.backward()
        train_optimizer.step()

        # Accumulating number of examples, losses and correct predictions
        total_num += batchsize_bc
        linear_loss += linear_loss_1.item() * batchsize_bc
        total_correct_1 += linear_correct_1 
        total_correct_5 += linear_correct_5


        acc_1 = total_correct_1/total_num*100
        # # This bar is used for live tracking on command line (batch_size -> batchsize_bc: to show current batchsize )
        train_bar.set_description('Lin.Train Epoch: [{}] Loss: {:.4f} ACC: {:.2f}'.format(\
                epoch, linear_loss / total_num, acc_1))
    scheduler.step()
    acc_1 = total_correct_1/total_num*100
    acc_5 = total_correct_5/total_num*100       
    wandb.log({" Linear Layer Train Loss ": linear_loss / total_num, " Epoch ": epoch})
    wandb.log({" Linear Layer Train - Acc": acc_1, " Epoch ": epoch})
    # print(f'Linear Layer Train - Acc: {acc_1}')
        
    return linear_loss/total_num, acc_1, acc_5


def linear_evaluation(net, data_loader,test_data_loader,train_optimizer,classifier, scheduler, epochs, device, task_sep=False, intra_task=False):
    for epoch in range(1, epochs+1):
        linear_loss, linear_acc1, linear_acc5 = linear_train(net,data_loader,train_optimizer,classifier,scheduler, epoch, device, task_sep, intra_task)
        with torch.no_grad():
            # Testing for linear evaluation
            test_loss, test_acc1, test_acc5 = linear_test(net, test_data_loader, classifier, epoch, device, task_sep, intra_task)

    return test_loss, test_acc1, test_acc5, classifier

In [27]:
if 'infomax' in args.appr or 'barlow' in args.appr:
    proj_hidden = args.proj_hidden
    proj_out = args.proj_out
    encoder = Encoder(hidden_dim=proj_hidden, output_dim=proj_out, normalization = args.normalization, weight_standard = args.weight_standard, appr_name = args.appr)
    model = Siamese(encoder).to(device)

In [28]:
#load model here
file_name = 'checkpoints/checkpoint_cifar100-algocassle_contrastive_v3_barlow-e[500, 500, 500, 500, 500]-b256-lr0.1-CS[20, 20, 20, 20, 20]acc_59.199999999999996.pth.tar'
dict = torch.load(file_name)

In [29]:
model.temporal_projector = nn.Sequential(
            nn.Linear(args.proj_out, args.proj_hidden, bias=False),
            nn.BatchNorm1d(args.proj_hidden),
            nn.ReLU(),
            nn.Linear(args.proj_hidden, args.proj_out),
        ).to(device)

In [30]:
model.load_state_dict(dict['state_dict'])

RuntimeError: Error(s) in loading state_dict for Siamese:
	Unexpected key(s) in state_dict: "contrastive_projector.weight", "encoder.backbone.1.running_mean", "encoder.backbone.1.running_var", "encoder.backbone.1.num_batches_tracked", "encoder.backbone.3.0.bn1.running_mean", "encoder.backbone.3.0.bn1.running_var", "encoder.backbone.3.0.bn1.num_batches_tracked", "encoder.backbone.3.0.bn2.running_mean", "encoder.backbone.3.0.bn2.running_var", "encoder.backbone.3.0.bn2.num_batches_tracked", "encoder.backbone.3.1.bn1.running_mean", "encoder.backbone.3.1.bn1.running_var", "encoder.backbone.3.1.bn1.num_batches_tracked", "encoder.backbone.3.1.bn2.running_mean", "encoder.backbone.3.1.bn2.running_var", "encoder.backbone.3.1.bn2.num_batches_tracked", "encoder.backbone.4.0.bn1.running_mean", "encoder.backbone.4.0.bn1.running_var", "encoder.backbone.4.0.bn1.num_batches_tracked", "encoder.backbone.4.0.bn2.running_mean", "encoder.backbone.4.0.bn2.running_var", "encoder.backbone.4.0.bn2.num_batches_tracked", "encoder.backbone.4.0.shortcut.1.running_mean", "encoder.backbone.4.0.shortcut.1.running_var", "encoder.backbone.4.0.shortcut.1.num_batches_tracked", "encoder.backbone.4.1.bn1.running_mean", "encoder.backbone.4.1.bn1.running_var", "encoder.backbone.4.1.bn1.num_batches_tracked", "encoder.backbone.4.1.bn2.running_mean", "encoder.backbone.4.1.bn2.running_var", "encoder.backbone.4.1.bn2.num_batches_tracked", "encoder.backbone.5.0.bn1.running_mean", "encoder.backbone.5.0.bn1.running_var", "encoder.backbone.5.0.bn1.num_batches_tracked", "encoder.backbone.5.0.bn2.running_mean", "encoder.backbone.5.0.bn2.running_var", "encoder.backbone.5.0.bn2.num_batches_tracked", "encoder.backbone.5.0.shortcut.1.running_mean", "encoder.backbone.5.0.shortcut.1.running_var", "encoder.backbone.5.0.shortcut.1.num_batches_tracked", "encoder.backbone.5.1.bn1.running_mean", "encoder.backbone.5.1.bn1.running_var", "encoder.backbone.5.1.bn1.num_batches_tracked", "encoder.backbone.5.1.bn2.running_mean", "encoder.backbone.5.1.bn2.running_var", "encoder.backbone.5.1.bn2.num_batches_tracked", "encoder.backbone.6.0.bn1.running_mean", "encoder.backbone.6.0.bn1.running_var", "encoder.backbone.6.0.bn1.num_batches_tracked", "encoder.backbone.6.0.bn2.running_mean", "encoder.backbone.6.0.bn2.running_var", "encoder.backbone.6.0.bn2.num_batches_tracked", "encoder.backbone.6.0.shortcut.1.running_mean", "encoder.backbone.6.0.shortcut.1.running_var", "encoder.backbone.6.0.shortcut.1.num_batches_tracked", "encoder.backbone.6.1.bn1.running_mean", "encoder.backbone.6.1.bn1.running_var", "encoder.backbone.6.1.bn1.num_batches_tracked", "encoder.backbone.6.1.bn2.running_mean", "encoder.backbone.6.1.bn2.running_var", "encoder.backbone.6.1.bn2.num_batches_tracked". 

In [14]:
lin_epoch=100
classifier = LinearClassifier(num_classes = 100).to(device)
lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
test_loss, test_acc1, test_acc5, classifier = linear_evaluation(model, train_data_loaders_linear_all[0],
                                                            test_data_loaders_all[0],lin_optimizer, classifier, 
                                                            lin_scheduler, epochs=lin_epoch, device=device) 
print(f'Total Accuracy: {test_acc1}')

Lin.Train Epoch: [1] Loss: 145.5651 ACC: 8.12: 100%|██████████| 196/196 [00:21<00:00,  9.26it/s]
Lin.Test Epoch: [1] Loss: 57.5297 ACC@1: 20.71% ACC@5: 40.16% : 100%|██████████| 20/20 [00:03<00:00,  5.28it/s]
Lin.Train Epoch: [2] Loss: 19.8100 ACC: 29.96: 100%|██████████| 196/196 [00:16<00:00, 12.02it/s]
Lin.Test Epoch: [2] Loss: 7.2459 ACC@1: 38.58% ACC@5: 66.12% : 100%|██████████| 20/20 [00:03<00:00,  5.27it/s]
Lin.Train Epoch: [3] Loss: 8.0540 ACC: 34.67: 100%|██████████| 196/196 [00:16<00:00, 12.24it/s]
Lin.Test Epoch: [3] Loss: 11.1811 ACC@1: 40.97% ACC@5: 66.24% : 100%|██████████| 20/20 [00:04<00:00,  4.67it/s]
Lin.Train Epoch: [4] Loss: 11.1018 ACC: 34.99: 100%|██████████| 196/196 [00:16<00:00, 11.55it/s]
Lin.Test Epoch: [4] Loss: 8.8164 ACC@1: 38.18% ACC@5: 64.05% : 100%|██████████| 20/20 [00:04<00:00,  4.77it/s] 
Lin.Train Epoch: [5] Loss: 7.9682 ACC: 36.04: 100%|██████████| 196/196 [00:16<00:00, 12.14it/s]
Lin.Test Epoch: [5] Loss: 7.0531 ACC@1: 42.58% ACC@5: 68.56% : 100%|██

Total Accuracy: 61.9





In [21]:
for task, loader in enumerate(train_data_loaders_linear):
    # print()
    # print(f'Task {task}')
    lin_epoch=50
    classifier = LinearClassifier(num_classes = 25).to(device)
    lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
    lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
    test_loss, test_acc1, test_acc5, classifier = linear_evaluation(model, loader,test_data_loaders[task],lin_optimizer, classifier, 
                                                                lin_scheduler, epochs=lin_epoch, device=device, intra_task=True) 
    # print(f'Task {task}  accuracy: {test_acc1}')

Lin.Train Epoch: [1] Loss: 203.9225 ACC: 6.25: 100%|██████████| 49/49 [00:03<00:00, 14.50it/s]
Lin.Test Epoch: [1] Loss: 138.1759 ACC@1: 23.40% ACC@5: 33.08% : 100%|██████████| 5/5 [00:01<00:00,  2.88it/s]
Lin.Train Epoch: [2] Loss: 112.1662 ACC: 20.09: 100%|██████████| 49/49 [00:03<00:00, 14.50it/s]
Lin.Test Epoch: [2] Loss: 31.6262 ACC@1: 38.08% ACC@5: 60.84% : 100%|██████████| 5/5 [00:01<00:00,  3.02it/s]
Lin.Train Epoch: [3] Loss: 15.8342 ACC: 48.94: 100%|██████████| 49/49 [00:03<00:00, 14.41it/s]
Lin.Test Epoch: [3] Loss: 5.4577 ACC@1: 61.48% ACC@5: 90.28% : 100%|██████████| 5/5 [00:01<00:00,  2.93it/s]
Lin.Train Epoch: [4] Loss: 5.7665 ACC: 58.30: 100%|██████████| 49/49 [00:03<00:00, 14.63it/s]
Lin.Test Epoch: [4] Loss: 3.6409 ACC@1: 66.00% ACC@5: 92.80% : 100%|██████████| 5/5 [00:01<00:00,  2.81it/s]
Lin.Train Epoch: [5] Loss: 5.3261 ACC: 57.66: 100%|██████████| 49/49 [00:03<00:00, 14.67it/s]
Lin.Test Epoch: [5] Loss: 3.0302 ACC@1: 67.12% ACC@5: 94.20% : 100%|██████████| 5/5 [00

In [29]:
lin_epoch=100
classifier = LinearClassifier(num_classes = 4).to(device)
lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
test_loss, test_acc1, test_acc5, classifier = linear_evaluation(model, train_data_loaders_linear_all[0],
                                                            test_data_loaders_all[0],lin_optimizer, classifier, 
                                                            lin_scheduler, epochs=lin_epoch, device=device, task_sep=True) 
print(f'Total Accuracy: {test_acc1}')

Lin.Train Epoch: [1] Loss: 81.7665 ACC: 28.78: 100%|██████████| 196/196 [00:09<00:00, 20.43it/s] 
Lin.Test Epoch: [1] Loss: 66.3012 ACC@1: 25.76% ACC@5: 100.00% : 100%|██████████| 20/20 [00:02<00:00,  6.97it/s]
Lin.Train Epoch: [2] Loss: 51.2712 ACC: 31.55: 100%|██████████| 196/196 [00:09<00:00, 20.46it/s]
Lin.Test Epoch: [2] Loss: 21.0970 ACC@1: 37.29% ACC@5: 100.00% : 100%|██████████| 20/20 [00:02<00:00,  6.86it/s]
Lin.Train Epoch: [3] Loss: 35.4477 ACC: 33.33: 100%|██████████| 196/196 [00:09<00:00, 20.19it/s]
Lin.Test Epoch: [3] Loss: 19.8299 ACC@1: 36.14% ACC@5: 100.00% : 100%|██████████| 20/20 [00:02<00:00,  7.02it/s]
Lin.Train Epoch: [4] Loss: 54.6476 ACC: 31.61: 100%|██████████| 196/196 [00:09<00:00, 20.55it/s]
Lin.Test Epoch: [4] Loss: 34.2510 ACC@1: 30.44% ACC@5: 100.00% : 100%|██████████| 20/20 [00:02<00:00,  7.05it/s]
Lin.Train Epoch: [5] Loss: 37.6728 ACC: 34.62: 100%|██████████| 196/196 [00:09<00:00, 20.53it/s]
Lin.Test Epoch: [5] Loss: 18.9933 ACC@1: 42.06% ACC@5: 100.00%

Total Accuracy: 47.9





In [22]:
#load model here
file_name = 'checkpoints/checkpoint_cifar100-algobasic_barlow-e[1000]-b256-lr0.3-CS[100]acc_69.38.pth.tar'
dict = torch.load(file_name)
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)
if 'infomax' in args.appr or 'barlow' in args.appr:
    proj_hidden = args.proj_hidden
    proj_out = args.proj_out
    encoder = Encoder(hidden_dim=proj_hidden, output_dim=proj_out, normalization = 'batch', weight_standard = args.weight_standard,appr_name =args.appr)
    old_model = Siamese(encoder)
    old_model.to(device) #automatically detects from model

old_model.load_state_dict(dict['state_dict'])

cuda:3


<All keys matched successfully>

In [25]:
lin_epoch=100
classifier = LinearClassifier(num_classes = 100).to(device)
lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
test_loss, test_acc1, test_acc5, classifier = linear_evaluation(old_model, train_data_loaders_linear_all[0],
                                                            test_data_loaders_all[0],lin_optimizer, classifier, 
                                                            lin_scheduler, epochs=lin_epoch, device=device) 
print(f'Total Accuracy: {test_acc1}')

Lin.Train Epoch: [1] Loss: 2.0468 ACC: 54.13: 100%|██████████| 196/196 [00:08<00:00, 22.26it/s]
Lin.Test Epoch: [1] Loss: 1.6562 ACC@1: 61.05% ACC@5: 84.72% : 100%|██████████| 20/20 [00:03<00:00,  6.31it/s]
Lin.Train Epoch: [2] Loss: 1.8614 ACC: 58.72: 100%|██████████| 196/196 [00:08<00:00, 21.93it/s]
Lin.Test Epoch: [2] Loss: 1.6454 ACC@1: 60.48% ACC@5: 85.81% : 100%|██████████| 20/20 [00:02<00:00,  7.21it/s]
Lin.Train Epoch: [3] Loss: 1.8168 ACC: 59.82: 100%|██████████| 196/196 [00:09<00:00, 21.29it/s]
Lin.Test Epoch: [3] Loss: 1.5600 ACC@1: 61.96% ACC@5: 87.07% : 100%|██████████| 20/20 [00:02<00:00,  7.25it/s]
Lin.Train Epoch: [4] Loss: 1.8265 ACC: 60.18: 100%|██████████| 196/196 [00:08<00:00, 22.09it/s]
Lin.Test Epoch: [4] Loss: 1.5855 ACC@1: 62.25% ACC@5: 86.93% : 100%|██████████| 20/20 [00:02<00:00,  6.91it/s]
Lin.Train Epoch: [5] Loss: 1.8088 ACC: 60.46: 100%|██████████| 196/196 [00:08<00:00, 22.19it/s]
Lin.Test Epoch: [5] Loss: 1.5537 ACC@1: 62.89% ACC@5: 87.23% : 100%|████████

Total Accuracy: 69.56





In [23]:
for task, loader in enumerate(train_data_loaders_linear):
    # print()
    # print(f'Task {task}')
    lin_epoch=50
    classifier = LinearClassifier(num_classes = 25).to(device)
    lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
    lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
    test_loss, test_acc1, test_acc5, classifier = linear_evaluation(old_model, loader,test_data_loaders[task],lin_optimizer, classifier, 
                                                                lin_scheduler, epochs=lin_epoch, device=device, intra_task=True) 
    # print(f'Task {task}  accuracy: {test_acc1}')

Lin.Train Epoch: [1] Loss: 1.2490 ACC: 69.21: 100%|██████████| 49/49 [00:03<00:00, 16.07it/s]
Lin.Test Epoch: [1] Loss: 0.7321 ACC@1: 79.36% ACC@5: 97.00% : 100%|██████████| 5/5 [00:01<00:00,  3.02it/s]
Lin.Train Epoch: [2] Loss: 1.0094 ACC: 75.76: 100%|██████████| 49/49 [00:03<00:00, 16.17it/s]
Lin.Test Epoch: [2] Loss: 0.6784 ACC@1: 81.40% ACC@5: 96.96% : 100%|██████████| 5/5 [00:01<00:00,  3.24it/s]
Lin.Train Epoch: [3] Loss: 0.9023 ACC: 76.94: 100%|██████████| 49/49 [00:03<00:00, 15.63it/s]
Lin.Test Epoch: [3] Loss: 0.7969 ACC@1: 78.84% ACC@5: 96.28% : 100%|██████████| 5/5 [00:01<00:00,  2.95it/s]
Lin.Train Epoch: [4] Loss: 0.9288 ACC: 76.52: 100%|██████████| 49/49 [00:03<00:00, 15.66it/s]
Lin.Test Epoch: [4] Loss: 0.6887 ACC@1: 80.16% ACC@5: 96.84% : 100%|██████████| 5/5 [00:01<00:00,  3.13it/s]
Lin.Train Epoch: [5] Loss: 0.9937 ACC: 76.18: 100%|██████████| 49/49 [00:02<00:00, 17.01it/s]
Lin.Test Epoch: [5] Loss: 0.7242 ACC@1: 80.64% ACC@5: 97.24% : 100%|██████████| 5/5 [00:01<00:

In [24]:
lin_epoch=100
classifier = LinearClassifier(num_classes = 4).to(device)
lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
test_loss, test_acc1, test_acc5, classifier = linear_evaluation(old_model, train_data_loaders_linear_all[0],
                                                            test_data_loaders_all[0],lin_optimizer, classifier, 
                                                            lin_scheduler, epochs=lin_epoch, device=device, task_sep=True) 
print(f'Total Accuracy: {test_acc1}')

Lin.Train Epoch: [1] Loss: 3.8722 ACC: 52.96: 100%|██████████| 196/196 [00:08<00:00, 22.60it/s]
Lin.Test Epoch: [1] Loss: 2.1194 ACC@1: 58.68% ACC@5: 100.00% : 100%|██████████| 20/20 [00:02<00:00,  7.27it/s]
Lin.Train Epoch: [2] Loss: 3.0685 ACC: 54.47: 100%|██████████| 196/196 [00:08<00:00, 22.48it/s]
Lin.Test Epoch: [2] Loss: 3.3024 ACC@1: 53.46% ACC@5: 100.00% : 100%|██████████| 20/20 [00:02<00:00,  7.50it/s]
Lin.Train Epoch: [3] Loss: 2.8170 ACC: 55.22: 100%|██████████| 196/196 [00:08<00:00, 22.63it/s]
Lin.Test Epoch: [3] Loss: 4.1562 ACC@1: 42.52% ACC@5: 100.00% : 100%|██████████| 20/20 [00:02<00:00,  7.61it/s]
Lin.Train Epoch: [4] Loss: 2.9544 ACC: 54.48: 100%|██████████| 196/196 [00:08<00:00, 22.54it/s]
Lin.Test Epoch: [4] Loss: 5.0768 ACC@1: 48.48% ACC@5: 100.00% : 100%|██████████| 20/20 [00:02<00:00,  7.61it/s]
Lin.Train Epoch: [5] Loss: 3.2824 ACC: 54.26: 100%|██████████| 196/196 [00:08<00:00, 22.29it/s]
Lin.Test Epoch: [5] Loss: 3.5134 ACC@1: 52.55% ACC@5: 100.00% : 100%|███

Total Accuracy: 67.56



