In [1]:
import random, gc, os, pickle, csv, time

import datasets.utils
import models.utils
from models.cls_oml_ori_v2 import OML
from models.base_models_ori import LabelAwareReplayMemory

import numpy as np

import higher
import torch
import torch.nn.functional as F
from torch.utils import data

# Constants

In [2]:
dataset_order_mapping = {
    1: [2, 0, 3, 1, 4],
    2: [3, 4, 0, 1, 2],
    3: [2, 4, 1, 3, 0],
    4: [0, 2, 1, 4, 3]
}
n_classes = 33
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
# model_path = "/data/model_runs/original_oml/aOML-order1-2022-07-18/OML-order1-id4-2022-07-18_17-53-13.518612.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr002-2022-07-31/OML-order1-id4-2022-07-31_14-53-46.456804.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr005-2022-07-31/OML-order1-id4-2022-07-31_18-47-41.477968.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr005-up20-2022-08-01/OML-order1-id4-2022-08-01_14-45-55.869765.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-07-31/OML-order1-id4-2022-07-31_21-18-36.241546.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr020-2022-08-16/OML-order1-id4-2022-08-16_11-37-19.424113.pt"
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr050-2022-08-16/OML-order1-id4-2022-08-16_14-16-12.167637.pt"

# v. SR 
# model_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-29-sr/OML-order1-id4-2022-08-29_18-10-31.695669.pt"
# v. SR Query
model_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-30-sr-query/OML-order1-id4-2022-08-30_05-21-18.854228.pt"

# memory_path = "/data/model_runs/original_oml/aOML-order1-2022-07-18/OML-order1-id4-2022-07-18_17-53-13.518639_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr002-2022-07-31/OML-order1-id4-2022-07-31_14-53-46.456828_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr005-2022-07-31/OML-order1-id4-2022-07-31_18-47-41.477992_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr005-up20-2022-08-01/OML-order1-id4-2022-08-01_14-45-55.869797_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-07-31/OML-order1-id4-2022-07-31_21-18-36.241572_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr020-2022-08-16/OML-order1-id4-2022-08-16_11-37-19.424139_memory.pickle"
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr050-2022-08-16/OML-order1-id4-2022-08-16_14-16-12.167666_memory.pickle"
# v. SR 
# memory_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-29-sr/OML-order1-id4-2022-08-29_18-10-31.695692_memory.pickle"
# v. SR Query
memory_path = "/data/model_runs/original_oml/aOML-order1-inlr010-2022-08-30-sr-query/OML-order1-id4-2022-08-30_05-21-18.854254_memory.pickle"


# new_memory_path, ext = os.path.splitext(memory_path)
# new_memory_path = new_memory_path + "_label" + ext

use_db_cache = True
cache_dir = 'tmp'

In [3]:
args = {
    "order": 1,
    "n_epochs": 1,
    "lr": 3e-5,
    "inner_lr": 0.001*10,
    "meta_lr": 3e-5,
    "model": "bert",
    "learner": "oml",
    "mini_batch_size": 16,
    "updates": 5*1,
    "write_prob": 1.0,
    "max_length": 448,
    "seed": 42,
    "replay_rate": 0.01,
    "replay_every": 9600
}
updates = args["updates"]
mini_batch_size = args["mini_batch_size"]
order = args["order"]

In [4]:
torch.manual_seed(args["seed"])
random.seed(args["seed"])
np.random.seed(args["seed"])

# Load Dataset

In [5]:
print('Loading the datasets')
test_datasets = []
for dataset_id in dataset_order_mapping[order]:
    test_dataset_file = os.path.join(cache_dir, f"{dataset_id}.cache")
    if os.path.exists(test_dataset_file):
        with open(test_dataset_file, 'rb') as f:
            test_dataset = pickle.load(f)
    else:
        test_dataset = datasets.utils.get_dataset_test("", dataset_id)
        print('Loaded {}'.format(test_dataset.__class__.__name__))
        test_dataset = datasets.utils.offset_labels(test_dataset)
        pickle.dump(test_dataset, open( test_dataset_file, "wb" ), protocol=pickle.HIGHEST_PROTOCOL)
        print(f"Pickle saved at {test_dataset_file}")
    test_datasets.append(test_dataset)
