In [1]:
%load_ext autoreload
%autoreload 2

## initialization

In [2]:
from my_train import create_training_dataset, load_vocabulary
from model import Transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse

args = argparse.Namespace()
args.src_vocab = '../corpus/vocab_en_fr.txt'
args.tgt_vocab = '../corpus/vocab_en_fr.txt'
args.device = 'cpu'

source_vocabulary, _ = load_vocabulary(args.src_vocab)
target_vocabulary, target_vocabulary_rev = load_vocabulary(args.tgt_vocab)

bos = target_vocabulary["<s>"]
eos = target_vocabulary["</s>"]

model = Transformer(
    len(source_vocabulary),
    len(target_vocabulary),
    share_embeddings=True,
)

model.to(args.device)
model.train()

n = 0
for param in model.parameters():
    n += param.numel()
print(n)

source_path = './corpus/rep_test.en.tok'
target_path = './corpus/rep_test.fr.tok'

# source_path = '../corpus/train.en.tok'
# target_path = '../corpus/train.fr.tok'

batch_type = "tokens"
# batch_type is the batch_size unit, for tokens we need to give larger one
batch_size = 1200
effective_batch_size = 2400

label_smoothing = 0.1
padding_idx = 0

max_source_len = 150
max_target_len = 150
world_size = 1

accum_steps = (
    effective_batch_size // (batch_size * world_size)
    if effective_batch_size is not None
    else 1
)

dataset = create_training_dataset(
    source_path,
    target_path,
    source_vocabulary,
    target_vocabulary,
    batch_size=batch_size,
    batch_type=batch_type,
    maximum_source_length=max_source_len,
    maximum_target_length=max_target_len,
    device=args.device,
    num_accum_batches=accum_steps,
)

209129472


In [3]:
accum_steps

2

In [4]:
batches = next(iter(dataset))

2023-Jul-09 09:37:58.618000 UTC [dataset@SpawnProcess-1] INFO Shuffling 739 elements
2023-Jul-09 09:37:58.660000 UTC [dataset@SpawnProcess-1] INFO Shuffling 739 elements


In [5]:
len(batches)

2

In [6]:
batch = batches[0]
s, ti, to = batch

In [7]:
s.shape

torch.Size([8, 76])

In [8]:
s.numel()

608

In [9]:
logits = model(s, ti)

In [10]:
logits.shape

torch.Size([8, 77, 32000])

In [11]:
logits.sum(dim=2)

tensor([[-1.1177e+01,  5.6754e+00,  2.0808e+01,  5.3710e+00,  5.8044e+00,
         -1.6948e+01, -2.1228e+01, -1.1452e+01, -2.0809e+01, -1.5157e+01,
         -4.9674e+01, -1.0409e+01, -1.6564e+01,  7.3628e+00, -1.3288e+01,
         -4.0528e+01,  8.6779e+00, -1.5255e+00,  1.8660e+01, -1.9715e+01,
         -5.1162e+01, -4.9488e+01, -1.7580e+00,  6.2387e+00, -8.0234e+00,
         -7.8371e-01, -1.8518e+01, -1.3783e+01, -8.3405e+00, -5.0241e+00,
         -4.7232e+00, -1.1336e+01, -4.4768e+01, -1.3607e+01, -4.1738e+01,
         -3.3142e+01, -3.3414e+00, -2.8888e+01, -1.5144e+01, -2.0982e+01,
         -4.5710e+01,  1.5156e+01, -1.8768e+01,  2.0997e+00, -2.0527e+01,
          3.8572e+00, -5.5131e+01, -1.1827e+01, -2.0391e+00, -1.4929e+01,
          3.3819e+00, -4.3766e+01, -3.7580e+01, -4.7864e+01, -2.6709e+01,
         -2.5742e+01, -1.3922e+01, -4.9000e+01, -1.7419e+01, -4.2266e+01,
         -9.4362e+00, -3.8519e+01, -9.0500e+00, -3.5786e+01, -4.9980e+00,
         -4.3623e+01,  1.5637e+01, -2.

In [12]:
# %%timeit -n 100
l = F.cross_entropy(
    input=logits.view(-1, logits.shape[-1]),
    target=to.view(-1),
    reduction='sum',
    # weight=weight,
    label_smoothing=label_smoothing
)
l

tensor(6404.0410, grad_fn=<AddBackward0>)

In [13]:
logits.shape

torch.Size([8, 77, 32000])

In [14]:
penalty_mask = torch.load('./penalty_mask.pt')

In [15]:
input = torch.tensor([[0.1, 0.2, 0.3], [0.2, 0.1, 0.3], [0.2, 0.1, 0.3]])
target = torch.tensor([[2*0.8, 0.1, 0.1], [2*0.1, 0.8, 0.1], [
                      2*0.1, 0.1, 0.8]], dtype=torch.float32)
target2 = torch.tensor([0, 1, 2])
print(F.cross_entropy(input, target, reduction='sum'))
# print(F.cross_entropy(input, target2, label_smoothing=0.2))
print(F.cross_entropy(input, target2, reduction='sum', label_smoothing=0.3,
      weight=torch.tensor([2, 1, 1], dtype=torch.float32)))

tensor(4.5578)
tensor(4.5578)


In [16]:
input.shape, target.shape

(torch.Size([3, 3]), torch.Size([3, 3]))

In [17]:
from my_train import ce_loss_with_rep_penalty

In [28]:
%%timeit -n 1
ce_loss_with_rep_penalty(logits, to, penalty_mask, 0.1, 0.1, device='cpu')

1.05 s ± 7.69 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [29]:
logits.shape

torch.Size([8, 77, 32000])

In [29]:
logits.shape, to.shape

(torch.Size([16, 57, 32000]), torch.Size([16, 57]))

In [26]:
%%timeit -n 10
F.cross_entropy(logits.view(-1, 32000), to.view(-1),
                label_smoothing=0.1, reduction='sum')

23.3 ms ± 490 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
