In [1]:
import random, gc, os, pickle, csv, time

import datasets.utils
import models.utils
from models.cls_oml_ori_v2 import OML
from models.base_models_ori import LabelAwareReplayMemory

import numpy as np
import matplotlib.pyplot as plt

import higher
import torch
import torch.nn.functional as F
from torch.utils import data

# Constants

In [2]:
dataset_order_mapping = {
    1: [2, 0, 3, 1, 4],
    2: [3, 4, 0, 1, 2],
    3: [2, 4, 1, 3, 0],
    4: [0, 2, 1, 4, 3]
}
n_classes = 33
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
# model_path = "/data/model_runs/original_oml/aOML-order1-2022-07-18/OML-order1-id4-2022-07-18_17-53-13.518612.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr002-2022-07-31/OML-order1-id4-2022-07-31_14-53-46.456804.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr005-2022-07-31/OML-order1-id4-2022-07-31_18-47-41.477968.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr005-up20-2022-08-01/OML-order1-id4-2022-08-01_14-45-55.869765.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-07-31/OML-order1-id4-2022-07-31_21-18-36.241546.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr020-2022-08-16/OML-order1-id4-2022-08-16_11-37-19.424113.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr050-2022-08-16/OML-order1-id4-2022-08-16_14-16-12.167637.pt"

# v. SR 
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-29-sr/OML-order1-id4-2022-08-29_18-10-31.695669.pt"
# v. SR Query
model_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-30-sr-query/OML-order1-id4-2022-08-30_05-21-18.854228.pt"

# memory_path = "/data/model_runs/original_oml/aOML-order1-2022-07-18/OML-order1-id4-2022-07-18_17-53-13.518639_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr002-2022-07-31/OML-order1-id4-2022-07-31_14-53-46.456828_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr005-2022-07-31/OML-order1-id4-2022-07-31_18-47-41.477992_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr005-up20-2022-08-01/OML-order1-id4-2022-08-01_14-45-55.869797_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-07-31/OML-order1-id4-2022-07-31_21-18-36.241572_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr020-2022-08-16/OML-order1-id4-2022-08-16_11-37-19.424139_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr050-2022-08-16/OML-order1-id4-2022-08-16_14-16-12.167666_memory.pickle"
# v. SR 
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-29-sr/OML-order1-id4-2022-08-29_18-10-31.695692_memory.pickle"
# v. SR Query
memory_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-30-sr-query/OML-order1-id4-2022-08-30_05-21-18.854254_memory.pickle"

# For Sample Score
TRIM_ER = 500
TOTAL_EPOCH = 600
# new_memory_path, ext = os.path.splitext(memory_path)
# new_memory_path = new_memory_path + "_label" + ext

use_db_cache = True
cache_dir = 'tmp'

In [3]:
args = {
    "order": 1,
    "n_epochs": 1,
    "lr": 3e-5,
    "inner_lr": 0.001*10,
    "meta_lr": 3e-5,
    "model": "bert",
    "learner": "oml",
    "mini_batch_size": 16,
    "updates": 5*1,
    "write_prob": 1.0,
    "max_length": 448,
    "seed": 42,
    "replay_rate": 0.01,
    "replay_every": 9600
}
updates = args["updates"]
mini_batch_size = args["mini_batch_size"]
order = args["order"]

In [4]:
torch.manual_seed(args["seed"])
random.seed(args["seed"])
np.random.seed(args["seed"])

# Load Dataset

In [5]:
print('Loading the datasets')
test_datasets = []
for dataset_id in dataset_order_mapping[order]:
    test_dataset_file = os.path.join(cache_dir, f"{dataset_id}.cache")
    if os.path.exists(test_dataset_file):
        with open(test_dataset_file, 'rb') as f:
            test_dataset = pickle.load(f)
    else:
        test_dataset = datasets.utils.get_dataset_test("", dataset_id)
        print('Loaded {}'.format(test_dataset.__class__.__name__))
        test_dataset = datasets.utils.offset_labels(test_dataset)
        pickle.dump(test_dataset, open( test_dataset_file, "wb" ), protocol=pickle.HIGHEST_PROTOCOL)
        print(f"Pickle saved at {test_dataset_file}")
    test_datasets.append(test_dataset)
print('Finished loading all the datasets')

Loading the datasets
Finished loading all the datasets


# Load Model

In [6]:
learner = OML(device=device, n_classes=n_classes, **args)
print('Using {} as learner'.format(learner.__class__.__name__))
learner.load_model(model_path)
with open(memory_path, 'rb') as f:
#     learner.memory = pickle.load(f)
    memory_buffer = pickle.load(f)


2022-11-28 20:20:55,826 - transformers.tokenization_utils_base - INFO - loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
2022-11-28 20:20:57,074 - transformers.configuration_utils - INFO - loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517
2022-11-28 20:20:57,077 - transformers.configuration_utils - INFO - Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer

Using OML as learner


In [7]:
# Setting up task dict for task-aware
memory_buffer.task_dict = {
    0: list(range(5, 9)), # AG
    1: list(range(0, 5)), # Amazon
    2: list(range(0, 5)), # Yelp
    3: list(range(9, 23)), # DBPedia
    4: list(range(23, 33)), # Yahoo
}

In [8]:
# label to task converter - hack, since normally we would just use the task token/identifier
def get_task_from_label(label_idx, task_dict):
    for task_idx, class_list in task_dict.items():
        if label_idx in class_list: return task_idx
    return -1
print(get_task_from_label(8, memory_buffer.task_dict))
# Find mode: https://stackoverflow.com/questions/10797819/finding-the-mode-of-a-list
def mode(array):
    return max(array, key = array.count)
# label list to task id
def get_task_from_label_list(label_list, task_dict):
    return mode([get_task_from_label(label, task_dict) for label in label_list])
print(get_task_from_label_list([1,2,32,1,4,2,0], memory_buffer.task_dict))

0
1


In [9]:
dataclass_mapper = {
    "AGNewsDataset": 0,
    "AmazonDataset": 1,
    "YelpDataset": 2,
    "DBPediaDataset": 3,
    "YahooAnswersDataset": 4
}
dataclass_mapper["AGNewsDataset"]

0

In [11]:
# Trim buffer dict to 1000 per class
memory_buffer.meta_length = 4
memory_buffer.reset_meta()
memory_buffer.trim_buffer_dict(limit_n=TRIM_ER)

In [12]:
len(memory_buffer.buffer_dict[0])

500

# Get Scores for each sample in ER

In [13]:
# Returns loss,preds,labels, labels_conf
def validate(fpln, validation_set):
    all_valid_preds, all_valid_labels, all_valid_label_conf = [], [], []
    
    for valid_text, valid_labels, _ in validation_set:        
        valid_labels = torch.tensor(valid_labels).to(device)
        valid_input_dict = learner.rln.encode_text(valid_text)
        valid_repr = learner.rln(valid_input_dict)
        valid_output = fpln(valid_repr) # Output has size of torch.Size([16, 33]) [BATCH, CLASSES]
        valid_loss = learner.loss_fn(valid_output, valid_labels)
        valid_loss = valid_loss.item()

        # output.detach().max(-1) max on each Batch, which will return [0] max, [1] indices
        valid_output_softmax = F.softmax(valid_output, -1)
        valid_label_conf = valid_output_softmax[np.arange(len(valid_output_softmax)), valid_labels] # Select labels in the softmax of 33 classes

        valid_pred = models.utils.make_prediction(valid_output.detach())
        
        
        # Put in things to return
        # all_valid_losses.extend(valid_loss)
        all_valid_preds.extend(valid_pred.tolist())
        all_valid_labels.extend(valid_labels.tolist())
        all_valid_label_conf.extend(valid_label_conf.tolist())
    return all_valid_preds, all_valid_labels, all_valid_label_conf # removed loss, since no need

