In [1]:
import os
import sys
import wandb
import argparse
import numpy as np

from tqdm import tqdm


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "")))
import torch
import torchvision.transforms as T
import torchvision

from dataloaders.dataloader_cifar10 import get_cifar10
from dataloaders.dataloader_cifar100 import get_cifar100
from utils.eval_metrics import linear_evaluation, get_t_SNE_plot
from models.linear_classifer import LinearClassifier
from models.ssl import  SimSiam, Siamese, Encoder, Predictor

from trainers.train_simsiam import train_simsiam
from trainers.train_infomax import train_infomax
from trainers.train_barlow import train_barlow

from trainers.train_PFR import train_PFR_simsiam
from trainers.train_PFR_contrastive import train_PFR_contrastive_simsiam
from trainers.train_contrastive import train_contrastive_simsiam
from trainers.train_ering import train_ering_simsiam

from torchsummary import summary
import random
from utils.lr_schedulers import LinearWarmupCosineAnnealingLR, SimSiamScheduler
from utils.eval_metrics import Knn_Validation_cont
from copy import deepcopy
from loss import invariance_loss,CovarianceLoss,ErrorCovarianceLoss
import torch.nn as nn
import time
import torch.nn.functional as F
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC

from models.linear_classifer import LinearClassifier
from torch.utils.data import DataLoader
from dataloaders.dataset import TensorDataset

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = torchvision.transforms.functional.gaussian_blur(x,kernel_size=[3,3],sigma=sigma)#kernel size and sigma are open problems but right now seems ok!
        return x


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [3]:
class Args():
    normalization = 'batch'
    weight_standard = False
    same_lr = False
    pretrain_batch_size = 512
    pretrain_warmup_epochs = 10
    pretrain_warmup_lr = 3e-3
    pretrain_base_lr = 0.03
    pretrain_momentum = 0.9
    pretrain_weight_decay = 5e-4
    min_lr = 0.00
    lambdap = 1.0
    appr = 'barlow_PFR'
    knn_report_freq = 10
    cuda_device = 5
    num_workers = 8
    contrastive_ratio = 0.001
    dataset = 'cifar100'
    class_split = [20,20,20,20,20]
    epochs = [500,500,500,500,500]
    cov_loss_weight = 1.0
    sim_loss_weight = 250.0
    info_loss = 'invariance'
    lambda_norm = 1.0
    subspace_rate = 0.99
    lambda_param = 5e-3
    bsize = 32
    msize = 150
    proj_hidden = 2048
    proj_out = 2048 #infomax 64
    pred_hidden = 512
    pred_out = 2048



In [4]:
args = Args()

In [5]:
if args.dataset == "cifar10":
    get_dataloaders = get_cifar10
    num_classes=10
elif args.dataset == "cifar100":
    get_dataloaders = get_cifar100
    num_classes=100
assert sum(args.class_split) == num_classes
assert len(args.class_split) == len(args.epochs)

In [6]:
num_worker = args.num_workers
#device
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)

cuda:5


In [7]:
#wandb init
wandb.init(project="CSSL",  entity="yavuz-team",
            mode="disabled",
            config=args,
            name= str(args.dataset) + '-algo' + str(args.appr) + "-e" + str(args.epochs) + "-b" 
            + str(args.pretrain_batch_size) + "-lr" + str(args.pretrain_base_lr)+"-CS"+str(args.class_split))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [8]:
if 'infomax' in args.appr or 'barlow' in args.appr:
    transform = T.Compose([
            T.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)]), p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur()], p=0.5), 
            T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])

    transform_prime = T.Compose([
            T.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)]), p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur()], p=0.5), 
            T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])

In [9]:
#Dataloaders
print("Creating Dataloaders..")
#Class Based
train_data_loaders, train_data_loaders_knn, test_data_loaders, _, train_data_loaders_linear, train_data_loaders_pure, train_data_loaders_generic  = get_dataloaders(transform, transform_prime, \
                                    classes=args.class_split, valid_rate = 0.00, batch_size=args.pretrain_batch_size, seed = 0, num_worker= num_worker)
