In [1]:
import pandas as pd
from datasets import load_dataset
import re
import string

from transformers import AutoTokenizer

import math

import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from torch.optim import Adam
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR

import evaluate 

import os

In [2]:
ja_tokenizer = AutoTokenizer.from_pretrained('cl-tohoku/bert-base-japanese')
en_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [3]:
sos_id = en_tokenizer.cls_token_id
eos_id = en_tokenizer.sep_token_id

In [4]:
ja_punct = ['。', '、', '（', '）', '[', ']', '{',  '}', '【 ', '】', '〔', '〕', '<', '>', '，', '゠', '＝', '…', '‥', '『', '』', '〝', '〟',
           '⟨', '⟩', '〜', '：', '！', '♪', '〖', '〗', '〘', '〙', '※', '〇', '│', '│']

In [5]:
zenkaku = ['０', '１', '２', '３', '４', '５', '６', '７', '８', '９', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 
           '（', '）', '＊', '「', '」', '［', '］', '【', '】', '＜', '＞', '？', '・', '＃', '＠', '＄', '％', '＝']

In [6]:
raw_data = load_dataset("Verah/JParaCrawl-Filtered-English-Japanese-Parallel-Corpus")

In [7]:
ja_en = pd.DataFrame.from_dict(raw_data["train"][2:38000])
ja_en = ja_en[["japanese", "english"]].drop_duplicates().reset_index(drop=True)

In [8]:
ja_en.head()

Unnamed: 0,japanese,english
0,スポンサードリンク この広告は一定期間更新がない場合に表示されます。,Sponsored link This advertisement is displayed...
1,また、 プレミアムユーザー になると常に非表示になります。,"Also, it will always be hidden when becoming a..."
2,コンテンツの更新が行われると非表示に戻ります。,It will return to non-display when content upd...
3,Youtubeを中心にミニマリストと言っている方の動画をたくさんみましたが、納得いくもののも...,It’s like you can enrich it and save money as ...
4,ffmpeg -i sample.mp4 -strict -2 video.webm まとめ...,Go to the original video hierarchy of the conv...


In [9]:
ja_en.tail()

Unnamed: 0,japanese,english
21101,桜 Exhibition 2010 参加 @moko_u からのツイート 代表作品 桜 Ex...,Sakura Exhibition 2010 @moko_u My Works Entry ...
21102,主な仕事実績として、育児・教育関係の雑誌・書籍のイラスト多数。,I have drawn a lot of illustrations for books ...
21103,ストーリー性のあるイラストを、水彩絵の具や、 Photoshopで描いています。 桜 Exh...,I draw a sweet and somewhat funny world with w...
21104,お仕事では書籍の表紙や、カードゲームイラストを描かせていただいています。 @simetta ...,I draw illustrations for book covers and card ...
21105,季節感を表現した情緒あふれる作風が持ち味です。,I like to add the sense of the seasons to my i...


In [10]:
ja_en.shape

(21106, 2)

In [11]:
ja_en['japanese'].isna().sum()

np.int64(0)

In [12]:
train_dataset = ja_en[2:19001]
val_dataset = ja_en[19001::]

In [13]:
train_dataset

Unnamed: 0,japanese,english
2,コンテンツの更新が行われると非表示に戻ります。,It will return to non-display when content upd...
3,Youtubeを中心にミニマリストと言っている方の動画をたくさんみましたが、納得いくもののも...,It’s like you can enrich it and save money as ...
4,ffmpeg -i sample.mp4 -strict -2 video.webm まとめ...,Go to the original video hierarchy of the conv...
5,大企業と付き合っていると 、 「 安定」している認識になってしまいがちです。,"When you’re dealing with a large corporation, ..."
6,極端に古くなければだいたい大丈夫でしょう。,"If it’s not too old, you should be able to use..."
...,...,...
18996,東京生まれ。,Born in Tokyo.
18997,イラスト関係の 出版・イベント企画・展示会企画等、国内外で幅広く活動。,Running a gallery only for illustrations “12G”...
18998,イラストだけでなくデザインや音楽、詩など幅広い分野で活躍を目論んでます。,I aim to be active not only with drawing illus...
18999,しかし次第に父が画家であることにコンプレックスを感じるようになり、しばらく美術の世界から離れ...,"However, I began to have an inferiority comple..."


