<a href="https://colab.research.google.com/github/GrainSack/ML/blob/main/Attention_based_RNN_Seq2Seq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/drive")
import os
import sys
from datetime import datetime

drive_project_root = "Your-path"
sys.path.append(drive_project_root)
!pip install -r "/content/drive/MyDrive/requirements.txt"

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

In [None]:
#Google colab 에서 pythorch 안될때
!pip install git+https://github.com/PyTorchLightning/pytorch-lightning
import pytorch_lightning as pl
print(pl.__version__)

# For data loading.
from typing import List
from typing import Dict
from typing import Union
from typing import Any
from typing import Optional
from typing import Iterable
from typing import Callable
from abc import abstractmethod
from abc import ABC
from datetime import datetime
from functools import partial
from collections import Counter
from collections import OrderedDict
import random
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pprint import pprint

from torchtext import data
from torchtext import datasets
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import Vocab, build_vocab_from_iterator, vocab
import spacy

# For configuration
from omegaconf import DictConfig
from omegaconf import OmegaConf
import hydra
from hydra.core.config_store import ConfigStore

# For logger
from torch.utils.tensorboard import SummaryWriter
import wandb
os.environ["WANDB_START_METHOD"]="thread"

In [None]:
from data_utils import dataset_split
from config_utils import flatten_dict
from config_utils import register_config
from config_utils import configure_optimizers_from_cfg
from config_utils import get_loggers
from config_utils import get_callbacks

In [None]:
# download spacy data.
!python -m spacy download en
!python -m spacy download en_core_web_sm
!python -m spacy download de
!python -m spacy download de_core_news_sm

In [None]:
data_spacy_de_en_cfg = {
    "name": "spacy_de_en",
    "data_root": os.path.join(os.getcwd(), "data"),
    "tokenizer": "spacy",
    "src_lang": "de",
    "tgt_lang": "en",
    "src_index": 0,
    "tgt_index": 1,
    "vocab": {
        "special_symbol2index": {
            "<unk>": 0,
            "<pad>": 1,
            "<bos>": 2,
            "<eos>": 3,
        },
        "special_first": True,
        "min_freq": 2
    }
}

data_cfg = OmegaConf.create(data_spacy_de_en_cfg)

In [None]:
train_data, valid_data, test_data = Multi30k(data_cfg.data_root)
test_data = to_map_style_dataset(test_data)

In [None]:
for i in test_data:
    print(i)
    break

In [None]:
# 1. token_transform (token ...)

def get_token_transform(data_cfg: DictConfig) -> dict:
    token_transform: dict = {}
    token_transform[data_cfg.src_lang] = get_tokenizer(
        data_cfg.tokenizer, language=data_cfg.src_lang
    )
    token_transform[data_cfg.tgt_lang] = get_tokenizer(
        data_cfg.tokenizer, language=data_cfg.tgt_lang
    )
    return token_transform

token_transform = get_token_transform(data_cfg)


In [None]:
# 2. vocab_transform 
def yield_tokens(
    data_iter: Iterable, lang: str, lang2index: Dict[str, int],
) -> List[str]:
    """help function to yield list of tokens"""
    for data_sample in data_iter:
        yield token_transform[lang](data_sample[lang2index[lang]])


def get_vocab_transform(data_cfg: DictConfig) -> dict:
    vocab_transform: dict = {}
    for ln in [data_cfg.src_lang, data_cfg.tgt_lang]:
        # build from train_data
        train_iter = Multi30k(
            split="train", language_pair=(data_cfg.src_lang, data_cfg.tgt_lang)
        )

        # create torchtext's Vocab object.
        vocab_transform[ln] = build_vocab_from_iterator(
            yield_tokens(
                train_iter,
                ln,
                {
                    data_cfg.src_lang: data_cfg.src_index,
                    data_cfg.tgt_lang: data_cfg.tgt_index,
                }
            ),
            min_freq=data_cfg.vocab.min_freq,
            specials=list(data_cfg.vocab.special_symbol2index.keys()),
            special_first=True
        )
    
    # set UNKNOWN as the default index. --> index가 unknown으로 return : token이 찾아지지 않을 경우 !
    # 만약 세팅되지 않으면, runtime error 가 날 수 있다 !
    for ln in [data_cfg.src_lang, data_cfg.tgt_lang]:
        vocab_transform[ln].set_default_index(data_cfg.vocab.special_symbol2index["<unk>"])
    
    return vocab_transform

