In [1]:
from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, CONFIG_NAME 
from pytorch_transformers import AdamW
from fp16 import FP16_Module, FP16_Optimizer
from torch import nn
import os
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import json
from multiprocessing import Pool
from parallel import DataParallelModel, DataParallelCriterion
from torch.nn import CrossEntropyLoss
from scheduler import AnnealingLR
import warnings
import time
from datetime import datetime
import pickle

# import GPUtil

import logging

In [2]:
tasks = ['movie', 'boolq', 'scifact']
TASK_TRIAD = ''.join([task_name[0] for task_name in tasks])
DATETIME = datetime.today().strftime('%Y%m%dT%H%M%S')

RUN_ID = DATETIME + "_" + TASK_TRIAD

logging.basicConfig(filename=f'{RUN_ID}.log', encoding='utf-8', level=logging.INFO)

'20210630T181500'

In [2]:
FILL_VAL = -1

n_gpus = 1
device_ids = [1]

train_batch_size = 3
test_batch_size = 3

data_dir = "/root/LAMOL/lamol_data"
MODEL_BASE_DIR = "./model"
n_train_epochs  = 5 

min_n_steps = 1500
min_batch_size = 4
n_train_epochs = 3

# Adaptation Phase
num_updates = 5

DEVICE = 'cuda:0'

model_dir = os.path.join(MODEL_BASE_DIR, RUN_ID)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

In [3]:
MODEL_CLASSES = {
    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config),
}

In [4]:
TASK_DICT = {
    "movie": {
               "train":os.path.join(data_dir,"movie_train.json"),
               "eval":os.path.join(data_dir,"movie_dev.json"),
               "test":os.path.join(data_dir,"movie_test.json"),
               "n_train_epochs": n_train_epochs 
    },
    "boolq": {
               "train":os.path.join(data_dir,"boolq_train.json"),
               "eval":os.path.join(data_dir,"boolq_dev.json"),
               "test":os.path.join(data_dir,"boolq_test.json"),
               "n_train_epochs": n_train_epochs 
    },
    "scifact": {
               "train":os.path.join(data_dir,"scifact_train.json"),
               "eval":os.path.join(data_dir,"scifact_dev.json"),
               "test":os.path.join(data_dir,"scifact_test.json"),
               "n_train_epochs": n_train_epochs 
    }
}

In [5]:
# In settings.py
special_tokens = {"ans_token":'__ans__', "pad_token":'__pad__', "unk_token":'__unk__', "eos_token": '<|endoftext|>'}

model_class, tokenizer_class, config_class = MODEL_CLASSES['gpt2']
tokenizer = tokenizer_class.from_pretrained('gpt2')
tokenizer.add_tokens(list(special_tokens.values()))
special_token_ids = {k:tokenizer.convert_tokens_to_ids(v) for k,v in special_tokens.items()}


model_config = config_class.from_pretrained('gpt2')
model_config.vocab_size = len(tokenizer)

tokens_weight = torch.ones([model_config.vocab_size], dtype=torch.float).to(DEVICE)
tokens_weight[special_token_ids["ans_token"]] = 5


MODEL_CLASS = model_class
TOKENIZER = tokenizer
SPECIAL_TOKENS = special_tokens
SPECIAL_TOKEN_IDS = special_token_ids
TOKENS_WEIGHT = tokens_weight

model_config = config_class.from_pretrained('gpt2')
max_len = model_config.n_positions

In [6]:
max_len

1024

In [7]:
def get_losses(parallel_model, cqa, Y, gen_X, gen_Y, loss_fct):
    qa_logits = parallel_model(cqa)
    lm_logits = parallel_model(gen_X)
    qa_loss = loss_fct([torch.transpose(l, 1, 2) for l in qa_logits], Y)
    lm_loss = loss_fct([torch.transpose(l, 1, 2) for l in lm_logits], gen_Y)
    return torch.mean(qa_loss), args.lm_lambda * torch.mean(lm_loss)