In [14]:
val_dataset

Unnamed: 0,japanese,english
19001,東洋美術学校卒業後、グラフィックデザイナー、webデザイナーなどを経て、イラストレーターとなる。,Became a graphic designer and a designer of we...
19002,コミティアやグループ展などで作品を発表しています。,I show my works at COMITIA and group exhibitions.
19003,12年間民族学校で学んだ後、様々なアルバイト経験を経てフリーの絵描きとして活動中。,After studying at a Korean school in Japan for...
19004,illustratorのベジェ曲線で描く、ファンタジックで可愛い、ダークでクール、そして妖艶...,"I hope viewers will enjoy fantastic, cute, dar..."
19005,2007年、カメラマン フルカワチヒロとの出会いをきっかけに創作活動を再開。,I started to create works again in 2007 after ...
...,...,...
21101,桜 Exhibition 2010 参加 @moko_u からのツイート 代表作品 桜 Ex...,Sakura Exhibition 2010 @moko_u My Works Entry ...
21102,主な仕事実績として、育児・教育関係の雑誌・書籍のイラスト多数。,I have drawn a lot of illustrations for books ...
21103,ストーリー性のあるイラストを、水彩絵の具や、 Photoshopで描いています。 桜 Exh...,I draw a sweet and somewhat funny world with w...
21104,お仕事では書籍の表紙や、カードゲームイラストを描かせていただいています。 @simetta ...,I draw illustrations for book covers and card ...


In [15]:
def prepare_en_train_corpus(data):
    train_data = data['train'][2:35000]
    train_data = pd.DataFrame.from_dict(train_data)
    train_data = train_data[['english', 'japanese']].reset_index(drop=True)
    train_data.drop_duplicates(inplace=True)
    en_corpus = train_data['english'].tolist()
    return en_corpus

In [16]:
def preprocess_the_ja_text(text):
    text = re.sub(r'(https?://[a-zA-Z0-9.-]*)', r'', text)
    text = re.sub(r'(quote=\w+\s?\w+;?\w+)', r'', text)
    text = re.sub(r'[^\w\s]|\d+', r'', text)
    
    for z in zenkaku:
        text = text.replace(z, '')
        
    for punct in ja_punct:
        text = text.replace(punct, '')
    return text.strip()

In [17]:
def preprocess_the_en_text(text):
    text = text.lower().strip()
    text = re.sub(r'(https?://[a-zA-Z0-9.-]*)', r'', text)
    text = re.sub(r'(quote=\w+\s?\w+;?\w+)', r'', text)
    text = text.replace("'m", ' am')
    text = text.replace("'re", " are")
    text = text.replace("'ll", " will")
    text = text.replace("'ve", " have")
    text = text.replace("'d", " would")
    text = text.translate(str.maketrans(' ', ' ', string.punctuation))
    return text.strip()

In [18]:
class TranslationDataset(Dataset):
    def __init__(self, dataset, ja_tokenizer, en_tokenizer, max_len=80):
        
        self.ja_samples = [preprocess_the_ja_text(t) for t in dataset["japanese"].to_list()]
        self.en_samples = [preprocess_the_en_text(t) for t in dataset["english"].to_list()]
        
        self.ja_tokenizer = ja_tokenizer
        self.en_tokenizer = en_tokenizer
        
        self.max_len = max_len

        self.ja_ids = [self.ja_tokenizer.encode(ja, 
                                                max_length=self.max_len, 
                                                padding="max_length", 
                                                truncation=True, 
                                                return_tensors='pt').squeeze() for ja in self.ja_samples]
        self.en_ids = [self.en_tokenizer.encode(en, 
                                                max_length=self.max_len, 
                                                padding="max_length", 
                                                truncation=True, 
                                                return_tensors='pt', 
                                                add_special_tokens=True).squeeze() for en in self.en_samples]
        self.ja_ids = torch.stack(self.ja_ids)
        self.en_ids = torch.stack(self.en_ids)

    def __getitem__(self, idx):
        return self.ja_ids[idx], self.en_ids[idx]

    def __len__(self):
        return self.ja_ids.size(0)

    def ja_vocab_size(self):
        return len(self.ja_tokenizer)

    def en_vocab_size(self):
        return len(self.en_tokenizer)