vocab_transform = get_vocab_transform(data_cfg)

In [None]:
print(vocab_transform["de"]["<unk>"])
print(vocab_transform["en"]["<unk>"])
print(vocab_transform["en"]["hello"], vocab_transform["en"]["world"])

In [None]:
# 3 integrated transforms
# text_transform: [token_transform --> vocab_transform --> torch.tensor transfrom]


# helper function for collate_fn
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# convert to torch.tensor with bos & eos
def tensor_transform(token_ids: List[int], bos_index: int, eos_index: int):
    return torch.cat(
        (torch.tensor([bos_index]), torch.tensor(token_ids), torch.tensor([eos_index]))
    )

# src & tgt lang language text_transforms to convert raw strings --> tensor indices
def get_text_transform(data_cfg: DictConfig):
    text_transform = {}
    for ln in [data_cfg.src_lang, data_cfg.tgt_lang]:
        text_transform[ln] = sequential_transforms(
            token_transform[ln],
            vocab_transform[ln],
            partial(
                tensor_transform,
                bos_index=data_cfg.vocab.special_symbol2index["<bos>"],
                eos_index=data_cfg.vocab.special_symbol2index["<eos>"],
            )
        )
    return text_transform

text_transform = get_text_transform(data_cfg)

In [None]:
print(text_transform["en"]("hello"))
print(text_transform["en"]("hello,"))
print(text_transform["en"]("hello, how"))
print(text_transform["en"]("hello, how are you ?"))

In [None]:
# 4 collate_fn. --> batch를 전처리 할까? 
def collate_fn(batch, data_cfg: DictConfig):
    src_batch, tgt_batch = [], []

    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[data_cfg.src_lang](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[data_cfg.tgt_lang](tgt_sample.rstrip("\n")))
    
    src_batch = pad_sequence(src_batch, padding_value=data_cfg.vocab.special_symbol2index["<pad>"])
    tgt_batch = pad_sequence(tgt_batch, padding_value=data_cfg.vocab.special_symbol2index["<pad>"])
    return src_batch, tgt_batch

def get_collate_fn(cfg: DictConfig):
    return partial(collate_fn, data_cfg=cfg.data)

# 5 data loader 
def get_multi30k_dataloader(split_mode: str, language_pair: tuple, batch_size: int, collate_fn: Callable):

    iter = Multi30k(split=split_mode, language_pair=language_pair)
    dataset = to_map_style_dataset(iter)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
    return dataloader

test_dataloader = get_multi30k_dataloader("test", (data_cfg.src_lang, data_cfg.tgt_lang), 3, collate_fn=partial(collate_fn, data_cfg=data_cfg))

In [None]:
for i in test_dataloader:
    print(i)
    break

In [None]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
def _text_postprocessing(res: List[str]) -> str:
    if "<eos>" in res:
        res = res[:res.index("<eos>")]
    if "<pad>" in res:
        res = res[:res.index["<pad>"]]
    res = " ".join(res).replace("<bos>", "")
    return res


