## Token-level Adversarial Contrastive Training (TACT)

This is a variant of CAT where instead of perturbing the word embedding matrix, we directly perturb the token representations. We consider this perturbed representation for the positive pair.

<div>
<img src="Images/ADV_1.png"  width="400"/>
</div>

### Import Libraries

In [1]:

import numpy as np
import pandas as pd
import json, sys, regex
import torch
#import GPUtil
import torch.nn as nn
import shutil
from glob import glob
from shutil import copyfile
from tqdm import tqdm, trange
import os
from pytorch_metric_learning import losses
from torch.nn import functional as F


import random

from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, classification_report, confusion_matrix
##----------------------------------------------------
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler

import datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print(device)


def set_seed(seed):
    # Set the random seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

cpu


### Function for Tokenizing Train & Test Datasets

In [2]:
class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, max_len, lab2ind, text_col_1 = 'sentence', text_col_2 = None, label_col = 'labels'):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text_col_1 = self.data[text_col_1]
        if(text_col_2 is None):
            self.text_col_2 = None
        else:
            self.text_col_2 = self.data[text_col_2]
        self.labels = self.data[label_col]
        self.max_len = max_len
        self.lab2ind = lab2ind
        
        self.isPair = True
        if(self.text_col_2 is None):
            self.isPair = False
            self.text_col_2 = self.data[text_col_1]

    def __len__(self):
        return len(self.text_col_1)

    def __getitem__(self, index):
        text_1 = str(self.text_col_1[index])     
        text_2 = str(self.text_col_2[index]) 
        
        label = self.labels[index]
        label = self.lab2ind[label]
        try:
            label = self.lab2ind[label]
        except:
            pass
        
        if(self.isPair):
            inputs = self.tokenizer.batch_encode_plus(
            [text_1, text_2],
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True
        )
        else:
            inputs = self.tokenizer.batch_encode_plus(
            [text_1],
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            #return_token_type_ids=False, !!!!!!!!!
            truncation=True
        )

        if(self.isPair):
            dic = {
            'ids': torch.tensor(inputs.input_ids, dtype=torch.long),
            'mask': torch.tensor(inputs.attention_mask, dtype=torch.long),
            'token': torch.tensor(inputs.token_type_ids, dtype=torch.long),
            'targets': torch.tensor(label, dtype=torch.long)
        }
        else:
            dic = {
            'ids': torch.tensor(inputs.input_ids, dtype=torch.long),
            'mask': torch.tensor(inputs.attention_mask, dtype=torch.long),
            'targets': torch.tensor(label, dtype=torch.long)
        }
        
        return dic

### Function for Encoding Dataset

In [3]:
# define a function for data preparation
def regular_encode(file_path, tokenizer, lab2ind, shuffle=True, num_workers = 1, batch_size=64, maxlen = 32, mode = 'train', text_col_1 = 'sentence', text_col_2 = None, label_col = 'labels'):
    
    # if we are in train mode, we will load two columns (i.e., text and label).
    delimiter = None
    if(str(file_path).endswith('tsv')):
        delimiter = '\t'
    if mode == 'train':
        # Use pandas to load dataset
        df = pd.read_csv(file_path, delimiter=delimiter)
        custom_set = CustomDataset(df, tokenizer, maxlen,lab2ind, text_col_1 = text_col_1, text_col_2 = text_col_2, label_col = label_col)
    
    # if we are in predict mode, we will load one column (i.e., text).
    elif mode == 'evaluate':
        df = pd.read_csv(file_path, delimiter=delimiter)
        custom_set = CustomDataset(df, tokenizer, maxlen,lab2ind, text_col_1 = text_col_1, text_col_2 = text_col_2, label_col = label_col)
    else:
        print("the type of mode should be either 'train' or 'predict'. ")

        return
        
    print("{} Dataset: {}".format(file_path, df.shape))
    
    dataset_params = {'batch_size': batch_size, 'shuffle': shuffle, 'num_workers': num_workers}

    batch_data_loader = DataLoader(custom_set, **dataset_params)
    
    return batch_data_loader