_, train_data_loaders_knn_all, test_data_loaders_all, _, train_data_loaders_linear_all, train_data_loaders_pure_all, _ = get_dataloaders(transform, transform_prime, \
                                        classes=[num_classes], valid_rate = 0.00, batch_size=args.pretrain_batch_size, seed = 0, num_worker= num_worker)


Creating Dataloaders..
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [10]:
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)
if 'infomax' in args.appr or 'barlow' in args.appr:
    proj_hidden = args.proj_hidden
    proj_out = args.proj_out
    encoder = Encoder(hidden_dim=proj_hidden, output_dim=proj_out, normalization = args.normalization, weight_standard = args.weight_standard, appr_name = args.appr)
    model = Siamese(encoder)
    model.to(device) #automatically detects from model
#load model here
file_name = "checkpoints/checkpoint_cifar100-algocassle_contrastive_v3_barlow-e[500, 500, 500, 500, 500]-b256-lr0.1-CS[20, 20, 20, 20, 20]acc_59.199999999999996.pth.tar"
dict = torch.load(file_name)
model.temporal_projector = nn.Sequential(
            nn.Linear(args.proj_out, args.proj_hidden, bias=False),
            nn.BatchNorm1d(args.proj_hidden),
            nn.ReLU(),
            nn.Linear(args.proj_hidden, args.proj_out),
        ).to(device)
model.contrastive_projector =  nn.Linear(512, len(train_data_loaders_generic), bias=False).to(device)
model.load_state_dict(dict['state_dict'])

cuda:5


<All keys matched successfully>

In [11]:
def first_eigenvector(model, loader):
    model.eval()
    outs = []
    for x, _ in loader:
        x = x.to(device)
        out = model(x).cpu().detach().numpy()
        outs.append(out)

    outs = np.concatenate(outs)
    outs = outs.transpose()
    outs = torch.tensor(outs)

    U, S, V = torch.svd(outs)
    return U[0:1,:]
    

