### Sinkhorn Transformer

Dopo aver installato le dipendenze con: 
- conda create -n nlp-hands-on-3 python=3.8
- pip install -r requirements.txt

Vediamo in azione il nostro Sinkhorn Transformer.

In [1]:
import os
import tqdm
import torch
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer import Autopadder
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper

In [5]:

model = SinkhornTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    bucket_size = 128,        # size of the buckets
    causal = False,           # auto-regressive or not
    n_sortcut = 2,            # use sortcut to reduce memory complexity to linear
    n_top_buckets = 2,        # sort specified number of key/value buckets to one query bucket. paper is at 1, defaults to 2
    ff_chunks = 10,           # feedforward chunking, from Reformer paper
    reversible = True,        # make network reversible, from Reformer paper
    emb_dropout = 0.1,        # embedding dropout
    ff_dropout = 0.1,         # feedforward dropout
    attn_dropout = 0.1,       # post attention dropout
    attn_layer_dropout = 0.1, # post attention layer dropout
    layer_dropout = 0.1,      # add layer dropout, from 'Reducing Transformer Depth on Demand' paper
    weight_tie = True,        # tie layer parameters, from Albert paper
    emb_dim = 128,            # embedding factorization, from Albert paper
    dim_head = 64,            # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    ff_glu = True,            # use GLU in feedforward, from paper 'GLU Variants Improve Transformer'
    n_local_attn_heads = 2,   # replace N heads with local attention, suggested to work well from Routing Transformer paper
    pkm_layers = (4,7),       # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,       # defaults to 128, but can be increased to 256 or 512 as memory allows
)

x = torch.randint(0, 20000, (1, 2048))
model(x) # (1, 2048, 20000)

tensor([[[ 0.1955, -0.4047, -0.3095,  ..., -0.7498, -1.0335,  0.2164],
         [-0.1427, -0.6244, -0.4162,  ..., -1.0227, -1.0701,  1.0770],
         [ 0.1593,  0.1160, -0.0282,  ..., -0.2226, -0.8753,  0.9878],
         ...,
         [ 0.0035, -0.5261,  0.9306,  ...,  0.5631, -0.1365,  0.3768],
         [ 0.2647, -1.1849,  0.6168,  ...,  0.3333, -0.5049,  0.2688],
         [ 0.0386, -0.7799,  0.5175,  ...,  0.2789,  0.7087, -0.3994]]],
       grad_fn=<AddBackward0>)

In [2]:
tr_en_file = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en"
tr_de_file = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de"

ts_en_file = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2015.en"
ts_de_file = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2015.de"

if not os.path.exists("./train.en"):
    !wget $tr_en_file
if not os.path.exists("./train.de"):
    !wget $tr_de_file
if not os.path.exists("./newstest2015.en"):
    !wget $ts_en_file
if not os.path.exists("./newstest2015.de"):
    !wget $ts_de_file

In [3]:
def construct_vocab(dataset):
    char_to_id = {
        "SOS": 0,
        "EOS": 1,
        "PAD": 2
    }
    id_to_char = {
        0: "SOS",
        1: "EOS",
        2: "PAD"
    }
    for line in tqdm.tqdm(dataset):
        for tk in line.lower():
            idx = len(char_to_id.keys())
            if tk not in char_to_id:
                char_to_id[tk] = idx
                id_to_char[idx] = tk
    return char_to_id, id_to_char

train_en = open("./train.en").readlines()
char_to_id_en, id_to_char_en = construct_vocab(train_en)
print(f"Totale vocabulary en: {len(char_to_id_en)}")

train_de = open("./train.de").readlines()
char_to_id_de, id_to_char_de = construct_vocab(train_de)
print(f"Totale vocabulary de: {len(char_to_id_de)}")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4468840/4468840 [01:05<00:00, 68303.72it/s]


Totale vocabulary en: 2805


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4468840/4468840 [01:14<00:00, 59694.24it/s]

Totale vocabulary de: 2505





In [4]:
SEQ_LEN = 4096
SEQ_LEN = 4096

