In [1]:
import os
import sys
import wandb
import argparse
import numpy as np

from tqdm import tqdm


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "")))
import torch
import torchvision.transforms as T
import torchvision

from dataloaders.dataloader_cifar10 import get_cifar10
from dataloaders.dataloader_cifar100 import get_cifar100
from utils.eval_metrics import linear_evaluation, get_t_SNE_plot
from models.linear_classifer import LinearClassifier
from models.ssl import  SimSiam, Siamese, Encoder, Predictor

from trainers.train_simsiam import train_simsiam
from trainers.train_infomax import train_infomax
from trainers.train_barlow import train_barlow

from trainers.train_PFR import train_PFR_simsiam
from trainers.train_PFR_contrastive import train_PFR_contrastive_simsiam
from trainers.train_contrastive import train_contrastive_simsiam
from trainers.train_ering import train_ering_simsiam

from torchsummary import summary
import random
from utils.lr_schedulers import LinearWarmupCosineAnnealingLR, SimSiamScheduler
from utils.eval_metrics import Knn_Validation_cont
from copy import deepcopy
from loss import invariance_loss,CovarianceLoss,ErrorCovarianceLoss
import torch.nn as nn
import time
import torch.nn.functional as F
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC

from models.linear_classifer import LinearClassifier
from torch.utils.data import DataLoader
from dataloaders.dataset import TensorDataset


from torchvision import transforms

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = torchvision.transforms.functional.gaussian_blur(x,kernel_size=[3,3],sigma=sigma)#kernel size and sigma are open problems but right now seems ok!
        return x


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

In [3]:
class Args():
    normalization = 'batch'
    weight_standard = False
    same_lr = False
    pretrain_batch_size = 512
    pretrain_warmup_epochs = 10
    pretrain_warmup_lr = 3e-3
    pretrain_base_lr = 0.03
    pretrain_momentum = 0.9
    pretrain_weight_decay = 5e-4
    min_lr = 0.00
    lambdap = 1.0
    appr = 'barlow_PFR'
    knn_report_freq = 10
    cuda_device = 7
    num_workers = 8
    contrastive_ratio = 0.001
    dataset = 'cifar100'
    class_split = [20,20,20,20,20]
    epochs = [500,500,500,500,500]
    cov_loss_weight = 1.0
    sim_loss_weight = 250.0
    info_loss = 'invariance'
    lambda_norm = 1.0
    subspace_rate = 0.99
    lambda_param = 5e-3
    bsize = 32
    msize = 150
    proj_hidden = 2048
    proj_out = 2048 #infomax 64
    pred_hidden = 512
    pred_out = 2048



In [4]:
args = Args()

In [5]:
if args.dataset == "cifar10":
    get_dataloaders = get_cifar10
    num_classes=10
elif args.dataset == "cifar100":
    get_dataloaders = get_cifar100
    num_classes=100
assert sum(args.class_split) == num_classes
assert len(args.class_split) == len(args.epochs)

In [6]:
num_worker = args.num_workers
#device
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)

cuda:7


In [7]:
#wandb init
wandb.init(project="CSSL",  entity="yavuz-team",
            mode="disabled",
            config=args,
            name= str(args.dataset) + '-algo' + str(args.appr) + "-e" + str(args.epochs) + "-b" 
            + str(args.pretrain_batch_size) + "-lr" + str(args.pretrain_base_lr)+"-CS"+str(args.class_split))

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




In [8]:
if 'infomax' in args.appr or 'barlow' in args.appr:
    transform = T.Compose([
            T.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)]), p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur()], p=0.5), 
            T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])

    transform_prime = T.Compose([
            T.RandomResizedCrop(size=32, scale=(0.2, 1.0)),
            T.RandomHorizontalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)]), p=0.8),
            T.RandomGrayscale(p=0.2),
            T.RandomApply([GaussianBlur()], p=0.5), 
            T.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])