# Compare diff results between the unadapted vs adapted
# Returns Dictionary of class_idx -> [ a - n, ...  ] for each i (300). Can np.sum() or np.mean() later
# validate_labels = The labels (Shared)
# validate_label_conf_0 = The label conf of validate_0
# validate_label_conf_n = The label conf of validate_n
def calculate_diff_class(validate_labels, validate_label_conf_0, validate_label_conf_n, initial_dict={}, return_dict=True): 
    # Adapted confs - NonAdapted Confs (a-n)
    validate_label_conf_diff = np.array(validate_label_conf_n) - np.array(validate_label_conf_0)
    
    if return_dict:
        # The dictionary to return  class_idx -> [ a - n, ...  ] 
        return_dict = initial_dict.copy()
        for i, class_idx in enumerate(validate_labels):
            # Filter conf_diff by class
            return_dict[class_idx] = return_dict.get(class_idx, []) + [validate_label_conf_diff[i]]
        return return_dict
    
    return validate_label_conf_diff

Every Task, we do `16*1 = 16` batch , there are 5 tasks so `16*5 = 80` per epoch (takes around 0.2s per task ~1s). So we need to do it 300 times so that it will be ~5 minutes (300s) and it means `80*300 = 24,000`, and hopefully support will be around  `16*300/300 ~ 16` per sample.

```
Adapt Time: 0.165510892868042 s
Support set metrics: Loss = 0.3509, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
```

```
For 600 Epoch
Total Scoring Time: 29.0 m
```

In [None]:
TIC_score = time.time()
for epoch in range(TOTAL_EPOCH):
    print(f"This is epoch {epoch}/{TOTAL_EPOCH}")
    learner.rln.eval()
    learner.pln.train()
    
    for task_idx in memory_buffer.task_dict.keys():
        support_set = []
        #for _ in range(updates):
        for _ in range(1): # CHANGE THIS to 1 minibatch.
            text, labels, indexes = memory_buffer.read_batch_task(batch_size=mini_batch_size, task_idx=task_idx, \
                                                                  with_index=True)
            support_set.append((text, labels, indexes))
        
        with higher.innerloop_ctx(learner.pln, learner.inner_optimizer,
                                  copy_initial_weights=False, track_higher_grads=False) as (fpln, diffopt):
            
            # Test validation_set BEFORE the update (update=0)
            with torch.no_grad():
                all_valid_preds_0, all_valid_labels_0, all_valid_label_conf_0  = validate(fpln, support_set)
            
            INNER_tic = time.time()
            # Inner loop
            task_predictions, task_labels, task_indexes = [], [], []
            support_loss = []
            for text, labels, indexes in support_set:
                labels = torch.tensor(labels).to(device)
                input_dict = learner.rln.encode_text(text)
                _repr = learner.rln(input_dict)
                output = fpln(_repr)
                loss = learner.loss_fn(output, labels)
                diffopt.step(loss)
                pred = models.utils.make_prediction(output.detach())
                support_loss.append(loss.item())
                task_predictions.extend(pred.tolist())
                task_labels.extend(labels.tolist())
                task_indexes.extend(indexes)
            INNER_toc = time.time() - INNER_tic
            print("Adapt Time: "+ str(INNER_toc) +" s" )
            
            # Test validation_set AFTER ALL the update
            with torch.no_grad():
                all_valid_preds_n, all_valid_labels_n, all_valid_label_conf_n = validate(fpln, support_set)
                diff_list = calculate_diff_class(all_valid_labels_0, all_valid_label_conf_0, all_valid_label_conf_n, return_dict=False)
                memory_buffer.update_meta(all_valid_labels_0, task_indexes, \
                                          np.array(all_valid_label_conf_0), np.array(all_valid_label_conf_n), diff_list)

            acc, prec, rec, f1 = models.utils.calculate_metrics(task_predictions, task_labels)

            print('Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, '
                        'recall = {:.4f}, F1 score = {:.4f}'.format(np.mean(support_loss), acc, prec, rec, f1))

