In [1]:
%load_ext autoreload
%autoreload 2

## initialization

In [59]:
from my_train import create_training_dataset, load_vocabulary
from model import Transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse

args = argparse.Namespace()
args.src_vocab = '../corpus/vocab_en_fr.txt'
args.tgt_vocab = '../corpus/vocab_en_fr.txt'
args.device = 'cpu'

source_vocabulary, _ = load_vocabulary(args.src_vocab)
target_vocabulary, target_vocabulary_rev = load_vocabulary(args.tgt_vocab)

bos = target_vocabulary["<s>"]
eos = target_vocabulary["</s>"]

model = Transformer(
    len(source_vocabulary),
    len(target_vocabulary),
    share_embeddings=True,
)

model.to(args.device)
model.train()

n = 0
for param in model.parameters():
    n += param.numel()
print(n)

source_path = './corpus/rep_test.en.tok'
target_path = './corpus/rep_test.fr.tok'

# source_path = '../corpus/train.en.tok'
# target_path = '../corpus/train.fr.tok'

batch_type = "tokens"
# batch_type is the batch_size unit, for tokens we need to give larger one
batch_size = 1200
effective_batch_size = 2400

label_smoothing = 0.1
padding_idx = 0

max_source_len = 150
max_target_len = 150
world_size = 1

accum_steps = (
    effective_batch_size // (batch_size * world_size)
    if effective_batch_size is not None
    else 1
)

dataset = create_training_dataset(
    source_path,
    target_path,
    source_vocabulary,
    target_vocabulary,
    batch_size=batch_size,
    batch_type=batch_type,
    maximum_source_length=max_source_len,
    maximum_target_length=max_target_len,
    device=args.device,
    num_accum_batches=accum_steps,
)

209129472


In [61]:
accum_steps

2

In [62]:
batches = next(iter(dataset))

2023-Jul-07 14:01:27.278000 UTC [dataset@SpawnProcess-10] INFO Shuffling 739 elements
2023-Jul-07 14:01:27.308000 UTC [dataset@SpawnProcess-10] INFO Shuffling 739 elements


In [63]:
len(batches)

2

In [67]:
batch = batches[0]
s, ti, to = batch

In [68]:
s.shape

torch.Size([8, 76])

In [69]:
s.numel()

608

In [70]:
logits = model(s, ti)

In [71]:
logits.shape

torch.Size([8, 77, 32000])

In [77]:
logits.sum(dim=2)

tensor([[ 1.0671e+01, -3.8562e+01,  1.1161e+01,  2.8411e+01,  2.2927e+01,
          1.0163e-01,  4.0305e+01,  7.4982e+00, -5.0273e+01, -4.1213e+01,
         -1.3429e+01, -7.7591e+00, -2.7842e+01,  2.6610e+01, -2.5975e+01,
         -3.9492e+00,  4.0974e+00,  8.8947e+00, -4.1727e+01, -1.4730e+01,
         -6.0218e+00, -3.2468e+01,  1.9227e+01, -8.0778e+00, -1.3399e+01,
         -3.4041e+01,  9.2794e+00, -3.2463e+01, -9.1612e+00, -4.5999e+01,
         -3.7868e+01, -1.5562e+01,  1.3546e+01, -3.7701e+01, -1.7544e+01,
          2.2611e+01,  8.0529e+00, -6.9963e+00,  4.8505e-01, -1.1369e+01,
         -1.6192e+01,  1.5604e+00, -2.0238e+01, -5.1166e+01,  1.4958e+01,
         -9.6400e+00, -2.5776e+00, -9.1963e+00, -1.5836e+01, -9.5776e-01,
         -2.2288e+00,  1.3019e+00, -1.9664e+01,  1.0099e+01, -9.1341e+00,
         -1.0610e+01, -1.3550e+01, -3.8314e+00, -1.2875e-01, -6.4148e+01,
         -4.4209e+01, -5.5544e+01, -3.2088e+01, -2.6983e+01, -1.4732e+00,
         -9.7586e+00, -1.2993e+01, -2.

In [7]:
mask = torch.load('./penalty_mask.pt')
mask.sum()

tensor(21762.)

In [None]:
weight = mask

In [103]:
%%timeit -n 100
l = F.cross_entropy(
    input=logits.view(-1, logits.shape[-1]),
    target=to.view(-1),
    reduction='sum',
    # weight=weight,
    label_smoothing=label_smoothing
)

8 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [85]:
logits.shape

torch.Size([8, 77, 32000])

In [89]:
penalty_mask = torch.load('./penalty_mask.pt')

In [90]:
from my_train import ce_loss_with_rep_penalty

In [105]:
%%timeit -n 10
ce_loss_with_rep_penalty(logits, to, penalty_mask, 0.1, 0.1)

182 ms ± 8.07 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