class BaseTranslateLightningModule(pl.LightningModule):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        self.loss_function = torch.nn.CrossEntropyLoss(
            ignore_index=cfg.data.vocab.special_symbol2index["<pad>"]
        )
    
    def configure_optimizers(self):
        self._optimizers, self._schedulers = configure_optimizers_from_cfg(
            self.cfg, self
        )
        return self._optimizers, self._schedulers
    
    @abstractmethod
    def forward(self, src, tgt, teacher_forcing_ratio: float):
        raise NotImplementedError()
    
    def _forward(self, src, tgt, mode: str, teacher_forcing_ratio: float = 0.5):
        # teacher forcing:
        # seq2seq 에서 많이 쓰인다.
        # src -> tgt autoregressive 학습하면, 맨 최초는 학습을 빠르게 한다. 근데, 미래부분 학습은? (앞부분 될때까지 기다리기 너무 힘들다...)
        # 랜덤으로 미래 정보도 조금 둬서 뒤에 있는 정보도 학습이 가능하게 하자 ! 
        # 근데 0.5다? ==> 0.5 확률로 teacher_forcing을 하겠다 !

        assert mode in ["train", "val", "test"]


        # get predictions
        # teacher_forcing 용 input --> 
        tgt_inputs = tgt[:-1, :] # delete ends for teacher forcing inputs.
        outputs = self(src, tgt_inputs, teacher_forcing_ratio=teacher_forcing_ratio)
        tgt_outputs = tgt[1:, :] # delete start tokens.

        loss = self.loss_function(
            outputs.reshape(-1, outputs.shape[-1]),  # [[batch X Seq_size], other_output_shape]
            tgt_outputs.reshape(-1),
        )

        logs_detail = {
            f"{mode}_src": src,
            f"{mode}_tgt": tgt,
            f"{mode}_results": outputs,
        }

        if mode in ["val", "test"]:
            _, tgt_results = torch.max(outputs, dim=2)

            src_texts = []
            tgt_texts = []
            res_texts = []

            # convert [L X B X others] --> [B X L X others]
            for src_i in torch.transpose(src, 0, 1).detach().cpu().numpy().tolist():
                res = vocab_transform[self.cfg.data.src_lang].lookup_tokens(src_i)
                src_texts.append(_text_postprocessing(res))
            
            for tgt_i in torch.transpose(tgt, 0, 1).detach().cpu().numpy().tolist():
                res = vocab_transform[self.cfg.data.tgt_lang].lookup_tokens(tgt_i)
                tgt_texts.append(_text_postprocessing(res))
            
            for tgt_res_i in torch.transpose(tgt_results, 0, 1).detach().cpu().numpy().tolist():
                res = vocab_transform[self.cfg.data.tgt_lang].lookup_tokens(tgt_res_i)
                res_texts.append(_text_postprocessing(res))

            text_result_summary = {
                f"{mode}_src_text": src_texts,
                f"{mode}_tgt_text": tgt_texts,
                f"{mode}_results_text": res_texts,
            }
            print(f"{self.global_step} step: \n src_text: {src_texts[0]}, \n tgt_text: {tgt_texts[0]}, \n result_text: {res_texts[0]}")
            logs_detail.update(text_result_summary)

        return {f"{mode}_loss": loss}, logs_detail
    

    def training_step(self, batch, batch_idx):
        src, tgt = batch[0], batch[1]
        logs, _ = self._forward(src, tgt, "train", self.cfg.model.teacher_forcing_ratio)
        self.log_dict(logs)
        logs["loss"] = logs["train_loss"]
        return logs
    
    def validation_step(self, batch, batch_idx):
        src, tgt = batch[0], batch[1]
        logs, logs_detail = self._forward(src, tgt, "val", 0.0)
        self.log_dict(logs)
        logs["loss"] = logs["val_loss"]
        logs.update(logs_detail)
        return logs
    
    def test_step(self, batch, batch_idx):
        src, tgt = batch[0], batch[1]
        logs, logs_detail = self._forward(src, tgt, "test", 0.0)
        self.log_dict(logs)
        logs["loss"] = logs["test_loss"]
        logs.update(logs_detail)
        return logs

In [None]:
# unils for initialization
def init_weights(model: Union[nn.Module, pl.LightningModule]):
    for name, param in model.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)

In [None]:
# model definition