In [9]:
#Dataloaders
print("Creating Dataloaders..")
#Class Based
train_data_loaders, train_data_loaders_knn, test_data_loaders, _, train_data_loaders_linear, train_data_loaders_pure, train_data_loaders_generic  = get_dataloaders(transform, transform_prime, \
                                    classes=args.class_split, valid_rate = 0.00, batch_size=args.pretrain_batch_size, seed = 0, num_worker= num_worker)
_, train_data_loaders_knn_all, test_data_loaders_all, _, train_data_loaders_linear_all, train_data_loaders_pure_all, _ = get_dataloaders(transform, transform_prime, \
                                        classes=[num_classes], valid_rate = 0.00, batch_size=args.pretrain_batch_size, seed = 0, num_worker= num_worker)


Creating Dataloaders..
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [10]:
def correct_top_k(outputs, targets, top_k=(1,5)):
    with torch.no_grad():
        prediction = torch.argsort(outputs, dim=-1, descending=True)
        result= []
        for k in top_k:
            correct_k = torch.sum((prediction[:, 0:k] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item() 
            result.append(correct_k)
        return result

def linear_test(net, data_loader, classifier, epoch, device, task_num):
    # evaluate model:
    net.eval() # for not update batchnorm
    linear_loss = 0.0
    num = 0
    total_loss, total_correct_1, total_num, test_bar = 0.0, 0.0, 0, tqdm(data_loader)
    with torch.no_grad():
        for data_tuple in test_bar:
            data, target = [t.to(device) for t in data_tuple]
            output = net(data)
            if classifier is not None:  #else net is already a classifier
                output = classifier(output) 
            linear_loss = F.cross_entropy(output, target)
            
            # Batchsize for loss and accuracy
            num = data.size(0)
            total_num += num 
            total_loss += linear_loss.item() * num 
            # Accumulating number of correct predictions 
            correct_top_1 = correct_top_k(output, target, top_k=[1])    
            total_correct_1 += correct_top_1[0]
            test_bar.set_description('Lin.Test Epoch: [{}] Loss: {:.4f} ACC: {:.2f}% '
                                     .format(epoch,  total_loss / total_num,
                                             total_correct_1 / total_num * 100
                                             ))
        acc_1 = total_correct_1/total_num*100
        wandb.log({f" {task_num} Linear Layer Test Loss ": linear_loss / total_num, "Linear Epoch ": epoch})
        wandb.log({f" {task_num} Linear Layer Test - Acc": acc_1, "Linear Epoch ": epoch})
    return total_loss / total_num, acc_1  

def linear_train(net, data_loader, train_optimizer, classifier, scheduler, epoch, device, task_num):

    net.eval() # for not update batchnorm 
    total_num, train_bar = 0, tqdm(data_loader)
    linear_loss = 0.0
    total_correct_1 = 0.0
    for data_tuple in train_bar:
        # Forward prop of the model with single augmented batch
        pos_1, target = data_tuple
        pos_1 = pos_1.to(device)
        feature_1 = net(pos_1)
        # Batchsize
        batchsize_bc = feature_1.shape[0]
        features = feature_1
        targets = target.to(device)
        logits = classifier(features.detach()) 
        # Cross Entropy Loss 
        linear_loss_1 = F.cross_entropy(logits, targets)

        # Number of correct predictions
        linear_correct_1 = correct_top_k(logits, targets, top_k=[1])
    
        # Backpropagation part
        train_optimizer.zero_grad()
        linear_loss_1.backward()
        train_optimizer.step()

        # Accumulating number of examples, losses and correct predictions
        total_num += batchsize_bc
        linear_loss += linear_loss_1.item() * batchsize_bc
        total_correct_1 += linear_correct_1[0] 

        acc_1 = total_correct_1/total_num*100
        # # This bar is used for live tracking on command line (batch_size -> batchsize_bc: to show current batchsize )
        train_bar.set_description('Lin.Train Epoch: [{}] Loss: {:.4f} ACC: {:.2f}'.format(\
                epoch, linear_loss / total_num, acc_1))
    scheduler.step()
    acc_1 = total_correct_1/total_num*100   
    wandb.log({f" {task_num} Linear Layer Train Loss ": linear_loss / total_num, "Linear Epoch ": epoch})
    wandb.log({f" {task_num} Linear Layer Train - Acc": acc_1, "Linear Epoch ": epoch})
        
    return linear_loss/total_num, acc_1


def linear_evaluation(net, data_loaders,test_data_loaders,train_optimizer,classifier, scheduler, epochs, device, task_num):
    train_X = torch.Tensor([])
    train_Y = torch.tensor([],dtype=int)
    for loader in data_loaders:
        train_X = torch.cat((train_X, loader.dataset.train_data), dim=0)
        train_Y = torch.cat((train_Y, loader.dataset.label_data), dim=0)
    data_loader = DataLoader(TensorDataset(train_X, train_Y,transform=data_loaders[0].dataset.transform), batch_size=256, shuffle=True, num_workers = 5, pin_memory=True)

    test_X = torch.Tensor([])
    test_Y = torch.tensor([],dtype=int)
    for loader in test_data_loaders:
        test_X = torch.cat((test_X, loader.dataset.train_data), dim=0)
        test_Y = torch.cat((test_Y, loader.dataset.label_data), dim=0)
    test_data_loader = DataLoader(TensorDataset(test_X, test_Y,transform=test_data_loaders[0].dataset.transform), batch_size=256, shuffle=True, num_workers = 5, pin_memory=True)

    for epoch in range(1, epochs+1):
        linear_train(net,data_loader,train_optimizer,classifier,scheduler, epoch, device, task_num)
        with torch.no_grad():
            # Testing for linear evaluation
            test_loss, test_acc1 = linear_test(net, test_data_loader, classifier, epoch, device, task_num)

    return test_loss, test_acc1, classifier

In [11]:
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)
if 'infomax' in args.appr or 'barlow' in args.appr:
    proj_hidden = args.proj_hidden
    proj_out = args.proj_out
    encoder = Encoder(hidden_dim=proj_hidden, output_dim=proj_out, normalization = args.normalization, weight_standard = args.weight_standard, appr_name = args.appr)
    model2 = Siamese(encoder)
    model2.to(device) #automatically detects from model
#load model here
file_name = "./checkpoints/checkpoint_cifar100-algocassle_barlow-e[500, 500, 500, 500, 500]-b256-lr0.25-CS[20, 20, 20, 20, 20]_task_1_same_lr_True_norm_batch_ws_False.pth.tar"
dict = torch.load(file_name)
model2.temporal_projector = nn.Sequential(
            nn.Linear(args.proj_out, args.proj_hidden, bias=False),
            nn.BatchNorm1d(args.proj_hidden),
            nn.ReLU(),
            nn.Linear(args.proj_hidden, args.proj_out),
        ).to(device)
model2.load_state_dict(dict['state_dict'])

cuda:5


<All keys matched successfully>

In [12]:
task_id = 1
lin_epoch = 100
num_class = np.sum(args.class_split[:task_id+1])
classifier = LinearClassifier(num_classes = num_class).to(device)
lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
linear_evaluation(model2, train_data_loaders_linear[:task_id+1], test_data_loaders[:task_id+1], lin_optimizer,classifier, lin_scheduler, lin_epoch, device, task_id)  


Lin.Train Epoch: [1] Loss: 3.6393 ACC: 42.91: 100%|██████████| 79/79 [00:04<00:00, 16.35it/s]
Lin.Test Epoch: [1] Loss: 3.0033 ACC: 51.20% : 100%|██████████| 16/16 [00:01<00:00, 11.42it/s]
Lin.Train Epoch: [2] Loss: 3.4512 ACC: 49.25: 100%|██████████| 79/79 [00:03<00:00, 22.13it/s]
Lin.Test Epoch: [2] Loss: 3.2797 ACC: 54.17% : 100%|██████████| 16/16 [00:01<00:00, 12.11it/s]
Lin.Train Epoch: [3] Loss: 3.3490 ACC: 50.17: 100%|██████████| 79/79 [00:03<00:00, 22.40it/s]
Lin.Test Epoch: [3] Loss: 3.3353 ACC: 50.58% : 100%|██████████| 16/16 [00:01<00:00, 11.91it/s]
Lin.Train Epoch: [4] Loss: 3.2450 ACC: 51.43: 100%|██████████| 79/79 [00:03<00:00, 22.26it/s]
Lin.Test Epoch: [4] Loss: 2.8813 ACC: 53.23% : 100%|██████████| 16/16 [00:01<00:00, 11.77it/s]
Lin.Train Epoch: [5] Loss: 3.0826 ACC: 51.90: 100%|██████████| 79/79 [00:03<00:00, 21.74it/s]
Lin.Test Epoch: [5] Loss: 2.7467 ACC: 54.12% : 100%|██████████| 16/16 [00:01<00:00, 11.94it/s]
Lin.Train Epoch: [6] Loss: 3.0418 ACC: 52.61: 100%|████

(1.1261112451553346,
 66.975,
 LinearClassifier(
   (classifier): Linear(in_features=512, out_features=40, bias=True)
 ))

In [11]:
def correct_top_k(outputs, targets, top_k=(1,5)):
    with torch.no_grad():
        prediction = torch.argsort(outputs, dim=-1, descending=True)
        result= []
        for k in top_k:
            correct_k = torch.sum((prediction[:, 0:k] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item() 
            result.append(correct_k)
        return result
    

def store_samples(loader, num=150):
    x_data = loader.dataset.train_data
    select = np.random.randint(0,x_data.shape[0],num)
    return torch.Tensor(x_data[select])
    

In [12]:
task_id = 0
old_samples = store_samples(train_data_loaders_linear[task_id],num=150)
old_labels = torch.ones(old_samples.shape[0],dtype=torch.long) * task_id 

data_normalize_mean = (0.5071, 0.4865, 0.4409)
data_normalize_std = (0.2673, 0.2564, 0.2762)
random_crop_size = 32
transform_linear = transforms.Compose( [
          transforms.RandomResizedCrop(random_crop_size,  interpolation=transforms.InterpolationMode.BICUBIC), # scale=(0.2, 1.0) is possible
          transforms.RandomHorizontalFlip(),
          transforms.Normalize(data_normalize_mean, data_normalize_std),
      ] )

old_data_loader = DataLoader(TensorDataset(old_samples,old_labels,transform=transform_linear), batch_size=16, shuffle=True, 
                         num_workers = 5, pin_memory=True)


In [50]:
def contrastive_train(net, data_loader, old_data_loader, new_batch_size, task_id, optimizer, scheduler, epochs, device):
    for epoch in range(1, epochs+1):
        net.eval() # for not update batchnorm 
        total_num, train_bar = 0, tqdm(data_loader)
        linear_loss = 0.0

        current_data_loader = DataLoader(data_loader.dataset, batch_size=new_batch_size, shuffle=True, num_workers = 5, pin_memory=True)
        dataloader_iterator = iter(old_data_loader)
        for x1, _ in current_data_loader:
            try:
                x2, y2 = next(dataloader_iterator)
            except StopIteration:
                dataloader_iterator = iter(old_data_loader)
                x2, y2 = next(dataloader_iterator)

            y1 = torch.ones(x1.shape[0],dtype=torch.long) * task_id
            x_all = torch.cat((x1, x2), dim=0)
            y_all = torch.cat((y1, y2), dim=0)
            x_all = x_all.to(device)
            y_all = y_all.to(device)

            # Forward prop of the model with single augmented batch
            features = net(x_all)
            logits = net.contrastive_projector(features) 
            
            #c_weights = torch.nn.functional.normalize(net.contrastive_projector.weight,dim=1)
            #logits = features @ c_weights.T
            
            # Cross Entropy Loss 
            linear_loss = F.cross_entropy(logits, y_all)

            # Backpropagation part
            optimizer.zero_grad()
            linear_loss.backward()
            # net.contrastive_projector.weight.grad[0:task_id] = torch.zeros(net.contrastive_projector.weight.grad[0:task_id].shape).to(device)
            optimizer.step()

            # Accumulating number of examples, losses and correct predictions
            batchsize_bc = features.shape[0]
            total_num += batchsize_bc
            linear_loss += linear_loss.item() * batchsize_bc

            train_bar.set_description('Lin.Train Epoch: [{}] Loss: {:.4f}'.format(epoch, linear_loss / total_num))
        if scheduler is not None:
            scheduler.step()
        # wandb.log({" Linear Layer Train Loss ": linear_loss / total_num, " Epoch ": epoch})
    return linear_loss/total_num

In [51]:
device = torch.device("cuda:" + str(args.cuda_device) if torch.cuda.is_available() else "cpu")
print(device)
if 'infomax' in args.appr or 'barlow' in args.appr:
    proj_hidden = args.proj_hidden
    proj_out = args.proj_out
    encoder = Encoder(hidden_dim=proj_hidden, output_dim=proj_out, normalization = args.normalization, weight_standard = args.weight_standard, appr_name = args.appr)
    model2 = Siamese(encoder)
    model2.to(device) #automatically detects from model
#load model here
file_name = "./checkpoints/checkpoint_cifar100-algocassle_barlow-e[500, 500, 500, 500, 500]-b256-lr0.25-CS[20, 20, 20, 20, 20]_task_1_same_lr_True_norm_batch_ws_False.pth.tar"
dict = torch.load(file_name)
model2.temporal_projector = nn.Sequential(
            nn.Linear(args.proj_out, args.proj_hidden, bias=False),
            nn.BatchNorm1d(args.proj_hidden),
            nn.ReLU(),
            nn.Linear(args.proj_hidden, args.proj_out),
        ).to(device)
model2.load_state_dict(dict['state_dict'])


cuda:5


<All keys matched successfully>

In [52]:
contrastive_projector =  nn.Linear(512, 2, bias=False).to(device)
model2.contrastive_projector = contrastive_projector 
task_id = 1
lin_epoch= 10
lin_optimizer = torch.optim.SGD(model2.parameters(), 1e-3, momentum=0.9, weight_decay=0) 
test_loss = contrastive_train(model2, train_data_loaders_linear[task_id],old_data_loader, old_labels, task_id, lin_optimizer, None, epochs=lin_epoch, device=device) 

Lin.Train Epoch: [1] Loss: 0.0003: 100%|██████████| 40/40 [00:05<00:00,  7.21it/s]
Lin.Train Epoch: [2] Loss: 0.0000: 100%|██████████| 40/40 [00:05<00:00,  7.15it/s]
Lin.Train Epoch: [3] Loss: 0.0000: 100%|██████████| 40/40 [00:05<00:00,  7.23it/s]
Lin.Train Epoch: [4] Loss: 0.0006:  20%|██        | 8/40 [00:02<00:08,  3.79it/s]


KeyboardInterrupt: 

In [53]:
task_id = 1
lin_epoch = 100
num_class = np.sum(args.class_split[:task_id+1])
classifier = LinearClassifier(num_classes = num_class).to(device)
lin_optimizer = torch.optim.SGD(classifier.parameters(), 0.2, momentum=0.9, weight_decay=0) # Infomax: no weight decay, epoch 100, cosine scheduler
lin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(lin_optimizer, lin_epoch, eta_min=0.002) #scheduler + values ref: infomax paper
linear_evaluation(model2, train_data_loaders_linear[:task_id+1], test_data_loaders[:task_id+1], lin_optimizer,classifier, lin_scheduler, lin_epoch, device, task_id)  


Lin.Train Epoch: [1] Loss: 1234.0271 ACC: 8.21: 100%|██████████| 79/79 [00:03<00:00, 21.60it/s]
Lin.Test Epoch: [1] Loss: 600.4466 ACC: 14.65% : 100%|██████████| 16/16 [00:01<00:00, 11.73it/s]
Lin.Train Epoch: [2] Loss: 178.7394 ACC: 37.83: 100%|██████████| 79/79 [00:03<00:00, 22.08it/s]
Lin.Test Epoch: [2] Loss: 45.3162 ACC: 49.65% : 100%|██████████| 16/16 [00:01<00:00, 11.01it/s]
Lin.Train Epoch: [3] Loss: 39.8305 ACC: 49.73: 100%|██████████| 79/79 [00:03<00:00, 21.98it/s]
Lin.Test Epoch: [3] Loss: 34.8870 ACC: 52.62% : 100%|██████████| 16/16 [00:01<00:00, 11.21it/s]
Lin.Train Epoch: [4] Loss: 34.9601 ACC: 49.29: 100%|██████████| 79/79 [00:03<00:00, 22.09it/s]
Lin.Test Epoch: [4] Loss: 27.7315 ACC: 51.62% : 100%|██████████| 16/16 [00:01<00:00, 11.49it/s]
Lin.Train Epoch: [5] Loss: 36.4738 ACC: 48.58: 100%|██████████| 79/79 [00:03<00:00, 21.42it/s]
Lin.Test Epoch: [5] Loss: 30.4145 ACC: 51.68% : 100%|██████████| 16/16 [00:01<00:00, 10.64it/s]
Lin.Train Epoch: [6] Loss: 37.1492 ACC: 48

(4.177773349761963,
 65.675,
 LinearClassifier(
   (classifier): Linear(in_features=512, out_features=40, bias=True)
 ))

In [None]:
def correct_top_k(outputs, targets, top_k=(1,5)):
    with torch.no_grad():
        prediction = torch.argsort(outputs, dim=-1, descending=True)
        result= []
        for k in top_k:
            correct_k = torch.sum((prediction[:, 0:k] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item() 
            result.append(correct_k)
        return result
    
def contrastive_train2(net, data_loader, task_id, optimizer, classifier, scheduler, epochs, device):
    for epoch in range(1, epochs+1):
        net.eval() # for not update batchnorm 
        total_num, train_bar = 0, tqdm(data_loader)
        linear_loss = 0.0
        for data_tuple in train_bar:
            # Forward prop of the model with single augmented batch
            pos_1, targets = data_tuple
            pos_1 = pos_1[0]
            pos_1 = torch.cat((pos_1, -pos_1), dim=0)
            # print(pos_1.shape)
            pos_1 = pos_1.to(device)
            features = net(pos_1)

            # Batchsize
            batchsize_bc = features.shape[0]
            targets = torch.zeros(targets.shape[0],dtype=torch.long).to(device)
            targets = torch.cat((targets, torch.ones(targets.shape[0],dtype=torch.long).to(device)), dim=0)
            # print(targets.shape)
            targets = targets.to(device)
            
            logits = classifier(features.detach()) 

            # Cross Entropy Loss 
            linear_loss = F.cross_entropy(logits, targets)

            # Backpropagation part
            optimizer.zero_grad()
            linear_loss.backward()
            optimizer.step()

            # Accumulating number of examples, losses and correct predictions
            total_num += batchsize_bc
            linear_loss += linear_loss.item() * batchsize_bc

            train_bar.set_description('Lin.Train Epoch: [{}] Loss: {:.4f}'.format(epoch, linear_loss / total_num))
        if scheduler is not None:
            scheduler.step()
        # wandb.log({" Linear Layer Train Loss ": linear_loss / total_num, " Epoch ": epoch})
    return linear_loss/total_num