In [19]:
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_len=80, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  
        
        self.register_buffer("pe", pe)
            
    def forward(self, x):
        seq_len = x.size(1)
        x = x * math.sqrt(self.d_model)
        x = x + self.pe[:, :seq_len, :].to(x.device)   
        return self.dropout(x)

In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        assert self.d_model % self.n_heads == 0, "d_model has to be divisible by n_heads."
        
        self.w_q = nn.Linear(self.d_model, self.d_model)
        self.w_k = nn.Linear(self.d_model, self.d_model)
        self.w_v = nn.Linear(self.d_model, self.d_model)
        self.out = nn.Linear(self.d_model, self.d_model)

        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        q_len = q.size(1)
        k_len = k.size(1)

        q = self.w_q(q)
        k = self.w_k(k)
        v = self.w_v(v)

        q = q.view(batch_size, q_len, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.view(batch_size, k_len, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.view(batch_size, k_len, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        attention_weights = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
        
        if mask is not None:
            attention_weights = attention_weights.masked_fill(mask == 0, float('-1e20'))
            
        attention_scores = F.softmax(attention_weights,dim=-1)
        attention_scores = self.dropout(attention_scores)
        context = torch.matmul(attention_scores, v)
        concat = context.permute(0, 2, 1, 3).contiguous().view(batch_size, q_len, self.n_heads*self.head_dim)

        x = self.out(concat)
        return x

In [21]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads=8, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_model*4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model*4, d_model)
        )
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, x, mask=None):
        attention_out = self.attention(x, x, x, mask)
        attention_residual_out = self.dropout1(attention_out) + x
        norm1_out = self.norm1(attention_residual_out)
        feedfwd_out = self.feedforward(norm1_out)
        feedfwd_residual_out = self.dropout2(feedfwd_out) + norm1_out
        output = self.norm2(feedfwd_residual_out)
        
        return output

In [22]:
class Encoder(nn.Module):
    def __init__(self, d_model, n_heads=8, dropout=0.1, n_layers=6):
        super().__init__()
        self.n_layers = n_layers
        self.layers = nn.ModuleList([
            EncoderBlock(d_model, n_heads, dropout=dropout)
            for i in range(n_layers)
        ])
        
    def forward(self, x, ja_mask=None):
        for layer in self.layers:
            x = layer(x, ja_mask)
        return x

In [23]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, n_heads=8, dropout=0.1, mask=None):
        super().__init__()
        self.masked_attention = MultiHeadAttention(d_model, n_heads)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_model*4), 
            nn.ReLU(),
            nn.Linear(d_model*4, d_model)
        )
        self.dropout3 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)
        
    def forward(self, x, enc_out, ja_mask=None, en_mask=None):
        masked_attention_out = self.masked_attention(x, x, x, mask=en_mask)
        masked_attention_residual_out = self.dropout1(masked_attention_out) + x
        norm1_out = self.norm1(masked_attention_residual_out) 
        
        attention_out = self.attention(norm1_out, enc_out, enc_out, mask=ja_mask)
        attention_residual_out = self.dropout2(attention_out) + norm1_out
        norm2_out = self.norm2(attention_residual_out)
        
        feedfwd_out = self.feedforward(norm2_out)
        feedfwd_residual_out = self.dropout3(feedfwd_out) + norm2_out
        output = self.norm3(feedfwd_residual_out)
        
        return output

