In [1]:
# https://stackoverflow.com/questions/30656777/how-to-call-module-written-with-argparse-in-ipython-notebook
# REMOVE THIS ONCE YOU PUT IT IN A MODULE!
import sys
sys.argv = ['']
# END

import os, time
from pytorch_transformers import AdamW
from fp16 import FP16_Module, FP16_Optimizer
import json, csv
import torch
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from collections import OrderedDict
import logging
import warnings
from metrics import compute_metrics

from knight_utils import (
    FILL_VAL, MODEL_BASE_DIR, TASK_DICT , QADataset, DEVICE, 
    args, MODEL_CONFIG, MODEL_CLASS, TOKENIZER, CONFIG_CLASS, SPECIAL_TOKEN_IDS, SPECIAL_TOKENS, TOKENS_WEIGHT, 
    create_dataloader, get_gen_token, CONFIG_NAME
)

In [2]:
data_dir = "/root/LAMOL/lamol_data"
MODEL_BASE_DIR = "/data/model_runs"
# MODEL_DIR_NAME = "20210630T184801_mbs_SEQ_MAML"
# MODEL_DIR_NAME = "20210701T033914_mbs_SEQ"
MODEL_DIR_NAME = "20210701T073948_msb_SEQ_MAML"
# MODEL_DIR_NAME = "20210702T034545_msb_SEQ"
# MODEL_DIR_NAME = "20210701T180911_sbm_SEQ_MAML"

MODEL_DIR_NAME = "20210824T030642_mbs_SEQ_MAML" # Special of movie_steps101.model. Save special in between
MODEL_DIR_NAME = "20210824T035319_mbs_SEQ_MAML_savenetpibetween" # special of movie_steps101.model 
MODEL_DIR_NAME = "20210824T053734_mbs_SEQ_MAML_loadstatedict" # special load state dict
MODEL_DIR_NAME = "20210829T003703_mbs_SEQ_MAML_v2" # special load state dict

MODEL_DIR = os.path.join(MODEL_BASE_DIR,MODEL_DIR_NAME)

# tasks = ['movie',  'scifact', 'boolq']
tasks = ['movie', 'boolq',  'scifact']

DEVICE = 'cuda:0'
temperature_qa = 1.0
n_train_epochs = 1
num_updates = 5

train_batch_size = 3
test_batch_size = 1

gen_lm_sample_percentage = 0.05

top_k_qa = 20
top_p_qa = 0.

FILL_VAL = -1
n_gpus = 1
device_ids = [1]

logging.basicConfig(filename=f'{MODEL_DIR}/test_run.log', level=logging.INFO)
logger = logging.getLogger(__name__)

In [3]:
def remove_id(idx, need_process, all_pasts):
    assert idx in need_process
    del need_process[idx]
    for layer_id in range(MODEL_CONFIG.n_layer):
        all_pasts[layer_id][idx] = 0
        
        
def sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens):
    while len(need_process) > 0:                       # While there is still any need_process
        first_id = next(iter(need_process))            # The first one to process
        shortest_len = len(qa_results[first_id])       # The shortest length is the length of itself?
#         decode_batch_size = int(args.memory_sizes[0] * MEMORY_FACTOR[args.seq_train_type] // (shortest_len+1)**LEN_FACTOR)
        decode_batch_size = test_batch_size
        it = iter(need_process)                        # it is iterable of need_process
        stop = False
        remove_ids = []
        while not stop:
            batch_ids, input_ids, past = [], [], [[] for _ in range(MODEL_CONFIG.n_layer)]
            while True:
                try:
                    cur_id = next(it)                   # let the current id be the next batch of need_process
                    if len(qa_results[cur_id]) > shortest_len:  # if the length is too long, just stop
                        stop = True
                        break
                    batch_ids.append(cur_id)            
                    input_ids.append(qa_results[cur_id][-1:])
                    for layer_id in range(MODEL_CONFIG.n_layer):
                        past[layer_id].append(all_pasts[layer_id][cur_id])
                    if len(input_ids) == decode_batch_size:
                        break
                except StopIteration:                    # if there is no more id in need_process, just stop
                    stop = True
                    break

            n_inputs = len(input_ids)
            if n_inputs == 0:
                break
            input_ids = torch.stack(input_ids)
            for layer_id in range(MODEL_CONFIG.n_layer):
                past[layer_id] = torch.stack(past[layer_id], dim=1)
            all_outputs = model(input_ids=input_ids.cuda(), past=past)

            outputs = all_outputs[0]
            pasts = all_outputs[1]

            next_logits = outputs[..., -1, :] / temperature_qa
            next_tokens = logits_to_tokens(next_logits).cpu()

            for i, cur_id in enumerate(batch_ids):
                if next_tokens[i] == SPECIAL_TOKEN_IDS["eos_token"]:
                    remove_ids.append(cur_id)
                else:
                    qa_results[cur_id] = torch.cat((qa_results[cur_id], next_tokens[i]))
                    if len(qa_results[cur_id]) in [max_tot_lens[cur_id], max_len]:
                        remove_ids.append(cur_id)
                    else:
                        for layer_id in range(MODEL_CONFIG.n_layer):
                            all_pasts[layer_id][cur_id] = pasts[layer_id][:, i].type(torch.half)
        for idx in remove_ids:
            remove_id(idx, need_process, all_pasts)