print('Finished loading all the datasets')

Loading the datasets
Finished loading all the datasets


# Load Model

In [6]:
learner = OML(device=device, n_classes=n_classes, **args)
print('Using {} as learner'.format(learner.__class__.__name__))
learner.load_model(model_path)
with open(memory_path, 'rb') as f:
#     learner.memory = pickle.load(f)
    memory_buffer = pickle.load(f)


2022-09-10 09:10:02,186 - transformers.tokenization_utils_base - INFO - loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /root/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
2022-09-10 09:10:03,347 - transformers.configuration_utils - INFO - loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef516483e4ca884517
2022-09-10 09:10:03,351 - transformers.configuration_utils - INFO - Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer

Using OML as learner


In [7]:
# Setting up task dict for task-aware
memory_buffer.task_dict = {
    0: list(range(5, 9)), # AG
    1: list(range(0, 5)), # Amazon
    2: list(range(0, 5)), # Yelp
    3: list(range(9, 23)), # DBPedia
    4: list(range(23, 33)), # Yahoo
}

In [8]:
# label to task converter - hack, since normally we would just use the task token/identifier
def get_task_from_label(label_idx, task_dict):
    for task_idx, class_list in task_dict.items():
        if label_idx in class_list: return task_idx
    return -1
print(get_task_from_label(8, memory_buffer.task_dict))
# Find mode: https://stackoverflow.com/questions/10797819/finding-the-mode-of-a-list
def mode(array):
    return max(array, key = array.count)
# label list to task id
def get_task_from_label_list(label_list, task_dict):
    return mode([get_task_from_label(label, task_dict) for label in label_list])
print(get_task_from_label_list([1,2,32,1,4,2,0], memory_buffer.task_dict))

0
1


In [4]:
dataclass_mapper = {
    "AGNewsDataset": 0,
    "AmazonDataset": 1,
    "YelpDataset": 2,
    "DBPediaDataset": 3,
    "YahooAnswersDataset": 4
}
dataclass_mapper["AGNewsDataset"]

0

# Testing

Select specific column index per row
https://stackoverflow.com/questions/23435782/numpy-selecting-specific-column-index-per-row-by-using-a-list-of-indexes