### Supervised Contrastive Loss

In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from pprint import pprint


class Soft_SupConLoss_CLS(nn.Module):

    def __init__(self, num_classes, temperature=0.07, device='cpu'):
        super(Soft_SupConLoss_CLS, self).__init__()
        self.temperature = temperature
        self.num_classes = num_classes
        self.device = device

    def forward(self, features, labels=None, weights=None, mask = None):
        """
        Returns:
            A loss scalar.
        """

        features = F.normalize(features, dim=1, p=2)

        batch_size = features.shape[0]
        weights = F.softmax(weights,dim=1) # logit to softmax

        labels_one_hot = F.one_hot(labels, num_classes=self.num_classes).float().to(self.device)

        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(self.device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(self.device)
        else:
            mask = mask.float().to(self.device)

        contrast_feature = features
        anchor_feature = contrast_feature

        # compute dot product of embeddings
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T), 
            self.temperature)

        # set diagonal as 0
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(self.device),
            0
        )

        ## it produces 0 for the non-matching places and 1 for matching places and neg mask does the opposite
        ## set diagonal as 0
        mask = mask * logits_mask

        weighted_mask = torch.matmul(weights, torch.transpose(labels_one_hot, 0, 1)).to(self.device)

        weighted_mask = weighted_mask * logits_mask

        # weights of postive samples
        pos_weighted_mask = weighted_mask * mask

        # compute log_prob with logsumexp
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)

        # remove diagonal
        logits = anchor_dot_contrast - logits_max.detach()

        # wiyk * exp(hi * hk / t)
        exp_logits = torch.exp(logits) * weighted_mask

        ## log_prob = x - max(x1,..,xn) - logsumexp(x1,..,xn) the equation
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (pos_weighted_mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = -1 * mean_log_prob_pos
        # loss = loss.view(anchor_count, batch_size).mean()
        loss = loss.mean()
        # print("loss",loss)
        return loss

### Define Bert Model

In [5]:
def initial_module(module):
    torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    torch.nn.init.constant_(module.bias, 0)


class Bert_CLS(nn.Module):
    def __init__(self, lab2ind, model_path, hidden_size, loss_func, cl_dim = 300):
        super(Bert_CLS, self).__init__()
        self.model_path = model_path
        self.hidden_size = hidden_size
        self.loss_func = loss_func
        self.cl_dim = cl_dim
        
        self.eps = 0.005

        self.label_num = len(lab2ind.keys())

        self.bert = AutoModel.from_pretrained(model_path)

        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.hidden_size, self.label_num)
        
        self.cl_dense = nn.Linear(self.hidden_size, self.hidden_size)
        self.cl_fc = nn.Linear(self.hidden_size, self.cl_dim)
        
        initial_module(self.dense)
        initial_module(self.fc)
        initial_module(self.cl_dense)
        initial_module(self.cl_fc)

    def forward(self, input_ids, attention_mask, token_type_ids = None, labels=None, sequence_output = None):
        is_seq = False
        if(sequence_output is None):
            outputs = self.bert(input_ids=input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids, 
                                    output_hidden_states = True, return_dict = True)
            sequence_output = outputs['last_hidden_state']
        else:
            is_seq = True
            sequence_output.requires_grad=True

        

        x = sequence_output[:, 0, :]
        x = self.dropout(x)
        x = self.dense(x)
        x = F.relu(x)
        logits = self.fc(x)
        
        if(is_seq):
            loss = self.loss_func(logits, labels)
            loss.backward(retain_graph=True)

            cl_grad = sequence_output.grad.detach()
            #print(cl_grad)
            #print('Dec Grad Size {}'.format(dec_grad.size()))
            l2_norm = torch.norm(cl_grad, dim=-1)

            cl_grad /= (l2_norm.unsqueeze(-1) + 1e-12)

            sequence_output = sequence_output + self.eps * cl_grad.detach()
        
        x = sequence_output[:, 0, :]
        x = self.dropout(x)
        x = self.cl_dense(x)
        x = F.relu(x)
        cl_emb = self.cl_fc(x)

        if labels is not None:
            loss = self.loss_func(logits, labels)
            return [loss, logits, cl_emb, sequence_output]
        else:
            return [None, logits, cl_emb, sequence_output]