# 1. encoder (??)
class LSTMEncoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        embed_dim: int,
        hidden_dim: int,
        n_layers: int,
        dropout: float
    ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

        # initialization of weights
        self.apply(init_weights)
    
    def forward(self, src):
        # src = [seq_len, batch_size]
        embedded = self.dropout(self.embedding(src)) # [seq_len, batch_size, emb_dim]

        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs = [seq_len, batch_size, hidden_dim * n directional]
        # hidden, cell = [n layers * n directions, batch_size, hidden_dim]
        
        # outputs will be used from top hidden layers. 
        return hidden, cell

# 2. decoder (??)
class LSTMDecoder(nn.Module):
    def __init__(
        self, 
        output_dim: int,
        embed_dim: int,
        hidden_dim: int,
        n_layers: int,
        dropout: float
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hidden_dim, output_dim)

        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input, hidden, cell):
        # input: [batch size  ...] < - start_token

        # outputs = [seq_len, batch_size, hidden_dim * n directional]
        # hidden, cell = [n layers * 1 directions, batch_size, hidden_dim]

        input = input.unsqueeze(0) # < - [1, batch_size]
        embedded = self.dropout(self.embedding(input))

        # embedding = [1, batch_size, embed_dim]
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))

        # output = [1, batch_size, hidden_dim]
        # hidden, cell = [n layers * 1 directions, batch_size, hidden_dim]

        prediction = self.fc_out(output.squeeze(0)) # [batch_size, output_dim] 

        return prediction, hidden, cell

# 3. Seq2Seq(cfg) <-- encoder + decoder
class LSTMSeq2Seq(BaseTranslateLightningModule):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg)

        self.encoder = LSTMEncoder(**cfg.model.enc)
        self.decoder = LSTMDecoder(**cfg.model.dec)

        assert self.encoder.hidden_dim == self.decoder.hidden_dim
        assert self.encoder.n_layers == self.decoder.n_layers

        # parameters 들 init.
        self.apply(init_weights)
    
    def forward(self, src, tgt, teacher_forcing_ratio: float = 0.5):
        
        # src, tgt = [seq_len (can be different), batch_size]
        # for val, test teacher forcing should be 0.0

        batch_size = tgt.shape[1]
        tgt_len = tgt.shape[0]
        tgt_vocab_size = self.decoder.output_dim

        # tensor to store decoder outputs
        outputs = torch.zeros(tgt_len, batch_size, tgt_vocab_size).to(self.device)

        hidden, cell = self.encoder(src)

        # start_token_input (<sos> tokens)
        input = tgt[0, :]

        for t in range(1, tgt_len):
            
            # get one cell's output
            output, hidden, cell = self.decoder(input, hidden, cell)

            # set to all outputs results
            outputs[t] = output

            # decide whether going to use teacher forcing or not.
            teacher_force = random.random() < teacher_forcing_ratio

            top1 = output.argmax(1)

            input = tgt[t] if teacher_force else top1

        return outputs



In [None]:
# Concat; Addictive Attention 기반의 모델 새로 정의.
# encoder, decoder rnn 이 다를 수 있다 ! 

class BidirectionalGRUEncoder(nn.Module):
    def __init__(
        self,
        input_dim: int,
        embed_dim: int,
        enc_hidden_dim: int,
        dec_hidden_dim: int,
        n_layers: int,
        dropout: float
    ):
        super().__init__()
        self.input_dim = input_dim
        self.n_layers = n_layers

        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.rnn = nn.GRU(embed_dim, enc_hidden_dim, n_layers, bidirectional=True, dropout=dropout)
        self.fc = nn.Linear(enc_hidden_dim * 2, dec_hidden_dim) #bidirection 
        self.dropout = nn.Dropout(dropout)

        # initialization of weights
        self.apply(init_weights)
    
    def forward(self, src):

        # src = [seq_len, batch_size]
        embedded = self.dropout(self.embedding(src)) # [seq_len, batch_size, emb_dim]

        outputs, hidden = self.rnn(embedded)
        # outputs = [seq_len, batch_size, hidden_dim * n directional]
        # hidden = [n layers * n directions, batch_size, hidden_dim]

        # hidden -> [forward_1, backward_1, forward_2, backward_2, ..]
        # 우리가 필요한건 맨 마지막 layer의 forward backward 두개 concat 한개 필요.
        # => torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) 
        
        # encoder RNNs fed through a linear layer to connect decoder.
        hidden = torch.tanh(self.fc(
            torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
        ))

        return outputs, hidden