In [9]:
def evaluate(dataloader, updates, mini_batch_size, dataname=""):
    learner.rln.eval()
    learner.pln.train()
    
    all_losses, all_predictions, all_labels, all_label_conf = [], [], [], []
    all_adaptation_time = []
    # Get Query set first. and then find supporting support set
    for query_idx, (query_text, query_labels) in enumerate(dataloader):
        print(f"Query ID {query_idx}/{len(dataloader)}")
        # The task id to optimize to for support set
        # task_idx = get_task_from_label_list(query_labels, memory_buffer.task_dict)
        task_idx = dataclass_mapper[dataname]
        
    
        support_set = []
        for _ in range(updates):
            text, labels = memory_buffer.read_batch_task(batch_size=mini_batch_size, task_idx=task_idx)
            support_set.append((text, labels))

        with higher.innerloop_ctx(learner.pln, learner.inner_optimizer,
                                  copy_initial_weights=False, track_higher_grads=False) as (fpln, diffopt):
            
            INNER_tic = time.time()
            # Inner loop
            task_predictions, task_labels = [], []
            support_loss = []
            for text, labels in support_set:
                labels = torch.tensor(labels).to(device)
                input_dict = learner.rln.encode_text(text)
                _repr = learner.rln(input_dict)
                output = fpln(_repr)
                loss = learner.loss_fn(output, labels)
                diffopt.step(loss)
                pred = models.utils.make_prediction(output.detach())
                support_loss.append(loss.item())
                task_predictions.extend(pred.tolist())
                task_labels.extend(labels.tolist())
            INNER_toc = time.time() - INNER_tic
            all_adaptation_time.append(INNER_toc)

            acc, prec, rec, f1 = models.utils.calculate_metrics(task_predictions, task_labels)

            print('Support set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, '
                        'recall = {:.4f}, F1 score = {:.4f}'.format(np.mean(support_loss), acc, prec, rec, f1))

            # Query set is now here!
            query_labels = torch.tensor(query_labels).to(device)
            query_input_dict = learner.rln.encode_text(query_text)
            with torch.no_grad():
                query_repr = learner.rln(query_input_dict)
                query_output = fpln(query_repr) # Output has size of torch.Size([16, 33]) [BATCH, CLASSES]
                query_loss = learner.loss_fn(query_output, query_labels)
            query_loss = query_loss.item()
            # print(output.detach().size())
            # output.detach().max(-1) max on each Batch, which will return [0] max, [1] indices
            query_output_softmax = F.softmax(query_output, -1)
            query_label_conf = query_output_softmax[np.arange(len(query_output_softmax)), query_labels] # Select labels in the softmax of 33 classes

            query_pred = models.utils.make_prediction(query_output.detach())
            query_acc, query_prec, query_rec, query_f1 = models.utils.calculate_metrics(query_pred.tolist(), query_labels.tolist())
            
            print('Query set metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, '
                'recall = {:.4f}, F1 score = {:.4f}'.format(np.mean(query_loss), query_acc, query_prec, query_rec, query_f1))

            all_losses.append(query_loss)
            all_predictions.extend(query_pred.tolist())
            all_labels.extend(query_labels.tolist())
            all_label_conf.extend(query_label_conf.tolist())

    acc, prec, rec, f1 = models.utils.calculate_metrics(all_predictions, all_labels)
    print('Test metrics: Loss = {:.4f}, accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
                'F1 score = {:.4f}'.format(np.mean(all_losses), acc, prec, rec, f1))
    return acc, prec, rec, f1, all_predictions, all_labels, all_label_conf, all_adaptation_time

In [10]:
tic = time.time()
print('----------Testing on test set starts here----------')

accuracies, precisions, recalls, f1s = [], [], [], []
all_adapt_time = []
# Data for Visualization: [data_idx, label, label_conf, pred]
data_for_visual = []

for test_dataset in test_datasets:
    print('Testing on {}'.format(test_dataset.__class__.__name__))
    test_dataloader = data.DataLoader(test_dataset, batch_size=mini_batch_size, shuffle=False,
                                      collate_fn=datasets.utils.batch_encode)
    acc, prec, rec, f1, all_pred, all_label, all_label_conf, all_adaptation_time = evaluate(dataloader=test_dataloader, updates=updates, 
                                                mini_batch_size=mini_batch_size, dataname=test_dataset.__class__.__name__)
    
    data_ids = [test_dataset.__class__.__name__ + str(i) for i in range(len(all_label))]
    data_for_visual.extend(list(zip(data_ids, all_label, all_label_conf, all_pred)))
    all_adapt_time.extend(all_adaptation_time)
#     print(data_ids)
#     print(all_label)
#     raise Exception("BREAKPOINT")
    
    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1s.append(f1)


print()
print("COPY PASTA - not really but ok")
for row in accuracies:
    print(row)
print()
print('Overall test metrics: Accuracy = {:.4f}, precision = {:.4f}, recall = {:.4f}, '
            'F1 score = {:.4f}'.format(np.mean(accuracies), np.mean(precisions), np.mean(recalls), np.mean(f1s)))

toc = time.time() - tic
print(f"Total Time used: {toc//60} minutes")