In [12]:
def correct_top_k(outputs, targets, top_k=(1,5)):
    with torch.no_grad():
        prediction = torch.argsort(outputs, dim=-1, descending=True)
        result= []
        for k in top_k:
            correct_k = torch.sum((prediction[:, 0:k] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item() 
            result.append(correct_k)
        return result
    
def contrastive_train_first_task(net, data_loader, task_id, optimizer, classifier, scheduler, epochs, device):
    for epoch in range(1, epochs+1):
        net.eval() # for not update batchnorm 
        total_num, train_bar = 0, tqdm(data_loader)
        linear_loss = 0.0
        for data_tuple in train_bar:
            # Forward prop of the model with single augmented batch
            pos_1, targets = data_tuple
            pos_1 = pos_1.to(device)
            features = net(pos_1)

            # Batchsize
            batchsize_bc = features.shape[0]
            targets = torch.ones(targets.shape[0],dtype=torch.long).to(device) * task_id 
            targets = targets.to(device)
            
            c_weights = torch.nn.functional.normalize(classifier.weight,dim=1)
            
            logits = features.detach() @ c_weights.T
            #classifier(features.detach()) 

            # Cross Entropy Loss 
            linear_loss = F.cross_entropy(logits, targets)

            # Backpropagation part
            optimizer.zero_grad()
            linear_loss.backward()
            optimizer.step()

            # Accumulating number of examples, losses and correct predictions
            total_num += batchsize_bc
            linear_loss += linear_loss.item() * batchsize_bc

            train_bar.set_description('Lin.Train Epoch: [{}] Loss: {:.4f}'.format(epoch, linear_loss / total_num))
        if scheduler is not None:
            scheduler.step()
        # wandb.log({" Linear Layer Train Loss ": linear_loss / total_num, " Epoch ": epoch})
    return linear_loss/total_num

In [13]:
def contrastive_train(net, data_loader, task_id, optimizer, scheduler, epochs, device):
    for epoch in range(1, epochs+1):
        net.eval() # for not update batchnorm 
        total_num, train_bar = 0, tqdm(data_loader)
        linear_loss = 0.0
        for data_tuple in train_bar:
            # Forward prop of the model with single augmented batch
            pos_1, targets = data_tuple
            pos_1 = pos_1.to(device)
            features = net(pos_1)
            
            #logits = net.contrastive_projector(features) 
            
            c_weights = torch.nn.functional.normalize(net.contrastive_projector.weight,dim=1)
                
            logits = features @ c_weights.T

            # Batchsize
            batchsize_bc = features.shape[0]
            targets = torch.ones(targets.shape[0],dtype=torch.long).to(device) * task_id 
            targets = targets.to(device)
            
            # Cross Entropy Loss 
            linear_loss = F.cross_entropy(logits, targets)

            # Backpropagation part
            optimizer.zero_grad()
            linear_loss.backward()
            net.contrastive_projector.weight.grad[0:task_id] = torch.zeros(net.contrastive_projector.weight.grad[0:task_id].shape).to(device)
            optimizer.step()

            # Accumulating number of examples, losses and correct predictions
            total_num += batchsize_bc
            linear_loss += linear_loss.item() * batchsize_bc

            train_bar.set_description('Lin.Train Epoch: [{}] Loss: {:.4f}'.format(epoch, linear_loss / total_num))
        if scheduler is not None:
            scheduler.step()
        # wandb.log({" Linear Layer Train Loss ": linear_loss / total_num, " Epoch ": epoch})
    return linear_loss/total_num

In [33]:
contrastive_projector =  nn.Linear(512, 2).to(device)
task_id = 0
lin_epoch= 10
lin_optimizer = torch.optim.SGD(contrastive_projector.parameters(), 1e-3, momentum=0.9, weight_decay=0) 
# lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) 
test_loss = contrastive_train_first_task(model, train_data_loaders_linear[task_id], task_id, lin_optimizer, 
                                                            contrastive_projector, None, epochs=lin_epoch, device=device) 

Lin.Train Epoch: [1] Loss: 0.0000: 100%|██████████| 20/20 [00:03<00:00,  5.31it/s]
Lin.Train Epoch: [2] Loss: 0.0000: 100%|██████████| 20/20 [00:03<00:00,  5.71it/s]
  0%|          | 0/20 [00:00<?, ?it/s]Exception ignored in: <function _releaseLock at 0x7f98779bba70>
Traceback (most recent call last):
  File "/home/duygu/anaconda3/envs/fedml_academic/lib/python3.7/logging/__init__.py", line 221, in _releaseLock
    def _releaseLock():
KeyboardInterrupt
Lin.Train Epoch: [3] Loss: 0.0000:  95%|█████████▌| 19/20 [00:03<00:00,  8.71it/s]Exception ignored in: <function Socket.__del__ at 0x7f9876df1ef0>
Traceback (most recent call last):
  File "/home/duygu/anaconda3/envs/fedml_academic/lib/python3.7/site-packages/zmq/sugar/socket.py", line 110, in __del__
    def __del__(self):
KeyboardInterrupt
Lin.Train Epoch: [3] Loss: 0.0000: 100%|██████████| 20/20 [00:03<00:00,  5.32it/s]
Lin.Train Epoch: [4] Loss: 0.0000: 100%|██████████| 20/20 [00:03<00:00,  5.48it/s]
Lin.Train Epoch: [5] Loss: 0.000

KeyboardInterrupt: 

In [15]:
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)
if 'infomax' in args.appr or 'barlow' in args.appr:
    proj_hidden = args.proj_hidden
    proj_out = args.proj_out
    encoder = Encoder(hidden_dim=proj_hidden, output_dim=proj_out, normalization = args.normalization, weight_standard = args.weight_standard, appr_name = args.appr)
    model2 = Siamese(encoder)
    model2.to(device) #automatically detects from model
#load model here
file_name = "./checkpoints/checkpoint_cifar100-algocassle_barlow-e[500, 500, 500, 500, 500]-b256-lr0.25-CS[20, 20, 20, 20, 20]_task_1_same_lr_True_norm_batch_ws_False.pth.tar"
dict = torch.load(file_name)
model2.temporal_projector = nn.Sequential(
            nn.Linear(args.proj_out, args.proj_hidden, bias=False),
            nn.BatchNorm1d(args.proj_hidden),
            nn.ReLU(),
            nn.Linear(args.proj_hidden, args.proj_out),
        ).to(device)
model2.load_state_dict(dict['state_dict'])


cuda:5


<All keys matched successfully>

In [14]:
def correct_top_k(outputs, targets, top_k=(1,5)):
    with torch.no_grad():
        prediction = torch.argsort(outputs, dim=-1, descending=True)
        result= []
        for k in top_k:
            correct_k = torch.sum((prediction[:, 0:k] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item() 
            result.append(correct_k)
        return result

def linear_test(net, data_loader, classifier, epoch, device, task_num):
    # evaluate model:
    net.eval() # for not update batchnorm
    linear_loss = 0.0
    num = 0
    total_loss, total_correct_1, total_num, test_bar = 0.0, 0.0, 0, tqdm(data_loader)
    with torch.no_grad():
        for data_tuple in test_bar:
            data, target = [t.to(device) for t in data_tuple]
            output = net(data)
            if classifier is not None:  #else net is already a classifier
                output = classifier(output) 
            linear_loss = F.cross_entropy(output, target)
            
            # Batchsize for loss and accuracy
            num = data.size(0)
            total_num += num 
            total_loss += linear_loss.item() * num 
            # Accumulating number of correct predictions 
            correct_top_1 = correct_top_k(output, target, top_k=[1])    
            total_correct_1 += correct_top_1[0]
            test_bar.set_description('Lin.Test Epoch: [{}] Loss: {:.4f} ACC: {:.2f}% '
                                     .format(epoch,  total_loss / total_num,
                                             total_correct_1 / total_num * 100
                                             ))
        acc_1 = total_correct_1/total_num*100
        wandb.log({f" {task_num} Linear Layer Test Loss ": linear_loss / total_num, "Linear Epoch ": epoch})
        wandb.log({f" {task_num} Linear Layer Test - Acc": acc_1, "Linear Epoch ": epoch})
    return total_loss / total_num, acc_1  

def linear_train(net, data_loader, train_optimizer, classifier, scheduler, epoch, device, task_num):

    net.eval() # for not update batchnorm 
    total_num, train_bar = 0, tqdm(data_loader)
    linear_loss = 0.0
    total_correct_1 = 0.0
    for data_tuple in train_bar:
        # Forward prop of the model with single augmented batch
        pos_1, target = data_tuple
        pos_1 = pos_1.to(device)
        feature_1 = net(pos_1)
        # Batchsize
        batchsize_bc = feature_1.shape[0]
        features = feature_1
        targets = target.to(device)
        logits = classifier(features.detach()) 
        # Cross Entropy Loss 
        linear_loss_1 = F.cross_entropy(logits, targets)

        # Number of correct predictions
        linear_correct_1 = correct_top_k(logits, targets, top_k=[1])
    
        # Backpropagation part
        train_optimizer.zero_grad()
        linear_loss_1.backward()
        train_optimizer.step()

        # Accumulating number of examples, losses and correct predictions
        total_num += batchsize_bc
        linear_loss += linear_loss_1.item() * batchsize_bc
        total_correct_1 += linear_correct_1[0] 

        acc_1 = total_correct_1/total_num*100
        # # This bar is used for live tracking on command line (batch_size -> batchsize_bc: to show current batchsize )
        train_bar.set_description('Lin.Train Epoch: [{}] Loss: {:.4f} ACC: {:.2f}'.format(\
                epoch, linear_loss / total_num, acc_1))
    scheduler.step()
    acc_1 = total_correct_1/total_num*100   
    wandb.log({f" {task_num} Linear Layer Train Loss ": linear_loss / total_num, "Linear Epoch ": epoch})
    wandb.log({f" {task_num} Linear Layer Train - Acc": acc_1, "Linear Epoch ": epoch})
        
    return linear_loss/total_num, acc_1


def linear_evaluation(net, data_loaders,test_data_loaders,train_optimizer,classifier, scheduler, epochs, device, task_num):
    train_X = torch.Tensor([])
    train_Y = torch.tensor([],dtype=int)
    for loader in data_loaders:
        train_X = torch.cat((train_X, loader.dataset.train_data), dim=0)
        train_Y = torch.cat((train_Y, loader.dataset.label_data), dim=0)
    data_loader = DataLoader(TensorDataset(train_X, train_Y,transform=data_loaders[0].dataset.transform), batch_size=256, shuffle=True, num_workers = 5, pin_memory=True)

    test_X = torch.Tensor([])
    test_Y = torch.tensor([],dtype=int)
    for loader in test_data_loaders:
        test_X = torch.cat((test_X, loader.dataset.train_data), dim=0)
        test_Y = torch.cat((test_Y, loader.dataset.label_data), dim=0)
    test_data_loader = DataLoader(TensorDataset(test_X, test_Y,transform=test_data_loaders[0].dataset.transform), batch_size=256, shuffle=True, num_workers = 5, pin_memory=True)

    for epoch in range(1, epochs+1):
        linear_train(net,data_loader,train_optimizer,classifier,scheduler, epoch, device, task_num)
        with torch.no_grad():
            # Testing for linear evaluation
            test_loss, test_acc1 = linear_test(net, test_data_loader, classifier, epoch, device, task_num)

    return test_loss, test_acc1, classifier

In [18]:
task_id = 1
lin_epoch = 100
num_class = np.sum(args.class_split[:task_id+1])
classifier = LinearClassifier(num_classes = num_class).to(device)
lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
linear_evaluation(model2, train_data_loaders_linear[:task_id+1], test_data_loaders[:task_id+1], lin_optimizer,classifier, lin_scheduler, lin_epoch, device, task_id)  


Lin.Train Epoch: [1] Loss: 3.3257 ACC: 43.93: 100%|██████████| 79/79 [00:03<00:00, 21.00it/s]
Lin.Test Epoch: [1] Loss: 3.4347 ACC: 47.48% : 100%|██████████| 16/16 [00:01<00:00, 12.17it/s]
Lin.Train Epoch: [2] Loss: 3.3782 ACC: 48.93: 100%|██████████| 79/79 [00:03<00:00, 20.80it/s]
Lin.Test Epoch: [2] Loss: 2.6580 ACC: 53.70% : 100%|██████████| 16/16 [00:01<00:00, 11.61it/s]
Lin.Train Epoch: [3] Loss: 3.0190 ACC: 51.44: 100%|██████████| 79/79 [00:03<00:00, 20.93it/s]
Lin.Test Epoch: [3] Loss: 2.6857 ACC: 55.33% : 100%|██████████| 16/16 [00:01<00:00, 11.48it/s]
Lin.Train Epoch: [4] Loss: 2.9855 ACC: 52.05: 100%|██████████| 79/79 [00:03<00:00, 19.92it/s]
Lin.Test Epoch: [4] Loss: 2.8573 ACC: 54.02% : 100%|██████████| 16/16 [00:01<00:00, 11.68it/s]
Lin.Train Epoch: [5] Loss: 3.0599 ACC: 52.29: 100%|██████████| 79/79 [00:04<00:00, 19.74it/s]
Lin.Test Epoch: [5] Loss: 3.1281 ACC: 52.88% : 100%|██████████| 16/16 [00:01<00:00, 10.30it/s]
Lin.Train Epoch: [6] Loss: 3.1692 ACC: 52.26: 100%|████

(1.1195441608428955,
 67.25,
 LinearClassifier(
   (classifier): Linear(in_features=512, out_features=40, bias=True)
 ))

In [17]:
model2.contrastive_projector = contrastive_projector 
task_id = 1
lin_epoch= 10
lin_optimizer = torch.optim.SGD(model2.parameters(), 1e-3, momentum=0.9, weight_decay=0) 
test_loss = contrastive_train(model2, train_data_loaders_linear[task_id], task_id, lin_optimizer, None, epochs=lin_epoch, device=device) 

Lin.Train Epoch: [1] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  3.00it/s]
Lin.Train Epoch: [2] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  3.11it/s]
Lin.Train Epoch: [3] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  3.09it/s]
Lin.Train Epoch: [4] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  3.07it/s]
Lin.Train Epoch: [5] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  3.09it/s]
Lin.Train Epoch: [6] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  3.16it/s]
Lin.Train Epoch: [7] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  3.19it/s]
Lin.Train Epoch: [8] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  3.07it/s]
Lin.Train Epoch: [9] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  2.94it/s]
Lin.Train Epoch: [10] Loss: 0.0000: 100%|██████████| 20/20 [00:06<00:00,  3.05it/s]