class ConcatAttention(nn.Module):
    def __init__(self, enc_hidden_dim: int, dec_hidden_dim: int):
        super().__init__()
        
        self.attn = nn.Linear((enc_hidden_dim * 2) + dec_hidden_dim, dec_hidden_dim)
        self.v = nn.Linear(dec_hidden_dim, 1, bias=False)
    
    def forward(self, hidden, encoder_outputs):
        
        # hidden = [batch_size, dec_hidden_dim] => from decoder. (query)
        # encoder_outputs = [src_len, batch_size, enc_hidden_dim * 2] (key, value)

        batch_size = encoder_outputs.shape[1]
        src_len = encoder_outputs.shape[0]

        # repeat decoder hidden state src_len times ...
        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)

        # hidden: [batch_size, src_len, dec_hidden_dim]
        # encoder_outputs = [barch_size, src_len, enc_hidden_dim * 2]

        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))

        # energy = [batch_size, src_len, dec_hidden_dim]

        attention = self.v(energy).squeeze(2)

        # attention = [batch_size, src_len]

        return F.softmax(attention, dim=1)


class AttentionalRNNDecoder(nn.Module):
    def __init__(
        self,
        output_dim: int,
        embed_dim: int,
        enc_hidden_dim: int,
        dec_hidden_dim: int,
        n_layers: int,
        dropout: float,
        attention: nn.Module
    ):
        super().__init__()

        self.output_dim = output_dim
        self.attention = attention

        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.rnn = nn.GRU((enc_hidden_dim * 2) + embed_dim, dec_hidden_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear((enc_hidden_dim*2) + dec_hidden_dim + embed_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input, hidden, encoder_outputs):

        # input [batch size] # start_token
        # hidden [batch_size, dec_hidden_dim]
        # encoder_outputs [src_len, batch_size, enc_hidden_dim*2]

        input = input.unsqueeze(0) # input = [1, batch_size]

        embedded = self.dropout(self.embedding(input)) # 1, batch_size, embed_dim

        a = self.attention(hidden, encoder_outputs) # [batch_size, src_len]
        a = a.unsqueeze(1) # [batch_size, 1, src_len]

        encoder_outputs = encoder_outputs.permute(1, 0, 2) # [batch_size, src_len, enc_hidden_dim*2]
        weighted = torch.bmm(a, encoder_outputs) # [batch_size, 1, enc_hidden_dim*2]
        weighted = weighted.permute(1, 0, 2) # [1, batch_size, enc_hidden_dim*2]

        rnn_input = torch.cat((embedded, weighted), dim=2) # [1, batch_size, (enc_hidden_dim*2 + embed_dim)]

        # hidden_unsqueeze: [1, batch_size, dec_hidden_dim]
        output, hidden = self.rnn(rnn_input, hidden.unsqueeze(0))
        # output = [seq_len, batch_size, dec_hidden_dim * n directions] => [1, batch_size, dec_hidden_dim]
        # hidden = [n layers * n_directions, batch_size, dec_hidden_dims] => [1, batch_size, dec_hidden_dim]

        if not (output == hidden).all():
            raise AssertionError()
        
        embedded = embedded.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)

        prediction = self.fc_out(torch.cat((output, weighted, embedded), dim=1)) # [batch_size, output_dim]

        return prediction, hidden.squeeze(0)