def get_gen_token(task):
    return '__' + task + '__'

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    # assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value
    return logits

def logits_to_tokens(next_logits):
    filtered_logits = top_k_top_p_filtering(next_logits, top_k=top_k_qa, top_p=top_p_qa)
    log_probs = F.softmax(filtered_logits, dim=-1)
    next_tokens = torch.multinomial(log_probs, num_samples=1)
    return next_tokens

In [4]:
def get_test_score(task_eval,qa_results,score_dict):

    score = compute_metrics(
            qa_results,
            bleu='iwslt.en.de' in task_eval or 'multinli.in.out' in task_eval,
            dialogue='woz.en' in task_eval,
            rouge='cnn_dailymail' in task_eval,
            logical_form='wikisql' in task_eval,
            corpus_f1='zre' in task_eval
    )
    score_dict[task_eval] = score


In [5]:
def test_one_to_one(task_load, task_eval, model, score_dict):
    tic_TASK = time.time()
    logger.info("start to test { task: %s (load) %s (eval)}" % (task_load, task_eval))
    print("start to test { task: %s (load) %s (eval)}" % (task_load, task_eval))

    
    # Test Dataset : Support (Train QAData) Query (Test QAData)
    support_qadata = QADataset(TASK_DICT[task_eval]["train"], "train", SPECIAL_TOKEN_IDS[task_load])
    test_qadata = QADataset(TASK_DICT[task_eval]["test"] , "test", SPECIAL_TOKEN_IDS[task_load]).sort()
    
    max_a_len = test_qadata.max_a_len
    n_examples = len(test_qadata)
    logger.info("len of test dataset: {}".format(n_examples))
    print("len of test dataset: {}".format(n_examples))
    
    ##### Make dataloaders for that particular dataset #####
    support_dataloader = create_dataloader(support_qadata, "train")
    test_dataloader = create_dataloader(test_qadata, "test")
    
    
    ##### Stream from that dataset's dataloader #####
    iter_support_dataloader = iter(support_dataloader)
    iter_test_dataloader = iter(test_dataloader)
    

    need_process = OrderedDict()
    # qa_results is qa_results[cnt]
    qa_results = [0 for _ in range(n_examples)]
    # All pasts is shape all_pasts[layer_id][cnt]
    all_pasts = [[0 for _ in range(n_examples)] for __ in range(MODEL_CONFIG.n_layer)]
    # max_tot_lens is qa_results[cnt]
    max_tot_lens = [0 for _ in range(n_examples)]

    cnt = 0
    n_steps = 0
    
    while True:
#     for n_steps, (cq, len_cq, cqa, len_cqa, Y, genX, genY) in enumerate(test_dataloader):
        # 1. Get the support data from the train dataloader
        #    and the query data from the test dataloader
        # Assume that query data >> support data!
        try:
            
            _, _, support_x, _, support_y, _, _ = next(iter_support_dataloader)
            query_x, query_x_len, query_x_cqa, _, query_y, _, _ = next(iter_test_dataloader) # Let query get the CQ!

            # Different inputs for train and test -> train with batch 3 and test with batch 1
            n_inputs_train = sum(_cqa.shape[0] for _cqa in support_x)
            n_inputs = sum(_cqa.shape[0] for _cqa in query_x)

            # Since we only have 1 GPU, just use the first one, it will separate batches according to the device IDS
            support_x = support_x[0]
            support_y = support_y[0]
            query_x = query_x[0]
            query_y = query_y[0]
            query_x_len = query_x_len[0] # an array of query x lengths, but test batch size is only1??
            query_x_cqa = query_x_cqa[0] #EXTRA DEBUG

            support_x = support_x.to(DEVICE)
            support_y = support_y.to(DEVICE)
            query_x = query_x.to(DEVICE)
            query_y = query_y
            query_x_cqa = query_x_cqa.to(DEVICE) #EXTRA DEBUG

        except StopIteration:
            break
            
    
        
        # assume n_gpus == 1
#         cqs = cqs[0]
#         len_cqs = len_cqs[0]
#         n_inputs = cqs.shape[0]
        
        ### START Adaptation Phase ###
        # 2. Reinitialize model with parameters from model_path
        state_dict = torch.load(model.model_path, map_location='cuda:0')
        model.load_state_dict(state_dict)
        model.train()
        
        # Training loss function
        train_loss_fct = CrossEntropyLoss(ignore_index=FILL_VAL, weight=TOKENS_WEIGHT)
        
        # Optimizer
        max_grad_norm=1
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        optimizer = AdamW(optimizer_grouped_parameters, lr=6.25e-5, eps=1e-4)
        optimizer = FP16_Optimizer(optimizer, static_loss_scale=None, dynamic_loss_scale=True,
                                           dynamic_loss_args={'scale_window': 100, 'min_scale': 1, 'delayed_shift': 2})

        
        
        # 3. Update the weights with the support set
        # May update for several steps
        for i in range(num_updates):

            qa_logits = model(support_x)
            # Somehow it also returns attentions in [1]?, this is selecting 0 of what WrapModel is doing 
            qa_logits = qa_logits[0]
            qa_loss = train_loss_fct(qa_logits.transpose(1,2), support_y)
            loss = qa_loss

            logger.info(f"[DEBUG] Adaptation loss: {qa_loss.item()}")
            # Update Optimizer
            optimizer.backward(loss, update_master_grads=False) # instead of loss.backward() for fp16
            optimizer.update_master_grads()
            optimizer.clip_master_grads(max_grad_norm)
            optimizer.step()
            # Ignore this for now
