In [29]:
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, CONFIG_NAME 
from pytorch_transformers import AdamW
from fp16 import FP16_Module, FP16_Optimizer
from torch import nn
import os
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import json
from multiprocessing import Pool
from parallel import DataParallelModel, DataParallelCriterion
from torch.nn import CrossEntropyLoss
from scheduler import AnnealingLR
import warnings

import GPUtil

In [2]:
FILL_VAL = -1

# MEMORY_FACTOR = {
#     "finetune": 0.58,
#     "multitask": 0.58,
#     "lll": 0.35,
#     "ewc": 0.30,
#     "mas": 0.18,
#     "gem": 0.50,
# }
n_gpus = 1

#DRIVER MISMATCH??
# device_ids = GPUtil.getAvailable(maxLoad=0.1, maxMemory=0.05, limit=n_gpus)
# gpus = GPUtil.getGPUs()
# gpu_names = [gpus[device_id].name for device_id in device_ids]

# memory_sizes = [gpus[device_id].memoryTotal for device_id in device_ids]
# memory_sizes[0] = args.memory_sizes[0] * (1 - 0.04 * (n_gpus-1))
# for i in range(1, n_gpus):
#     memory_sizes[i] = args.memory_sizes[i] * 1.04

# train_batch_size = [int(memory_size * MEMORY_FACTOR['lll']) for memory_size in memory_sizes]
# test_batch_size = [int(memory_size * MEMORY_FACTOR['lll']) for memory_size in memory_sizes]

train_batch_size = 4
test_batch_size = 4


In [3]:
# n_gpus = 1
# device_ids = [0]
# gpus = GPUtil.getGPUs()
# gpu_names = [gpus[device_id].name for device_id in device_ids]
# memory_sizes = [gpus[device_id].memoryTotal for device_id in device_ids]
# memory_sizes[0] = memory_sizes[0] * (1 - 0.04 * (n_gpus-1))
# for i in range(1, n_gpus):
#     memory_sizes[i] = args.memory_sizes[i] * 1.04
    
# train_batch_size = [int(memory_size * MEMORY_FACTOR['lll']) for memory_size in memory_sizes]
# test_batch_size = [int(memory_size * MEMORY_FACTOR['lll']) for memory_size in memory_sizes]
# print(memory_sizes)
# print(train_batch_size)
# print(test_batch_size)

In [4]:
MODEL_CLASSES = {
    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config),
}

In [5]:
# In settings.py
special_tokens = {"ans_token":'__ans__', "pad_token":'__pad__', "unk_token":'__unk__', "eos_token": '<|endoftext|>'}

model_class, tokenizer_class, config_class = MODEL_CLASSES['gpt2']
tokenizer = tokenizer_class.from_pretrained('gpt2')
tokenizer.add_tokens(list(special_tokens.values()))
special_token_ids = {k:tokenizer.convert_tokens_to_ids(v) for k,v in special_tokens.items()}


model_config = config_class.from_pretrained('gpt2')
model_config.vocab_size = len(tokenizer)

tokens_weight = torch.ones([model_config.vocab_size], dtype=torch.float).cuda()
tokens_weight[special_token_ids["ans_token"]] = 5


MODEL_CLASS = model_class
TOKENIZER = tokenizer
SPECIAL_TOKENS = special_tokens
SPECIAL_TOKEN_IDS = special_token_ids
TOKENS_WEIGHT = tokens_weight

model_config = config_class.from_pretrained('gpt2')
max_len = model_config.n_positions

In [6]:
max_len

1024

In [7]:
# In Train.py
model = MODEL_CLASS.from_pretrained('gpt2').cuda()
model.resize_token_embeddings(len(TOKENIZER))
model = FP16_Module(model)
model