### Define Train Function for TACT

<div>
<img src="Images/tact_code_0.PNG"  width="300"/>
</div>

In [6]:
from copy import deepcopy

def train(model_main, iterator, optimizer, scheduler, label_size, supervised_cl_weight, supervised_cls_temp, 
          soft_supervised_cl = False, weight_model = None, isPair = False):
    
    model_main.train()


    epoch_loss = 0.0
    if soft_supervised_cl:
        print("ADV CL")
        loss_supercl_cls = losses.NTXentLoss(temperature=supervised_cls_temp)
    else:
        print("Supervised CL")
        loss_supercl = losses.SupConLoss(temperature=supervised_cls_temp)

    for _, batch in enumerate(tqdm(iterator, desc="Iteration")):
        
        input_ids = batch['ids'].to(device, dtype = torch.long)
        attention_mask = batch['mask'].to(device, dtype = torch.long)
        try:
            token_type_ids = batch['token'].to(device, dtype=torch.long)
        except:
            token_type_ids = None
        labels = batch['targets'].to(device, dtype = torch.long)

        batch_size = input_ids.shape[0]
        num_sent = input_ids.shape[1]

        if (len(input_ids.shape) == 3):
            input_ids = input_ids.view((-1, input_ids.size(-1)))
            attention_mask = attention_mask.view((-1, input_ids.size(-1)))
        if (token_type_ids and (len(token_type_ids.shape) == 3) and (isPair==True)):
            token_type_ids = token_type_ids.view((-1, input_ids.size(-1)))

        
        if labels is not None:
            supervised_labels = labels
            if((len(supervised_labels.shape) == 1) and (isPair==False)):
                supervised_labels = supervised_labels.unsqueeze(1)
                supervised_labels = torch.cat([supervised_labels] * num_sent, 1).view(-1)  # size (bs * num_sent)

        outputs = model_main(input_ids=input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids, labels = supervised_labels)
        loss, logits, cl_emb = outputs[:3]
        sequence_output = outputs[3]

        loss = loss * (1.0 - supervised_cl_weight)

        weight_model = deepcopy(model_main)
        sequence_output_clone = sequence_output.detach()
        sequence_output_clone.requires_grad = True

        outputs = weight_model(input_ids=input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids, labels = supervised_labels,
                              sequence_output=sequence_output_clone)

        loss_weight, logits_adv, cl_emb_adv = outputs[:3]
            
            
            

        cl_emb_new = torch.cat([cl_emb,cl_emb_adv],dim=1).view(-1,cl_emb.size(-1))
        cont_labels = torch.arange(batch_size).unsqueeze(1)
        cont_labels = torch.cat([cont_labels, cont_labels], dim=1)
        cont_labels = cont_labels.view(-1)


        if supervised_cl_weight > 0.0:
            if soft_supervised_cl:
                loss += loss_supercl_cls(cl_emb_new, cont_labels) * supervised_cl_weight
            else:
                loss += loss_supercl(cls_embedding, supervised_labels) * supervised_cl_weight


        if torch.cuda.device_count() == 1:
            loss.backward()
            epoch_loss += loss.cpu().item()

        elif torch.cuda.device_count() > 1:
            loss.mean().backward()
            epoch_loss += loss.mean().cpu().item()
        else:
            loss.backward()
            epoch_loss += loss.cpu().item()
        
        torch.nn.utils.clip_grad_norm_(model_main.parameters(), max_grad_norm)

        optimizer.step()

        model_main.zero_grad()

        scheduler.step()
        optimizer.zero_grad()


    return epoch_loss / len(iterator)

