In [1]:
%matplotlib inline

In [2]:
import os
import torch

In [3]:
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
print(device)

cuda:5


In [4]:
%load_ext autoreload
%autoreload 2

# load data

In [5]:
from transformer_translation.dataset import ParallelLanguageDataset
from torch.utils.data import DataLoader

In [6]:
data_path = r"/home/alex/data/nlp/agmir/transf_processed_data"
#data_path = 'transformer_translation/data/processed'

In [7]:
num_tokens = 2000
max_seq_length = 96
dataset = ParallelLanguageDataset(
    os.path.join(data_path, 'tags/set.pkl')#'en/train.pkl')#
    ,os.path.join(data_path, 'reports/set.pkl')#'fr/train.pkl')#
    ,num_tokens
    ,max_seq_length)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

# train

In [8]:
from transformer_translation.model import LanguageTransformer

from einops import rearrange
def gen_nopeek_mask(length):
    mask = rearrange(torch.triu(torch.ones(length, length)) == 1, 'h w -> w h')
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

    return mask

In [9]:
vocab_size = 10000 + 4
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
pos_dropout = 0.1
trans_dropout = 0.1

model = LanguageTransformer(
    vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
    max_seq_length, pos_dropout, trans_dropout
).to(device)

In [17]:
    from transformer_translation.Optim import ScheduledOptim
    import torch.nn as nn
    from torch.optim import Adam
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_normal_(p)

    n_warmup_steps = 4000
    optim = ScheduledOptim(
        Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09),
        d_model, n_warmup_steps)

    criterion = nn.CrossEntropyLoss(ignore_index=0)

In [18]:
    %%time
    print_every = 15
    num_epochs = 2
    model.train()

    lowest_val = 1e9
    val_losses = []
    total_step = 0
    
    for epoch in range(num_epochs):
        
        total_loss = 0

        for step, (src, src_key_padding_mask, tgt, tgt_key_padding_mask) in enumerate(iter(loader)):
            total_step += 1

            src, src_key_padding_mask = src[0].to(device), src_key_padding_mask[0].to(device)
            tgt, tgt_key_padding_mask = tgt[0].to(device), tgt_key_padding_mask[0].to(device)

            memory_key_padding_mask = src_key_padding_mask.clone()
            tgt_inp, tgt_out = tgt[:, :-1], tgt[:, 1:]
            tgt_mask = gen_nopeek_mask(tgt_inp.shape[1]).to(device)

            optim.zero_grad()
            outputs = model(src, tgt_inp, src_key_padding_mask, tgt_key_padding_mask[:, :-1], memory_key_padding_mask, tgt_mask)
            loss = criterion(rearrange(outputs, 'b t v -> (b t) v'), rearrange(tgt_out, 'b o -> (b o)'))

            loss.backward()
            optim.step_and_update_lr()

            total_loss += loss.item()
            if step % print_every == print_every - 1:
                print(f'Epoch [{epoch + 1} / {num_epochs}] \t Step [{step + 1} / {len(loader)}] \t '
                      f'Train Loss: {total_loss / print_every}')
                total_loss = 0

Epoch [1 / 2] 	 Step [10 / 92] 	 Train Loss: 9.223678493499756
Epoch [1 / 2] 	 Step [20 / 92] 	 Train Loss: 9.07329568862915
Epoch [1 / 2] 	 Step [30 / 92] 	 Train Loss: 8.81158094406128
Epoch [1 / 2] 	 Step [40 / 92] 	 Train Loss: 8.494706916809083
Epoch [1 / 2] 	 Step [50 / 92] 	 Train Loss: 8.2903902053833
Epoch [1 / 2] 	 Step [60 / 92] 	 Train Loss: 8.098780965805053
Epoch [1 / 2] 	 Step [70 / 92] 	 Train Loss: 7.920980644226074
Epoch [1 / 2] 	 Step [80 / 92] 	 Train Loss: 7.8710309028625485
Epoch [1 / 2] 	 Step [90 / 92] 	 Train Loss: 7.702437925338745
Epoch [2 / 2] 	 Step [10 / 92] 	 Train Loss: 7.611214828491211
Epoch [2 / 2] 	 Step [20 / 92] 	 Train Loss: 7.4070210456848145
Epoch [2 / 2] 	 Step [30 / 92] 	 Train Loss: 7.286341428756714
Epoch [2 / 2] 	 Step [40 / 92] 	 Train Loss: 7.071765279769897
Epoch [2 / 2] 	 Step [50 / 92] 	 Train Loss: 6.956470012664795
Epoch [2 / 2] 	 Step [60 / 92] 	 Train Loss: 6.735355424880981
Epoch [2 / 2] 	 Step [70 / 92] 	 Train Loss: 6.5008950710

In [22]:
len(dataset.data_1)

3580

In [27]:
src.shape, src_key_padding_mask.shape, tgt.shape, tgt_key_padding_mask.shape
#src_key_padding_mask, tgt, tgt_key_padding_mask

(torch.Size([1, 27, 3]),
 torch.Size([1, 27, 3]),
 torch.Size([1, 27, 67]),
 torch.Size([1, 27, 67]))

In [57]:
(src, src_key_padding_mask, tgt, tgt_key_padding_mask) = next(iter(loader))
src.shape, tgt.shape

(torch.Size([1, 30, 3]), torch.Size([1, 30, 95]))