FP16_Module(
  (module): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50260, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0): Block(
          (ln_1): FusedLayerNorm(torch.Size([768]), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): FusedLayerNorm(torch.Size([768]), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): Block(
          (ln_1): FusedLayerNorm(torch.Size([768]), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dr

In [7]:
class WrapModel(torch.nn.Module):
    def __init__(self, model):
        super(WrapModel, self).__init__()
        self.model = model

    def forward(self, input_ids):
        outputs = self.model(input_ids)
        return outputs[0]

In [27]:
model = WrapModel(model)

In [8]:
parallel_model = DataParallelModel(WrapModel(model), [0])

In [None]:
class MetaLearner(nn.Module):
    """
    As we have mentioned in Learner class, the metalearner class will receive a series of loss on different tasks/episodes
    on theta_pi network, and it will merage all loss and then sum over it. The summed loss will be backproped on theta
    network to update theta parameters, which is the initialization point we want to find.
    """
    
    def __init__(self, net_cls, net_cls_args, n_way, k_shot, meta_batchsz, beta, num_updates):
        """
        :param net_cls: class, not instance. the class of specific Network for learner
        :param net_cls_args: tuple, args for net_cls, like (n_way, imgsz)
        :param n_way:
        :param k_shot:
        :param meta_batchsz: number of tasks/episode
        :param beta: learning rate for meta-learner
        :param num_updates: number of updates for learner
        """
        super(MetaLearner, self).__init__()

        self.n_way = n_way
        self.k_shot = k_shot
        self.meta_batchsz = meta_batchsz
        self.beta = beta
        # self.alpha = alpha # set alpha in Learner.optimizer directly.
        self.num_updates = num_updates

        # it will contains a learner class to learn on episodes and gather the loss together.
        self.learner = Learner(net_cls, *net_cls_args)
        # the optimizer is to update theta parameters, not theta_pi parameters.
        self.optimizer = optim.Adam(self.learner.parameters(), lr=beta)
        
    def write_grads(self, dummy_loss, sum_grads_pi):
        """
        write loss into learner.net, gradients come from sum_grads_pi.
        Since the gradients info is not calculated by general backward, we need this function to write the right gradients
        into theta network and update theta parameters as wished.
        :param dummy_loss: dummy loss, nothing but to write our gradients by hook
        :param sum_grads_pi: the summed gradients
        :return:
        """

        # Register a hook on each parameter in the net that replaces the current dummy grad
        # with our grads accumulated across the meta-batch
        hooks = []

        for i, v in enumerate(self.learner.parameters()):
            def closure():
                ii = i
                return lambda grad: sum_grads_pi[ii]

            # if you write: hooks.append( v.register_hook(lambda grad : sum_grads_pi[i]) )
            # it will pop an ERROR, i don't know why?
            hooks.append(v.register_hook(closure()))

        # use our sumed gradients_pi to update the theta/net network,
        # since our optimizer receive the self.net.parameters() only.
        self.optimizer.zero_grad()
        dummy_loss.backward()
        self.optimizer.step()

        # if you do NOT remove the hook, the GPU memory will expode!!!
        for h in hooks:
            h.remove()
            
    def forward(self, support_x, support_y, query_x, query_y):
        """
        Here we receive a series of episode, each episode will be learned by learner and get a loss on parameters theta.
        we gather the loss and sum all the loss and then update theta network.
        setsz = n_way * k_shotf
        querysz = n_way * k_shot
        :param support_x: [meta_batchsz, setsz, c_, h, w]
        :param support_y: [meta_batchsz, setsz]
        :param query_x:   [meta_batchsz, querysz, c_, h, w]
        :param query_y:   [meta_batchsz, querysz]
        :return:
        """
        sum_grads_pi = None
        meta_batchsz = support_y.size(0)

        # support_x[i]: [setsz, c_, h, w]
        # we do different learning task sequentially, not parallel.
        accs = []
        # for each task/episode.
        for i in range(meta_batchsz):
            _, grad_pi, episode_acc = self.learner(support_x[i], support_y[i], query_x[i], query_y[i], self.num_updates)
            accs.append(episode_acc)
            if sum_grads_pi is None:
                sum_grads_pi = grad_pi
            else:  # accumulate all gradients from different episode learner
                sum_grads_pi = [torch.add(i, j) for i, j in zip(sum_grads_pi, grad_pi)]

        # As we already have the grads to update
        # We use a dummy forward / backward pass to get the correct grads into self.net
        # the right grads will be updated by hook, ignoring backward.
        # use hook mechnism to write sumed gradient into network.
        # we need to update the theta/net network, we need a op from net network, so we call self.learner.net_forward
        # to get the op from net network, since the loss from self.learner.forward will return loss from net_pi network.
        dummy_loss, _ = self.learner.net_forward(support_x[0], support_y[0])
        self.write_grads(dummy_loss, sum_grads_pi)

        return accs
    
    def pred(self, support_x, support_y, query_x, query_y):
        """
        predict for query_x
        :param support_x:
        :param support_y:
        :param query_x:
        :param query_y:
        :return:
        """
        meta_batchsz = support_y.size(0)

        accs = []
        # for each task/episode.
        # the learner will copy parameters from current theta network and then fine-tune on support set.
        for i in range(meta_batchsz):
            _, _, episode_acc = self.learner(support_x[i], support_y[i], query_x[i], query_y[i], self.num_updates)
            accs.append(episode_acc)

        return np.array(accs).mean()

In [None]:
meta = MetaLearner(Naive, (n_way, imgsz), n_way=n_way, k_shot=k_shot, meta_batchsz=meta_batchsz, beta=meta_lr,
                   num_updates=num_updates).cuda()

In [None]:
learner = Learner(net_cls, *net_cls_args)

In [11]:
# pls make sure net_cls is a class but NOT an instance of class.
print(model.__class__)
assert model.__class__ == type

<class 'fp16.FP16_Module'>


AssertionError: 

In [18]:
class Learner(nn.Module):
    """
    Learner class will be responsible for ADAPT on meta-train set (Support) and return loss on meta-test set (Query)
    (1) Create 2 same network, theta and theta_pi.
    (2) For each episode, theta_pi copy initial parameters from theta and ADAPT several steps by meta-train set (Support)
    (3) and then calculate its loss on meta-test set (Query). 
    (4) All loss on meta-test will be summed and backprop on theta network (done on Metalearner class)
    """
    def __init__(self, net_cls, token_size):
        """
        It will receive a class: net_cls and its parameters: args for net_cls.
        :param net_cls: class, not instance
        :param args: the parameters for net_cls
        """
        super(Learner, self).__init__()
        # pls make sure net_cls is a class but NOT an instance of class.
        assert net_cls.__class__ == type
        
        # we will create two class instance meanwhile and use one as theta network and the other as theta_pi network.
        self.net = net_cls.from_pretrained('gpt2').cuda()
        self.net.resize_token_embeddings(token_size)
        self.net = FP16_Module(model)
        
        self.net_pi = net_cls.from_pretrained('gpt2').cuda()
        self.net_pi.resize_token_embeddings(token_size)
        self.net_pi = FP16_Module(model)
        
        
        # From train.py
        param_optimizer = list(self.net_pi.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        
        # ADAPT theta_pi = theta_pi - lr * grad
        self.optimizer = AdamW(optimizer_grouped_parameters, lr=6.25e-5, eps=1e-4)
        
        
    def parameters(self):
        """
        Override this function to return only net parameters for MetaLearner's optimize
        it will ignore theta_pi network parameters.
        :return:
        """
        return self.net.parameters()
    def update_pi(self):
        """
        copy parameters from self.net -> self.net_pi
        :return:
        """
        for m_from, m_to in zip(self.net.modules(), self.net_pi.modules()):
#             if isinstance(m_to, nn.Linear) or isinstance(m_to, nn.Conv2d) or isinstance(m_to, nn.BatchNorm2d):
#                 m_to.weight.data = m_from.weight.data.clone()
#                 if m_to.bias is not None:
#                     m_to.bias.data = m_from.bias.data.clone()

            # NOT SURE IF THIS IS CORRECT??? USING STATEDICT loading/ via NAMED PARAMETERS??
            m_to.load_state_dict(m_from.state_dict().copy())
    def forward(self, support_x, support_y, query_x, query_y, num_updates):
        """
        learn on current episode meta-train: support_x & support_y and then calculate loss on meta-test set: query_x&y
        :param support_x: [setsz, c_, h, w]
        :param support_y: [setsz]
        :param query_x:   [querysz, c_, h, w]
        :param query_y:   [querysz]
        :param num_updates: 5
        :return:
        """
        
        # now try to fine-tune from current $theta$ parameters -> $theta_pi$
        # after num_updates of fine-tune, we will get a good theta_pi parameters so that it will retain satisfying
        # performance on specific task, that's, current episode.
        # firstly, copy theta_pi from theta network
        self.update_pi()
        
        # update for several steps
        for i in range(num_updates):
            # forward and backward to update net_pi grad.
            loss, pred = self.net_pi(support_x, support_y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        
        # Compute the meta gradient and return it, the gradient is from one episode
        # in metalearner, it will merge all loss from different episode and sum over it.
        loss, pred = self.net_pi(query_x, query_y)
        # pred: [setsz, n_way], indices: [setsz]
        _, indices = torch.max(pred, dim=1)
#         correct = torch.eq(indices, query_y).sum().data[0]
        correct = torch.eq(indices, query_y).sum().item()
        acc = correct / query_y.size(0)
        
        # gradient for validation on theta_pi
        # after call autorad.grad, you can not call backward again except for setting create_graph = True
        # as we will use the loss as dummpy loss to conduct a dummy backprop to write our gradients to theta network,
        # here we set create_graph to true to support second time backward.
        grads_pi = autograd.grad(loss, self.net_pi.parameters(), create_graph=True)

        return loss, grads_pi, acc
    
    def net_forward(self, support_x, support_y):
        """
        This function is purely for updating net network. In metalearner, we need the get the loss op from net network
        to write our merged gradients into net network, hence will call this function to get a dummy loss op.
        :param support_x: [setsz, c, h, w]
        :param support_y: [sessz, c, h, w]
        :return: dummy loss and dummy pred
        """
        loss, pred = self.net(support_x, support_y)
        return loss, pred

In [7]:
def get_losses(parallel_model, cqa, Y, gen_X, gen_Y, loss_fct):
    qa_logits = parallel_model(cqa)
    lm_logits = parallel_model(gen_X)
    qa_loss = loss_fct([torch.transpose(l, 1, 2) for l in qa_logits], Y)
    lm_loss = loss_fct([torch.transpose(l, 1, 2) for l in lm_logits], gen_Y)
    return torch.mean(qa_loss), args.lm_lambda * torch.mean(lm_loss)

In [19]:
learner = Learner(MODEL_CLASS, token_size =len(TOKENIZER) )

# Dataset

In [8]:
data_dir = "/root/LAMOL/lamol_data"
n_train_epochs  = 5 

In [9]:
TASK_DICT = {
    "movie": {
               "train":os.path.join(data_dir,"movie_train.json"),
               "eval":os.path.join(data_dir,"movie_dev.json"),
               "test":os.path.join(data_dir,"movie_test.json"),
               "n_train_epochs": n_train_epochs 
    },
    "boolq": {
               "train":os.path.join(data_dir,"boolq_train.json"),
               "eval":os.path.join(data_dir,"boolq_dev.json"),
               "test":os.path.join(data_dir,"boolq_test.json"),
               "n_train_epochs": n_train_epochs 
    },
    "scifact": {
               "train":os.path.join(data_dir,"scifact_train.json"),
               "eval":os.path.join(data_dir,"scifact_dev.json"),
               "test":os.path.join(data_dir,"scifact_test.json"),
               "n_train_epochs": n_train_epochs 
    }
}

In [10]:

class QADataset(Dataset):
    def __init__(self, data_paths, data_type, gen_token, extra_data=[]):
        self.data_type = data_type
        self.gen_token = gen_token
        self.ans_token = SPECIAL_TOKEN_IDS["ans_token"]
        self.eos_token = SPECIAL_TOKEN_IDS["eos_token"]
        self.pad_token = SPECIAL_TOKEN_IDS["pad_token"]

        if not isinstance(data_paths, list):
            data_paths = [data_paths]

        data = []
        for data_path in data_paths:
            if not data_path:
                continue
            with open(data_path, "r") as f:
                raw_ds = json.load(f)
            raw_ds = map(lambda x: x["paragraphs"], raw_ds["data"])
            d = []
            for raw_d in raw_ds:
                d.extend(raw_d)
            data += d
        
        self.data = []
        self.max_a_len = 0
        if len(data) > 0:
            self.data_tokenization(data)

        if len(extra_data) > 0:
            extra_data = map(lambda x: self.etl_single_extra_data(x), extra_data)
            extra_data = list(filter(lambda x:x, extra_data))
            if args.gen_lm_sample_percentage > 0. and len(extra_data) == 0:
                logger.warning("No good extra data but sample percentage > 0!")
            self.data += extra_data


    def etl_single_extra_data(self, data):
        gen_token = data[0]
        data = ' '.join([str(datum) for datum in data[1:]])
        try:
            context = ""
            qa = data
            question, answer = re.split(str(SPECIAL_TOKEN_IDS["ans_token"]), qa)
            context = [int(c) for c in context.strip().split()]
            question = [int(q) for q in question.strip().split()]
            answer = [int(a) for a in re.sub(str(SPECIAL_TOKEN_IDS["eos_token"]), "", answer).strip().split()]
            uid = uuid.uuid1().hex
            data = self.parse_example(gen_token, context, question, answer, uid)
        except ValueError:
            return
        return data

    def concat_example(self, gen_token, c, sep_token, q, ans_token, a, eos_token):
        example = sep_token + q + ans_token + a
        if len(example) + 1 > max_len:
            logger.warning('an example with len {} is too long!'.format(len(example) + 1))
            return
        example = gen_token + c[:max_len-len(example)-1] + example + eos_token
        return example

    def parse_example(self, gen_token, context, question, answer, idx):
        cq_example = self.concat_example([], context, [], question, [self.ans_token], [], [])
        cqa_example = self.concat_example([], context, [], question, [self.ans_token], answer, [])
        Y_example = self.concat_example([], [], [], [], [], answer, [self.eos_token])
        Y_example = [FILL_VAL] * (len(cqa_example) - len(Y_example)) + Y_example
        gen_X_example = self.concat_example([gen_token], context, [], question, [self.ans_token], answer, [])
        gen_Y_example = self.concat_example([], context, [], question, [self.ans_token], answer, [self.eos_token])
        return cq_example, len(cq_example), cqa_example, len(cqa_example), Y_example, gen_X_example, gen_Y_example, idx

    def parallel_tokenization(self, d):
        examples = []
        context = TOKENIZER.encode(d["context"])
        max_a_len = 0
        for qa in d["qas"]:
            question = TOKENIZER.encode(qa["question"])

            raw_answers = qa["answers"]
            if len(raw_answers) == 0:
                assert qa["is_impossible"]
                raw_answers.append({"text": ""})

            answer = []
            for i, raw_answer in enumerate(raw_answers):
                answer.extend(TOKENIZER.encode(raw_answer["text"]))
                if i != len(raw_answers) - 1:
                    answer.append(self.pad_token)
            max_a_len = max(max_a_len, len(answer))

            examples.append(self.parse_example(self.gen_token, context, question, answer, qa.get("id", 0)))
        return examples, max_a_len

    def data_tokenization(self, data):
        with Pool(4) as pool:
            data = pool.map(self.parallel_tokenization, data)
        for datum, max_a_len in data:
            self.data.extend(datum)
            self.max_a_len = max(self.max_a_len, max_a_len)

    def sort(self):
        self.data.sort(key=lambda x: len(x[0]))
        return self

    def sort_by_index(self):
        self.data.sort(key=lambda x: x[-1])

    def get_indices(self):
        return [d[-1] for d in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [11]:
class DynamicBatchSampler(Sampler):
    def __init__(self, dataset, data_type, max_batch_size):
        self.dataset = dataset
        self.data_type = data_type
        if data_type == "train":
            self.batch_size = train_batch_size
        else:
            self.batch_size = test_batch_size
        self.n_samples = len(dataset)
        self.max_batch_size = max_batch_size

    def __iter__(self):
        if self.data_type == "test":
            indices = range(self.n_samples)
        else:
            indices = np.random.permutation(self.n_samples)
        max_len, cnt, st = 0, 0, 0
        batch = []
        for ed, idx in enumerate(indices):
            ln = len(self.dataset[idx][2])
            if max(max_len, ln)**LEN_FACTOR * (ed - st + 1) > self.batch_size[cnt]:
                st = ed
                cnt += 1
                max_len = 0
                if cnt == args.n_gpus:
                    yield batch
                    cnt = 0
                    batch = []
            max_len = max(max_len, ln)
            batch.append(idx)
            if len(batch) == self.max_batch_size and self.data_type == "train":
                yield batch
                cnt, max_len, st = 0, 0, ed
                batch = []
        if len(batch) > 0:
            yield batch

    def __len__(self):
        raise NotImplementedError


In [12]:

def dynamic_collate_fn(data, batch_size):

    def local_collate():
        null_counter = 0
        _cqs, _len_cqs, _cqas, _len_cqas, _Ys, _gen_Xs, _gen_Ys = [], [], [], [], [], [], []
        Y_max_len = max(len(data[j][4]) for j in range(st, ed))
        cq_max_len = max(len(data[j][0]) for j in range(st, ed))
        for j in range(st, ed):
            if None in data[j] or [] in data[j]:
                null_counter+=1
                logger.warning('null example in collate_fn, count: {}'.format(null_counter))
                continue

            pad_len = cqa_max_len - len(data[j][2])

            _cqs.append(pad_to_max_len(data[j][0], cq_max_len-len(data[j][0]), SPECIAL_TOKEN_IDS["pad_token"]))
            _len_cqs.append(data[j][1])
            _cqas.append(pad_to_max_len(data[j][2], pad_len, SPECIAL_TOKEN_IDS["pad_token"]))
            _len_cqas.append(data[j][3])
            _Ys.append(pad_to_max_len(data[j][4], Y_max_len - len(data[j][4]), FILL_VAL))
            _gen_Xs.append(pad_to_max_len(data[j][5], pad_len, SPECIAL_TOKEN_IDS["pad_token"]))
            _gen_Ys.append(pad_to_max_len(data[j][6], pad_len, FILL_VAL))

        cqs.append(torch.tensor(_cqs))
        len_cqs.append(torch.tensor(_len_cqs))
        cqas.append(torch.tensor(_cqas))
        len_cqas.append(torch.tensor(_len_cqas))
        Ys.append(torch.tensor(_Ys))
        gen_Xs.append(torch.tensor(_gen_Xs))
        gen_Ys.append(torch.tensor(_gen_Ys))

    cqs, len_cqs, cqas, len_cqas, Ys, gen_Xs, gen_Ys = [], [], [], [], [], [], []
    cqa_max_len, cnt, st = 0, 0, 0
    for ed, datum in enumerate(data):
        ln = len(datum[2]) # use cqas to calibrate
        if max(cqa_max_len, ln)**LEN_FACTOR * (ed - st + 1) > batch_size[cnt]:
            local_collate()
            cnt += 1
            cqa_max_len = 0
            st = ed
        cqa_max_len = max(cqa_max_len, ln)
    ed += 1  # otherwise ed will be len(data)-1
    local_collate()

    return cqs, len_cqs, cqas, len_cqas, Ys, gen_Xs, gen_Ys

In [13]:

def varlen_collate_fn(data):
    batch_size = (len(data) + n_gpus - 1) // n_gpus
    cqs = torch.tensor(pad_all_to_max_len([datum[0] for datum in data], SPECIAL_TOKEN_IDS["pad_token"])).split(batch_size)
    len_cqs = torch.tensor([datum[1] for datum in data]).split(batch_size)
    cqas = torch.tensor(pad_all_to_max_len([datum[2] for datum in data], SPECIAL_TOKEN_IDS["pad_token"])).split(batch_size)
    len_cqas = torch.tensor([datum[3] for datum in data]).split(batch_size)
    Ys = torch.tensor(pad_all_to_max_len([datum[4] for datum in data], FILL_VAL)).split(batch_size)
    gen_Xs = torch.tensor(pad_all_to_max_len([datum[5] for datum in data], SPECIAL_TOKEN_IDS["pad_token"])).split(batch_size)
    gen_Ys = torch.tensor(pad_all_to_max_len([datum[6] for datum in data], FILL_VAL)).split(batch_size)
    return list(cqs), list(len_cqs), list(cqas), list(len_cqas), list(Ys), list(gen_Xs), list(gen_Ys)

In [14]:
def pad_to_max_len(l, pad_len, val):
    return l + [val] * pad_len
def pad_all_to_max_len(ls, val):
    max_len = max(len(l) for l in ls)
    return [pad_to_max_len(l, max_len-len(l), val) for l in ls]

In [15]:


def create_dataloader(dataset, data_type, max_batch_size=1000000000):
    if data_type == "train":
        batch_size = train_batch_size
    else:
        batch_size = test_batch_size

    if isinstance(batch_size, list):
        collate_fn=lambda x,bs=batch_size: dynamic_collate_fn(x, bs)
        shuffle = False
        batch_size = 1
        batch_sampler = DynamicBatchSampler(dataset, data_type, max_batch_size)
    else:
        collate_fn=lambda x: varlen_collate_fn(x)
#         shuffle = not (data_type != "train" or args.debug)
        shuffle = False
        batch_sampler = None

    dataloader =  DataLoader(dataset, num_workers=4,
                             collate_fn=collate_fn,
                             shuffle=shuffle,
                             batch_size=batch_size,
                             batch_sampler=batch_sampler)
    return dataloader


In [16]:
def get_gen_token(task):
    return '__' + task + '__'

In [17]:
min_n_steps = 1500
min_batch_size = 4

In [18]:
# First Epoch
tasks = ['movie']


gen_token = get_gen_token(tasks[0])
TOKENIZER.add_tokens([gen_token])
# TOKENIZER.save_pretrained(model_dir)
SPECIAL_TOKENS[tasks[0]] = gen_token
SPECIAL_TOKEN_IDS[tasks[0]] = TOKENIZER.convert_tokens_to_ids(gen_token)
# logger.info('gen token = {} , gen token id = {}'.format(gen_token, SPECIAL_TOKEN_IDS[tasks[0]]))
# MODEL_CONFIG.vocab_size = len(TOKENIZER)
# MODEL_CONFIG.to_json_file(os.path.join(model_dir,CONFIG_NAME))


train_extra_data = []
train_dataset = [TASK_DICT[t]["train"] for t in tasks]
train_qadata = QADataset(train_dataset, "train", SPECIAL_TOKEN_IDS[tasks[0]], train_extra_data)
max_train_batch_size = max(len(train_qadata) // min_n_steps, min_batch_size)
train_dataloader = create_dataloader(train_qadata, "train", max_train_batch_size)

Token indices sequence length is longer than the specified maximum sequence length for this model (1185 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1808 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2222 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1044 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1099 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1329 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1324 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1107 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1198 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1077 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1945 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1356 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1120 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1201 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1048 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1210 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1072 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1065 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1467 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1289 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1685 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1029 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1167 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1236 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1298 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1248 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1489 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1088 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2182 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2260 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1840 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1351 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1845 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1487 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1596 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1463 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1418 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1280 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1112 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1047 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1187 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1564 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1072 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1668 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1563 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

In [19]:
train_qadata.__len__()

1600

In [20]:
# return cq_example, len(cq_example), cqa_example, len(cqa_example), Y_example, gen_X_example, gen_Y_example, idx
#           0                1               2           3                   4       5              6          7
print(train_qadata.__getitem__(0)[0][:10])
print(len(train_qadata.__getitem__(0)[0])) # Checking 
print(train_qadata.__getitem__(0)[1])
print(train_qadata.__getitem__(0)[2][:10])
print(train_qadata.__getitem__(0)[3])
print(train_qadata.__getitem__(0)[4][:10])
print(train_qadata.__getitem__(0)[5][:10])
print(train_qadata.__getitem__(0)[6][:10])
print(train_qadata.__getitem__(0)[7])

[7110, 25, 734, 6036, 11886, 467, 284, 257, 4928, 2151]
901
901
[7110, 25, 734, 6036, 11886, 467, 284, 257, 4928, 2151]
903
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
[50260, 7110, 25, 734, 6036, 11886, 467, 284, 257, 4928]
[7110, 25, 734, 6036, 11886, 467, 284, 257, 4928, 2151]
0


## LAMOL Dataset

`train_qadata` is a pytorch Dataset.  
A single datapoint of a returns a list of length 8.  
```python
return cq_example, len(cq_example), cqa_example, len(cqa_example), Y_example, gen_X_example, gen_Y_example, idx
           0                1               2           3                   4       5              6          7
# 0 cq_example is context+question+__ans__. ie. [7110, 25, 734, 6036, 11886, 467, 284, 257, 4928, 2151]
# 1 len(cq_example) is the length ie. 901
# 2 cqa_example is context+question+__ans__+answer ie. [7110, 25, 734, 6036, 11886, 467, 284, 257, 4928, 2151]
# 3 len(cqa_example) is the length ie. 903
# 4 Y_example is FILL_VALUE+answer only. ie. [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
# 5 gen_X_example is __gen__+context+question+__ans__+answer ie. [50260, 7110, 25, 734, 6036, 11886, 467, 284, 257, 4928]
# 6 gen_Y_example is context+question+__ans__+answer ie. [7110, 25, 734, 6036, 11886, 467, 284, 257, 4928, 2151]
# 7 idx is id (supposed to be uuid? but i don't see it) ie. 0
```

In [21]:
train_loss_fct = CrossEntropyLoss(ignore_index=FILL_VAL, weight=TOKENS_WEIGHT)

In [26]:
max_grad_norm=1
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=6.25e-5, eps=1e-4)
optimizer = FP16_Optimizer(optimizer, static_loss_scale=None, dynamic_loss_scale=True,
                                   dynamic_loss_args={'scale_window': 100, 'min_scale': 1, 'delayed_shift': 2})

In [30]:
n_train_epochs = 3
n_train_optimization_steps = len(train_qadata) * n_train_epochs
scheduler = AnnealingLR(optimizer, start_lr=6.25e-5, warmup_iter=int(0.005*len(train_qadata)),
            num_iters=int(n_train_optimization_steps), decay_style="linear")

In [33]:
device_ids = [0]

# Initialize 2 models
net_cls = MODEL_CLASS
net = net_cls.from_pretrained('gpt2').cuda()
net.resize_token_embeddings(token_size)
net = FP16_Module(model)

net_pi = net_cls.from_pretrained('gpt2').cuda()
net_pi.resize_token_embeddings(token_size)
net_pi = FP16_Module(model)


# Sequential Tasks
tasks = ['movies', 'boolq', 'scifact']
for task_id in range(len(tasks)):
    
    _tasks = [tasks[task_id] for task_id in task_ids]
    
    ##### Start training on task_id #####
    gen_token = get_gen_token(tasks[task_id])
    TOKENIZER.add_tokens([gen_token])
    # TOKENIZER.save_pretrained(model_dir)
    SPECIAL_TOKENS[tasks[0]] = gen_token
    SPECIAL_TOKEN_IDS[tasks[0]] = TOKENIZER.convert_tokens_to_ids(gen_token)
    # logger.info('gen token = {} , gen token id = {}'.format(gen_token, SPECIAL_TOKEN_IDS[tasks[0]]))
    # MODEL_CONFIG.vocab_size = len(TOKENIZER)
    # MODEL_CONFIG.to_json_file(os.path.join(model_dir,CONFIG_NAME))

    ##### Get Extra data and that particular dataset #####
    train_extra_data = []
    train_dataset = [TASK_DICT[t]["train"] for t in _tasks]
    train_qadata = QADataset(train_dataset, "train", SPECIAL_TOKEN_IDS[_tasks[0]], train_extra_data)
    max_train_batch_size = max(len(train_qadata) // min_n_steps, min_batch_size)
    train_dataloader = create_dataloader(train_qadata, "train", max_train_batch_size)


    ##### Stream from that dataset's dataloader #####
    cum_loss, cum_qa_loss, cum_lm_loss, cur_n_inputs = 0, 0, 0, 0
    for n_steps, (_, _, cqa, _, Y, gen_X, gen_Y) in enumerate(train_dataloader):
        print(len(cqa))
        print(cqa)
        print(cqa[0].shape)

        n_inputs = sum(_cqa.shape[0] for _cqa in cqa)
        print(n_inputs)

        support_x = cqa[0][:10,...]
        support_y = Y[0][:10]
        query_x = cqa[0][10:20,...]
        query_y = Y[0][10:20]

        support_x = support_x.to('cuda:0')
        support_y = support_y.to('cuda:0')


    #     for i in range(len(cqa)):
    #         cqa[i] = (cqa[i].to(device_ids[i]),)
    #         Y[i] = Y[i].to(device_ids[i])
    #         gen_X[i] = (gen_X[i].to(device_ids[i]),)
    #         gen_Y[i] = gen_Y[i].to(device_ids[i])


    #     losses = get_losses(parallel_model, cqa, Y, gen_X, gen_Y, train_loss_fct)
        # Supposed to be parallel!!
    #     qa_logits = parallel_model([cqa[0][:10,...]])
        qa_logits = model(support_x)

    #     lm_logits = parallel_model(gen_X)
    #     print(qa_logits)
        print(len(qa_logits))
        print(qa_logits[0].shape)
        print(len(qa_logits[1]))
    #      Somehow it also returns attentions in [1]?, this is selecting 0 of what WrapModel is doing 



        qa_logits = qa_logits[0]

        print(qa_logits.shape)
        print(support_y.shape)
        qa_loss = train_loss_fct(qa_logits.transpose(1,2), support_y)
    #     qa_loss = train_loss_fct([torch.transpose(l, 1, 2) for l in qa_logits], support_y)
        print(qa_loss)
    #     lm_loss = loss_fct([torch.transpose(l, 1, 2) for l in lm_logits], gen_Y)
    #     loss = sum([torch.mean(qa_loss), args.lm_lambda * torch.mean(lm_loss)])
        loss = qa_loss

        # Update Optimizer
        optimizer.backward(loss, update_master_grads=False) # instead of loss.backward() for fp16
        optimizer.update_master_grads()
        optimizer.clip_master_grads(max_grad_norm)
        optimizer.step()
        if not optimizer.overflow:
            for i in range(n_inputs):
                scheduler.step()
        optimizer.zero_grad()


        qa_loss = qa_loss.item() * n_inputs # since each cross entropy loss only gives us the mean as reduciton
    #     lm_loss = losses[1].item() * n_inputs
    #     cum_loss += (qa_loss + lm_loss)
        cum_qa_loss += qa_loss
    #     cum_lm_loss += lm_loss
        cur_n_inputs += n_inputs
        print(qa_loss)

        break

1
[tensor([[ 7110,    25,   734,  ..., 50257,   399,  7156],
        [  262,  3772, 31030,  ..., 50258, 50258, 50258],
        [  340,   318,  6918,  ..., 50258, 50258, 50258],
        [  366,  1235,   329,  ..., 50258, 50258, 50258]])]
torch.Size([4, 903])
4
2
torch.Size([4, 903, 50260])
12
torch.Size([4, 903, 50260])
torch.Size([4, 903])
tensor(102.7731, device='cuda:0', grad_fn=<NllLoss2DBackward>)
411.0924987792969


In [28]:
model

FP16_Module(
  (module): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50260, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0): Block(
          (ln_1): FusedLayerNorm(torch.Size([768]), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): FusedLayerNorm(torch.Size([768]), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): Block(
          (ln_1): FusedLayerNorm(torch.Size([768]), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dr

In [28]:
dir(tokenizer)

['SPECIAL_TOKENS_ATTRIBUTES',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_additional_special_tokens',
 '_bos_token',
 '_cls_token',
 '_convert_id_to_token',
 '_convert_token_to_id',
 '_convert_token_to_id_with_added_voc',
 '_eos_token',
 '_from_pretrained',
 '_mask_token',
 '_pad_token',
 '_sep_token',
 '_tokenize',
 '_unk_token',
 'add_special_tokens',
 'add_special_tokens_sentences_pair',
 'add_special_tokens_single_sentence',
 'add_tokens',
 'added_tokens_decoder',
 'added_tokens_encoder',
 'additional_special_tokens',
 'additional_special_tokens_ids',
 'all_special_ids',
 'all_special_tokens',
 'bos_token',
 'bos_token_id',
 'bpe',
 'bpe_ranks',
 

In [29]:
tokenizer

<pytorch_transformers.tokenization_gpt2.GPT2Tokenizer at 0x7f4c7d589550>

In [None]:
train_loss_fct = DataParallelCriterion(CrossEntropyLoss(ignore_index=FILL_VAL, weight=TOKENS_WEIGHT), args.device_ids)