### Define Evaluation

In [7]:
def evaluate(model, iterator, metric, is_regression = False, isPair = False):
    AvgRec=0.00
    Fpn=0.00
    model.eval()
    epoch_loss = 0

    all_pred=[]
    all_label = []

    with torch.no_grad():
        for _, batch in enumerate(iterator, 0):
        # Add batch to GPU
            input_ids = batch['ids'].to(device, dtype = torch.long)
            attention_mask = batch['mask'].to(device, dtype = torch.long)
            try:
                token_type_ids = batch['token'].to(device, dtype=torch.long)
            except:
                token_type_ids = None
            labels = batch['targets'].to(device, dtype = torch.long)
            
            if (len(input_ids.shape) == 3):
                input_ids = input_ids.view((-1, input_ids.size(-1)))
                attention_mask = attention_mask.view((-1, input_ids.size(-1)))
            if (token_type_ids and (len(token_type_ids.shape) == 3) and (isPair==True)):
                token_type_ids = token_type_ids.view((-1, input_ids.size(-1)))
            
            outputs = model(input_ids=input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids, labels = labels)
            loss, logits = outputs[:2]

            # delete used variables to free GPU memory
            del batch, input_ids, attention_mask, token_type_ids

            if torch.cuda.device_count() == 1:
                epoch_loss += loss.cpu().item()
            else:
                epoch_loss += loss.sum().cpu().item()
            # identify the predicted class for each example in the batch
            probabilities, predicted = torch.max(logits.cpu().data, 1)
            # put all the true labels and predictions to two lists
            all_pred.extend(logits.cpu() if is_regression else predicted)
            all_label.extend(labels.cpu())

   
    if(is_regression):
        result =  {"mse": (np.array((all_pred - all_label) ** 2)).mean().item()}
    else:
        result = metric(predictions=all_pred, references=all_label)
    return epoch_loss/len(iterator), result

### Create Optimizer and Metric

In [8]:
def create_optimizer_and_scheduler(total_params, num_training_steps, warmup_steps, weight_decay, learning_rate, is_constant_lr):
    """
    Setup the optimizer and the learning rate scheduler.
    """
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in total_params if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in total_params if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=learning_rate
    )
    
    if is_constant_lr == True:
    	lr_scheduler = get_constant_schedule(optimizer)
    else:
	    lr_scheduler = get_linear_schedule_with_warmup(
	        optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps
	    )
    return optimizer, lr_scheduler

In [9]:
def simple_accuracy(preds, labels):
    return (preds == labels).mean().item()

def accuracy(predictions, references):
    acc = accuracy_score(references, predictions)
    return {
        "accuracy": acc,
    }

### Finetuning Function

