In [1]:
from pytorch_transformers import AdamW, WarmupLinearSchedule, WarmupConstantSchedule
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
import argparse
import copy
import logging
import numpy as np
import os, time, csv, random
import pickle
import higher
from settings import parse_test_args, model_classes, init_logging
from utils import TextClassificationDataset, dynamic_collate_fn, prepare_inputs, DynamicBatchSampler

# From Meta-MbPA Paper

>In our experiments, we find that it is sufficient to take this to the extreme such that we consider all test examples as a single cluster. Consequently, we consider the whole memory as neighbours and we randomly sample from it to be comparable with the original local adaptation for- mulation (i.e. same batch sizes and gradient steps). As shown in the next section, it has two benefits: (1) it is more robust to negative transfer, (2) it is faster when we evaluate testing examples as a group.

In [2]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [3]:
args = dotdict({
#     "output_dir": "/data/model_runs/em_in_lll/MetaMBPA_order1",
#     "output_dir": "/data/model_runs/em_in_lll/MetaMBPA_order1_v2",
#     "output_dir": "/data/model_runs/em_in_lll/MetaMBPA_order1_v3",
#     "output_dir": "/data/model_runs/em_in_lll/MetaMBPA_order1_v32",
#     "output_dir": "/data/model_runs/em_in_lll/MetaMBPA_order2_v3",
#     "output_dir": "/data/model_runs/em_in_lll/MetaMBPA_order3_v3",
    "output_dir": "/data/model_runs/em_in_lll/MetaMBPA_order4_v3",
    "adapt_lambda": 1e-3,
    "adapt_lr": 5e-5,
    "adapt_steps": 30,
    "no_fp16_test": False,
    "seed": 42
})
output_dir = args["output_dir"]

torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)

In [4]:
train_args = pickle.load(open(os.path.join(output_dir, 'train_args'), 'rb'))
args.update(train_args.__dict__)
print(str(args))
model_type = args["model_type"]
model_name = args["model_name"]
n_labels = args["n_labels"]
tasks = args["tasks"]

{'output_dir': '/data/model_runs/em_in_lll/MetaMBPA_order4_v3', 'adapt_lambda': 0.001, 'adapt_lr': 5e-05, 'adapt_steps': 30, 'no_fp16_test': False, 'seed': 42, 'adam_epsilon': 1e-08, 'batch_size': 12353, 'debug': False, 'learning_rate': 3e-05, 'logging_steps': 500, 'max_grad_norm': 1.0, 'model_name': 'bert-base-uncased', 'model_type': 'bert', 'n_labels': 33, 'n_neighbors': 32, 'n_test': 7600, 'n_train': 115000, 'n_workers': 4, 'overwrite': False, 'replay_interval': 100, 'reproduce': False, 'tasks': ['ag_news_csv', 'yelp_review_full_csv', 'amazon_review_full_csv', 'yahoo_answers_csv', 'dbpedia_csv'], 'valid_ratio': 0, 'warmup_steps': 0, 'weight_decay': 0, 'inner_lr': 1e-05, 'write_prob': 0.01, 'device_id': 0}


In [5]:
def test_task(task_id, args, model, test_dataset):

    if not args.no_fp16_test:
        model = model.half()

    def update_metrics(loss, logits, cur_loss, cur_acc):
        preds = np.argmax(logits, axis=1)
        return cur_loss + loss, cur_acc + np.sum(preds == labels.detach().cpu().numpy())
    
    # Before anything else, just sample randomly from memory!
    # Use this as sample going forward!
    s_input_ids, s_masks, s_labels = memory.sample(32)
    print(f"Total No.# of sampled: {len(s_labels)}")
    # query like this first just like training... this will need to be removed later!!! so we can adapt 32x32
    with torch.no_grad():
        q_input_ids, q_masks, q_labels = memory.query(s_input_ids, s_masks)
        # Get only one is enough! (Batch of 32)
        # Note: Need to find way to improve later
        q_input_ids = q_input_ids[0].cuda()
        q_masks     = q_masks[0].cuda()
        q_labels    = q_labels[0].cuda()
    
    # Meta-Learning Local Adaptation
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    inner_optimizer = optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate)
    
    
    # Get Original Params
    with torch.no_grad():
        org_params = torch.cat([torch.reshape(param, [-1]) for param in model.parameters()], 0)

    with higher.innerloop_ctx(model, inner_optimizer, copy_initial_weights=False,track_higher_grads=False) as (fmodel, diffopt):
        # 1 Inner Loop (Support Set) - Once for all testing samples
        for step in range(args.adapt_steps):
            params = torch.cat([torch.reshape(param, [-1]) for param in fmodel.parameters()], 0)
            loss = fmodel(input_ids=q_input_ids, attention_mask=q_masks, labels=q_labels)[0]\
                    + args.adapt_lambda * torch.sum((org_params - params)**2)
            diffopt.step(loss)
            fmodel.zero_grad() # Is this necessary? but local adapt have this!
            
        # 2 Outer Loop (Query Set)
        tot_n_inputs = 0
        cur_loss, cur_acc = 0, 0
        all_labels, all_label_confs, all_preds = [], [], []
        
        test_dataloader = DataLoader(test_dataset, num_workers=args.n_workers, collate_fn=dynamic_collate_fn,
                                     batch_sampler=DynamicBatchSampler(test_dataset, args.batch_size * 4))

        for step, batch in enumerate(test_dataloader):
            n_inputs, input_ids, masks, labels = prepare_inputs(batch)
            tot_n_inputs += n_inputs
            
            with torch.no_grad():
                fmodel.eval()
                output = fmodel(input_ids=input_ids, attention_mask=masks, labels=labels)[:2]
                loss = output[0].item()
                logits = output[1].detach().cpu().numpy()
                softmax = F.softmax(output[1], -1) 
                
            # Output has size of torch.Size([16, 33]) [BATCH, CLASSES]
            label_conf = softmax[np.arange(len(softmax)), labels] # Select labels in the softmax of 33 classes
            preds = np.argmax(logits, axis=1)

            cur_loss, cur_acc = update_metrics(loss*n_inputs, logits, cur_loss, cur_acc)

            # Append all!
            all_labels.extend(labels.tolist())
            all_label_confs.extend(label_conf.tolist())
            all_preds.extend(preds.tolist())

            if (step+1) % args.logging_steps == 0:
                print("Tested {}/{} examples, test loss: {:.3f} , test acc: {:.3f}".format(
                    tot_n_inputs, len(test_dataset), cur_loss/tot_n_inputs, cur_acc/tot_n_inputs))


    print("test loss: {:.3f} , test acc: {:.3f}".format(
        cur_loss / len(test_dataset), cur_acc / len(test_dataset)))
    return cur_acc / len(test_dataset), all_labels, all_label_confs, all_preds