In [18]:
task_id = 1
lin_epoch = 100
num_class = np.sum(args.class_split[:task_id+1])
classifier = LinearClassifier(num_classes = num_class).to(device)
lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
linear_evaluation(model2, train_data_loaders_linear[:task_id+1], test_data_loaders[:task_id+1], lin_optimizer,classifier, lin_scheduler, lin_epoch, device, task_id)  


Lin.Train Epoch: [1] Loss: 3.0140 ACC: 44.57: 100%|██████████| 79/79 [00:04<00:00, 19.69it/s]
Lin.Test Epoch: [1] Loss: 2.5046 ACC: 51.70% : 100%|██████████| 16/16 [00:01<00:00, 11.06it/s]
Lin.Train Epoch: [2] Loss: 3.2415 ACC: 48.98: 100%|██████████| 79/79 [00:04<00:00, 19.28it/s]
Lin.Test Epoch: [2] Loss: 2.2925 ACC: 56.50% : 100%|██████████| 16/16 [00:01<00:00, 11.58it/s]
Lin.Train Epoch: [3] Loss: 2.8112 ACC: 51.92: 100%|██████████| 79/79 [00:03<00:00, 20.35it/s]
Lin.Test Epoch: [3] Loss: 2.6728 ACC: 54.97% : 100%|██████████| 16/16 [00:01<00:00, 11.19it/s]
Lin.Train Epoch: [4] Loss: 2.9044 ACC: 51.85: 100%|██████████| 79/79 [00:03<00:00, 20.46it/s]
Lin.Test Epoch: [4] Loss: 3.2235 ACC: 51.52% : 100%|██████████| 16/16 [00:01<00:00, 11.25it/s]
Lin.Train Epoch: [5] Loss: 3.1853 ACC: 52.03: 100%|██████████| 79/79 [00:03<00:00, 20.14it/s]
Lin.Test Epoch: [5] Loss: 2.8158 ACC: 55.07% : 100%|██████████| 16/16 [00:01<00:00, 10.21it/s]
Lin.Train Epoch: [6] Loss: 2.8871 ACC: 53.77: 100%|████

