In [1]:
from pytorch_transformers import AdamW, WarmupLinearSchedule, WarmupConstantSchedule
from torch import optim
from torch.utils.data import DataLoader
import argparse
import copy
import logging
import numpy as np
import os, time, csv
import pickle
import torch
import torch.nn.functional as F
import higher
from settings import parse_test_args, model_classes, init_logging
from utils import TextClassificationDataset, dynamic_collate_fn, prepare_inputs, DynamicBatchSampler

# From Meta-MbPA Paper

>In our experiments, we find that it is sufficient to take this to the extreme such that we consider all test examples as a single cluster. Consequently, we consider the whole memory as neighbours and we randomly sample from it to be comparable with the original local adaptation for- mulation (i.e. same batch sizes and gradient steps). As shown in the next section, it has two benefits: (1) it is more robust to negative transfer, (2) it is faster when we evaluate testing examples as a group.

In [2]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [3]:
args = dotdict({
#     "output_dir": "/data/model_runs/em_in_lll/MetaMBPA_order1",
    "output_dir": "/data/model_runs/em_in_lll/MetaMBPA_order1_v2",
    "adapt_lambda": 1e-3,
    "adapt_lr": 2e-3,
    "adapt_steps": 20,
    "no_fp16_test": False
})
output_dir = args["output_dir"]

In [4]:
train_args = pickle.load(open(os.path.join(output_dir, 'train_args'), 'rb'))
args.update(train_args.__dict__)
str(args)
model_type = args["model_type"]
model_name = args["model_name"]
n_labels = args["n_labels"]
tasks = args["tasks"]

In [5]:
config_class, model_class, args["tokenizer_class"] = model_classes[model_type]
tokenizer = args["tokenizer_class"].from_pretrained(model_name)
model_config = config_class.from_pretrained(model_name, num_labels=n_labels, hidden_dropout_prob=0, attention_probs_dropout_prob=0)
save_model_path = os.path.join(output_dir, f'checkpoint-{len(tasks)-1}')
model = model_class.from_pretrained(save_model_path, config=model_config).cuda()
memory = pickle.load(open(os.path.join(output_dir, f'memory-{len(tasks)-1}'), 'rb'))

In [6]:
def test_task(task_id, args, model, test_dataset):

    if not args.no_fp16_test:
        model = model.half()

    def update_metrics(loss, logits, cur_loss, cur_acc):
        preds = np.argmax(logits, axis=1)
        return cur_loss + loss, cur_acc + np.sum(preds == labels.detach().cpu().numpy())
    
    # Before anything else, just sample randomly from memory!
    # Use this as sample going forward!
    s_input_ids, s_masks, s_labels = memory.sample(32)
    print(f"Total No.# of sampled: {len(s_labels)}")
    # query like this first just like training... this will need to be removed later!!! so we can adapt 32x32
    with torch.no_grad():
        q_input_ids, q_masks, q_labels = memory.query(s_input_ids, s_masks)
    
    # Meta-Learning Local Adaptation
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    inner_optimizer = optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate)
    with higher.innerloop_ctx(model, inner_optimizer, copy_initial_weights=False,track_higher_grads=False) as (fmodel, diffopt):
        # 1 Inner Loop (Support Set) - Once for all testing samples
        for sup_input_ids, sup_masks, sup_labels in zip(q_input_ids, q_masks, q_labels):
            sup_input_ids = sup_input_ids.cuda()
            sup_masks = sup_masks.cuda()
            sup_labels = sup_labels.cuda()
            loss = fmodel(input_ids=sup_input_ids, attention_mask=sup_masks, labels=sup_labels)[0]
            diffopt.step(loss)
            
        # 2 Outer Loop (Query Set)
        tot_n_inputs = 0
        cur_loss, cur_acc = 0, 0
        all_labels, all_label_confs, all_preds = [], [], []
        
        test_dataloader = DataLoader(test_dataset, num_workers=args.n_workers, collate_fn=dynamic_collate_fn,
                                     batch_sampler=DynamicBatchSampler(test_dataset, args.batch_size * 4))

        for step, batch in enumerate(test_dataloader):
            n_inputs, input_ids, masks, labels = prepare_inputs(batch)
            tot_n_inputs += n_inputs
            
            with torch.no_grad():
                fmodel.eval()
                output = fmodel(input_ids=input_ids, attention_mask=masks, labels=labels)[:2]
                loss = output[0].item()
                logits = output[1].detach().cpu().numpy()
                softmax = F.softmax(output[1], -1) 
                
            # Output has size of torch.Size([16, 33]) [BATCH, CLASSES]
            label_conf = softmax[np.arange(len(softmax)), labels] # Select labels in the softmax of 33 classes
            preds = np.argmax(logits, axis=1)

            cur_loss, cur_acc = update_metrics(loss*n_inputs, logits, cur_loss, cur_acc)

            # Append all!
            all_labels.extend(labels.tolist())
            all_label_confs.extend(label_conf.tolist())
            all_preds.extend(preds.tolist())

            if (step+1) % args.logging_steps == 0:
                print("Tested {}/{} examples, test loss: {:.3f} , test acc: {:.3f}".format(
                    tot_n_inputs, len(test_dataset), cur_loss/tot_n_inputs, cur_acc/tot_n_inputs))


    print("test loss: {:.3f} , test acc: {:.3f}".format(
        cur_loss / len(test_dataset), cur_acc / len(test_dataset)))
    return cur_acc / len(test_dataset), all_labels, all_label_confs, all_preds

In [7]:
avg_acc = 0
accuracies = []
data_for_visual = []
for task_id, task in enumerate(tasks):
    print("Start testing {}...".format(task))
    test_dataset = TextClassificationDataset(task, "test", args, tokenizer)
    task_acc, all_labels, all_label_confs, all_preds = test_task(task_id, args, model, test_dataset)

    # Start Edit
    data_ids = [task + str(i) for i in range(len(all_labels))]
    data_for_visual.extend(list(zip(data_ids, all_labels, all_label_confs, all_preds)))
    accuracies.append(task_acc)

    avg_acc += task_acc / len(args.tasks)

Start testing yelp_review_full_csv...
Total No.# of sampled: 32
test loss: 8.006 , test acc: 0.004
Start testing ag_news_csv...
Total No.# of sampled: 32
test loss: 6.987 , test acc: 0.000
Start testing dbpedia_csv...
Total No.# of sampled: 32
test loss: 1.789 , test acc: 0.550
Start testing amazon_review_full_csv...
Total No.# of sampled: 32
test loss: 7.823 , test acc: 0.003
Start testing yahoo_answers_csv...
Total No.# of sampled: 32
test loss: 0.798 , test acc: 0.747


In [8]:
print(f"Average Accuracy: {avg_acc}")
print(f"Accuracies: {accuracies}")

Average Accuracy: 0.26092105263157894
Accuracies: [0.003947368421052632, 0.0, 0.5502631578947368, 0.0030263157894736843, 0.7473684210526316]
