In [1]:
import sys
sys.path.append("/disk2/iping/NTU_ADL/ADL_hw1/hw1_sample_code/src")

import os
import pickle
from argparse import Namespace
from typing import Tuple, Dict
import random

import torch
from torch import nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import torch.nn.functional as F

from dataset import Seq2SeqDataset
from pytorch_lightning.callbacks import ModelCheckpoint

In [2]:
class Encoder(nn.Module):
    def __init__(self, embedding_path, emb_dim, enc_hid_dim, enc_layers, dec_hid_dim, enc_dropout):
        super(Encoder, self).__init__()
        with open(embedding_path, 'rb') as f:
            embedding = pickle.load(f)
        embedding_weight = embedding.vectors
        self.embedding = nn.Embedding.from_pretrained(embedding_weight)
        
        self.emb_dim = emb_dim
        self.enc_hid_dim = enc_hid_dim
        self.enc_layers = enc_layers
        self.enc_dropout = enc_dropout
        
        self.rnn = nn.GRU(input_size = emb_dim,
                          hidden_size = enc_hid_dim,
                          num_layers = enc_layers,
                          bidirectional = True,
                          batch_first = False)
        
        self.fc = nn.Linear(enc_hid_dim * 2, dec_hid_dim)
        self.dropout = nn.Dropout(enc_dropout)

    def forward(self, src):
        embedded = self.dropout(self.embedding(src))
        outputs, hidden = self.rnn(embedded)
        hidden = torch.tanh(self.fc(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)))
        return outputs, hidden    


class Decoder(nn.Module):
    def __init__(self, embedding_path, output_dim, emb_dim, enc_hid_dim, dec_hid_dim, dec_dropout, dec_num_layers):
        super().__init__()

        with open(embedding_path, 'rb') as f:
            embedding = pickle.load(f)
        embedding_weight = embedding.vectors
        self.embedding = nn.Embedding.from_pretrained(embedding_weight)
        
        self.output_dim = output_dim
        
        self.rnn = nn.GRU(emb_dim, dec_hid_dim, num_layers = dec_num_layers, batch_first = False)
#         self.fc_out = nn.Linear(dec_hid_dim + emb_dim, output_dim)
        self.fc_out = nn.Linear(dec_hid_dim, output_dim)
        self.dropout = nn.Dropout(dec_dropout)
        
    def forward(self, input, hidden):

        input = input.unsqueeze(0)
        embedded = self.dropout(self.embedding(input))
        

        output, hidden = self.rnn(embedded, hidden.unsqueeze(0))

        assert (output == hidden).all()
        
#         embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        
#         prediction = self.fc_out(torch.cat((output, embedded), dim = 1))

        prediction = self.fc_out(output)

        return prediction, hidden.squeeze(0)

class Seq2Seq(pl.LightningModule):
    def __init__(self, hparams) -> None:
        super(Seq2Seq, self).__init__()
        self.hparams = hparams
        self.criterion = nn.CrossEntropyLoss(ignore_index = self.hparams.ignore_idx)
        self.encoder = Encoder(hparams.embedding_path,
                               hparams.emb_dim,
                               hparams.enc_hid_dim,
                               hparams.enc_num_layers,
                               hparams.dec_hid_dim,
                               hparams.enc_dropout)
        self.decoder = Decoder(hparams.embedding_path,
                               hparams.output_dim,
                               hparams.emb_dim,
                               hparams.enc_hid_dim,
                               hparams.dec_hid_dim,
                               hparams.dec_dropout,
                               hparams.dec_num_layers)

    def forward(self, src, trg, teacher_forcing_ratio = 0.5):
        batch_size = src.shape[1]
        trg_vocab_size = self.decoder.output_dim
                
        if trg != 'test':
            input = trg[0,:]
            trg_len = trg.shape[0]

            outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to('cuda:0')
            encoder_outputs, hidden = self.encoder(src)
            for t in range(1, trg_len):
                output, hidden = self.decoder(input, hidden)
                outputs[t] = output
                teacher_force = random.random() > teacher_forcing_ratio
                top1 = output.argmax(1) 
                input = trg[t] if teacher_force else top1
        else:
            input = torch.ones(batch_size).to(device = 'cuda:0', dtype=torch.int64)
            trg_len = 80
            outputs = torch.zeros(trg_len, batch_size, trg_vocab_size).to('cuda:0')
            encoder_outputs, hidden = self.encoder(src)
            for t in range(1, trg_len):
                output, hidden = self.decoder(input, hidden)
                outputs[t] = output
                input = output.argmax(1) 

        return outputs

    def _unpack_batch(self, batch) -> Tuple[torch.tensor, torch.tensor]:
        try:
            return batch['text'], batch['summary']
        except:
            return batch['text']

    def _calculate_loss(self, output, trg) -> torch.tensor:
        # TODO
        # calculate the logits
        # plz use BCEWithLogit
        # adjust pos_weight!
        # MASK OUT PADDINGS' LOSSES!
        
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].reshape(-1)
        
        loss = self.criterion(output, trg)
        
        return loss

    def training_step(self, batch, batch_nb) -> Dict:
        x, y = self._unpack_batch(batch)
        x = x.permute(1,0)
        y = y.permute(1,0)
        output = self.forward(x,y)
        loss = self._calculate_loss(output, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb) -> Dict:
        x, y = self._unpack_batch(batch)
        x = x.permute(1,0)
        y = y.permute(1,0)
        output = self.forward(x,y)
        loss = self._calculate_loss(output, y)
        return {'val_loss': loss}
    
    def validation_epoch_end(self, outputs) -> Dict:
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
    
    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

    def _load_dataset(self, dataset_path: str) -> Seq2SeqDataset:
        with open(dataset_path, 'rb') as f:
            dataset = pickle.load(f)
        return dataset

    def train_dataloader(self):
        dataset = self._load_dataset(self.hparams.train_dataset_path)
        return DataLoader(dataset, 
                          self.hparams.batch_size, 
                          shuffle=True,
                          collate_fn=dataset.collate_fn)

    def val_dataloader(self):
        dataset = self._load_dataset(self.hparams.valid_dataset_path)
        return DataLoader(dataset, 
                          self.hparams.batch_size, 
                          collate_fn=dataset.collate_fn)
    