In [24]:
class Decoder(nn.Module):
    def __init__(self, d_model, n_heads=8, dropout=0.1, n_layers=6):
        super().__init__()
        self.n_layers = n_layers
        self.layers = nn.ModuleList([
            DecoderBlock(d_model, n_heads, dropout=dropout)
            for i in range(n_layers)
        ])
    def forward(self, x, enc_out, en_mask=None, ja_mask=None):
        for layer in self.layers:
            x = layer(x, enc_out, en_mask=en_mask, ja_mask=ja_mask)
        return x

In [25]:
class Transformer(nn.Module):
    def __init__(self, d_model, ja_vocab_size, en_vocab_size, max_seq_len=80, n_heads=8, dropout=0.1, n_layers=6):
        super().__init__()
        self.d_model = d_model
        self.ja_embedding = nn.Embedding(ja_vocab_size, d_model)
        self.en_embedding = nn.Embedding(en_vocab_size, d_model)
        
        self.positional_encoder = PositionalEncoder(d_model, max_seq_len)

        self.dropout = nn.Dropout(dropout)
        
        self.encoder = Encoder(d_model=d_model, n_heads=n_heads, dropout=dropout, n_layers=n_layers)
        self.decoder = Decoder(d_model=d_model, n_heads=n_heads, dropout=dropout, n_layers=n_layers)
        
        self.out_proj = nn.Linear(d_model, en_vocab_size)
        
    def generate_mask(self, ja_ids, dec_input, device):
        device = dec_input.device
        ja_mask = (ja_ids != 0).unsqueeze(1).unsqueeze(2)
        en_mask =  (dec_input != 0).unsqueeze(1).unsqueeze(2)
        seq_len = dec_input.size(1)
        causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool()
        causal_mask = ~causal_mask
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(1)
        en_mask = en_mask & causal_mask
        return ja_mask, en_mask
        
    def forward(self, ja_ids, dec_input, ja_mask=None, en_mask=None):
        ja_mask, en_mask = self.generate_mask(ja_ids, dec_input, device)
        ja_embedding = self.dropout(self.positional_encoder(self.ja_embedding(ja_ids)))
        en_embedding = self.dropout(self.positional_encoder(self.en_embedding(dec_input)))
        encoder_out = self.encoder(ja_embedding, ja_mask)
        decoder_out = self.decoder(en_embedding, encoder_out, en_mask=en_mask, ja_mask=ja_mask)
        out = self.out_proj(decoder_out)
        return out

In [26]:
train_data = TranslationDataset(train_dataset,ja_tokenizer, en_tokenizer, max_len=80)
val_data = TranslationDataset(val_dataset,ja_tokenizer, en_tokenizer, max_len=80)

train_loader = DataLoader(train_data, batch_size=32, num_workers=0, shuffle=True, drop_last=False)
val_loader = DataLoader(val_data, batch_size=32, num_workers=0, shuffle=True, drop_last=False)

In [27]:
JA_VOCAB_SIZE = train_data.ja_vocab_size()
EN_VOCAB_SIZE = train_data.en_vocab_size()
D_MODEL = 512
N_HEADS = 8
N_LAYERS = 6
MAX_SEQ_LEN = 80
DROPOUT = 0.1
N_EPOCHS = 200
SAVE_EPOCH = 10
LR = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [28]:
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [29]:
print(str(JA_VOCAB_SIZE) + ",", EN_VOCAB_SIZE)
print(len(train_dataset))

32000, 30522
18999


In [30]:
def save_checkpoint(epoch, model, optimizer, loss, path):
    checkpoint = {
        "epoch" : epoch,
        "model_state" : model.state_dict(),
        "optimizer_state" : optimizer.state_dict(),
        "loss" : loss
    }
    torch.save(checkpoint, path)
    print(f"Saved checkpoint: {path}")

In [31]:
def load_checkpoint(path, model, optimizer=None, device="cpu"):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint["model_state"])
    if optimizer is not None and "optimizer_state" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state"])
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']} with loss {checkpoint['loss']:.4f}")
    return checkpoint["epoch"], checkpoint["loss"]