----------Testing on test set starts here----------
Testing on YelpDataset
Query ID 0/475
Support set metrics: Loss = 1.0140, accuracy = 0.6250, precision = 0.6529, recall = 0.6233, F1 score = 0.6345
Query set metrics: Loss = 0.6118, accuracy = 0.6875, precision = 0.7800, recall = 0.6933, F1 score = 0.6900
Query ID 1/475
Support set metrics: Loss = 1.0684, accuracy = 0.5250, precision = 0.5283, recall = 0.5233, F1 score = 0.5191
Query set metrics: Loss = 0.6789, accuracy = 0.7500, precision = 0.7267, recall = 0.7933, F1 score = 0.7200
Query ID 2/475
Support set metrics: Loss = 0.8462, accuracy = 0.6125, precision = 0.6007, recall = 0.6000, F1 score = 0.5938
Query set metrics: Loss = 0.9035, accuracy = 0.4375, precision = 0.4733, recall = 0.4667, F1 score = 0.4538
Query ID 3/475
Support set metrics: Loss = 0.8568, accuracy = 0.6125, precision = 0.6152, recall = 0.6033, F1 score = 0.6053
Query set metrics: Loss = 0.8478, accuracy = 0.6250, precision = 0.6033, recall = 0.6000, F1 score = 

Support set metrics: Loss = 0.9176, accuracy = 0.6000, precision = 0.6328, recall = 0.6000, F1 score = 0.6126
Query set metrics: Loss = 0.6832, accuracy = 0.7500, precision = 0.8250, recall = 0.7625, F1 score = 0.7667
Query ID 36/475
Support set metrics: Loss = 0.7069, accuracy = 0.7500, precision = 0.7605, recall = 0.7467, F1 score = 0.7523
Query set metrics: Loss = 0.8782, accuracy = 0.6875, precision = 0.6600, recall = 0.7167, F1 score = 0.6502
Query ID 37/475
Support set metrics: Loss = 0.9647, accuracy = 0.5750, precision = 0.5916, recall = 0.5700, F1 score = 0.5747
Query set metrics: Loss = 0.6858, accuracy = 0.7500, precision = 0.6333, recall = 0.6433, F1 score = 0.6325
Query ID 38/475
Support set metrics: Loss = 0.9136, accuracy = 0.6625, precision = 0.6578, recall = 0.6467, F1 score = 0.6472
Query set metrics: Loss = 0.8557, accuracy = 0.6250, precision = 0.6000, recall = 0.6367, F1 score = 0.6033
Query ID 39/475
Support set metrics: Loss = 0.7527, accuracy = 0.6500, precision

Query set metrics: Loss = 1.3701, accuracy = 0.4375, precision = 0.5167, recall = 0.5867, F1 score = 0.5071
Query ID 71/475
Support set metrics: Loss = 0.8392, accuracy = 0.6500, precision = 0.6812, recall = 0.6533, F1 score = 0.6642
Query set metrics: Loss = 0.8921, accuracy = 0.5625, precision = 0.6033, recall = 0.4833, F1 score = 0.4967
Query ID 72/475
Support set metrics: Loss = 0.8179, accuracy = 0.6625, precision = 0.6677, recall = 0.6567, F1 score = 0.6557
Query set metrics: Loss = 0.7228, accuracy = 0.6250, precision = 0.6800, recall = 0.7400, F1 score = 0.6889
Query ID 73/475
Support set metrics: Loss = 0.8694, accuracy = 0.6500, precision = 0.6784, recall = 0.6567, F1 score = 0.6612
Query set metrics: Loss = 0.8577, accuracy = 0.7500, precision = 0.6600, recall = 0.6533, F1 score = 0.6425
Query ID 74/475
Support set metrics: Loss = 0.7934, accuracy = 0.7000, precision = 0.7083, recall = 0.6967, F1 score = 0.6984
Query set metrics: Loss = 0.8737, accuracy = 0.5625, precision =

Query set metrics: Loss = 0.6818, accuracy = 0.5625, precision = 0.5500, recall = 0.5500, F1 score = 0.5452
Query ID 106/475
Support set metrics: Loss = 0.9915, accuracy = 0.5875, precision = 0.5744, recall = 0.5800, F1 score = 0.5766
Query set metrics: Loss = 1.1959, accuracy = 0.3750, precision = 0.3500, recall = 0.4333, F1 score = 0.3705
Query ID 107/475
Support set metrics: Loss = 0.9823, accuracy = 0.6000, precision = 0.6078, recall = 0.5867, F1 score = 0.5932
Query set metrics: Loss = 0.9474, accuracy = 0.4375, precision = 0.5000, recall = 0.4000, F1 score = 0.3943
Query ID 108/475
Support set metrics: Loss = 0.9843, accuracy = 0.6500, precision = 0.6900, recall = 0.6467, F1 score = 0.6556
Query set metrics: Loss = 1.2025, accuracy = 0.3125, precision = 0.3524, recall = 0.2900, F1 score = 0.2438
Query ID 109/475
Support set metrics: Loss = 0.8028, accuracy = 0.6500, precision = 0.6443, recall = 0.6467, F1 score = 0.6418
Query set metrics: Loss = 1.0253, accuracy = 0.5625, precisi