class AttentionBasedSeq2Seq(BaseTranslateLightningModule):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg)

        self.encoder = BidirectionalGRUEncoder(**cfg.model.enc)
        self.attention = ConcatAttention(**cfg.model.attention)
        self.decoder = AttentionalRNNDecoder(
            attention=self.attention, **cfg.model.dec
        )

    def forward(self, src, tgt, teacher_forcing_ratio: float = 0.5):
        
        # src, tgt = [seq_len (can be different), batch_size]
        # for val, test teacher forcing should be 0.0

        batch_size = tgt.shape[1]
        tgt_len = tgt.shape[0]
        tgt_vocab_size = self.decoder.output_dim

        # tensor to store decoder outputs
        outputs = torch.zeros(tgt_len, batch_size, tgt_vocab_size).to(self.device)

        encoder_outputs, hidden = self.encoder(src)

        # start_token_input (<sos> tokens)
        input = tgt[0, :]

        for t in range(1, tgt_len):
            
            # get one cell's output
            output, hidden = self.decoder(input, hidden, encoder_outputs)

            # set to all outputs results
            outputs[t] = output

            # decide whether going to use teacher forcing or not.
            teacher_force = random.random() < teacher_forcing_ratio

            top1 = output.argmax(1)

            input = tgt[t] if teacher_force else top1

        return outputs


In [None]:
data_spacy_de_en_cfg = {
    "name": "spacy_de_en",
    "data_root": os.path.join(os.getcwd(), "data"),
    "tokenizer": "spacy",
    "src_lang": "de",
    "tgt_lang": "en",
    "src_index": 0,
    "tgt_index": 1,
    "vocab": {
        "special_symbol2index": {
            "<unk>": 0,
            "<pad>": 1,
            "<bos>": 2,
            "<eos>": 3,
        },
        "special_first": True,
        "min_freq": 2
    }
}

data_cfg = OmegaConf.create(data_spacy_de_en_cfg)

# get_dataset
train_data, valid_data, test_data = Multi30k(data_cfg.data_root)

token_transform = get_token_transform(data_cfg)
vocab_transform = get_vocab_transform(data_cfg)


In [None]:
# model configs

model_translate_lstm_seq2seq_cfg = {
    "name": "LSTMSeq2Seq",
    "enc": {
        "input_dim": len(vocab_transform[data_cfg.src_lang]),
        "embed_dim": 256,
        "hidden_dim": 256,
        "n_layers": 2,
        "dropout": 0.5,
    },
    "dec": {
        "output_dim": len(vocab_transform[data_cfg.tgt_lang]),
        "embed_dim": 256,
        "hidden_dim": 256,
        "n_layers": 2,
        "dropout": 0.5,
    },
    "teacher_forcing_ratio": 0.5
}

model_translate_attention_based_seq2seq_cfg = {
    "name": "AttentionBasedSeq2Seq",
    "enc": {
        "input_dim": len(vocab_transform[data_cfg.src_lang]),
        "embed_dim": 256,
        "enc_hidden_dim": 512,
        "dec_hidden_dim": 512,
        "n_layers": 1,
        "dropout": 0.5,
    },
    "dec": {
        "output_dim": len(vocab_transform[data_cfg.tgt_lang]),
        "embed_dim": 256,
        "enc_hidden_dim": 512,
        "dec_hidden_dim": 512,
        "n_layers": 1,
        "dropout": 0.5,
    },
    "attention": {
        "enc_hidden_dim": 512,
        "dec_hidden_dim": 512,
    },
    "teacher_forcing_ratio": 0.5
}

# opt_config
opt_cfg = {
    "optimizers": [
        {
            "name": "RAdam",
            "kwargs": {
                "lr": 1e-3,
            }
        }
    ],
    "lr_schedulers": [
        {
            "name": None,
            "kwargs": {
                "warmup_end_steps": 1000
            }
        }
    ]
}

_merged_cfg_presets = {
    "LSTM_seq2seq_de_en_translate": {
        "opt": opt_cfg,
        "data": data_spacy_de_en_cfg,
        "model": model_translate_lstm_seq2seq_cfg,
    },
    "attention_based_seq2seq_de_en_translate": {
        "opt": opt_cfg,
        "data": data_spacy_de_en_cfg,
        "model": model_translate_attention_based_seq2seq_cfg,
    },
}