BATCH_SIZE = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512

In [5]:
enc = SinkhornTransformerLM(
    num_tokens = len(char_to_id_en),
    dim = 512,
    depth = 6,
    heads = 8,
    bucket_size = 128,
    max_seq_len = SEQ_LEN,
    reversible = True,
    return_embeddings = True
)#.cuda()

enc = AutoregressiveWrapper(enc)

#enc = Autopadder(enc, pad_left=True)

In [6]:
dec = SinkhornTransformerLM(
    num_tokens = len(char_to_id_de),
    dim = 512,
    depth = 6,
    causal = True,
    bucket_size = 128,
    max_seq_len = SEQ_LEN,
    receives_context = True,
    context_bucket_size = 128,  # context key / values can be bucketed differently
    reversible = True
)#.cuda()

dec = AutoregressiveWrapper(dec)

#dec = Autopadder(dec, pad_left=True)

In [7]:
enc_optim = torch.optim.Adam(enc.parameters(), lr=LEARNING_RATE)
dec_optim = torch.optim.Adam(dec.parameters(), lr=LEARNING_RATE)

In [8]:
def create_tensor(sentence, char_to_id):
    tokens = [char_to_id["SOS"]]
    for char in sentence:
        if len(tokens) < (SEQ_LEN-1):
            tokens.append(char_to_id[char])
    tokens.append(char_to_id["EOS"])

    while len(tokens) < SEQ_LEN:
        tokens.append(char_to_id["PAD"])
    
    return tokens

In [9]:
for i in tqdm.tqdm(range(0,len(train_en), BATCH_SIZE), mininterval=10., desc='training'):
    en_input = [] # Batch!
    de_input = []
    for j in range(BATCH_SIZE):
        en_sentence = train_en[i+j].lower()
        en_tokens = create_tensor(sentence=en_sentence, char_to_id=char_to_id_en)
        en_input.append(en_tokens)
        
        de_sentence = train_de[i+j].lower()
        de_tokens = create_tensor(sentence=de_sentence, char_to_id=char_to_id_de)
        de_input.append(de_tokens)
        
    en_input = torch.tensor(en_input)
    de_input = torch.tensor(de_input)
    
    en_mask = torch.ones_like(en_input).bool()#.cuda()
    de_mask = torch.ones_like(de_input).bool()#.cuda()
        
    enc.train()
    dec.train()
    
    context = enc(en_input, input_mask=en_mask)
    loss = dec(de_input, context=context, input_mask=de_mask, context_mask=en_mask, return_loss = True) 
    loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(enc.parameters(), 0.5)
    torch.nn.utils.clip_grad_norm_(dec.parameters(), 0.5)
    enc_optim.step()
    enc_optim.zero_grad()
    dec_optim.step()
    dec_optim.zero_grad()
    
    if i > 24:
        break

training:   0%|                                                                                                                   | 1/1117210 [02:49<52664:25:57, 169.70s/it]

training loss: 7.434681415557861


training:   0%|                                                                                                                   | 2/1117210 [05:34<51709:35:15, 166.62s/it]

training loss: 5.121203422546387


training:   0%|                                                                                                                   | 3/1117210 [07:47<47044:15:45, 151.59s/it]

training loss: 3.3080813884735107


training:   0%|                                                                                                                   | 4/1117210 [10:01<44867:20:56, 144.58s/it]

training loss: 1.6939918994903564


training:   0%|                                                                                                                   | 5/1117210 [11:49<40706:12:15, 131.17s/it]

training loss: 0.8909008502960205


training:   0%|                                                                                                                   | 6/1117210 [13:33<37871:51:23, 122.04s/it]

training loss: 0.6237004399299622
training loss: 0.25593942403793335


training:   0%|                                                                                                                   | 7/1117210 [15:35<37888:44:03, 122.09s/it]

training loss: 0.3690425753593445


training:   0%|                                                                                                                   | 7/1117210 [18:54<50301:56:57, 162.09s/it]