In [6]:
tic_RUN = time.time()

all_avg_acc = []
all_acc = []

# Test for all sequential  tasks
for run_task in range(0, len(tasks)):
    config_class, model_class, args["tokenizer_class"] = model_classes[model_type]
    tokenizer = args["tokenizer_class"].from_pretrained(model_name)
    model_config = config_class.from_pretrained(model_name, num_labels=n_labels, hidden_dropout_prob=0, attention_probs_dropout_prob=0)
    save_model_path = os.path.join(output_dir, f'checkpoint-{run_task}')
    model = model_class.from_pretrained(save_model_path, config=model_config).cuda()
    memory = pickle.load(open(os.path.join(output_dir, f'memory-{run_task}'), 'rb'))
    
    
    avg_acc = 0
    accuracies = []
    data_for_visual = []
    for task_id, task in enumerate(tasks):
        print("Start testing {}...".format(task))
        test_dataset = TextClassificationDataset(task, "test", args, tokenizer)
        task_acc, all_labels, all_label_confs, all_preds = test_task(task_id, args, model, test_dataset)

        # Start Edit
        data_ids = [task + str(i) for i in range(len(all_labels))]
        data_for_visual.extend(list(zip(data_ids, all_labels, all_label_confs, all_preds)))
        accuracies.append(task_acc)

        avg_acc += task_acc / len(args.tasks)
    print(f"Average Accuracy: {avg_acc}")
    print(f"Accuracies: {accuracies}")
    
    all_avg_acc.append(avg_acc)
    all_acc.append(accuracies)
    
toc_RUN = time.time() - tic_RUN
print(f"Time run in {toc_RUN} seconds")
print(all_avg_acc)

Start testing ag_news_csv...
Total No.# of sampled: 32
test loss: 0.194 , test acc: 0.938
Start testing yelp_review_full_csv...
Total No.# of sampled: 32
test loss: 8.524 , test acc: 0.000
Start testing amazon_review_full_csv...
Total No.# of sampled: 32
test loss: 8.370 , test acc: 0.000
Start testing yahoo_answers_csv...
Total No.# of sampled: 32
test loss: 8.972 , test acc: 0.000
Start testing dbpedia_csv...
Total No.# of sampled: 32
test loss: 6.947 , test acc: 0.000
Average Accuracy: 0.18760526315789475
Accuracies: [0.9380263157894737, 0.0, 0.0, 0.0, 0.0]
Start testing ag_news_csv...
Total No.# of sampled: 32
test loss: 0.324 , test acc: 0.908
Start testing yelp_review_full_csv...
Total No.# of sampled: 32
test loss: 0.945 , test acc: 0.602
Start testing amazon_review_full_csv...
Total No.# of sampled: 32
test loss: 1.140 , test acc: 0.542
Start testing yahoo_answers_csv...
Total No.# of sampled: 32
test loss: 8.098 , test acc: 0.000
Start testing dbpedia_csv...
Total No.# of samp

In [7]:
for acc_row in all_acc:
    print("\t".join(map(str, acc_row)))

0.9380263157894737	0.0	0.0	0.0	0.0
0.9084210526315789	0.6015789473684211	0.5422368421052631	0.0	0.0
0.8928947368421053	0.5911842105263158	0.6123684210526316	0.0	0.0
0.8798684210526316	0.5481578947368421	0.5101315789473684	0.728421052631579	0.0
0.8798684210526316	0.5563157894736842	0.5260526315789473	0.6967105263157894	0.9923684210526316