# clear config hydra instance first
hydra.core.global_hydra.GlobalHydra.instance().clear()

# register preset configs
register_config(_merged_cfg_presets)

# initialization & compose configs
hydra.initialize(config_path=None)
cfg = hydra.compose("attention_based_seq2seq_de_en_translate")

# overrride some cfg
run_name = f"{datetime.now().isoformat(timespec='seconds')}-{cfg.model.name}-{cfg.data.name}"

project_root_dir = os.path.join(
    drive_project_root, "runs", "de_en_translate_tutorials"
)
save_dir = os.path.join(project_root_dir, run_name)
run_root_dir = os.path.join(project_root_dir, run_name)

train_cfg = {
    "train_batch_size": 128,
    "val_batch_size": 32,
    "test_batch_size": 32,
    "train_val_split": [0.9, 0.1],
    "run_root_dir": run_root_dir,
    "trainer_kwargs": {
        "accelerator": "gpu",
        "gpus": "0",
        "max_epochs": 50,
        "val_check_interval": 1.0,
        "log_every_n_steps": 100,
    }
}

# logger config
log_cfg = {
    "loggers": {
        "WandbLogger": {
            "project": "fastcampus_de_en_translate_tutorials",
            "name": run_name,
            "tags": ["fastcampus_de_en_translate_tutorials"],
            "save_dir": run_root_dir,
        },
        "TensorBoardLogger": {
            "save_dir": project_root_dir,
            "name": run_name,
            "log_graph": True,
        }
    },
    "callbacks": {
        "ModelCheckpoint": {
            "save_top_k": 3,
            "monitor": "val_loss",
            "mode": "min",
            "verbose": True,
            "dirpath": os.path.join(run_root_dir, "weights"),
            "filename": "{epoch}-{val_loss:.3f}",
        },
        "EarlyStopping": {
            "monitor": "val_loss",
            "mode": "min",
            "patience": 3,
            "verbose": True,
        }
    }
}

OmegaConf.set_struct(cfg, False)
cfg.train = train_cfg
cfg.log = log_cfg

# lock config
OmegaConf.set_struct(cfg, True)
print(OmegaConf.to_yaml(cfg))


In [None]:
# dataloader def
train_dataloader = get_multi30k_dataloader(
    "train",
    (data_cfg.src_lang, data_cfg.tgt_lang),
    cfg.train.train_batch_size,
    collate_fn=get_collate_fn(cfg)
)
val_dataloader = get_multi30k_dataloader(
    "valid",
    (data_cfg.src_lang, data_cfg.tgt_lang),
    cfg.train.val_batch_size,
    collate_fn=get_collate_fn(cfg)
)
test_dataloader = get_multi30k_dataloader(
    "test",
    (data_cfg.src_lang, data_cfg.tgt_lang),
    cfg.train.test_batch_size,
    collate_fn=get_collate_fn(cfg)
)

In [None]:
# get model
def get_pl_model(cfg: DictConfig, checkpoint_path: Optional[str] = None):

    if cfg.model.name == "LSTMSeq2Seq":
        model = LSTMSeq2Seq(cfg)
    elif cfg.model.name == "AttentionBasedSeq2Seq":
        model = AttentionBasedSeq2Seq(cfg)
    else:
        raise NotImplementedError("Not implemented model")
    
    if checkpoint_path is not None:
        model = model.load_from_checkpoint(cfg, checkpoint_path=checkpoint_path)
    return model

model = None
model = get_pl_model(cfg)
print(model)

In [None]:
# pytorch lightning trainer def
logger = get_loggers(cfg)
callbacks = get_callbacks(cfg)

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=logger,
    default_root_dir=cfg.train.run_root_dir,
    num_sanity_val_steps=2,
    **cfg.train.trainer_kwargs,
)

In [None]:
trainer.fit(model, train_dataloader, val_dataloader)