#             if not optimizer.overflow:
#                 for i in range(n_inputs):
#                     scheduler.step()
            optimizer.zero_grad()
        ### END Adaptation Phase ###
        
        model.eval()

        ### START Meta-Learning Phase ###
        # 4. After Adaptation, use the query set for test (CQ ONLY)
        # model() returns Tuple of length 2: 
        #  The [0] is a  torch.Size([1, 225, 50260]), and the [1] is 12 of torch.Size([2, 1, 12, 225, 64])
        # Thinking that the [0] is the actual output and [1] is the pasts?
        all_outputs = model(query_x)
        outputs = all_outputs[0]
        pasts = all_outputs[1]
        next_logits = outputs[range(n_inputs), query_x_len-1, :] / temperature_qa
        next_tokens = logits_to_tokens(next_logits).cpu()
        
        
        # EXTRA FOR COMPARE
        qa_logits = model(query_x_cqa)[0]
        qa_loss = train_loss_fct(qa_logits.transpose(1,2), query_y.to(DEVICE))
        logger.info(f"[DEBUG] QUERY LOSS: {qa_loss.item()}")
        
        # Maybe this is not needed in testing since n_inputs is only 1??
        for batch_i in range(n_inputs):
            # max total length = max answer length + length of cq
            max_tot_lens[cnt] = max_a_len + test_qadata[cnt][1] 
            # add the cq of that particular batch to qa_results (Change it to cpu first!)
            qa_results[cnt] = query_x.cpu()[batch_i][:query_x_len[batch_i]]
            
            # If the next tokens is not eos
            if next_tokens[batch_i] != SPECIAL_TOKEN_IDS["eos_token"]:
                # Concat the result
                qa_results[cnt] = torch.cat((qa_results[cnt], next_tokens[batch_i]))
                # if the length is not max yet -> MAXTOT 225 1024
                if len(qa_results[cnt]) not in [max_tot_lens[cnt], max_len]:
                    # Append need_process of that cnt
                    need_process.update([[cnt, None]])
                    # Update all pasts
                    for layer_id in range(MODEL_CONFIG.n_layer):
                        all_pasts[layer_id][cnt] = pasts[layer_id][:, batch_i, ..., :query_x_len[batch_i], :].type(torch.half)
            
            # Try sample_sequence here! it will get all need_process (should be only 1 batch, and generate all!)
            sample_sequence(model, need_process, qa_results, all_pasts, max_tot_lens)
            
            
            
            logger.info(f"[ERROR_ANALYSIS] {task_eval} {cnt}/{n_examples} Predicted Answer {TOKENIZER.decode(qa_results[cnt].tolist())}")
            logger.info(f"[ERROR_ANALYSIS] {task_eval} {cnt}/{n_examples} Predicted Tokens {qa_results[cnt].tolist()[query_x_len:]}")

            # Do the score calculation here
            # The answer of that particular batch to list
            Y = query_y[batch_i].tolist()
            Y = list(filter(lambda x: x != -1, Y))[:-1]  # remove eos from the answer
            logger.info(f"[ERROR_ANALYSIS] {task_eval} {cnt}/{n_examples} Actual Tokens {Y}")
            Y = ' '.join([str(y) for y in Y]).split(str(SPECIAL_TOKEN_IDS["pad_token"]))
            Y = [TOKENIZER.decode(list(map(int, y.split()))) for y in Y]
            # Change the QA Results to a decoded version of real answer and predicted answer
            qa_results[cnt] = [TOKENIZER.decode(qa_results[cnt].tolist()[query_x_len:]), Y]
            print(f"Predict vs Actual {cnt}/{n_examples}", qa_results[cnt])
            logger.info(f"[ERROR_ANALYSIS] {task_eval} {cnt}/{n_examples} Actual Answer {Y}")
            logger.info(f"[ERROR_ANALYSIS] {task_eval} {cnt}/{n_examples} Predict vs Actual {qa_results[cnt]}")
            
            cnt += 1
        n_steps += 1
    
    toc_TASK = time.time() - tic_TASK
    logger.info(f'[TIME] TASK {(task_load, task_eval)} {toc_TASK}')
    
    get_test_score(task_eval, qa_results, score_dict)
    print(score_dict)

    model_dir = model.model_dir
    results_path = os.path.join(model_dir,f"qa_{task_eval}.csv")
    with open(results_path, "w",encoding="utf-8") as f:
        qa_writer = csv.writer(f,delimiter=',')
        qa_writer.writerow(["y","pred"])
        for pred, y in qa_results:
            qa_writer.writerow([y,pred]) 

    return model, score_dict

In [6]:
%%time

for task in tasks:
    
    model_path = os.path.join(MODEL_DIR, f"{task}.model")
    config_path = os.path.join(MODEL_DIR,CONFIG_NAME)

    gen_token = get_gen_token(task)
    TOKENIZER.add_tokens([gen_token])
    SPECIAL_TOKENS[task] = gen_token
    SPECIAL_TOKEN_IDS[task] = TOKENIZER.convert_tokens_to_ids(gen_token
    model = MODEL_CLASS(MODEL_CONFIG).cuda()
    # Don't load state dict here, load for every adaptation phase!
    
#     print(model)
    print(model_path)
    
    global TOKENS_WEIGHT
    if len(TOKENIZER) != TOKENS_WEIGHT.shape[0]:
        TOKENS_WEIGHT = torch.cat((TOKENS_WEIGHT, torch.ones([1]).cuda()))
    
    model.resize_token_embeddings(len(TOKENIZER))
    model = FP16_Module(model)
    
    
    model.model_dir = MODEL_DIR
    model.model_path = model_path
    # Try Loading the state dict like this!
    model.state_dict = torch.load(model.model_path, map_location='cuda:0')
    logger.info(f"task: {task}")
    score_dict = {k:None for k in tasks}
    
    for task_eval in tasks:
        test_one_to_one(task, task_eval, model, score_dict)
    logger.info("score: {}".format(score_dict))

    with open(os.path.join(MODEL_DIR, f"metrics-{task}.json"),"w") as f:
        json.dump(score_dict, f)

NameError: name 'model_config' is not defined