TOC_score = time.time() - TIC_score 
print("Total Scoring Time: "+ str(TOC_score//60) +" m" )

This is epoch 0/600
Adapt Time: 0.1683790683746338 s
Support set metrics: Loss = 0.2438, accuracy = 0.8125, precision = 0.8375, recall = 0.8125, F1 score = 0.8185
Adapt Time: 0.20820999145507812 s
Support set metrics: Loss = 1.0968, accuracy = 0.5625, precision = 0.6333, recall = 0.5667, F1 score = 0.5571
Adapt Time: 0.19322800636291504 s
Support set metrics: Loss = 1.0571, accuracy = 0.6250, precision = 0.6700, recall = 0.6167, F1 score = 0.6048
Adapt Time: 0.17128252983093262 s
Support set metrics: Loss = 0.0049, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.2116405963897705 s
Support set metrics: Loss = 0.7671, accuracy = 0.7500, precision = 0.6167, recall = 0.7000, F1 score = 0.6300
This is epoch 1/600
Adapt Time: 0.16647052764892578 s
Support set metrics: Loss = 0.2433, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.19341707229614258 s
Support set metrics: Loss = 1.2023, accuracy = 0.4375, precision

Adapt Time: 0.21056628227233887 s
Support set metrics: Loss = 0.8190, accuracy = 0.7500, precision = 0.7833, recall = 0.7333, F1 score = 0.7190
Adapt Time: 0.19951701164245605 s
Support set metrics: Loss = 1.2495, accuracy = 0.5625, precision = 0.6400, recall = 0.5333, F1 score = 0.5378
Adapt Time: 0.16719889640808105 s
Support set metrics: Loss = 0.0955, accuracy = 0.9375, precision = 0.8929, recall = 0.9286, F1 score = 0.9048
Adapt Time: 0.185899019241333 s
Support set metrics: Loss = 0.7900, accuracy = 0.8125, precision = 0.8167, recall = 0.8500, F1 score = 0.8300
This is epoch 12/600
Adapt Time: 0.17038464546203613 s
Support set metrics: Loss = 0.1295, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.19846367835998535 s
Support set metrics: Loss = 0.6144, accuracy = 0.7500, precision = 0.7867, recall = 0.7333, F1 score = 0.7167
Adapt Time: 0.2400650978088379 s
Support set metrics: Loss = 1.1507, accuracy = 0.5000, precision = 0.5200, recall = 

Adapt Time: 0.20836448669433594 s
Support set metrics: Loss = 0.9955, accuracy = 0.5000, precision = 0.4800, recall = 0.4833, F1 score = 0.4671
Adapt Time: 0.17590618133544922 s
Support set metrics: Loss = 0.0051, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.22094941139221191 s
Support set metrics: Loss = 0.9305, accuracy = 0.6250, precision = 0.5917, recall = 0.6000, F1 score = 0.5800
This is epoch 23/600
Adapt Time: 0.1657707691192627 s
Support set metrics: Loss = 0.3286, accuracy = 0.8125, precision = 0.8375, recall = 0.8125, F1 score = 0.7986
Adapt Time: 0.19753646850585938 s
Support set metrics: Loss = 1.0517, accuracy = 0.4375, precision = 0.5633, recall = 0.4333, F1 score = 0.4514
Adapt Time: 0.18878674507141113 s
Support set metrics: Loss = 1.4980, accuracy = 0.4375, precision = 0.5167, recall = 0.4667, F1 score = 0.4492
Adapt Time: 0.1702558994293213 s
Support set metrics: Loss = 0.0047, accuracy = 1.0000, precision = 1.0000, recall =

Adapt Time: 0.1761157512664795 s
Support set metrics: Loss = 0.0063, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.20282530784606934 s
Support set metrics: Loss = 1.0842, accuracy = 0.6250, precision = 0.5067, recall = 0.6000, F1 score = 0.5371
This is epoch 34/600
Adapt Time: 0.16529273986816406 s
Support set metrics: Loss = 0.4552, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.20038914680480957 s
Support set metrics: Loss = 0.6693, accuracy = 0.6875, precision = 0.5900, recall = 0.6667, F1 score = 0.6092
Adapt Time: 0.18573665618896484 s
Support set metrics: Loss = 0.8065, accuracy = 0.5625, precision = 0.6233, recall = 0.5833, F1 score = 0.5814
Adapt Time: 0.1697525978088379 s
Support set metrics: Loss = 0.0058, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.21049976348876953 s
Support set metrics: Loss = 0.3556, accuracy = 0.8750, precision = 0.9000, recall =

Adapt Time: 0.19167184829711914 s
Support set metrics: Loss = 0.7050, accuracy = 0.8750, precision = 0.9167, recall = 0.9000, F1 score = 0.8800
This is epoch 45/600
Adapt Time: 0.16561245918273926 s
Support set metrics: Loss = 0.3662, accuracy = 0.8125, precision = 0.8667, recall = 0.8125, F1 score = 0.8294
Adapt Time: 0.19085216522216797 s
Support set metrics: Loss = 0.7569, accuracy = 0.7500, precision = 0.7533, recall = 0.7500, F1 score = 0.7348
Adapt Time: 0.18537259101867676 s
Support set metrics: Loss = 0.7617, accuracy = 0.6875, precision = 0.7133, recall = 0.6833, F1 score = 0.6848
Adapt Time: 0.17517566680908203 s
Support set metrics: Loss = 0.0141, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.23413825035095215 s
Support set metrics: Loss = 0.8318, accuracy = 0.8125, precision = 0.6833, recall = 0.7500, F1 score = 0.7000
This is epoch 46/600
Adapt Time: 0.16699814796447754 s
Support set metrics: Loss = 0.3086, accuracy = 0.8750, preci

Adapt Time: 0.1645963191986084 s
Support set metrics: Loss = 0.1960, accuracy = 0.8750, precision = 0.9000, recall = 0.8750, F1 score = 0.8730
Adapt Time: 0.1981372833251953 s
Support set metrics: Loss = 1.2868, accuracy = 0.4375, precision = 0.4667, recall = 0.4333, F1 score = 0.4222
Adapt Time: 0.20116662979125977 s
Support set metrics: Loss = 1.0744, accuracy = 0.5000, precision = 0.5667, recall = 0.5000, F1 score = 0.5219
Adapt Time: 0.1747589111328125 s
Support set metrics: Loss = 0.0055, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.1942131519317627 s
Support set metrics: Loss = 0.6555, accuracy = 0.7500, precision = 0.7833, recall = 0.8000, F1 score = 0.7500
This is epoch 57/600
Adapt Time: 0.16799616813659668 s
Support set metrics: Loss = 0.1539, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.21863031387329102 s
Support set metrics: Loss = 0.6273, accuracy = 0.7500, precision = 0.7500, recall = 0

Adapt Time: 0.18785834312438965 s
Support set metrics: Loss = 1.0946, accuracy = 0.6250, precision = 0.6333, recall = 0.6167, F1 score = 0.5981
Adapt Time: 0.2176380157470703 s
Support set metrics: Loss = 0.8950, accuracy = 0.6875, precision = 0.7333, recall = 0.6833, F1 score = 0.6933
Adapt Time: 0.18323874473571777 s
Support set metrics: Loss = 0.0188, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.20559191703796387 s
Support set metrics: Loss = 0.9088, accuracy = 0.6875, precision = 0.6667, recall = 0.7000, F1 score = 0.6467
This is epoch 68/600
Adapt Time: 0.16702651977539062 s
Support set metrics: Loss = 0.2482, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.20554709434509277 s
Support set metrics: Loss = 0.7632, accuracy = 0.6875, precision = 0.7300, recall = 0.6833, F1 score = 0.6829
Adapt Time: 0.2503190040588379 s
Support set metrics: Loss = 1.3707, accuracy = 0.3750, precision = 0.3867, recall =

Adapt Time: 0.18242812156677246 s
Support set metrics: Loss = 0.8290, accuracy = 0.6250, precision = 0.7000, recall = 0.6167, F1 score = 0.6014
Adapt Time: 0.18062710762023926 s
Support set metrics: Loss = 0.0067, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.1844499111175537 s
Support set metrics: Loss = 1.0022, accuracy = 0.6250, precision = 0.4167, recall = 0.5500, F1 score = 0.4667
This is epoch 79/600
Adapt Time: 0.16719365119934082 s
Support set metrics: Loss = 0.1706, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.18859291076660156 s
Support set metrics: Loss = 0.8374, accuracy = 0.6875, precision = 0.6600, recall = 0.6667, F1 score = 0.6578
Adapt Time: 0.1992177963256836 s
Support set metrics: Loss = 0.8229, accuracy = 0.6250, precision = 0.6000, recall = 0.6000, F1 score = 0.6000
Adapt Time: 0.17033815383911133 s
Support set metrics: Loss = 0.0910, accuracy = 0.9375, precision = 0.8929, recall =

Adapt Time: 0.17834734916687012 s
Support set metrics: Loss = 0.0837, accuracy = 0.9375, precision = 0.9643, recall = 0.9643, F1 score = 0.9524
Adapt Time: 0.19498419761657715 s
Support set metrics: Loss = 0.8347, accuracy = 0.7500, precision = 0.7167, recall = 0.7500, F1 score = 0.6967
This is epoch 90/600
Adapt Time: 0.16582417488098145 s
Support set metrics: Loss = 0.3660, accuracy = 0.8750, precision = 0.8875, recall = 0.8750, F1 score = 0.8740
Adapt Time: 0.19243693351745605 s
Support set metrics: Loss = 0.5943, accuracy = 0.6875, precision = 0.7000, recall = 0.6833, F1 score = 0.6814
Adapt Time: 0.20187783241271973 s
Support set metrics: Loss = 0.9042, accuracy = 0.5625, precision = 0.5533, recall = 0.5667, F1 score = 0.5310
Adapt Time: 0.1769580841064453 s
Support set metrics: Loss = 0.0052, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.21320319175720215 s
Support set metrics: Loss = 0.9657, accuracy = 0.7500, precision = 0.5900, recall 

Adapt Time: 0.19288253784179688 s
Support set metrics: Loss = 1.4727, accuracy = 0.5000, precision = 0.4333, recall = 0.4000, F1 score = 0.3900
This is epoch 101/600
Adapt Time: 0.1664111614227295 s
Support set metrics: Loss = 0.1149, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.21629118919372559 s
Support set metrics: Loss = 0.9110, accuracy = 0.5625, precision = 0.6300, recall = 0.5500, F1 score = 0.5500
Adapt Time: 0.20493841171264648 s
Support set metrics: Loss = 0.9179, accuracy = 0.5000, precision = 0.4667, recall = 0.4833, F1 score = 0.4738
Adapt Time: 0.1764969825744629 s
Support set metrics: Loss = 0.1428, accuracy = 0.9375, precision = 0.8929, recall = 0.9286, F1 score = 0.9048
Adapt Time: 0.22547149658203125 s
Support set metrics: Loss = 1.2031, accuracy = 0.6250, precision = 0.5167, recall = 0.6000, F1 score = 0.5267
This is epoch 102/600
Adapt Time: 0.1726522445678711 s
Support set metrics: Loss = 0.3519, accuracy = 0.8750, precis

Adapt Time: 0.16817569732666016 s
Support set metrics: Loss = 0.3727, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.19678711891174316 s
Support set metrics: Loss = 0.8170, accuracy = 0.6250, precision = 0.7067, recall = 0.6333, F1 score = 0.6000
Adapt Time: 0.19364047050476074 s
Support set metrics: Loss = 0.8972, accuracy = 0.6250, precision = 0.6833, recall = 0.6333, F1 score = 0.6324
Adapt Time: 0.1683037281036377 s
Support set metrics: Loss = 0.0060, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.2252659797668457 s
Support set metrics: Loss = 1.0213, accuracy = 0.6875, precision = 0.5667, recall = 0.6500, F1 score = 0.5767
This is epoch 113/600
Adapt Time: 0.16701269149780273 s
Support set metrics: Loss = 0.3456, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.19420933723449707 s
Support set metrics: Loss = 1.1465, accuracy = 0.5625, precision = 0.6867, recall 

Adapt Time: 0.18324565887451172 s
Support set metrics: Loss = 0.9918, accuracy = 0.5625, precision = 0.4467, recall = 0.5333, F1 score = 0.4810
Adapt Time: 0.2028512954711914 s
Support set metrics: Loss = 0.7575, accuracy = 0.6875, precision = 0.7533, recall = 0.6833, F1 score = 0.6948
Adapt Time: 0.17281770706176758 s
Support set metrics: Loss = 0.0048, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.1951305866241455 s
Support set metrics: Loss = 1.1061, accuracy = 0.6875, precision = 0.8000, recall = 0.7500, F1 score = 0.7300
This is epoch 124/600
Adapt Time: 0.17316603660583496 s
Support set metrics: Loss = 0.3261, accuracy = 0.8750, precision = 0.8875, recall = 0.8750, F1 score = 0.8740
Adapt Time: 0.2149796485900879 s
Support set metrics: Loss = 0.9912, accuracy = 0.6875, precision = 0.7533, recall = 0.7000, F1 score = 0.6967
Adapt Time: 0.19191980361938477 s
Support set metrics: Loss = 0.7121, accuracy = 0.7500, precision = 0.6643, recall =

Adapt Time: 0.19580316543579102 s
Support set metrics: Loss = 1.4279, accuracy = 0.3125, precision = 0.3167, recall = 0.3167, F1 score = 0.3086
Adapt Time: 0.17070770263671875 s
Support set metrics: Loss = 0.0049, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18976759910583496 s
Support set metrics: Loss = 0.8308, accuracy = 0.7500, precision = 0.7667, recall = 0.8000, F1 score = 0.7467
This is epoch 135/600
Adapt Time: 0.16315650939941406 s
Support set metrics: Loss = 0.7384, accuracy = 0.8125, precision = 0.8667, recall = 0.8125, F1 score = 0.7722
Adapt Time: 0.20214366912841797 s
Support set metrics: Loss = 0.8019, accuracy = 0.6875, precision = 0.7167, recall = 0.6833, F1 score = 0.6910
Adapt Time: 0.206817626953125 s
Support set metrics: Loss = 0.7613, accuracy = 0.7500, precision = 0.8333, recall = 0.7500, F1 score = 0.7476
Adapt Time: 0.16741204261779785 s
Support set metrics: Loss = 0.3539, accuracy = 0.9375, precision = 0.9643, recall 

Adapt Time: 0.1746993064880371 s
Support set metrics: Loss = 0.0107, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.20491456985473633 s
Support set metrics: Loss = 1.0149, accuracy = 0.7500, precision = 0.6667, recall = 0.8000, F1 score = 0.7200
This is epoch 146/600
Adapt Time: 0.16571521759033203 s
Support set metrics: Loss = 0.1730, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.20123052597045898 s
Support set metrics: Loss = 0.8461, accuracy = 0.5625, precision = 0.6467, recall = 0.5833, F1 score = 0.5610
Adapt Time: 0.18778610229492188 s
Support set metrics: Loss = 1.0022, accuracy = 0.5625, precision = 0.6000, recall = 0.5667, F1 score = 0.5467
Adapt Time: 0.17483258247375488 s
Support set metrics: Loss = 0.0081, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18797945976257324 s
Support set metrics: Loss = 1.1268, accuracy = 0.6875, precision = 0.7417, recall

Adapt Time: 0.21052885055541992 s
Support set metrics: Loss = 1.4787, accuracy = 0.5625, precision = 0.5000, recall = 0.5000, F1 score = 0.4633
This is epoch 157/600
Adapt Time: 0.17017650604248047 s
Support set metrics: Loss = 0.1589, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.21952128410339355 s
Support set metrics: Loss = 0.8771, accuracy = 0.5000, precision = 0.5700, recall = 0.4833, F1 score = 0.4848
Adapt Time: 0.1905994415283203 s
Support set metrics: Loss = 0.9054, accuracy = 0.6250, precision = 0.7000, recall = 0.6167, F1 score = 0.5990
Adapt Time: 0.17270636558532715 s
Support set metrics: Loss = 0.0077, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.20772123336791992 s
Support set metrics: Loss = 0.6873, accuracy = 0.8750, precision = 0.8333, recall = 0.8500, F1 score = 0.8267
This is epoch 158/600
Adapt Time: 0.1663072109222412 s
Support set metrics: Loss = 0.2153, accuracy = 0.8750, preci

Adapt Time: 0.16913938522338867 s
Support set metrics: Loss = 0.1214, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18047237396240234 s
Support set metrics: Loss = 1.2006, accuracy = 0.5000, precision = 0.5967, recall = 0.5167, F1 score = 0.5086
Adapt Time: 0.18573951721191406 s
Support set metrics: Loss = 0.7809, accuracy = 0.6875, precision = 0.7367, recall = 0.6833, F1 score = 0.6714
Adapt Time: 0.1770024299621582 s
Support set metrics: Loss = 0.4695, accuracy = 0.9375, precision = 0.9643, recall = 0.9643, F1 score = 0.9524
Adapt Time: 0.19587016105651855 s
Support set metrics: Loss = 1.1938, accuracy = 0.5625, precision = 0.5583, recall = 0.5500, F1 score = 0.5167
This is epoch 169/600
Adapt Time: 0.17243647575378418 s
Support set metrics: Loss = 0.0606, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18101978302001953 s
Support set metrics: Loss = 1.0339, accuracy = 0.6875, precision = 0.6000, recall

Adapt Time: 0.19517970085144043 s
Support set metrics: Loss = 0.9503, accuracy = 0.6250, precision = 0.5367, recall = 0.6167, F1 score = 0.5595
Adapt Time: 0.1973738670349121 s
Support set metrics: Loss = 1.0593, accuracy = 0.6875, precision = 0.6433, recall = 0.6667, F1 score = 0.6425
Adapt Time: 0.16850566864013672 s
Support set metrics: Loss = 0.0048, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.21028614044189453 s
Support set metrics: Loss = 1.3751, accuracy = 0.5625, precision = 0.5500, recall = 0.6000, F1 score = 0.5500
This is epoch 180/600
Adapt Time: 0.16668057441711426 s
Support set metrics: Loss = 0.2715, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.18498516082763672 s
Support set metrics: Loss = 0.8268, accuracy = 0.6875, precision = 0.7500, recall = 0.6833, F1 score = 0.6714
Adapt Time: 0.2097301483154297 s
Support set metrics: Loss = 0.9346, accuracy = 0.5625, precision = 0.5467, recall 

Adapt Time: 0.18601512908935547 s
Support set metrics: Loss = 0.7859, accuracy = 0.6875, precision = 0.7033, recall = 0.6833, F1 score = 0.6733
Adapt Time: 0.16601943969726562 s
Support set metrics: Loss = 0.0051, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.1947786808013916 s
Support set metrics: Loss = 1.0799, accuracy = 0.7500, precision = 0.6500, recall = 0.7000, F1 score = 0.6567
This is epoch 191/600
Adapt Time: 0.16765689849853516 s
Support set metrics: Loss = 0.4578, accuracy = 0.8125, precision = 0.8929, recall = 0.8125, F1 score = 0.8128
Adapt Time: 0.20365285873413086 s
Support set metrics: Loss = 0.9218, accuracy = 0.6250, precision = 0.6333, recall = 0.6167, F1 score = 0.6133
Adapt Time: 0.18169426918029785 s
Support set metrics: Loss = 1.0271, accuracy = 0.5625, precision = 0.6833, recall = 0.5667, F1 score = 0.5905
Adapt Time: 0.16752076148986816 s
Support set metrics: Loss = 0.0051, accuracy = 1.0000, precision = 1.0000, recall

Adapt Time: 0.1715984344482422 s
Support set metrics: Loss = 0.0055, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.19799304008483887 s
Support set metrics: Loss = 0.8670, accuracy = 0.7500, precision = 0.5500, recall = 0.6500, F1 score = 0.5900
This is epoch 202/600
Adapt Time: 0.16341066360473633 s
Support set metrics: Loss = 0.0660, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.19615864753723145 s
Support set metrics: Loss = 1.0054, accuracy = 0.5000, precision = 0.4600, recall = 0.5333, F1 score = 0.4800
Adapt Time: 0.20015859603881836 s
Support set metrics: Loss = 0.8861, accuracy = 0.5000, precision = 0.4167, recall = 0.4833, F1 score = 0.4381
Adapt Time: 0.1771843433380127 s
Support set metrics: Loss = 0.0053, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.257366418838501 s
Support set metrics: Loss = 0.7509, accuracy = 0.8125, precision = 0.7167, recall = 

Adapt Time: 0.22620463371276855 s
Support set metrics: Loss = 0.7631, accuracy = 0.6875, precision = 0.7000, recall = 0.7000, F1 score = 0.6700
This is epoch 213/600
Adapt Time: 0.1670849323272705 s
Support set metrics: Loss = 0.2038, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.18323278427124023 s
Support set metrics: Loss = 0.7777, accuracy = 0.7500, precision = 0.8200, recall = 0.7500, F1 score = 0.7314
Adapt Time: 0.1925034523010254 s
Support set metrics: Loss = 0.6303, accuracy = 0.8125, precision = 0.8333, recall = 0.8167, F1 score = 0.8076
Adapt Time: 0.17110633850097656 s
Support set metrics: Loss = 0.0941, accuracy = 0.9375, precision = 0.9643, recall = 0.9643, F1 score = 0.9524
Adapt Time: 0.18986821174621582 s
Support set metrics: Loss = 0.9663, accuracy = 0.6875, precision = 0.7167, recall = 0.7000, F1 score = 0.6567
This is epoch 214/600
Adapt Time: 0.1661076545715332 s
Support set metrics: Loss = 0.1715, accuracy = 0.9375, precis

Adapt Time: 0.1635754108428955 s
Support set metrics: Loss = 0.3725, accuracy = 0.8750, precision = 0.8875, recall = 0.8750, F1 score = 0.8740
Adapt Time: 0.20159626007080078 s
Support set metrics: Loss = 1.0164, accuracy = 0.5625, precision = 0.6000, recall = 0.5667, F1 score = 0.5800
Adapt Time: 0.19233155250549316 s
Support set metrics: Loss = 0.9910, accuracy = 0.6875, precision = 0.7000, recall = 0.6833, F1 score = 0.6800
Adapt Time: 0.16874361038208008 s
Support set metrics: Loss = 0.0048, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.2004241943359375 s
Support set metrics: Loss = 0.5751, accuracy = 0.8125, precision = 0.7833, recall = 0.8000, F1 score = 0.7767
This is epoch 225/600
Adapt Time: 0.17081236839294434 s
Support set metrics: Loss = 0.3456, accuracy = 0.8750, precision = 0.9167, recall = 0.8750, F1 score = 0.8667
Adapt Time: 0.22786569595336914 s
Support set metrics: Loss = 0.9694, accuracy = 0.5625, precision = 0.6500, recall 

Adapt Time: 0.18546819686889648 s
Support set metrics: Loss = 0.7935, accuracy = 0.7500, precision = 0.8133, recall = 0.7333, F1 score = 0.7533
Adapt Time: 0.19890165328979492 s
Support set metrics: Loss = 0.8393, accuracy = 0.6250, precision = 0.5767, recall = 0.6000, F1 score = 0.5759
Adapt Time: 0.17829275131225586 s
Support set metrics: Loss = 0.0054, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18452811241149902 s
Support set metrics: Loss = 1.2510, accuracy = 0.6875, precision = 0.6500, recall = 0.7000, F1 score = 0.6300
This is epoch 236/600
Adapt Time: 0.16663312911987305 s
Support set metrics: Loss = 0.1968, accuracy = 0.8750, precision = 0.9167, recall = 0.8750, F1 score = 0.8786
Adapt Time: 0.20247292518615723 s
Support set metrics: Loss = 1.2420, accuracy = 0.5625, precision = 0.5667, recall = 0.5667, F1 score = 0.5619
Adapt Time: 0.21647214889526367 s
Support set metrics: Loss = 1.0458, accuracy = 0.6875, precision = 0.6833, recal

Adapt Time: 0.20318603515625 s
Support set metrics: Loss = 1.1523, accuracy = 0.6250, precision = 0.6833, recall = 0.6167, F1 score = 0.6376
Adapt Time: 0.17152667045593262 s
Support set metrics: Loss = 0.0058, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18566560745239258 s
Support set metrics: Loss = 0.4672, accuracy = 0.8125, precision = 0.7500, recall = 0.7500, F1 score = 0.7333
This is epoch 247/600
Adapt Time: 0.16517400741577148 s
Support set metrics: Loss = 0.1560, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.21874260902404785 s
Support set metrics: Loss = 1.1895, accuracy = 0.3750, precision = 0.3300, recall = 0.3667, F1 score = 0.3371
Adapt Time: 0.21083760261535645 s
Support set metrics: Loss = 1.0366, accuracy = 0.5625, precision = 0.5633, recall = 0.5667, F1 score = 0.5381
Adapt Time: 0.1691455841064453 s
Support set metrics: Loss = 0.0067, accuracy = 1.0000, precision = 1.0000, recall = 

Adapt Time: 0.1721940040588379 s
Support set metrics: Loss = 0.0048, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.21241044998168945 s
Support set metrics: Loss = 1.2598, accuracy = 0.5625, precision = 0.5000, recall = 0.5500, F1 score = 0.5033
This is epoch 258/600
Adapt Time: 0.1671741008758545 s
Support set metrics: Loss = 0.8019, accuracy = 0.8125, precision = 0.8500, recall = 0.8125, F1 score = 0.8056
Adapt Time: 0.1904916763305664 s
Support set metrics: Loss = 0.9931, accuracy = 0.7500, precision = 0.7433, recall = 0.7333, F1 score = 0.7225
Adapt Time: 0.21106457710266113 s
Support set metrics: Loss = 0.9143, accuracy = 0.6875, precision = 0.7500, recall = 0.6833, F1 score = 0.6833
Adapt Time: 0.17960524559020996 s
Support set metrics: Loss = 0.0143, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.19687819480895996 s
Support set metrics: Loss = 0.7468, accuracy = 0.7500, precision = 0.7667, recall =

Adapt Time: 0.1876358985900879 s
Support set metrics: Loss = 0.5088, accuracy = 0.8750, precision = 0.8167, recall = 0.8500, F1 score = 0.8133
This is epoch 269/600
Adapt Time: 0.16602134704589844 s
Support set metrics: Loss = 0.6301, accuracy = 0.8125, precision = 0.8542, recall = 0.8125, F1 score = 0.8304
Adapt Time: 0.18856477737426758 s
Support set metrics: Loss = 1.1377, accuracy = 0.6250, precision = 0.5900, recall = 0.6167, F1 score = 0.5500
Adapt Time: 0.19510889053344727 s
Support set metrics: Loss = 0.8301, accuracy = 0.6250, precision = 0.6433, recall = 0.6000, F1 score = 0.6083
Adapt Time: 0.17245888710021973 s
Support set metrics: Loss = 0.0358, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.19927120208740234 s
Support set metrics: Loss = 1.0898, accuracy = 0.5625, precision = 0.5167, recall = 0.5000, F1 score = 0.4800
This is epoch 270/600
Adapt Time: 0.16546368598937988 s
Support set metrics: Loss = 0.0655, accuracy = 1.0000, prec

Adapt Time: 0.16567659378051758 s
Support set metrics: Loss = 0.0921, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.19524264335632324 s
Support set metrics: Loss = 0.6485, accuracy = 0.8125, precision = 0.8533, recall = 0.8000, F1 score = 0.8033
Adapt Time: 0.2182016372680664 s
Support set metrics: Loss = 0.8152, accuracy = 0.6875, precision = 0.8000, recall = 0.6833, F1 score = 0.7048
Adapt Time: 0.16894841194152832 s
Support set metrics: Loss = 0.3846, accuracy = 0.9375, precision = 1.0000, recall = 0.9643, F1 score = 0.9762
Adapt Time: 0.191619873046875 s
Support set metrics: Loss = 0.9338, accuracy = 0.6250, precision = 0.4833, recall = 0.6500, F1 score = 0.5433
This is epoch 281/600
Adapt Time: 0.16677260398864746 s
Support set metrics: Loss = 0.1119, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.21265482902526855 s
Support set metrics: Loss = 0.6959, accuracy = 0.6875, precision = 0.6500, recall =

Adapt Time: 0.19730901718139648 s
Support set metrics: Loss = 1.0447, accuracy = 0.5625, precision = 0.5133, recall = 0.5500, F1 score = 0.5190
Adapt Time: 0.19533777236938477 s
Support set metrics: Loss = 1.2523, accuracy = 0.4375, precision = 0.4800, recall = 0.4167, F1 score = 0.4314
Adapt Time: 0.17195701599121094 s
Support set metrics: Loss = 0.3568, accuracy = 0.9375, precision = 0.9048, recall = 0.9286, F1 score = 0.9143
Adapt Time: 0.2189805507659912 s
Support set metrics: Loss = 0.6205, accuracy = 0.7500, precision = 0.7833, recall = 0.8000, F1 score = 0.7667
This is epoch 292/600
Adapt Time: 0.16374993324279785 s
Support set metrics: Loss = 0.1611, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.19223332405090332 s
Support set metrics: Loss = 0.7106, accuracy = 0.8750, precision = 0.8833, recall = 0.8667, F1 score = 0.8648
Adapt Time: 0.2198631763458252 s
Support set metrics: Loss = 0.8550, accuracy = 0.5625, precision = 0.6633, recall 

Adapt Time: 0.18420886993408203 s
Support set metrics: Loss = 1.1442, accuracy = 0.4375, precision = 0.4500, recall = 0.4500, F1 score = 0.4381
Adapt Time: 0.16864514350891113 s
Support set metrics: Loss = 0.0049, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.20273327827453613 s
Support set metrics: Loss = 0.4601, accuracy = 0.8125, precision = 0.7667, recall = 0.8500, F1 score = 0.7967
This is epoch 303/600
Adapt Time: 0.16558456420898438 s
Support set metrics: Loss = 0.5333, accuracy = 0.8750, precision = 0.9000, recall = 0.8750, F1 score = 0.8611
Adapt Time: 0.2068476676940918 s
Support set metrics: Loss = 0.9274, accuracy = 0.7500, precision = 0.8033, recall = 0.7667, F1 score = 0.7481
Adapt Time: 0.19281911849975586 s
Support set metrics: Loss = 0.9544, accuracy = 0.5000, precision = 0.3800, recall = 0.4667, F1 score = 0.4067
Adapt Time: 0.17183756828308105 s
Support set metrics: Loss = 0.2772, accuracy = 0.9375, precision = 0.8929, recall

Adapt Time: 0.17074847221374512 s
Support set metrics: Loss = 0.0072, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.19604802131652832 s
Support set metrics: Loss = 1.2334, accuracy = 0.6875, precision = 0.6167, recall = 0.7000, F1 score = 0.6133
This is epoch 314/600
Adapt Time: 0.16481804847717285 s
Support set metrics: Loss = 0.3635, accuracy = 0.8750, precision = 0.9167, recall = 0.8750, F1 score = 0.8786
Adapt Time: 0.18412041664123535 s
Support set metrics: Loss = 0.9376, accuracy = 0.6250, precision = 0.7000, recall = 0.6333, F1 score = 0.6552
Adapt Time: 0.1882154941558838 s
Support set metrics: Loss = 1.0778, accuracy = 0.6250, precision = 0.5033, recall = 0.6167, F1 score = 0.5524
Adapt Time: 0.17554664611816406 s
Support set metrics: Loss = 0.0066, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.1934819221496582 s
Support set metrics: Loss = 0.4853, accuracy = 0.8125, precision = 0.8167, recall 

Adapt Time: 0.1924140453338623 s
Support set metrics: Loss = 0.6664, accuracy = 0.8125, precision = 0.6667, recall = 0.8000, F1 score = 0.7133
This is epoch 325/600
Adapt Time: 0.16641902923583984 s
Support set metrics: Loss = 0.4106, accuracy = 0.8750, precision = 0.8750, recall = 0.8750, F1 score = 0.8750
Adapt Time: 0.18693280220031738 s
Support set metrics: Loss = 0.8479, accuracy = 0.6250, precision = 0.7500, recall = 0.6167, F1 score = 0.6500
Adapt Time: 0.18883752822875977 s
Support set metrics: Loss = 1.0704, accuracy = 0.5625, precision = 0.7000, recall = 0.5667, F1 score = 0.6019
Adapt Time: 0.1738119125366211 s
Support set metrics: Loss = 0.0071, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.2085714340209961 s
Support set metrics: Loss = 0.5931, accuracy = 0.8125, precision = 0.8167, recall = 0.8000, F1 score = 0.7800
This is epoch 326/600
Adapt Time: 0.16846466064453125 s
Support set metrics: Loss = 0.1143, accuracy = 0.9375, precis

Adapt Time: 0.163254976272583 s
Support set metrics: Loss = 0.1065, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.21427416801452637 s
Support set metrics: Loss = 1.0885, accuracy = 0.5625, precision = 0.5667, recall = 0.5500, F1 score = 0.5524
Adapt Time: 0.17825579643249512 s
Support set metrics: Loss = 0.5965, accuracy = 0.8750, precision = 0.9100, recall = 0.8667, F1 score = 0.8692
Adapt Time: 0.17694807052612305 s
Support set metrics: Loss = 0.0055, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.19681644439697266 s
Support set metrics: Loss = 1.2272, accuracy = 0.5625, precision = 0.4750, recall = 0.5000, F1 score = 0.4567
This is epoch 337/600
Adapt Time: 0.1675703525543213 s
Support set metrics: Loss = 0.1332, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.2000141143798828 s
Support set metrics: Loss = 0.9531, accuracy = 0.5000, precision = 0.5533, recall = 

Adapt Time: 0.1856222152709961 s
Support set metrics: Loss = 0.7189, accuracy = 0.5000, precision = 0.5033, recall = 0.5000, F1 score = 0.4814
Adapt Time: 0.20359468460083008 s
Support set metrics: Loss = 1.0857, accuracy = 0.6250, precision = 0.5833, recall = 0.6167, F1 score = 0.5905
Adapt Time: 0.17803430557250977 s
Support set metrics: Loss = 0.4433, accuracy = 0.9375, precision = 0.9048, recall = 0.9286, F1 score = 0.9143
Adapt Time: 0.20475435256958008 s
Support set metrics: Loss = 0.7887, accuracy = 0.6875, precision = 0.5667, recall = 0.7000, F1 score = 0.5967
This is epoch 348/600
Adapt Time: 0.16727685928344727 s
Support set metrics: Loss = 0.1351, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.19313979148864746 s
Support set metrics: Loss = 1.0070, accuracy = 0.5625, precision = 0.5333, recall = 0.5333, F1 score = 0.4886
Adapt Time: 0.20200181007385254 s
Support set metrics: Loss = 0.9747, accuracy = 0.5000, precision = 0.6067, recall

Adapt Time: 0.19615793228149414 s
Support set metrics: Loss = 0.8464, accuracy = 0.5000, precision = 0.5200, recall = 0.5000, F1 score = 0.4900
Adapt Time: 0.1798570156097412 s
Support set metrics: Loss = 0.1795, accuracy = 0.9375, precision = 0.8929, recall = 0.9286, F1 score = 0.9048
Adapt Time: 0.1832103729248047 s
Support set metrics: Loss = 0.8710, accuracy = 0.6875, precision = 0.7500, recall = 0.7000, F1 score = 0.6833
This is epoch 359/600
Adapt Time: 0.16570234298706055 s
Support set metrics: Loss = 0.2164, accuracy = 0.8750, precision = 0.9167, recall = 0.8750, F1 score = 0.8667
Adapt Time: 0.18292474746704102 s
Support set metrics: Loss = 1.0206, accuracy = 0.5625, precision = 0.5200, recall = 0.5500, F1 score = 0.5157
Adapt Time: 0.20116758346557617 s
Support set metrics: Loss = 1.3829, accuracy = 0.4375, precision = 0.3833, recall = 0.4167, F1 score = 0.3976
Adapt Time: 0.17636513710021973 s
Support set metrics: Loss = 0.0166, accuracy = 1.0000, precision = 1.0000, recall 

Adapt Time: 0.17838478088378906 s
Support set metrics: Loss = 0.0414, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.2289869785308838 s
Support set metrics: Loss = 0.8172, accuracy = 0.8125, precision = 0.9167, recall = 0.8500, F1 score = 0.8467
This is epoch 370/600
Adapt Time: 0.16971516609191895 s
Support set metrics: Loss = 0.0580, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18059062957763672 s
Support set metrics: Loss = 1.1401, accuracy = 0.6250, precision = 0.6667, recall = 0.6167, F1 score = 0.6267
Adapt Time: 0.1798551082611084 s
Support set metrics: Loss = 1.0282, accuracy = 0.5000, precision = 0.6000, recall = 0.5167, F1 score = 0.5127
Adapt Time: 0.17212605476379395 s
Support set metrics: Loss = 0.0056, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.20731496810913086 s
Support set metrics: Loss = 1.0670, accuracy = 0.7500, precision = 0.7333, recall 

Adapt Time: 0.2089846134185791 s
Support set metrics: Loss = 1.0901, accuracy = 0.6250, precision = 0.5000, recall = 0.6000, F1 score = 0.5133
This is epoch 381/600
Adapt Time: 0.16750574111938477 s
Support set metrics: Loss = 0.3528, accuracy = 0.8125, precision = 0.8375, recall = 0.8125, F1 score = 0.8185
Adapt Time: 0.20549750328063965 s
Support set metrics: Loss = 0.7818, accuracy = 0.7500, precision = 0.7867, recall = 0.7500, F1 score = 0.7481
Adapt Time: 0.2228858470916748 s
Support set metrics: Loss = 0.7246, accuracy = 0.6875, precision = 0.7667, recall = 0.6833, F1 score = 0.7124
Adapt Time: 0.1767435073852539 s
Support set metrics: Loss = 0.0073, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18294668197631836 s
Support set metrics: Loss = 1.4547, accuracy = 0.6250, precision = 0.6500, recall = 0.6000, F1 score = 0.6000
This is epoch 382/600
Adapt Time: 0.16437268257141113 s
Support set metrics: Loss = 0.3401, accuracy = 0.9375, precis

Adapt Time: 0.16495251655578613 s
Support set metrics: Loss = 0.2481, accuracy = 0.8750, precision = 0.9000, recall = 0.8750, F1 score = 0.8611
Adapt Time: 0.1889503002166748 s
Support set metrics: Loss = 0.6611, accuracy = 0.8125, precision = 0.8600, recall = 0.8000, F1 score = 0.7806
Adapt Time: 0.19796109199523926 s
Support set metrics: Loss = 1.3474, accuracy = 0.5000, precision = 0.5667, recall = 0.5000, F1 score = 0.5219
Adapt Time: 0.1745593547821045 s
Support set metrics: Loss = 0.0058, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.2133314609527588 s
Support set metrics: Loss = 0.5676, accuracy = 0.8750, precision = 0.7167, recall = 0.8000, F1 score = 0.7467
This is epoch 393/600
Adapt Time: 0.16763639450073242 s
Support set metrics: Loss = 0.0605, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.1840832233428955 s
Support set metrics: Loss = 1.0244, accuracy = 0.7500, precision = 0.7667, recall = 

Adapt Time: 0.20681405067443848 s
Support set metrics: Loss = 0.9303, accuracy = 0.7500, precision = 0.9200, recall = 0.7667, F1 score = 0.7833
Adapt Time: 0.19379591941833496 s
Support set metrics: Loss = 1.0197, accuracy = 0.5625, precision = 0.6500, recall = 0.5500, F1 score = 0.5419
Adapt Time: 0.16972041130065918 s
Support set metrics: Loss = 0.0049, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.1935133934020996 s
Support set metrics: Loss = 0.5562, accuracy = 0.8750, precision = 0.9333, recall = 0.9000, F1 score = 0.8933
This is epoch 404/600
Adapt Time: 0.16552281379699707 s
Support set metrics: Loss = 0.1713, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.19651198387145996 s
Support set metrics: Loss = 0.7456, accuracy = 0.6250, precision = 0.6500, recall = 0.6000, F1 score = 0.5781
Adapt Time: 0.2013099193572998 s
Support set metrics: Loss = 0.6395, accuracy = 0.8125, precision = 0.8833, recall 

Adapt Time: 0.20928144454956055 s
Support set metrics: Loss = 0.9235, accuracy = 0.5000, precision = 0.4533, recall = 0.5000, F1 score = 0.4633
Adapt Time: 0.17627382278442383 s
Support set metrics: Loss = 0.0101, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18070626258850098 s
Support set metrics: Loss = 0.6258, accuracy = 0.8125, precision = 0.7833, recall = 0.8000, F1 score = 0.7600
This is epoch 415/600
Adapt Time: 0.16610121726989746 s
Support set metrics: Loss = 0.0837, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.2535674571990967 s
Support set metrics: Loss = 0.7894, accuracy = 0.7500, precision = 0.8100, recall = 0.7333, F1 score = 0.7235
Adapt Time: 0.21556997299194336 s
Support set metrics: Loss = 0.7141, accuracy = 0.7500, precision = 0.7500, recall = 0.7500, F1 score = 0.7329
Adapt Time: 0.1662764549255371 s
Support set metrics: Loss = 0.0086, accuracy = 1.0000, precision = 1.0000, recall 

Adapt Time: 0.17197608947753906 s
Support set metrics: Loss = 0.0052, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.19038176536560059 s
Support set metrics: Loss = 0.9899, accuracy = 0.6250, precision = 0.5000, recall = 0.5500, F1 score = 0.5167
This is epoch 426/600
Adapt Time: 0.16393065452575684 s
Support set metrics: Loss = 0.1399, accuracy = 0.9375, precision = 0.9500, recall = 0.9375, F1 score = 0.9365
Adapt Time: 0.1852118968963623 s
Support set metrics: Loss = 0.9581, accuracy = 0.6250, precision = 0.7633, recall = 0.6333, F1 score = 0.6381
Adapt Time: 0.1920464038848877 s
Support set metrics: Loss = 1.3501, accuracy = 0.5000, precision = 0.5000, recall = 0.5000, F1 score = 0.4895
Adapt Time: 0.16941070556640625 s
Support set metrics: Loss = 0.0458, accuracy = 1.0000, precision = 1.0000, recall = 1.0000, F1 score = 1.0000
Adapt Time: 0.18837547302246094 s
Support set metrics: Loss = 0.5302, accuracy = 0.8750, precision = 0.8167, recall 

In [None]:
print("Support of each class: ")
for class_idx in range(len(memory_buffer.meta_score)):
    total_support = np.sum(memory_buffer.meta_score[class_idx, :TRIM_ER, 1])
    mean_support = np.mean(memory_buffer.meta_score[class_idx, :TRIM_ER, 1])
    print(f"\t#Class {class_idx}: {total_support} , Avg: {mean_support}")

In [None]:
print("Average Score of each class: ")
for class_idx in range(len(memory_buffer.meta_score)):
    #mean_score = np.mean(memory_buffer.meta_score[class_idx, ..., 0])
    mean_score = np.mean(memory_buffer.meta_score[class_idx][memory_buffer.meta_score[class_idx, ...,1] != 0][:TRIM_ER, 0])
    print(f"\tScore Class {class_idx}: {mean_score}")

In [None]:
print(memory_buffer.meta_score[0,:TRIM_ER,0].shape)
print(memory_buffer.meta_score[0,:TRIM_ER,0][memory_buffer.meta_score[0,:TRIM_ER,1] != 0].shape)

In [None]:
_model_path0 = os.path.splitext(model_path)[0]
MEMORY_SAVE_LOC = _model_path0 + "_memory_ss6.pickle"
pickle.dump( memory_buffer, open( MEMORY_SAVE_LOC, "wb" ), protocol=pickle.HIGHEST_PROTOCOL)
print(f"Done writing Memory Pickle File at {MEMORY_SAVE_LOC}")

In [None]:
# new_memory_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-30-sr-query/OML-order1-id4-2022-08-30_05-21-18.854228_memory_ss4.pickle"
# with open(new_memory_path, 'rb') as f:
#     memory_buffer = pickle.load(f)

### Plotting Distribution of Scores for each Class

In [None]:
class_idx = 1
#score_list = memory_buffer.meta_score[class_idx, ..., 0]
# Remove the support = 0
score_list = memory_buffer.meta_score[class_idx,...,0][memory_buffer.meta_score[class_idx, ...,1] != 0]

fig, ax =  plt.subplots(figsize=(10,4))
ax.hist(score_list, bins=15)

In [None]:
class_idx = 8
#score_list = memory_buffer.meta_score[class_idx, ..., 0]
# Remove the support = 0
score_list = memory_buffer.meta_score[class_idx,...,0][memory_buffer.meta_score[class_idx, ...,1] != 0]

fig, ax =  plt.subplots(figsize=(10,4))
ax.hist(score_list, bins=15)

In [None]:
class_idx = 30
#score_list = memory_buffer.meta_score[class_idx, ..., 0]
# Remove the support = 0
score_list = memory_buffer.meta_score[class_idx,...,0][memory_buffer.meta_score[class_idx, ...,1] != 0]

fig, ax =  plt.subplots(figsize=(10,4))
ax.hist(score_list, bins=15)

### Plotting Adapted Conf

In [None]:
class_idx = 30
# Remove the support = 0
score_list = memory_buffer.meta_score[class_idx,...,2][memory_buffer.meta_score[class_idx, ...,1] != 0]

fig, ax =  plt.subplots(figsize=(10,4))
ax.hist(score_list, bins=15)

### Plotting Distribution of Scores for one sample

In [None]:
class_idx = 0
sample_idx = 0
score_list = memory_buffer.meta_debug[class_idx][sample_idx]

print(f"#Scores in ScoreList {len(score_list)}")
print(f"# Score in sample_idx: {memory_buffer.meta_score[class_idx, sample_idx, 0]}")
print(f"# Support in sample_idx: {memory_buffer.meta_score[class_idx, sample_idx, 1]}")

fig, ax =  plt.subplots(figsize=(10,4))
ax.hist(score_list, bins=15)

In [None]:
memory_buffer.meta_debug[class_idx].keys()

# Testing

Select specific column index per row
https://stackoverflow.com/questions/23435782/numpy-selecting-specific-column-index-per-row-by-using-a-list-of-indexes

In [None]:
def evaluate(dataloader, updates, mini_batch_size, dataname=""):
    learner.rln.eval()
    learner.pln.train()
    
    all_losses, all_predictions, all_labels, all_label_conf = [], [], [], []
    all_adaptation_time = []
    # Get Query set first. and then find supporting support set
    for query_idx, (query_text, query_labels) in enumerate(dataloader):
        print(f"Query ID {query_idx}/{len(dataloader)}")
        # The task id to optimize to for support set
        # task_idx = get_task_from_label_list(query_labels, memory_buffer.task_dict)
        task_idx = dataclass_mapper[dataname]
        
    
        support_set = []
        for _ in range(updates):
            text, labels = memory_buffer.read_batch_task(batch_size=mini_batch_size, task_idx=task_idx, sort_score=True, \
                                                        sort_asc=True ) # this is changed!!!
            support_set.append((text, labels))

        with higher.innerloop_ctx(learner.pln, learner.inner_optimizer,
                                  copy_initial_weights=False, track_higher_grads=False) as (fpln, diffopt):
            
            INNER_tic = time.time()
            # Inner loop
            task_predictions, task_labels = [], []
            support_loss = []
            for text, labels in support_set:
                labels = torch.tensor(labels).to(device)
                input_dict = learner.rln.encode_text(text)
                _repr = learner.rln(input_dict)
                output = fpln(_repr)
                loss = learner.loss_fn(output, labels)
                diffopt.step(loss)
                pred = models.utils.make_prediction(output.detach())
                support_loss.append(loss.item())
                task_predictions.extend(pred.tolist())
                task_labels.extend(labels.tolist())
            INNER_toc = time.time() - INNER_tic
            all_adaptation_time.append(INNER_toc)

            acc, prec, rec, f1 = models.utils.calculate_metrics(task_predictions, task_labels)

            print('Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, '
                        'recall = {:.4f}, F1 score = {:.4f}'.format(np.mean(support_loss), acc, prec, rec, f1))

            # Query set is now here!
            query_labels = torch.tensor(query_labels).to(device)
            query_input_dict = learner.rln.encode_text(query_text)
            with torch.no_grad():
                query_repr = learner.rln(query_input_dict)
                query_output = fpln(query_repr) # Output has size of torch.Size([16, 33]) [BATCH, CLASSES]
                query_loss = learner.loss_fn(query_output, query_labels)
            query_loss = query_loss.item()
            # print(output.detach().size())
            # output.detach().max(-1) max on each Batch, which will return [0] max, [1] indices
            query_output_softmax = F.softmax(query_output, -1)
            query_label_conf = query_output_softmax[np.arange(len(query_output_softmax)), query_labels] # Select labels in the softmax of 33 classes

            query_pred = models.utils.make_prediction(query_output.detach())
            query_acc, query_prec, query_rec, query_f1 = models.utils.calculate_metrics(query_pred.tolist(), query_labels.tolist())
            
            print('Query set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, '
                'recall = {:.4f}, F1 score = {:.4f}'.format(np.mean(query_loss), query_acc, query_prec, query_rec, query_f1))

            all_losses.append(query_loss)
            all_predictions.extend(query_pred.tolist())
            all_labels.extend(query_labels.tolist())
            all_label_conf.extend(query_label_conf.tolist())

    acc, prec, rec, f1 = models.utils.calculate_metrics(all_predictions, all_labels)
    print('Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
                'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec, f1))
    return acc, prec, rec, f1, all_predictions, all_labels, all_label_conf, all_adaptation_time

In [None]:
tic = time.time()
print('----------Testing on test set starts here----------')

accuracies, precisions, recalls, f1s = [], [], [], []
all_adapt_time = []
# Data for Visualization: [data_idx, label, label_conf, pred]
data_for_visual = []

for test_dataset in test_datasets:
    print('Testing on {}'.format(test_dataset.__class__.__name__))
    test_dataloader = data.DataLoader(test_dataset, batch_size=mini_batch_size, shuffle=False,
                                      collate_fn=datasets.utils.batch_encode)
    acc, prec, rec, f1, all_pred, all_label, all_label_conf, all_adaptation_time = evaluate(dataloader=test_dataloader, updates=updates, 
                                                mini_batch_size=mini_batch_size, dataname=test_dataset.__class__.__name__)
    
    data_ids = [test_dataset.__class__.__name__ + str(i) for i in range(len(all_label))]
    data_for_visual.extend(list(zip(data_ids, all_label, all_label_conf, all_pred)))
    all_adapt_time.extend(all_adaptation_time)
#     print(data_ids)
#     print(all_label)
#     raise Exception("BREAKPOINT")
    
    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1s.append(f1)


print()
print("COPY PASTA - not really but ok")
for row in accuracies:
    print(row)
print()
print('Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(accuracies), np.mean(precisions), np.mean(recalls), np.mean(f1s)))

toc = time.time() - tic
print(f"Total Time used: {toc//60} minutes")

In [None]:
_model_path0 = os.path.splitext(model_path)[0]
csv_filename = _model_path0 + "_update"+ str(updates) +"_results_sr_ta_ss6.csv" # for selective replay
with open(csv_filename, 'w') as csv_file:
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(["data_idx", "label", "label_conf", "pred"])
    csv_writer.writerows(data_for_visual)
print(f"Done writing CSV File at {csv_filename}")

In [None]:
# Log Time for Inference
_model_path0 = os.path.splitext(model_path)[0]
time_txt_filename = _model_path0 + "_update"+ str(updates) +"_time_inference_sr_ta_ss6.csv" 
with open(time_txt_filename, 'w') as csv_file:
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(["time_id", "time"])
    csv_writer.writerow(["Total Time", f"{toc//60} minutes"])
    csv_writer.writerow(["mean Adapt Time", f"{np.mean(all_adapt_time)} s"])
print(f"Done writing Time CSV File at {time_txt_filename}")