In [32]:
def lr_lambda(step):
    warmup_steps=4000
    step = max(step, 1)
    d_model=model.d_model
    return (d_model ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5)

In [33]:
model = Transformer(d_model=D_MODEL,
                    ja_vocab_size=JA_VOCAB_SIZE,
                    en_vocab_size=EN_VOCAB_SIZE,
                    max_seq_len=MAX_SEQ_LEN, 
                    n_heads=N_HEADS,
                    dropout=DROPOUT,
                   n_layers=N_LAYERS).to(device)
writer = SummaryWriter()
optimizer = Adam(model.parameters(), lr=LR, betas=(0.9, 0.98), eps=1e-9)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
scheduler = LambdaLR(optimizer, lr_lambda)
bleu = evaluate.load("bleu")

In [None]:
for epoch in range(1, N_EPOCHS+1):
    model.train()
    epoch_loss = 0
    train_iterator = tqdm(train_loader, desc=f"EPOCH: {epoch}/{N_EPOCHS}")
    
    for batch_idx, (ja_ids, en_ids) in enumerate(train_iterator):
        ja_ids = ja_ids.to(device)
        en_ids = en_ids.to(device)
        
        dec_input = en_ids[:, :-1]
        target = en_ids[:, 1:].contiguous().view(-1)
        
        outputs = model(ja_ids, dec_input)
        outputs = outputs.view(-1, outputs.size(-1))
        
        loss = criterion(outputs, target)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        

        epoch_loss += loss.item()
        train_iterator.set_postfix(loss=loss.item())
        writer.add_scalar("Train/Loss", loss.item(), epoch * len(train_loader) + batch_idx)

    avg_loss = epoch_loss / len(train_loader)
    perplexity = torch.exp(torch.tensor(avg_loss))
    print(f"Epoch {epoch} | Avg Loss: {avg_loss:.4f} | Perplexity: {perplexity:.4f}")
    writer.add_scalar("Train/Perplexity", perplexity, epoch)

    #model.eval()
    #predictions, references = [], []
    #with torch.no_grad():
        #for i, (ja_ids, en_ref_ids) in enumerate(val_loader):
            #if i >= 20:
                #break
                
            #ja_ids = ja_ids[0].unsqueeze(0).to(device)
            #dec_input = torch.tensor([[en_tokenizer.cls_token_id]], device=device)
            
            #output_tokens = []

            #for i in range(MAX_SEQ_LEN):
                #outputs = model(ja_ids, dec_input)
                #next_token = outputs[:, -1, :].argmax(-1)
                #dec_input = torch.cat([dec_input, next_token.unsqueeze(1)], dim=1)

                #if next_token.item() == en_tokenizer.sep_token_id:
                    #break
                #output_tokens.append(next_token.item())
                
            #prediction = en_tokenizer.decode(output_tokens, skip_special_tokens=True)
            #predictions.append(prediction)
            #en_ref = en_tokenizer.decode(en_ref_ids[0].cpu().tolist(), skip_special_tokens=True)
            #references.append([en_ref])  
            
    #bleu_score = bleu.compute(predictions=predictions, references=references)
    #print(f"Epoch {epoch} | BLEU: {bleu_score['bleu']:.4f}")
    #writer.add_scalar("Eval/BLEU", bleu_score['bleu'], epoch)
        
    if epoch % SAVE_EPOCH == 0:
        save_path = os.path.join(CHECKPOINT_DIR, f"epoch_{epoch}.pth")
        save_checkpoint(epoch, model, optimizer, avg_loss, save_path)

EPOCH: 1/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 1 | Avg Loss: 10.4652 | Perplexity: 35072.9414


EPOCH: 2/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 2 | Avg Loss: 10.4266 | Perplexity: 33746.3711


EPOCH: 3/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 3 | Avg Loss: 10.3504 | Perplexity: 31268.5332