In [12]:
def fine_tuning(config, seed):
    set_seed(seed)
    #---------------------------------------
    print("[INFO] step (1) load train_test config:")

    task_name = config["task_name"]
    text_col_1 = config["text_column_1"]
    isPair = False
    try:
        text_col_2 = config["text_column_2"]
        if(text_col_2 is not None):
            isPair = True #Pair Sentence True
    except: 
        text_col_2 = None
    label_col = config["label_column"]

    train_file = os.path.join(config["data_dir"], config["train_file"])


    try:
        is_constant_lr = config["is_constant_lr"]
    except: 
        is_constant_lr = False

    print("[INFO] USE Supervised Contrastive Learing:")
    supervised_cl = config["supervised_cl"]
    supervised_cl_weight = config["supervised_cl_weight"]
    supervised_cls_temp = config["supervised_cls_temp"]
    soft_supervised_cl = config["soft_supervised_cl"] # true or false


    if supervised_cl_weight == 0.0:
        weight_model_flag = False
        soft_supervised_cl = False

    if soft_supervised_cl == False:
        weight_model_flag = False

    
        
    dev_file = os.path.join(config["data_dir"], config["dev_file"])
    test_file = os.path.join(config["data_dir"], config["test_file"])


    max_seq_length= int(config["max_seq_length"])
    batch_size = int(config["batch_size"])

    try: 
        early_stop = config["early_stop"]
    except:
        early_stop = 5

    try:
        save_model = config["save_model"]
    except: 
        save_model = False


    learning_rate = float(config["lr"]) 
    model_path = config['pretrained_model_path']
    num_epochs = config['epochs']


    #---------------------------------------------------------
    print("[INFO] step (2) check checkpoit directory and report file:")
    ckpt_dir = config["ckpt_dir"] + "/"
    #-------------------------------------------------------
    print("[INFO] step (3) load label to number dictionary:")
    
    delimiter = None
    if(str(train_file).endswith('tsv')):
        delimiter = '\t'
    df = pd.read_csv(train_file, delimiter=delimiter)
    labels = df[label_col].tolist()
    is_regression = False
    if(isinstance(labels[0], float)):
        is_regression = True
        lab2ind = {'float':0}
    
    unique_labels = list(set(labels))
    lab2ind = {l:ind for ind,l in enumerate(unique_labels)}
    
    
    try:
        num_workers = config['num_workers']
    except:
        num_workers = 1
    
    
    print("[INFO] train_file", train_file)
    print("[INFO] dev_file", dev_file)
    print("[INFO] test_file", test_file)
    print("[INFO] num_epochs", num_epochs)
    print("[INFO] model_path", model_path)
    print("[INFO] max_seq_length", max_seq_length)
    print("[INFO] batch_size", batch_size)
    print("[INFO] Number of Classes", len(lab2ind))
    print("[INFO] Number of Workers", num_workers)
    print("[INFO] step (4) Use defined funtion to extract tokanize data")

    criterion = nn.CrossEntropyLoss()
    if(is_regression):
        criterion = nn.MSELoss()


    print("loading Model setting")
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    print("[INFO] step (5) Create an iterator of data with torch DataLoader.")

    train_dataloader = regular_encode(train_file, tokenizer, lab2ind, True, batch_size=batch_size, maxlen = max_seq_length, mode = "train",
                                     text_col_1 = text_col_1, text_col_2 = text_col_2, label_col = label_col)
    validation_dataloader = regular_encode(dev_file, tokenizer, lab2ind, True, batch_size=batch_size, maxlen = max_seq_length, mode = "evaluate",
                                     text_col_1 = text_col_1, text_col_2 = text_col_2, label_col = label_col)

    model = Bert_CLS(lab2ind, model_path, 768, criterion, 300)


    print("[INFO] step (6) run with parallel CPU/GPUs")
    if torch.cuda.is_available():
        if torch.cuda.device_count() == 1:
            print("Run",model_path, "with one GPU")
            model = model.to(device)


    #---------------------------------------------------
    print("[INFO] step (7) set Parameters, schedules, and loss function:")
    global max_grad_norm
    max_grad_norm = 1.0
    try:
        warmup_proportion = config["warmup_proportion"]
    except: 
        warmup_proportion = 0.06

    num_training_steps	= len(train_dataloader) * num_epochs
    num_warmup_steps = num_training_steps * warmup_proportion
    ### In Transformers, optimizer and schedules are instantiated like this:
    # Note: AdamW is a class from the huggingface library
    # the 'W' stands for 'Weight Decay"
    weight_decay = 0.01

    total_params = list(model.named_parameters())

    optimizer, scheduler = create_optimizer_and_scheduler(total_params, num_training_steps, num_warmup_steps, weight_decay, learning_rate, is_constant_lr)

    metric = accuracy
    
    print("[INFO] step (8) start fine_tuning")
    

    for epoch in trange(num_epochs, desc="Epoch"):
        print(f'Epoch: {epoch+1}')
        train_loss = train(model, train_dataloader, optimizer, scheduler, len(lab2ind), supervised_cl_weight, supervised_cls_temp, soft_supervised_cl, isPair)	  
        eval_loss, eval_result = evaluate(model, validation_dataloader, metric, isPair=isPair)
        print(f'Train Loss: {train_loss}')
        print(eval_result)