(1.1258106384277344,
 67.4,
 LinearClassifier(
   (classifier): Linear(in_features=512, out_features=40, bias=True)
 ))

In [None]:
def correct_top_k(outputs, targets, top_k=(1,5)):
    with torch.no_grad():
        prediction = torch.argsort(outputs, dim=-1, descending=True)
        result= []
        for k in top_k:
            correct_k = torch.sum((prediction[:, 0:k] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item() 
            result.append(correct_k)
        return result
    
def contrastive_train2(net, data_loader, task_id, optimizer, classifier, scheduler, epochs, device):
    for epoch in range(1, epochs+1):
        net.eval() # for not update batchnorm 
        total_num, train_bar = 0, tqdm(data_loader)
        linear_loss = 0.0
        for data_tuple in train_bar:
            # Forward prop of the model with single augmented batch
            pos_1, targets = data_tuple
            pos_1 = torch.cat((pos_1, -pos_1), dim=0)
            # print(pos_1.shape)
            pos_1 = pos_1.to(device)
            features = net(pos_1)

            # Batchsize
            batchsize_bc = features.shape[0]
            targets = torch.zeros(targets.shape[0],dtype=torch.long).to(device)
            targets = torch.cat((targets, torch.ones(targets.shape[0],dtype=torch.long).to(device)), dim=0)
            # print(targets.shape)
            targets = targets.to(device)
            
            logits = classifier(features.detach()) 

            # Cross Entropy Loss 
            linear_loss = F.cross_entropy(logits, targets)

            # Backpropagation part
            optimizer.zero_grad()
            linear_loss.backward()
            optimizer.step()

            # Accumulating number of examples, losses and correct predictions
            total_num += batchsize_bc
            linear_loss += linear_loss.item() * batchsize_bc

            train_bar.set_description('Lin.Train Epoch: [{}] Loss: {:.4f}'.format(epoch, linear_loss / total_num))
        if scheduler is not None:
            scheduler.step()
        # wandb.log({" Linear Layer Train Loss ": linear_loss / total_num, " Epoch ": epoch})
    return linear_loss/total_num