EPOCH: 4/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 4 | Avg Loss: 10.2367 | Perplexity: 27909.9141


EPOCH: 5/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 5 | Avg Loss: 10.0934 | Perplexity: 24183.8203


EPOCH: 6/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 6 | Avg Loss: 9.9331 | Perplexity: 20600.0977


EPOCH: 7/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 7 | Avg Loss: 9.7699 | Perplexity: 17499.1406


EPOCH: 8/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 8 | Avg Loss: 9.6301 | Perplexity: 15215.6426


EPOCH: 9/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 9 | Avg Loss: 9.5267 | Perplexity: 13720.8047


EPOCH: 10/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 10 | Avg Loss: 9.4497 | Perplexity: 12703.9580
Saved checkpoint: ./checkpoints\epoch_10.pth


EPOCH: 11/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 11 | Avg Loss: 9.3896 | Perplexity: 11963.0029


EPOCH: 12/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 12 | Avg Loss: 9.3408 | Perplexity: 11393.2949


EPOCH: 13/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 13 | Avg Loss: 9.2990 | Perplexity: 10927.1992


EPOCH: 14/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 14 | Avg Loss: 9.2613 | Perplexity: 10522.8652


EPOCH: 15/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 15 | Avg Loss: 9.2263 | Perplexity: 10161.1113


EPOCH: 16/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 16 | Avg Loss: 9.1939 | Perplexity: 9836.5566


EPOCH: 17/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 17 | Avg Loss: 9.1599 | Perplexity: 9508.1582


EPOCH: 18/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 18 | Avg Loss: 9.1256 | Perplexity: 9187.6289


EPOCH: 19/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 19 | Avg Loss: 9.0922 | Perplexity: 8886.0293


EPOCH: 20/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 20 | Avg Loss: 9.0563 | Perplexity: 8572.3105
Saved checkpoint: ./checkpoints\epoch_20.pth


EPOCH: 21/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 21 | Avg Loss: 9.0221 | Perplexity: 8283.8828


EPOCH: 22/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 22 | Avg Loss: 8.9845 | Perplexity: 7978.3169


EPOCH: 23/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 23 | Avg Loss: 8.9493 | Perplexity: 7702.4233


EPOCH: 24/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 24 | Avg Loss: 8.9111 | Perplexity: 7413.6099


EPOCH: 25/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 25 | Avg Loss: 8.8770 | Perplexity: 7165.2754


EPOCH: 26/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 26 | Avg Loss: 8.8426 | Perplexity: 6923.2983


EPOCH: 27/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 27 | Avg Loss: 8.8113 | Perplexity: 6709.9453


EPOCH: 28/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 28 | Avg Loss: 8.7791 | Perplexity: 6497.2715


EPOCH: 29/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 29 | Avg Loss: 8.7511 | Perplexity: 6317.8359


EPOCH: 30/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 30 | Avg Loss: 8.7197 | Perplexity: 6122.5283
Saved checkpoint: ./checkpoints\epoch_30.pth


EPOCH: 31/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 31 | Avg Loss: 8.6913 | Perplexity: 5950.9102


EPOCH: 32/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 32 | Avg Loss: 8.6653 | Perplexity: 5797.9814


EPOCH: 33/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 33 | Avg Loss: 8.6379 | Perplexity: 5641.5259


EPOCH: 34/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 34 | Avg Loss: 8.6144 | Perplexity: 5510.2725


EPOCH: 35/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 35 | Avg Loss: 8.5925 | Perplexity: 5390.8315


EPOCH: 36/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 36 | Avg Loss: 8.5655 | Perplexity: 5247.6196


EPOCH: 37/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 37 | Avg Loss: 8.5431 | Perplexity: 5131.0669


EPOCH: 38/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 38 | Avg Loss: 8.5183 | Perplexity: 5005.3511


EPOCH: 39/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 39 | Avg Loss: 8.4985 | Perplexity: 4907.2329