Query set metrics: Loss = 0.5191, accuracy = 0.8125, precision = 0.8500, recall = 0.8533, F1 score = 0.8267
Query ID 141/475
Support set metrics: Loss = 0.8407, accuracy = 0.7125, precision = 0.7089, recall = 0.7033, F1 score = 0.7015
Query set metrics: Loss = 0.7784, accuracy = 0.6250, precision = 0.5133, recall = 0.5500, F1 score = 0.5254
Query ID 142/475
Support set metrics: Loss = 0.8067, accuracy = 0.6750, precision = 0.7066, recall = 0.6733, F1 score = 0.6835
Query set metrics: Loss = 0.6619, accuracy = 0.8750, precision = 0.9100, recall = 0.8500, F1 score = 0.8611
Query ID 143/475
Support set metrics: Loss = 0.7769, accuracy = 0.6750, precision = 0.6901, recall = 0.6733, F1 score = 0.6794
Query set metrics: Loss = 0.6048, accuracy = 0.6875, precision = 0.7133, recall = 0.6933, F1 score = 0.6663
Query ID 144/475
Support set metrics: Loss = 0.9612, accuracy = 0.5625, precision = 0.5521, recall = 0.5533, F1 score = 0.5486
Query set metrics: Loss = 0.9686, accuracy = 0.6250, precisi

Query set metrics: Loss = 0.8521, accuracy = 0.5625, precision = 0.5200, recall = 0.5533, F1 score = 0.5133
Query ID 176/475
Support set metrics: Loss = 0.8111, accuracy = 0.7000, precision = 0.6979, recall = 0.6900, F1 score = 0.6889
Query set metrics: Loss = 1.0040, accuracy = 0.4375, precision = 0.5000, recall = 0.3833, F1 score = 0.4071
Query ID 177/475
Support set metrics: Loss = 0.8129, accuracy = 0.6625, precision = 0.6650, recall = 0.6500, F1 score = 0.6567
Query set metrics: Loss = 0.7716, accuracy = 0.6250, precision = 0.8167, recall = 0.7333, F1 score = 0.7000
Query ID 178/475
Support set metrics: Loss = 0.8538, accuracy = 0.6125, precision = 0.5898, recall = 0.5933, F1 score = 0.5838
Query set metrics: Loss = 1.1714, accuracy = 0.5000, precision = 0.4429, recall = 0.5000, F1 score = 0.4667
Query ID 179/475
Support set metrics: Loss = 0.8317, accuracy = 0.6625, precision = 0.6577, recall = 0.6467, F1 score = 0.6499
Query set metrics: Loss = 0.8823, accuracy = 0.4375, precisi

Query set metrics: Loss = 1.1353, accuracy = 0.5625, precision = 0.6667, recall = 0.5833, F1 score = 0.5533
Query ID 211/475
Support set metrics: Loss = 0.8882, accuracy = 0.6875, precision = 0.6841, recall = 0.6833, F1 score = 0.6816
Query set metrics: Loss = 0.9715, accuracy = 0.5000, precision = 0.4500, recall = 0.5833, F1 score = 0.4400
Query ID 212/475
Support set metrics: Loss = 0.8997, accuracy = 0.6250, precision = 0.6239, recall = 0.6200, F1 score = 0.6210
Query set metrics: Loss = 0.8099, accuracy = 0.5625, precision = 0.7133, recall = 0.6000, F1 score = 0.5619
Query ID 213/475
Support set metrics: Loss = 0.7324, accuracy = 0.7125, precision = 0.7083, recall = 0.7033, F1 score = 0.6983
Query set metrics: Loss = 0.6822, accuracy = 0.6875, precision = 0.5533, recall = 0.6167, F1 score = 0.5600
Query ID 214/475
Support set metrics: Loss = 0.9554, accuracy = 0.6250, precision = 0.6198, recall = 0.6167, F1 score = 0.6159
Query set metrics: Loss = 1.0487, accuracy = 0.4375, precisi