class MyPrintingCallback(pl.Callback):
    
    def on_validation_start(self, trainer, pl_module):
        print('validation starts')
        
    def on_validation_end(self, trainer, pl_module):
        print('validation end')

In [3]:
data_path = '/disk2/iping/NTU_ADL/ADL_hw1/data'
hparams = Namespace(**{
    'embedding_path': data_path + "/embedding_seq2seq.pkl",
    'train_dataset_path': data_path + "/train_seq2seq.pkl",
    'valid_dataset_path': data_path + "/valid_seq2seq.pkl",
    
    'ignore_idx': 0,
    'batch_size': 64,
    
    'emb_dim' : 300,
    'enc_hid_dim' : 512,
    'enc_num_layers' : 1,
    'enc_dropout' : 0,
    
    'output_dim' : 97513,
    'dec_hid_dim' : 512,
    'dec_dropout' : 0,
    'dec_num_layers': 1,

    'lr': 1e-04,
})

In [4]:
PATH_checkpoint = "/disk2/iping/NTU_ADL/ADL_hw1/seq2seq_model/"
PATH_checkpoint += "model_drop0_lr1e04_noembed_{epoch:02d}"

checkpoint_callback = ModelCheckpoint(
    filepath=PATH_checkpoint,
    save_top_k=True,
    verbose=True,
    monitor='avg_val_loss',
    mode='min',
    prefix=''
)

seq2seq = Seq2Seq(hparams)
trainer = pl.Trainer(gpus=1, max_epochs=20, checkpoint_callback=checkpoint_callback)
trainer.fit(seq2seq)

  from ._conv import register_converters as _register_converters


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=313.0, style=Pr…




1

In [4]:
from tqdm import tqdm
PATH_checkpoint = "/disk2/iping/NTU_ADL/ADL_hw1/seq2seq_model/"
PATH_checkpoint += "best_seq2seq_abstractive_model.ckpt"
seq2seq = Seq2Seq.load_from_checkpoint(PATH_checkpoint)

In [5]:
seq2seq

Seq2Seq(
  (criterion): CrossEntropyLoss()
  (encoder): Encoder(
    (embedding): Embedding(97513, 300)
    (rnn): GRU(300, 512, bidirectional=True)
    (fc): Linear(in_features=1024, out_features=512, bias=True)
    (dropout): Dropout(p=0, inplace=False)
  )
  (decoder): Decoder(
    (embedding): Embedding(97513, 300)
    (rnn): GRU(300, 512)
    (fc_out): Linear(in_features=512, out_features=97513, bias=True)
    (dropout): Dropout(p=0, inplace=False)
  )
)

In [5]:
# if pl test doesnt work
with open(data_path + "/valid_seq2seq.pkl", 'rb') as f:
    dataset = pickle.load(f)
test_loader = DataLoader(dataset, 
                         16, 
                         collate_fn=dataset.collate_fn)
device ='cuda:0'
print(device)
trange_test = tqdm(enumerate(test_loader), total=len(test_loader), desc = 'Test')
ans = []
seq2seq.train(False)
log_softmax = nn.LogSoftmax(dim=-1)
for z, (batch) in trange_test:
    x,_ = seq2seq._unpack_batch(batch)
    x = x.to(device)
    x = x.permute(1,0)
    seq2seq.to(device)
    output = seq2seq.forward(x, 'test', 0)
    output = log_softmax(output)
    output = torch.argmax(output.permute(1,0,2), axis = 2)    
    output = output.type(torch.int64).tolist()
    ans.extend(output)

Test:   0%|          | 0/1250 [00:00<?, ?it/s]

cuda:0


Test: 100%|██████████| 1250/1250 [04:49<00:00,  4.32it/s]


In [6]:
data_path = '/disk2/iping/NTU_ADL/ADL_hw1/data'
with open(data_path + "/embedding_seq2seq.pkl", 'rb') as f:
    embed_dataset = pickle.load(f)
with open(data_path + "/valid_seq2seq.pkl", 'rb') as f:
    valid_dataset = pickle.load(f)

In [19]:
import numpy as np
ans_jsonl = []
for data, a in zip(valid_dataset, ans):
    now_sent = ''
    for i in a:
        if i != 2:
            now_vocab = embed_dataset.vocab[i]
            now_sent = now_sent + now_vocab + ' '
        else:
            now_sent = now_sent[:-6]
            now_sent = now_sent[7:]
            break
    ans_jsonl.append({'id':data['id'], 'predict': now_sent})

In [21]:
import json
PATH_save = "/disk2/iping/NTU_ADL/ADL_hw1/for_testing/"
filename_save = 'output_seq2seq_lr1e04_nounk_noheader'
with open(PATH_save+filename_save+'.jsonl', 'w') as outfile:
    for entry in ans_jsonl:
        json.dump(entry, outfile)
        outfile.write('\n')