EPOCH: 40/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 40 | Avg Loss: 8.4786 | Perplexity: 4810.8501
Saved checkpoint: ./checkpoints\epoch_40.pth


EPOCH: 41/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 41 | Avg Loss: 8.4554 | Perplexity: 4700.2095


EPOCH: 42/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 42 | Avg Loss: 8.4371 | Perplexity: 4615.2153


EPOCH: 43/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 43 | Avg Loss: 8.4149 | Perplexity: 4513.8623


EPOCH: 44/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 44 | Avg Loss: 8.3907 | Perplexity: 4405.7554


EPOCH: 45/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 45 | Avg Loss: 8.3749 | Perplexity: 4336.9912


EPOCH: 46/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 46 | Avg Loss: 8.3536 | Perplexity: 4245.2852


EPOCH: 47/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 47 | Avg Loss: 8.3393 | Perplexity: 4185.0278


EPOCH: 48/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 48 | Avg Loss: 8.3190 | Perplexity: 4100.8857


EPOCH: 49/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 49 | Avg Loss: 8.3036 | Perplexity: 4038.3279


EPOCH: 50/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 50 | Avg Loss: 8.2816 | Perplexity: 3950.5894
Saved checkpoint: ./checkpoints\epoch_50.pth


EPOCH: 51/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 51 | Avg Loss: 8.2656 | Perplexity: 3887.9578


EPOCH: 52/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 52 | Avg Loss: 8.2474 | Perplexity: 3817.5679


EPOCH: 53/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 53 | Avg Loss: 8.2324 | Perplexity: 3760.8882


EPOCH: 54/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 54 | Avg Loss: 8.2118 | Perplexity: 3684.3176


EPOCH: 55/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 55 | Avg Loss: 8.1959 | Perplexity: 3625.9871


EPOCH: 56/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 56 | Avg Loss: 8.1765 | Perplexity: 3556.2407


EPOCH: 57/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 57 | Avg Loss: 8.1624 | Perplexity: 3506.7537


EPOCH: 58/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 58 | Avg Loss: 8.1476 | Perplexity: 3455.1270


EPOCH: 59/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 59 | Avg Loss: 8.1271 | Perplexity: 3384.9656


EPOCH: 60/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 60 | Avg Loss: 8.1129 | Perplexity: 3337.2732
Saved checkpoint: ./checkpoints\epoch_60.pth


EPOCH: 61/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 61 | Avg Loss: 8.0982 | Perplexity: 3288.5994


EPOCH: 62/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 62 | Avg Loss: 8.0825 | Perplexity: 3237.1975


EPOCH: 63/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 63 | Avg Loss: 8.0669 | Perplexity: 3187.3193


EPOCH: 64/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 64 | Avg Loss: 8.0551 | Perplexity: 3149.8105


EPOCH: 65/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 65 | Avg Loss: 8.0358 | Perplexity: 3089.6655


EPOCH: 66/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 66 | Avg Loss: 8.0238 | Perplexity: 3052.8691


EPOCH: 67/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 67 | Avg Loss: 8.0091 | Perplexity: 3008.2544


EPOCH: 68/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 68 | Avg Loss: 7.9968 | Perplexity: 2971.4807


EPOCH: 69/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 69 | Avg Loss: 7.9841 | Perplexity: 2933.9241


EPOCH: 70/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 70 | Avg Loss: 7.9681 | Perplexity: 2887.3774
Saved checkpoint: ./checkpoints\epoch_70.pth


EPOCH: 71/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 71 | Avg Loss: 7.9543 | Perplexity: 2847.7043


EPOCH: 72/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 72 | Avg Loss: 7.9342 | Perplexity: 2791.0259


EPOCH: 73/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 73 | Avg Loss: 7.9241 | Perplexity: 2763.1052


EPOCH: 74/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 74 | Avg Loss: 7.9133 | Perplexity: 2733.5027


EPOCH: 75/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 75 | Avg Loss: 7.8982 | Perplexity: 2692.4048


EPOCH: 76/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 76 | Avg Loss: 7.8975 | Perplexity: 2690.4246