### Run Model

In [13]:
config = {}

config['task_name'] = "Binary Sentiment Classification"
config['text_column_1'] = "sentence"
config['text_column_2'] = None
config['label_column'] = "label"
config['data_dir'] = "./"
config['train_file'] = "sst2_tiny.csv"
config['dev_file'] = "sst2_tiny.csv"
config['test_file'] = "sst2_tiny.csv"
config['is_constant_lr'] = False
config['supervised_cl'] = True
config['supervised_cl_weight'] = 0.5
config['supervised_cls_temp'] = 0.3
config['soft_supervised_cl'] = True
config['weight_model_flag'] = True
config['max_seq_length'] = 64
config['batch_size'] = 5
config['early_stop'] = 5
config['save_model'] = False
config['lr'] = 0.00005
config['pretrained_model_path'] = "bert-base-uncased"
config['epochs'] = 3
config['ckpt_dir'] = "ckpt"
config['warmup_proportion'] = 0.05
config['metric'] = "accuracy"
config['num_workers'] = 1


seed = 42

fine_tuning(config, seed)

[INFO] step (1) load train_test config:
[INFO] USE Supervised Contrastive Learing:
[INFO] step (2) check checkpoit directory and report file:
[INFO] step (3) load label to number dictionary:
[INFO] train_file ./sst2_tiny.csv
[INFO] dev_file ./sst2_tiny.csv
[INFO] test_file ./sst2_tiny.csv
[INFO] num_epochs 3
[INFO] model_path bert-base-uncased
[INFO] max_seq_length 64
[INFO] batch_size 5
[INFO] Number of Classes 2
[INFO] Number of Workers 1
[INFO] step (4) Use defined funtion to extract tokanize data
loading Model setting
[INFO] step (5) Create an iterator of data with torch DataLoader.
./sst2_tiny.csv Dataset: (10, 3)
./sst2_tiny.csv Dataset: (10, 3)




[INFO] step (6) run with parallel CPU/GPUs
[INFO] step (7) set Parameters, schedules, and loss function:
[INFO] step (8) start fine_tuning


Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch: 1
ADV CL



Iteration:   0%|          | 0/2 [00:00<?, ?it/s][A

1.0909855365753174



Iteration:  50%|█████     | 1/2 [00:04<00:04,  4.62s/it][A

1.1126958131790161



Iteration: 100%|██████████| 2/2 [00:08<00:00,  4.08s/it][A
Epoch:  33%|███▎      | 1/3 [00:10<00:21, 10.96s/it]

Train Loss: 1.1018406748771667
{'accuracy': 0.9}
Epoch: 2
ADV CL



Iteration:   0%|          | 0/2 [00:00<?, ?it/s][A

0.9569948315620422



Iteration:  50%|█████     | 1/2 [00:05<00:05,  5.18s/it][A

0.9035152196884155



Iteration: 100%|██████████| 2/2 [00:08<00:00,  4.40s/it][A
Epoch:  67%|██████▋   | 2/3 [00:22<00:11, 11.25s/it]

Train Loss: 0.9302550256252289
{'accuracy': 0.8}
Epoch: 3
ADV CL



Iteration:   0%|          | 0/2 [00:00<?, ?it/s][A

0.8728577494621277



Iteration:  50%|█████     | 1/2 [00:05<00:05,  5.05s/it][A

0.8933364152908325



Iteration: 100%|██████████| 2/2 [00:08<00:00,  4.35s/it][A
Epoch: 100%|██████████| 3/3 [00:33<00:00, 11.29s/it]

Train Loss: 0.8830970823764801
{'accuracy': 0.9}