Support set metrics: Loss = 0.9146, accuracy = 0.6250, precision = 0.6328, recall = 0.6167, F1 score = 0.6184
Query set metrics: Loss = 0.8607, accuracy = 0.5625, precision = 0.6000, recall = 0.6000, F1 score = 0.5733
Query ID 247/475
Support set metrics: Loss = 0.7731, accuracy = 0.7375, precision = 0.7580, recall = 0.7400, F1 score = 0.7470
Query set metrics: Loss = 1.1487, accuracy = 0.6250, precision = 0.7333, recall = 0.6800, F1 score = 0.6262
Query ID 248/475
Support set metrics: Loss = 0.8689, accuracy = 0.6500, precision = 0.6464, recall = 0.6367, F1 score = 0.6405
Query set metrics: Loss = 0.8692, accuracy = 0.5625, precision = 0.6500, recall = 0.6000, F1 score = 0.6033
Query ID 249/475
Support set metrics: Loss = 0.9255, accuracy = 0.6375, precision = 0.6303, recall = 0.6233, F1 score = 0.6238
Query set metrics: Loss = 0.8672, accuracy = 0.5625, precision = 0.6267, recall = 0.6810, F1 score = 0.6267
Query ID 250/475
Support set metrics: Loss = 1.0228, accuracy = 0.5875, preci

Support set metrics: Loss = 0.8844, accuracy = 0.6250, precision = 0.6362, recall = 0.6167, F1 score = 0.6207
Query set metrics: Loss = 0.8300, accuracy = 0.5625, precision = 0.5667, recall = 0.6600, F1 score = 0.5689
Query ID 282/475
Support set metrics: Loss = 0.9463, accuracy = 0.5625, precision = 0.5783, recall = 0.5633, F1 score = 0.5687
Query set metrics: Loss = 0.9759, accuracy = 0.6250, precision = 0.6167, recall = 0.7000, F1 score = 0.6333
Query ID 283/475
Support set metrics: Loss = 0.9176, accuracy = 0.6125, precision = 0.6047, recall = 0.6067, F1 score = 0.5985
Query set metrics: Loss = 1.0721, accuracy = 0.5000, precision = 0.5800, recall = 0.4833, F1 score = 0.4856
Query ID 284/475
Support set metrics: Loss = 0.8257, accuracy = 0.6500, precision = 0.6560, recall = 0.6467, F1 score = 0.6489
Query set metrics: Loss = 0.6061, accuracy = 0.6875, precision = 0.7429, recall = 0.7429, F1 score = 0.6905
Query ID 285/475
Support set metrics: Loss = 1.0221, accuracy = 0.5375, preci

Support set metrics: Loss = 0.9778, accuracy = 0.5875, precision = 0.5916, recall = 0.5800, F1 score = 0.5831
Query set metrics: Loss = 0.5701, accuracy = 0.6875, precision = 0.7000, recall = 0.6333, F1 score = 0.6467
Query ID 317/475
Support set metrics: Loss = 0.8850, accuracy = 0.6250, precision = 0.6408, recall = 0.6133, F1 score = 0.6241
Query set metrics: Loss = 0.8173, accuracy = 0.6250, precision = 0.5000, recall = 0.5667, F1 score = 0.5133
Query ID 318/475
Support set metrics: Loss = 0.8744, accuracy = 0.6250, precision = 0.6263, recall = 0.6200, F1 score = 0.6218
Query set metrics: Loss = 0.9696, accuracy = 0.5625, precision = 0.5500, recall = 0.6000, F1 score = 0.5048
Query ID 319/475
Support set metrics: Loss = 1.1479, accuracy = 0.5125, precision = 0.5127, recall = 0.5067, F1 score = 0.5088
Query set metrics: Loss = 0.9744, accuracy = 0.6250, precision = 0.6033, recall = 0.6333, F1 score = 0.6000
Query ID 320/475
Support set metrics: Loss = 1.0466, accuracy = 0.5500, preci