# Dataset

In [8]:
class QADataset(Dataset):
    def __init__(self, data_paths, data_type, gen_token, extra_data=[]):
        self.data_type = data_type
        self.gen_token = gen_token
        self.ans_token = SPECIAL_TOKEN_IDS["ans_token"]
        self.eos_token = SPECIAL_TOKEN_IDS["eos_token"]
        self.pad_token = SPECIAL_TOKEN_IDS["pad_token"]

        if not isinstance(data_paths, list):
            data_paths = [data_paths]

        data = []
        for data_path in data_paths:
            if not data_path:
                continue
            with open(data_path, "r") as f:
                raw_ds = json.load(f)
            raw_ds = map(lambda x: x["paragraphs"], raw_ds["data"])
            d = []
            for raw_d in raw_ds:
                d.extend(raw_d)
            data += d
        
        self.data = []
        self.max_a_len = 0
        if len(data) > 0:
            self.data_tokenization(data)

        if len(extra_data) > 0:
            extra_data = map(lambda x: self.etl_single_extra_data(x), extra_data)
            extra_data = list(filter(lambda x:x, extra_data))
            if args.gen_lm_sample_percentage > 0. and len(extra_data) == 0:
                logger.warning("No good extra data but sample percentage > 0!")
            self.data += extra_data


    def etl_single_extra_data(self, data):
        gen_token = data[0]
        data = ' '.join([str(datum) for datum in data[1:]])
        try:
            context = ""
            qa = data
            question, answer = re.split(str(SPECIAL_TOKEN_IDS["ans_token"]), qa)
            context = [int(c) for c in context.strip().split()]
            question = [int(q) for q in question.strip().split()]
            answer = [int(a) for a in re.sub(str(SPECIAL_TOKEN_IDS["eos_token"]), "", answer).strip().split()]
            uid = uuid.uuid1().hex
            data = self.parse_example(gen_token, context, question, answer, uid)
        except ValueError:
            return
        return data

    def concat_example(self, gen_token, c, sep_token, q, ans_token, a, eos_token):
        example = sep_token + q + ans_token + a
        if len(example) + 1 > max_len:
            logger.warning('an example with len {} is too long!'.format(len(example) + 1))
            return
        example = gen_token + c[:max_len-len(example)-1] + example + eos_token
        return example

    def parse_example(self, gen_token, context, question, answer, idx):
        cq_example = self.concat_example([], context, [], question, [self.ans_token], [], [])
        cqa_example = self.concat_example([], context, [], question, [self.ans_token], answer, [])
        Y_example = self.concat_example([], [], [], [], [], answer, [self.eos_token])
        Y_example = [FILL_VAL] * (len(cqa_example) - len(Y_example)) + Y_example
        gen_X_example = self.concat_example([gen_token], context, [], question, [self.ans_token], answer, [])
        gen_Y_example = self.concat_example([], context, [], question, [self.ans_token], answer, [self.eos_token])
        return cq_example, len(cq_example), cqa_example, len(cqa_example), Y_example, gen_X_example, gen_Y_example, idx

    def parallel_tokenization(self, d):
        # ADD MAX LENGTH FOR MODEL SO IT DOESNT SHOW WARNING
        # OLD VERSION OF PYTORCH HUGGINGFACE DOESNT HAVE MAX LENGTH!!!!!
        # Suppress the warnings instead! https://stackoverflow.com/questions/14463277/how-to-disable-python-warnings
        # Still doesn't work. idk what to do. we can delete all the warnings catches here
        examples = []
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            context = TOKENIZER.encode(d["context"])
        max_a_len = 0
        for qa in d["qas"]:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                question = TOKENIZER.encode(qa["question"])

            raw_answers = qa["answers"]
            if len(raw_answers) == 0:
                assert qa["is_impossible"]
                raw_answers.append({"text": ""})

            answer = []
            for i, raw_answer in enumerate(raw_answers):
                answer.extend(TOKENIZER.encode(raw_answer["text"]))
                if i != len(raw_answers) - 1:
                    answer.append(self.pad_token)
            max_a_len = max(max_a_len, len(answer))

            examples.append(self.parse_example(self.gen_token, context, question, answer, qa.get("id", 0)))
        return examples, max_a_len

    def data_tokenization(self, data):
        with Pool(4) as pool:
            data = pool.map(self.parallel_tokenization, data)
        for datum, max_a_len in data:
            self.data.extend(datum)
            self.max_a_len = max(self.max_a_len, max_a_len)

    def sort(self):
        self.data.sort(key=lambda x: len(x[0]))
        return self

    def sort_by_index(self):
        self.data.sort(key=lambda x: x[-1])

    def get_indices(self):
        return [d[-1] for d in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

In [9]:
class DynamicBatchSampler(Sampler):
    def __init__(self, dataset, data_type, max_batch_size):
        self.dataset = dataset
        self.data_type = data_type
        if data_type == "train":
            self.batch_size = train_batch_size
        else:
            self.batch_size = test_batch_size
        self.n_samples = len(dataset)
        self.max_batch_size = max_batch_size

    def __iter__(self):
        if self.data_type == "test":
            indices = range(self.n_samples)
        else:
            indices = np.random.permutation(self.n_samples)
        max_len, cnt, st = 0, 0, 0
        batch = []
        for ed, idx in enumerate(indices):
            ln = len(self.dataset[idx][2])
            if max(max_len, ln)**LEN_FACTOR * (ed - st + 1) > self.batch_size[cnt]:
                st = ed
                cnt += 1
                max_len = 0
                if cnt == args.n_gpus:
                    yield batch
                    cnt = 0
                    batch = []
            max_len = max(max_len, ln)
            batch.append(idx)
            if len(batch) == self.max_batch_size and self.data_type == "train":
                yield batch
                cnt, max_len, st = 0, 0, ed
                batch = []
        if len(batch) > 0:
            yield batch

    def __len__(self):
        raise NotImplementedError


In [10]:
def dynamic_collate_fn(data, batch_size):

    def local_collate():
        null_counter = 0
        _cqs, _len_cqs, _cqas, _len_cqas, _Ys, _gen_Xs, _gen_Ys = [], [], [], [], [], [], []
        Y_max_len = max(len(data[j][4]) for j in range(st, ed))
        cq_max_len = max(len(data[j][0]) for j in range(st, ed))
        for j in range(st, ed):
            if None in data[j] or [] in data[j]:
                null_counter+=1
                logger.warning('null example in collate_fn, count: {}'.format(null_counter))
                continue

            pad_len = cqa_max_len - len(data[j][2])

            _cqs.append(pad_to_max_len(data[j][0], cq_max_len-len(data[j][0]), SPECIAL_TOKEN_IDS["pad_token"]))
            _len_cqs.append(data[j][1])
            _cqas.append(pad_to_max_len(data[j][2], pad_len, SPECIAL_TOKEN_IDS["pad_token"]))
            _len_cqas.append(data[j][3])
            _Ys.append(pad_to_max_len(data[j][4], Y_max_len - len(data[j][4]), FILL_VAL))
            _gen_Xs.append(pad_to_max_len(data[j][5], pad_len, SPECIAL_TOKEN_IDS["pad_token"]))
            _gen_Ys.append(pad_to_max_len(data[j][6], pad_len, FILL_VAL))

        cqs.append(torch.tensor(_cqs))
        len_cqs.append(torch.tensor(_len_cqs))
        cqas.append(torch.tensor(_cqas))
        len_cqas.append(torch.tensor(_len_cqas))
        Ys.append(torch.tensor(_Ys))
        gen_Xs.append(torch.tensor(_gen_Xs))
        gen_Ys.append(torch.tensor(_gen_Ys))

    cqs, len_cqs, cqas, len_cqas, Ys, gen_Xs, gen_Ys = [], [], [], [], [], [], []
    cqa_max_len, cnt, st = 0, 0, 0
    for ed, datum in enumerate(data):
        ln = len(datum[2]) # use cqas to calibrate
        if max(cqa_max_len, ln)**LEN_FACTOR * (ed - st + 1) > batch_size[cnt]:
            local_collate()
            cnt += 1
            cqa_max_len = 0
            st = ed
        cqa_max_len = max(cqa_max_len, ln)
    ed += 1  # otherwise ed will be len(data)-1
    local_collate()

    return cqs, len_cqs, cqas, len_cqas, Ys, gen_Xs, gen_Ys

In [11]:

def varlen_collate_fn(data):
    batch_size = (len(data) + n_gpus - 1) // n_gpus
    cqs = torch.tensor(pad_all_to_max_len([datum[0] for datum in data], SPECIAL_TOKEN_IDS["pad_token"])).split(batch_size)
    len_cqs = torch.tensor([datum[1] for datum in data]).split(batch_size)
    cqas = torch.tensor(pad_all_to_max_len([datum[2] for datum in data], SPECIAL_TOKEN_IDS["pad_token"])).split(batch_size)
    len_cqas = torch.tensor([datum[3] for datum in data]).split(batch_size)
    Ys = torch.tensor(pad_all_to_max_len([datum[4] for datum in data], FILL_VAL)).split(batch_size)
    gen_Xs = torch.tensor(pad_all_to_max_len([datum[5] for datum in data], SPECIAL_TOKEN_IDS["pad_token"])).split(batch_size)
    gen_Ys = torch.tensor(pad_all_to_max_len([datum[6] for datum in data], FILL_VAL)).split(batch_size)
    return list(cqs), list(len_cqs), list(cqas), list(len_cqas), list(Ys), list(gen_Xs), list(gen_Ys)

In [12]:
def pad_to_max_len(l, pad_len, val):
    return l + [val] * pad_len
def pad_all_to_max_len(ls, val):
    max_len = max(len(l) for l in ls)
    return [pad_to_max_len(l, max_len-len(l), val) for l in ls]

In [13]:
def create_dataloader(dataset, data_type, max_batch_size=1000000000):
    if data_type == "train":
        batch_size = train_batch_size
    else:
        batch_size = test_batch_size

    if isinstance(batch_size, list):
        collate_fn=lambda x,bs=batch_size: dynamic_collate_fn(x, bs)
        shuffle = False
        batch_size = 1
        batch_sampler = DynamicBatchSampler(dataset, data_type, max_batch_size)
    else:
        collate_fn=lambda x: varlen_collate_fn(x)
#         shuffle = not (data_type != "train" or args.debug)
        shuffle = False
        batch_sampler = None

    dataloader =  DataLoader(dataset, num_workers=4,
                             collate_fn=collate_fn,
                             shuffle=shuffle,
                             batch_size=batch_size,
                             batch_sampler=batch_sampler)
    return dataloader


In [14]:
def get_gen_token(task):
    return '__' + task + '__'

## LAMOL Dataset

`train_qadata` is a pytorch Dataset.  
A single datapoint of a returns a list of length 8.  
```python
return cq_example, len(cq_example), cqa_example, len(cqa_example), Y_example, gen_X_example, gen_Y_example, idx
           0                1               2           3                   4       5              6          7
# 0 cq_example is context+question+__ans__. ie. [7110, 25, 734, 6036, 11886, 467, 284, 257, 4928, 2151]
# 1 len(cq_example) is the length ie. 901
# 2 cqa_example is context+question+__ans__+answer ie. [7110, 25, 734, 6036, 11886, 467, 284, 257, 4928, 2151]
# 3 len(cqa_example) is the length ie. 903
# 4 Y_example is FILL_VALUE+answer only. ie. [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]
# 5 gen_X_example is __gen__+context+question+__ans__+answer ie. [50260, 7110, 25, 734, 6036, 11886, 467, 284, 257, 4928]
# 6 gen_Y_example is context+question+__ans__+answer ie. [7110, 25, 734, 6036, 11886, 467, 284, 257, 4928, 2151]
# 7 idx is id (supposed to be uuid? but i don't see it) ie. 0
```

In [15]:
# Initialize 2 models
net_cls = MODEL_CLASS
net = net_cls.from_pretrained('gpt2').to(DEVICE)
net.resize_token_embeddings(len(TOKENIZER))
net = FP16_Module(net)

net_pi = net_cls.from_pretrained('gpt2').to(DEVICE)
net_pi.resize_token_embeddings(len(TOKENIZER))
net_pi = FP16_Module(net_pi)

# Training loss function
train_loss_fct = CrossEntropyLoss(ignore_index=FILL_VAL, weight=TOKENS_WEIGHT)

# Optimizer
max_grad_norm=1
param_optimizer = list(net_pi.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr=6.25e-5, eps=1e-4)
optimizer = FP16_Optimizer(optimizer, static_loss_scale=None, dynamic_loss_scale=True,
                                   dynamic_loss_args={'scale_window': 100, 'min_scale': 1, 'delayed_shift': 2})


# the optimizer is to update theta parameters, not theta_pi parameters.

meta_param_optimizer = list(net.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
meta_optimizer_grouped_parameters = [
    {'params': [p for n, p in meta_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in meta_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
meta_optimizer = AdamW(meta_optimizer_grouped_parameters, lr=6.25e-5, eps=1e-4)
meta_optimizer = FP16_Optimizer(meta_optimizer, static_loss_scale=None, dynamic_loss_scale=True,
                                   dynamic_loss_args={'scale_window': 100, 'min_scale': 1, 'delayed_shift': 2})

# Sequential Tasks
for task_id in range(len(tasks)):
    tic_TASK = time.time()
    
    
    _tasks = [tasks[task_id]]
    
    
    
    ##### Start training on task_id #####
    gen_token = get_gen_token(_tasks[0])
    TOKENIZER.add_tokens([gen_token])
    # TOKENIZER.save_pretrained(model_dir)
    SPECIAL_TOKENS[_tasks[0]] = gen_token
    SPECIAL_TOKEN_IDS[_tasks[0]] = TOKENIZER.convert_tokens_to_ids(gen_token)
    # logger.info('gen token = {} , gen token id = {}'.format(gen_token, SPECIAL_TOKEN_IDS[tasks[0]]))
    # MODEL_CONFIG.vocab_size = len(TOKENIZER)
    # MODEL_CONFIG.to_json_file(os.path.join(model_dir,CONFIG_NAME))

    ##### Get Extra data and that particular dataset #####
    train_extra_data = []
    train_dataset = [TASK_DICT[t]["train"] for t in _tasks]
    train_qadata = QADataset(train_dataset, "train", SPECIAL_TOKEN_IDS[_tasks[0]], train_extra_data)
    max_train_batch_size = max(len(train_qadata) // min_n_steps, min_batch_size)
    train_dataloader = create_dataloader(train_qadata, "train", max_train_batch_size)
    
    
    
    # Scheduler
    n_train_optimization_steps = len(train_qadata) * n_train_epochs
    scheduler = AnnealingLR(optimizer, start_lr=6.25e-5, warmup_iter=int(0.005*len(train_qadata)),
            num_iters=int(n_train_optimization_steps), decay_style="linear")


    ##### Stream from that dataset's dataloader #####
    iter_dataloader = iter(train_dataloader)
    n_steps = 0
    
    episode_loss = []
    while True:
        tic_BATCH = time.time()
        
        # 1. Get the support data from the first batch 
        #    and the query data from the second batch
        try:
            _, _, support_x, _, support_y, support_gen_x, support_gen_y = next(iter_dataloader)
            _, _, query_x, _, query_y, query_gen_x, query_gen_y = next(iter_dataloader)
            
            n_inputs = sum(_cqa.shape[0] for _cqa in support_x)
            
            # Since we only have 1 GPU, just use the first one, it will separate batches according to the device IDS
            support_x = support_x[0]
            support_y = support_y[0]
            query_x = query_x[0]
            query_y = query_y[0]
            
            support_x = support_x.to(DEVICE)
            support_y = support_y.to(DEVICE)
            query_x = query_x.to(DEVICE)
            query_y = query_y.to(DEVICE)
            
        except StopIteration:
            break
        
        
#         print(support_x.shape)
#         print(query_x.shape)
        
        
        ### START Adaptation Phase ###
        # 2. Reinitialize net_pi with parameters from net
        for m_from, m_to in zip(net.modules(), net_pi.modules()):
            m_to.load_state_dict(m_from.state_dict().copy())
        
        
        # 3. Update the weights with the support set
        # May update for several steps
        for i in range(num_updates):
            
            qa_logits = net_pi(support_x)
            # Somehow it also returns attentions in [1]?, this is selecting 0 of what WrapModel is doing 
            qa_logits = qa_logits[0]
            qa_loss = train_loss_fct(qa_logits.transpose(1,2), support_y)
            loss = qa_loss
        
            # Update Optimizer
            optimizer.backward(loss, update_master_grads=False) # instead of loss.backward() for fp16
            optimizer.update_master_grads()
            optimizer.clip_master_grads(max_grad_norm)
            optimizer.step()
            if not optimizer.overflow:
                for i in range(n_inputs):
                    scheduler.step()
            optimizer.zero_grad()
        
        ### END Adaptation Phase ###

        ### START Meta-Learning Phase ###
        # 4. After Adaptation, use the query set for learning
        # Somehow it also returns attentions in [1]?, this is selecting 0 of what WrapModel is doing 
        qa_logits = net_pi(query_x)[0]
        qa_loss = train_loss_fct(qa_logits.transpose(1,2), query_y)
        loss = qa_loss
        
        # Add loss to episode loss
        episode_loss.append(loss.item())
        
        # gradient for validation on theta_pi
        # after call autorad.grad, you can not call backward again except for setting create_graph = True
        # as we will use the loss as dummpy loss to conduct a dummy backprop to write our gradients to theta network,
        # here we set create_graph to true to support second time backward.
        grads_pi = torch.autograd.grad(loss, net_pi.parameters(), create_graph=True)
        
        # As we already have the grads to update
        # We use a dummy forward / backward pass to get the correct grads into self.net
        # the right grads will be updated by hook, ignoring backward.
        # use hook mechnism to write sumed gradient into network.
        # we need to update the theta/net network, we need a op from net network, so we call self.learner.net_forward
        # to get the op from net network, since the loss from self.learner.forward will return loss from net_pi network.
        
        
        # Somehow it also returns attentions in [1]?, this is selecting 0 of what WrapModel is doing 
        qa_logits = net(query_x)[0]
        dummy_loss = train_loss_fct(qa_logits.transpose(1,2), query_y)
        
        
        # Register a hook on each parameter in the net that replaces the current dummy grad
        # with our grads accumulated across the meta-batch
        hooks = []
        for i, v in enumerate(net.parameters()):
            def closure():
                ii = i
                return lambda grad: grads_pi[ii]
            # if you write: hooks.append( v.register_hook(lambda grad : sum_grads_pi[i]) )
            # it will pop an ERROR, i don't know why?
            hooks.append(v.register_hook(closure()))
        
        # use our sumed gradients_pi to update the theta/net network,
        # since our optimizer receive the self.net.parameters() only.
        # Update Meta Optimizer
        meta_optimizer.backward(dummy_loss, update_master_grads=False) # instead of loss.backward() for fp16
        meta_optimizer.update_master_grads()
        meta_optimizer.clip_master_grads(max_grad_norm)
        meta_optimizer.step()
        # DO I NEED SCHEDULER HERE???
        meta_optimizer.zero_grad()

        # if you do NOT remove the hook, the GPU memory will expode!!!
        for h in hooks:
            h.remove()
        
        
        ### END Meta-Learning Phase ###
        n_steps += 1
        
        toc_BATCH = time.time() - tic_BATCH
        
        torch.cuda.empty_cache()
#         mem = float(torch.cuda.memory_allocated() / (1024 * 1024))
#         print("memory allocated:", mem, "MiB")
        if n_steps%10 == 0:
            logger.info(f'{RUN_ID} {_tasks[0]} Steps: {n_steps}/{len(train_qadata)//max_train_batch_size} Episode {n_steps}: Loss: {loss:.5f} lr {scheduler.get_lr():.1E} Batch: {n_inputs}')
            logger.info(f'[TIME] BATCH {RUN_ID} {_tasks[0]} {toc_BATCH}')
            
    toc_TASK = time.time() - tic_TASK
    MODEL_SAVE_LOC = os.path.join(model_dir, f'{_tasks[0]}.model')
    LOSS_SAVE_LOC = os.path.join(model_dir, f'{_tasks[0]}_loss.pickle')
    torch.save(net.state_dict(), MODEL_SAVE_LOC)
    logger.info(f'{RUN_ID} {_tasks[0]} Done Saving Model at {MODEL_SAVE_LOC}')
    logger.info(f'[TIME] TASK {RUN_ID} {_tasks[0]} {toc_TASK}')
    pickle.dump( episode_loss, open( LOSS_SAVE_LOC, "wb" ), protocol=pickle.HIGHEST_PROTOCOL )


Token indices sequence length is longer than the specified maximum sequence length for this model (1185 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1808 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2222 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1044 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1099 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1324 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1107 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1198 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1263 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1094 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1356 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1670 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1945 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1201 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1043 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1446 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1065 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1072 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1432 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1450 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1798 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1685 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1236 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1757 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1489 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2079 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1141 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1088 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1732 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1174 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1487 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2214 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1406 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1351 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1418 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1463 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1280 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1112 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1079 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

Token indices sequence length is longer than the specified maximum sequence length for this model (1067 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1187 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1694 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2917 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1070 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence leng

memory allocated: 12172.85546875 MiB
Episode 1: Loss: 98.00357055664062
memory allocated: 12872.00732421875 MiB
Episode 2: Loss: 100.80538940429688
memory allocated: 11563.5888671875 MiB
Episode 3: Loss: 91.32156372070312
memory allocated: 12851.9111328125 MiB
Episode 4: Loss: 97.4994125366211
memory allocated: 12848.3662109375 MiB
Episode 5: Loss: 104.6044692993164
memory allocated: 10717.36376953125 MiB
Episode 6: Loss: 100.3438491821289
memory allocated: 12266.951171875 MiB
Episode 7: Loss: 75.62680053710938
memory allocated: 13466.7998046875 MiB
Episode 8: Loss: 77.94529724121094
memory allocated: 14279.810546875 MiB
Episode 9: Loss: 71.88252258300781
memory allocated: 8320.7880859375 MiB
Episode 10: Loss: 40.5863037109375
memory allocated: 14278.22900390625 MiB
Episode 11: Loss: 25.581228256225586
memory allocated: 14290.01318359375 MiB
Episode 12: Loss: 58.23491668701172
memory allocated: 12218.890625 MiB
Episode 13: Loss: 12.959354400634766
memory allocated: 14286.60302734375 Mi

KeyboardInterrupt: 