In [None]:
import os

from tqdm import tqdm
import pandas as pd

import torch

from torch.utils.data import Dataset

from transformers import GPT2LMHeadModel, GPT2Config, CTRLLMHeadModel, GPT2TokenizerFast, CTRLTokenizer, AdamW, get_linear_schedule_with_warmup

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from collections import OrderedDict


class CustomDataset(Dataset):
    def __init__(self, type, tokenizer, maxlen):

        # Store the contents of the file in a pandas dataframe
        self.df = pd.read_csv('data/dataset_v1.tsv', sep='\t')
        self.maxlen = maxlen
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        inp_str = self.df.loc[index, 'abstract']
        trg_str = self.df.loc[index, 'output']
        # cond_str = self.df.loc[index, 'attribute']
        inp_ids = self.tokenizer.encode(inp_str, truncation=True, max_length=1000)
        # cond_str = self.tokenizer.encode(cond_str + ':')
        trg_ids = self.tokenizer.encode(trg_str)
        pad_id = int(self.tokenizer.pad_token_id)
        bos_id = int(self.tokenizer.bos_token_id)
        eos_id = int(self.tokenizer.eos_token_id)
        sep_id = int(self.tokenizer.sep_token_id)
        # sepo_id = int(self.tokenizer('<SEPO>'))

        src = [bos_id] + inp_ids + [sep_id] + trg_ids + [eos_id]

        if len(src) > self.maxlen:
            inp_ids = inp_ids[:-(len(src) - self.maxlen)]
            inp_ids = [bos_id] + inp_ids + [sep_id]
            src = inp_ids + trg_ids + [eos_id]

        src = src +\
            [pad_id for _ in range(self.maxlen - len(src))]

        attn_mask = (torch.tensor(src) != pad_id)
        # attn_mask[len(inp_ids):len(inp_ids + trg_ids)] = attn_mask_outputs

        labels = torch.tensor(src)
        labels[:len(inp_ids)] = torch.ones((len(inp_ids))) * -100
        labels[labels == pad_id] = -100

        assert src.count(bos_id) == 1, f'bos_id missing {src[:3]}'
        assert src.count(eos_id) == 1, 'eos_id missing'
        assert src.count(sep_id) == 1, 'sep_id missing'
        # assert src.count(sepo_id) == 13, 'sepo_id missing'
        return torch.tensor(src), attn_mask, labels


class LM(pl.LightningModule):
    def __init__(self, max_len: int=1000):
        super(LM, self).__init__()

        ######### CHANGED ##########
        self.model = GPT2LMHeadModel.from_pretrained("mrm8488/GPT-2-finetuned-CORD19", return_dict=True)
        self.tokenizer = GPT2TokenizerFast.from_pretrained("mrm8488/GPT-2-finetuned-CORD19")
        self.tokenizer.add_special_tokens({
            'pad_token': '<PAD>',
            'bos_token': '<BOS>',
            'eos_token': '<EOS>',
            'sep_token': '<SEP>',
            'additional_special_tokens': ['<SEPO>']

        })
        self.model.resize_token_embeddings(len(self.tokenizer))
        # checkpoint = torch.load('checkpoint-5-encode-15+-epoch=17-val_loss=0.035.ckpt')
        # if 'state_dict' in checkpoint.keys():
        #     state_dict = checkpoint['state_dict']
        #     new_state_dict = OrderedDict()
        #     for k, v in state_dict.items():
        #         if k[:6] == 'model.':
        #             name = k[6:]
        #         else:
        #             name = k
        #         new_state_dict[name] = v
        #     self.model.load_state_dict(new_state_dict)
        self.max_len = max_len

    def forward(self, batch):
        inp, mask, labels = batch
        return self.model(inp, attention_mask=mask, labels=labels)

    def training_step(self, batch, batch_idx):
        output = self.forward(batch)
        loss = output.loss
        self.log('train_loss', loss)
        return {'loss': loss}

    # def validation_step(self, batch, batch_idx):
    #     output = self.forward(batch)
    #     loss = output.loss
    #     self.log('val_loss', loss, prog_bar=True)
    #     return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def train_dataloader(self):
        train_dataset = CustomDataset('train', self.tokenizer, self.max_len)

        train_dataloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=2, num_workers=2, shuffle=True)

        return train_dataloader

    # def val_dataloader(self):
    #     val_dataset = CustomDataset('validation', self.tokenizer, self.max_len)

    #     val_dataloader = torch.utils.data.DataLoader(
    #     val_dataset, batch_size=12, num_workers=2, shuffle=False)

    #     return val_dataloader


if __name__ == "__main__":
    import os.path
    from os import path
    save_dir_path = "/content/Checkpoints"
    if not path.exists(save_dir_path):
        os.mkdir(save_dir_path)
    drive_save_dir_path = "/content/drive/MyDrive/Muteffstage/Checkpoints"
    if not path.exists(drive_save_dir_path):
        os.mkdir(drive_save_dir_path)
    model = LM()
    checkpoint_callback = ModelCheckpoint(
        dirpath=save_dir_path,
        # monitor='val_loss',
        filename='multiple-instances-cord19-{epoch:02d}',
        save_top_k=1,
        # mode='min',
        verbose=True,
        save_weights_only=True
    )
    # early_stopping_callback = EarlyStopping(
    #     monitor='val_loss',
    #     patience=2,
    #     mode='min',
    #     verbose=True
    # )
    trainer = Trainer(
        gpus=1,
        max_epochs=8,
        precision=16,
        callbacks=[
           checkpoint_callback,
        #    early_stopping_callback
        ]
    )
    trainer.fit(model)

    from distutils.dir_util import copy_tree
    copy_tree("/content/Checkpoints", "Checkpoints")