EPOCH: 77/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 77 | Avg Loss: 7.8742 | Perplexity: 2628.7036


EPOCH: 78/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 78 | Avg Loss: 7.8631 | Perplexity: 2599.4973


EPOCH: 79/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 79 | Avg Loss: 7.8534 | Perplexity: 2574.3445


EPOCH: 80/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 80 | Avg Loss: 7.8404 | Perplexity: 2541.0957
Saved checkpoint: ./checkpoints\epoch_80.pth


EPOCH: 81/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 81 | Avg Loss: 7.8259 | Perplexity: 2504.5559


EPOCH: 82/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 82 | Avg Loss: 7.8169 | Perplexity: 2482.1472


EPOCH: 83/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 83 | Avg Loss: 7.8051 | Perplexity: 2453.1848


EPOCH: 84/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 84 | Avg Loss: 7.7846 | Perplexity: 2403.3936


EPOCH: 85/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 85 | Avg Loss: 7.7813 | Perplexity: 2395.4236


EPOCH: 86/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 86 | Avg Loss: 7.7645 | Perplexity: 2355.3628


EPOCH: 87/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 87 | Avg Loss: 7.7533 | Perplexity: 2329.3088


EPOCH: 88/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 88 | Avg Loss: 7.7457 | Perplexity: 2311.7231


EPOCH: 89/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 89 | Avg Loss: 7.7327 | Perplexity: 2281.7043


EPOCH: 90/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 90 | Avg Loss: 7.7240 | Perplexity: 2262.0293
Saved checkpoint: ./checkpoints\epoch_90.pth


EPOCH: 91/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 91 | Avg Loss: 7.7133 | Perplexity: 2237.9712


EPOCH: 92/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 92 | Avg Loss: 7.7041 | Perplexity: 2217.3408


EPOCH: 93/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 93 | Avg Loss: 7.6901 | Perplexity: 2186.4863


EPOCH: 94/200:   0%|          | 0/594 [00:00<?, ?it/s]

Epoch 94 | Avg Loss: 7.6792 | Perplexity: 2162.8975


EPOCH: 95/200:   0%|          | 0/594 [00:00<?, ?it/s]

In [None]:
torch.save(model.state_dict(), 'outputs/model_state.pth')

In [None]:
torch.save(model, 'outputs/ja_en_transformer_translator_model.pth')

In [None]:
def translate(model, ja_sentence, ja_tokenizer=ja_tokenizer, en_tokenizer=en_tokenizer, max_len=80, device=device):
    model.eval()
    ja_ids = ja_tokenizer.encode(ja_sentence, add_special_tokens=False, return_tensors="pt").to(torch.long).to(device)
    dec_input = torch.tensor([[en_tokenizer.cls_token_id]], device=device)
    preds = []
    with torch.no_grad():
        for i in range(max_len):
            outputs = model(ja_ids, dec_input)
            next_token = outputs[:, -1, :].argmax(-1)
            next_token_id = next_token.item()
            
            preds.append(next_token_id)
            dec_input = torch.cat([dec_input, next_token.unsqueeze(1)], dim=1)
            
            if next_token_id == en_tokenizer.sep_token_id:
                break
    tokens = en_tokenizer.decode(dec_input, skip_special_tokens=True)
    print(translation)

In [None]:
model = Transformer(d_model=D_MODEL,
                    ja_vocab_size=JA_VOCAB_SIZE,
                    en_vocab_size=EN_VOCAB_SIZE,
                    max_seq_len=MAX_SEQ_LEN, 
                    n_heads=N_HEADS,
                    dropout=DROPOUT,
                   n_layers=N_LAYERS)
                   
model.load_state_dict(torch.load('projects/jp-eng_machine_translation/outputs/ja_en_transformer_translator_model.pt'))

In [None]:
os.getcwd()

In [None]:
translate(model, "親友は、私のいいところも悪いところも全部ひっくるめて、受け入れてくれている。", max_len=80)