Support set metrics: Loss = 0.9874, accuracy = 0.6125, precision = 0.6027, recall = 0.6100, F1 score = 0.5999
Query set metrics: Loss = 1.1214, accuracy = 0.6250, precision = 0.6800, recall = 0.6190, F1 score = 0.6200
Query ID 352/475
Support set metrics: Loss = 0.9512, accuracy = 0.5875, precision = 0.5874, recall = 0.5733, F1 score = 0.5759
Query set metrics: Loss = 0.6870, accuracy = 0.7500, precision = 0.7381, recall = 0.7714, F1 score = 0.7124
Query ID 353/475
Support set metrics: Loss = 0.9564, accuracy = 0.5750, precision = 0.5802, recall = 0.5733, F1 score = 0.5754
Query set metrics: Loss = 1.3969, accuracy = 0.3750, precision = 0.4300, recall = 0.2667, F1 score = 0.3089
Query ID 354/475
Support set metrics: Loss = 1.0781, accuracy = 0.5500, precision = 0.5857, recall = 0.5467, F1 score = 0.5522
Query set metrics: Loss = 0.8561, accuracy = 0.5625, precision = 0.5833, recall = 0.6300, F1 score = 0.5443
Query ID 355/475
Support set metrics: Loss = 1.0475, accuracy = 0.6000, preci

Support set metrics: Loss = 0.8662, accuracy = 0.7000, precision = 0.7310, recall = 0.7000, F1 score = 0.7139
Query set metrics: Loss = 1.0037, accuracy = 0.5000, precision = 0.5800, recall = 0.5133, F1 score = 0.5032
Query ID 387/475
Support set metrics: Loss = 0.9853, accuracy = 0.7000, precision = 0.7187, recall = 0.7000, F1 score = 0.7066
Query set metrics: Loss = 1.5303, accuracy = 0.5000, precision = 0.5167, recall = 0.4524, F1 score = 0.4514
Query ID 388/475
Support set metrics: Loss = 1.1670, accuracy = 0.5625, precision = 0.5991, recall = 0.5567, F1 score = 0.5702
Query set metrics: Loss = 0.6792, accuracy = 0.7500, precision = 0.7933, recall = 0.7667, F1 score = 0.7378
Query ID 389/475
Support set metrics: Loss = 0.9831, accuracy = 0.6000, precision = 0.6119, recall = 0.5867, F1 score = 0.5968
Query set metrics: Loss = 1.0246, accuracy = 0.5000, precision = 0.4933, recall = 0.5667, F1 score = 0.4310
Query ID 390/475
Support set metrics: Loss = 0.8303, accuracy = 0.6500, preci

KeyboardInterrupt: 

In [None]:
_model_path0 = os.path.splitext(model_path)[0]
csv_filename = _model_path0 + "_update"+ str(updates) +"_results_sr_ta2.csv" # for selective replay
with open(csv_filename, 'w') as csv_file:
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(["data_idx", "label", "label_conf", "pred"])
    csv_writer.writerows(data_for_visual)
print(f"Done writing CSV File at {csv_filename}")

In [None]:
# Log Time for Inference
_model_path0 = os.path.splitext(model_path)[0]
time_txt_filename = _model_path0 + "_update"+ str(updates) +"_time_inference_sr_ta2.csv" 
with open(time_txt_filename, 'w') as csv_file:
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(["time_id", "time"])
    csv_writer.writerow(["Total Time", f"{toc//60} minutes"])
    csv_writer.writerow(["mean Adapt Time", f"{np.mean(all_adapt_time)} s"])
print(f"Done writing Time CSV File at {time_txt_filename}")