### Reformer

Dopo aver installato le dipendenze con: 
- conda create -n nlp-hands-on-3 python=3.8
- pip install -r requirements.txt

Vediamo in azione il nostro Reformer.

In [1]:
from reformer_pytorch import ReformerLM
from reformer_pytorch.generative_tools import TrainingWrapper
from reformer_pytorch import ReformerEncDec

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

import os

In [2]:
# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 512
SEQ_LEN = 4096



In [3]:
# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))



In [4]:
# instantiate model

model = ReformerLM(
    dim = 512,
    depth = 6,
    max_seq_len = SEQ_LEN,
    num_tokens = 256,
    heads = 8,
    bucket_size = 64,
    n_hashes = 4,
    ff_chunks = 10,
    lsh_dropout = 0.1,
    weight_tie = True,
    causal = True,
    n_local_attn_heads = 4,
    use_full_attn = False # set this to true for comparison with full attention
)

model = TrainingWrapper(model)
#model.cuda()



In [7]:
# prepare enwik8 data

with gzip.open('./data/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    trX, vaX = np.split(X, [int(90e6)])
    data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq#.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))



  X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)


In [8]:
# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

training:   0%|                                                                                                     | 0/100000 [00:00<?, ?it/s]

training loss: 5.712687969207764
validation loss: 5.393490314483643
%s 

 %s (' Database and HFR Industry Reports] *[http://icf.som.yale.edu/research/hedgefund.shtml Hedge Fund Research Initiative] of the International Center for Finance at the [[Yale School of Management]] *[http://www.hedgefund.net/HFN_Averages_January_06_Report.pdf/ HFN Averages January Performance Report] *[http://www.hedgefund.net/Strategy_Focus_Report_022806.pdf/ HedgeFund.net Strategy Focus Report: HFN Small/Micro Cap Average]  {{finance-footer}}  [[category:Funds]]  [[de:Hedge-Fonds]] [[fr:Gestion alternative]] [[ja:&amp;#12504;&amp;#12483;&amp;#12472;&amp;#12501;&amp;#12449;&amp;#12531;&amp;#12489;]] [[zh:å¯¹å\x86²å\x9fºé\x87\x91]]</text>     </revision>   </page>   <page>     <title>Hydrocodone</title>     <id>14413</id>     <revision>       <id>42147808</id>       <timestamp>2006-03-04T03:21:19Z</timestamp>       <contributor>         <username>Shanel</username>         <id>301280</id>       </contributor>  

training:   0%|                                                                                     | 1/100000 [16:42<27843:56:04, 1002.39s/it]

ºµhïäK|ÆÖ   ¯'® =oY  é^'<C5 Çü}ñsÏX:Ï côuñþ¦  ®o+n"æÏ{%±éxx¯'®®®  0Ñá ã©ÄêÂÑÿ    x  Fm^¹òt Ê º 6à]uÖ ÙtÖ¸xÂ33í× äRXÔ ½v{'cáè0¹òsýLÞÊ ½á=Ää¸e%o0c (R¡nvÐ  ShaÍs38ÿäf e´ ¥W|°Â®xÂÖûrá 5GV:QhbFv{>o  7 ÅÁ8à ø-ºîéèNøÈ¢tjhlæE=?QkC ¾ ¤f éþ»  KÊÂ3 ®ß¹¬þ2ÚÖ] ei¯¾äsn %÷u Ñ K|GXèÐ7#5àÒXnád ¤m}$¡nr®oY _7¿°i  7 ~?¼sWàÑáÂéxQ Ü^Ye~u­ fÄöçÂÔ¿i³ÜS³÷ µ Î=9Nì~g«%ÏËýLÎÏÊ%¦<{2ïèÖeF8 o0U\0~Và÷ cK WÄêªeæ _¯Ý(õµÖÌ²¸e[½Å©ªmïX ýdLþÊg]% sÏ ½;  }cXSoÓó/sW+E  K×g°X·C ÓóÂo©


training:   0%|                                                                                      | 2/100000 [18:58<13693:17:37, 492.97s/it]

training loss: 5.367817401885986


training:   0%|                                                                                       | 3/100000 [21:13<9151:57:14, 329.48s/it]

training loss: 4.965239524841309


training:   0%|                                                                                       | 4/100000 [23:28<7019:08:04, 252.70s/it]

training loss: 4.744425296783447


training:   0%|                                                                                       | 5/100000 [25:43<5835:31:36, 210.09s/it]

training loss: 4.450093746185303


training:   0%|                                                                                       | 6/100000 [27:56<5113:25:02, 184.09s/it]

training loss: 4.1302666664123535


training:   0%|                                                                                       | 7/100000 [30:12<4670:46:10, 168.16s/it]

training loss: 4.087907791137695


training:   0%|                                                                                       | 8/100000 [32:26<4372:49:29, 157.43s/it]

training loss: 3.7613089084625244


training:   0%|                                                                                       | 9/100000 [34:43<4192:36:02, 150.95s/it]

training loss: 3.7313404083251953


training:   0%|                                                                                      | 10/100000 [36:56<4040:06:04, 145.46s/it]

training loss: 3.635711669921875


training:   0%|                                                                                      | 11/100000 [39:10<3944:32:41, 142.02s/it]

training loss: 3.606860876083374


training:   0%|                                                                                      | 12/100000 [41:27<3896:48:41, 140.30s/it]

training loss: 3.6015586853027344


training:   0%|                                                                                      | 13/100000 [43:46<3889:23:15, 140.04s/it]

training loss: 3.5999655723571777


training:   0%|                                                                                      | 14/100000 [46:07<3896:12:07, 140.28s/it]

training loss: 3.521846294403076


training:   0%|                                                                                      | 15/100000 [48:25<3881:57:20, 139.77s/it]

training loss: 3.5089263916015625


training:   0%|                                                                                      | 16/100000 [50:45<3879:41:20, 139.69s/it]

training loss: 3.380735158920288


training:   0%|                                                                                      | 17/100000 [53:05<3879:22:19, 139.68s/it]

training loss: 3.2193524837493896


training:   0%|                                                                                      | 18/100000 [55:24<3877:12:49, 139.60s/it]

training loss: 3.167842388153076


training:   0%|                                                                                      | 19/100000 [57:44<3877:35:12, 139.62s/it]

training loss: 3.2234718799591064


training:   0%|                                                                                    | 20/100000 [1:00:03<3873:05:08, 139.46s/it]

training loss: 3.238429307937622


training:   0%|                                                                                    | 21/100000 [1:02:23<3877:51:44, 139.63s/it]

training loss: 3.3115792274475098


training:   0%|                                                                                    | 22/100000 [1:04:42<3875:30:30, 139.55s/it]

training loss: 3.2402753829956055


training:   0%|                                                                                    | 23/100000 [1:07:02<3879:40:47, 139.70s/it]

training loss: 3.3628571033477783


training:   0%|                                                                                    | 24/100000 [1:09:23<3886:12:58, 139.94s/it]

training loss: 3.3242132663726807


training:   0%|                                                                                    | 25/100000 [1:11:42<3876:27:40, 139.59s/it]

training loss: 3.1888089179992676


training:   0%|                                                                                    | 26/100000 [1:14:00<3869:47:05, 139.35s/it]

training loss: 3.3380186557769775


training:   0%|                                                                                    | 27/100000 [1:16:20<3874:32:03, 139.52s/it]

training loss: 3.503314256668091


training:   0%|                                                                                    | 28/100000 [1:18:40<3879:00:35, 139.68s/it]

training loss: 3.2019283771514893


training:   0%|                                                                                    | 29/100000 [1:21:01<3883:24:08, 139.84s/it]

training loss: 3.08240008354187


training:   0%|                                                                                    | 30/100000 [1:23:20<3878:31:02, 139.67s/it]

training loss: 3.000864267349243


training:   0%|                                                                                    | 31/100000 [1:25:38<3869:06:05, 139.33s/it]

training loss: 3.0344667434692383


training:   0%|                                                                                    | 32/100000 [1:27:57<3859:51:12, 139.00s/it]

training loss: 2.973862409591675


training:   0%|                                                                                    | 33/100000 [1:30:18<3876:48:38, 139.61s/it]

training loss: 2.9405009746551514


training:   0%|                                                                                    | 34/100000 [1:32:39<3888:30:18, 140.03s/it]

training loss: 2.9729928970336914


training:   0%|                                                                                    | 35/100000 [1:34:59<3890:33:46, 140.11s/it]

training loss: 3.010070323944092


training:   0%|                                                                                    | 36/100000 [1:37:19<3892:41:05, 140.19s/it]

training loss: 3.2033281326293945


training:   0%|                                                                                    | 37/100000 [1:39:39<3888:53:18, 140.05s/it]

training loss: 2.9056718349456787


training:   0%|                                                                                    | 38/100000 [1:41:55<3852:56:46, 138.76s/it]

training loss: 2.812208890914917


training:   0%|                                                                                    | 39/100000 [1:44:09<3812:59:28, 137.32s/it]

training loss: 2.9044740200042725


training:   0%|                                                                                    | 40/100000 [1:46:23<3787:15:25, 136.40s/it]

training loss: 2.8990018367767334


training:   0%|                                                                                    | 41/100000 [1:48:39<3784:15:39, 136.29s/it]

training loss: 2.8606436252593994


training:   0%|                                                                                    | 42/100000 [1:50:54<3777:31:09, 136.05s/it]

training loss: 3.040532350540161


training:   0%|                                                                                    | 43/100000 [1:53:10<3770:21:51, 135.79s/it]

training loss: 2.940197706222534


training:   0%|                                                                                    | 44/100000 [1:55:24<3760:46:17, 135.45s/it]

training loss: 2.8159096240997314


training:   0%|                                                                                    | 45/100000 [1:57:39<3754:37:59, 135.23s/it]

training loss: 2.8694612979888916


training:   0%|                                                                                    | 46/100000 [1:59:53<3745:45:49, 134.91s/it]

training loss: 2.848743438720703


training:   0%|                                                                                    | 47/100000 [2:02:07<3737:04:09, 134.60s/it]

training loss: 2.881574869155884


training:   0%|                                                                                    | 48/100000 [2:04:22<3744:02:01, 134.85s/it]

training loss: 2.826660633087158


training:   0%|                                                                                    | 49/100000 [2:06:38<3745:22:02, 134.90s/it]

training loss: 2.8716864585876465


training:   0%|                                                                                    | 50/100000 [2:08:54<3761:24:56, 135.48s/it]

training loss: 2.8646767139434814


training:   0%|                                                                                    | 51/100000 [2:11:12<3777:38:55, 136.06s/it]

training loss: 2.8590614795684814


training:   0%|                                                                                    | 52/100000 [2:13:28<3781:08:32, 136.19s/it]

training loss: 2.9650213718414307


training:   0%|                                                                                    | 52/100000 [2:15:02<4326:11:56, 155.82s/it]


KeyboardInterrupt: 

### More models

Adesso che abbiamo visto un semplice Language Model, passiamo a qualcosa di più interessante.

## Translation 1
Sequence to Sequence - Full

In [31]:
import torch
from reformer_pytorch import ReformerLM

DE_SEQ_LEN = 2048
EN_SEQ_LEN = 2048

encoder_engine = ReformerLM(
    num_tokens = 10000,#20000,
    emb_dim = 128,
    dim = 256, # 1024,
    depth = 6, #12,
    heads = 4, #8,
    max_seq_len = DE_SEQ_LEN,
    fixed_position_emb = True,
    return_embeddings = True # return output of last attention layer
)#.cuda()

decoder_engine = ReformerLM(
    num_tokens = 10000,#20000,
    emb_dim = 128,
    dim = 256, # 1024,
    depth = 6, #12,
    heads = 4, #8,
    max_seq_len = EN_SEQ_LEN,
    fixed_position_emb = True,
    causal = True
)#.cuda()

In [34]:
# Creiamo il dataset e facciamo il train!
enc_optim = torch.optim.Adam(encoder_engine.parameters(), lr=LEARNING_RATE)
dec_optim = torch.optim.Adam(decoder_engine.parameters(), lr=LEARNING_RATE)

validate_every = 100
for tr_step in tqdm.tqdm(range(1000)):
    ## Sostituire questo con i token del vostro dataset
    src  = torch.randint(0, 10000, (1, DE_SEQ_LEN)).long()#.cuda()
    trg = torch.randint(0, 10000, (1, EN_SEQ_LEN)).long()#.cuda()
    ###

    encoder_engine.train()
    decoder_engine.train()
    src = src.to("cpu")
    trg = trg.to("cpu")
    
    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        enc_keys = encoder_engine(src)
        loss = decoder_engine(trg, keys=enc_keys, return_loss=True)
        loss.sum().backward()

    print(f'training loss: {loss.sum().item()}')
    torch.nn.utils.clip_grad_norm_(encoder_engine.parameters(), 0.5)
    torch.nn.utils.clip_grad_norm_(decoder_engine.parameters(), 0.5)
    
    enc_optim.step()
    dec_optim.step()
    
    enc_optim.zero_grad()
    dec_optim.zero_grad()

    if tr_step % validate_every == 0:
        encoder_engine.eval()
        decoder_engine.eval()
        with torch.no_grad():
            ts_src = torch.randint(0, 10000, (1, DE_SEQ_LEN)).long()#.cuda()
            ts_trg = torch.randint(0, 10000, (1, EN_SEQ_LEN)).long()#.cuda()

            ts_src = ts_src.to("cpu")
            ts_trg = ts_trg.to("cpu")

            enc_keys = encoder_engine(ts_src)
            loss = decoder_engine(ts_trg,
                                  keys=enc_keys,
                                  return_loss=True)
                

        print(
            f'\tValidation Loss: {loss.sum().item()}'
        )



  0%|                                                                                                                    | 0/1 [00:00<?, ?it/s]

training loss: -711421.75


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:45<00:00, 45.54s/it]

	Validation Loss: -1059510.625





## Translation 2
Sequence to Sequence - Shortcut

In [39]:
DE_SEQ_LEN = 2048
EN_SEQ_LEN = 2048

enc_dec = ReformerEncDec(
    dim = 256,
    enc_num_tokens = 10000,
    enc_depth = 6,
    enc_max_seq_len = DE_SEQ_LEN,
    dec_num_tokens = 10000,
    dec_depth = 6,
    dec_max_seq_len = EN_SEQ_LEN
)#.cuda()

optim = torch.optim.Adam(enc_dec.parameters(), lr=LEARNING_RATE)

validate_every = 100
for tr_step in tqdm.tqdm(range(1000)):
    train_seq_in = torch.randint(0, 10000, (1, DE_SEQ_LEN)).long()#.cuda()
    train_seq_out = torch.randint(0, 10000, (1, EN_SEQ_LEN)).long()#.cuda()
    input_mask = torch.ones(1, DE_SEQ_LEN).bool()#.cuda()

    loss = enc_dec(train_seq_in, train_seq_out, return_loss = True, enc_input_mask = input_mask)
    loss.sum().backward()
    # learn
    torch.nn.utils.clip_grad_norm_(enc_dec.parameters(), 0.5)
   
    optim.step()    
    optim.zero_grad()

    # evaluate with the following
    if tr_step % validate_every == 0:
        eval_seq_in = torch.randint(0, 10000, (1, DE_SEQ_LEN)).long()#.cuda()
        eval_seq_out_start = torch.tensor([[0.]]).long()#.cuda() 
        samples = enc_dec.generate(eval_seq_in, eval_seq_out_start, seq_len = EN_SEQ_LEN, eos_token = 1) 
        print(samples.shape) 

0it [00:00, ?it/s]


### Translation con dati reali

In [6]:
tr_en_file = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en"
tr_de_file = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de"

ts_en_file = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2015.en"
ts_de_file = "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/newstest2015.de"

if not os.path.exists("./train.en"):
    !wget $tr_en_file
if not os.path.exists("./train.de"):
    !wget $tr_de_file
if not os.path.exists("./newstest2015.en"):
    !wget $ts_en_file
if not os.path.exists("./newstest2015.de"):
    !wget $ts_de_file

In [4]:
def construct_vocab(dataset):
    char_to_id = {
        "SOS": 0,
        "EOS": 1,
        "PAD": 2
    }
    id_to_char = {
        0: "SOS",
        1: "EOS",
        2: "PAD"
    }
    for line in tqdm.tqdm(dataset):
        for tk in line.lower():
            idx = len(char_to_id.keys())
            if tk not in char_to_id:
                char_to_id[tk] = idx
                id_to_char[idx] = tk
    return char_to_id, id_to_char

train_en = open("./train.en").readlines()
char_to_id_en, id_to_char_en = construct_vocab(train_en)
print(f"Totale vocabulary en: {len(char_to_id_en)}")

train_de = open("./train.de").readlines()
char_to_id_de, id_to_char_de = construct_vocab(train_de)
print(f"Totale vocabulary de: {len(char_to_id_de)}")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4468840/4468840 [01:04<00:00, 68898.55it/s]


Totale vocabulary en: 2805


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4468840/4468840 [01:13<00:00, 60663.15it/s]

Totale vocabulary de: 2505





In [5]:
DE_SEQ_LEN = 2048
EN_SEQ_LEN = 2048

enc_dec = ReformerEncDec(
    dim = 256,
    enc_num_tokens = len(char_to_id_de),
    enc_depth = 6,
    enc_max_seq_len = DE_SEQ_LEN,
    dec_num_tokens = len(char_to_id_en),
    dec_depth = 6,
    dec_max_seq_len = EN_SEQ_LEN
)

In [6]:
optim = torch.optim.Adam(enc_dec.parameters(), lr=LEARNING_RATE)

In [10]:
def create_tensor(sentence, char_to_id):
    tokens = [char_to_id["SOS"]]
    for char in sentence:
        if len(tokens) < (SEQ_LEN-1):
            if char in char_to_id:
                tokens.append(char_to_id[char])
    tokens.append(char_to_id["EOS"])

    while len(tokens) < SEQ_LEN:
        tokens.append(char_to_id["PAD"])
    
    return tokens

In [None]:
'''
[SOS] the cat is on the table     [EOS] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
[SOS] il gatto è sul tavolo [EOS] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]
'''

In [8]:
def get_eval_sentece():
    en_sentences = open("./newstest2015.en").readlines()
    de_sentences = open("./newstest2015.de").readlines()
    assert len(en_sentences) == len(de_sentences)
    
    sent_idx = random.randint(0, len(en_sentences))
    
    return en_sentences[sent_idx].lower(), de_sentences[sent_idx].lower()

In [11]:
validate_every = 100    
for tr_step in tqdm.tqdm(range(0,len(train_en), BATCH_SIZE), mininterval=10., desc='training'):
    en_input = []
    de_input = []
    for j in range(BATCH_SIZE):
        en_sentence = train_en[tr_step+j].lower()
        en_tokens = create_tensor(sentence=en_sentence, char_to_id=char_to_id_en)
        en_input.append(en_tokens)
        
        de_sentence = train_de[tr_step+j].lower()
        de_tokens = create_tensor(sentence=de_sentence, char_to_id=char_to_id_de)
        de_input.append(de_tokens)
        
    en_input = torch.tensor(en_input)
    de_input = torch.tensor(de_input)
    
    
    input_mask = torch.ones(BATCH_SIZE, DE_SEQ_LEN).bool()#.cuda()

    loss = enc_dec(de_input, en_input, return_loss = True, enc_input_mask = input_mask)
    loss.sum().backward()
    print(f"Training loss: {loss.sum().item()}")
    # learn
    torch.nn.utils.clip_grad_norm_(enc_dec.parameters(), 0.5)
   
    optim.step()    
    optim.zero_grad()

    # evaluate with the following
    if tr_step % validate_every == 0:
        en_ts_sent, de_ts_sent = get_eval_sentece()
        
        en_ts_tokens = create_tensor(sentence=en_ts_sent, char_to_id=char_to_id_en)
        de_ts_tokens = create_tensor(sentence=de_ts_sent, char_to_id=char_to_id_de)
        
        eval_seq_de = torch.tensor([de_ts_tokens])
        eval_seq_en_start = torch.tensor([[char_to_id_en["SOS"]]]).long()#.cuda() 
        samples = enc_dec.generate(eval_seq_de, eval_seq_en_start, seq_len = EN_SEQ_LEN, eos_token = char_to_id_en["EOS"]) 
        print(samples.shape)
    
    if tr_step > 24:
        break

training:   0%|                                                                                                                                  | 0/1117210 [00:00<?, ?it/s]

Training loss: 7.1914496421813965


training:   0%|                                                                                                                  | 1/1117210 [08:05<150639:53:34, 485.41s/it]

torch.Size([1, 141])


training:   0%|                                                                                                                   | 2/1117210 [11:18<97232:34:59, 313.31s/it]

Training loss: 5.719971179962158


training:   0%|                                                                                                                   | 3/1117210 [14:30<80039:55:56, 257.91s/it]

Training loss: 4.518599987030029


training:   0%|                                                                                                                   | 4/1117210 [17:23<69662:10:23, 224.47s/it]

Training loss: 3.384230136871338


training:   0%|                                                                                                                   | 5/1117210 [20:19<64190:19:15, 206.84s/it]

Training loss: 2.572434186935425


training:   0%|                                                                                                                   | 6/1117210 [23:17<61191:32:02, 197.18s/it]

Training loss: 1.978898525238037


training:   0%|                                                                                                                   | 7/1117210 [26:13<59028:34:06, 190.21s/it]

Training loss: 1.3155440092086792


training:   0%|                                                                                                                   | 7/1117210 [29:17<77894:50:40, 251.00s/it]

Training loss: 1.148115873336792





## Facciamo una evaluation di prova

In [12]:
en_ts_sent, de_ts_sent = get_eval_sentece()

en_ts_tokens = create_tensor(sentence=en_ts_sent, char_to_id=char_to_id_en)
de_ts_tokens = create_tensor(sentence=de_ts_sent, char_to_id=char_to_id_de)

eval_seq_de = torch.tensor([de_ts_tokens])
eval_seq_en_start = torch.tensor([[char_to_id_en["SOS"]]]).long()#.cuda() 
samples = enc_dec.generate(eval_seq_de, eval_seq_en_start, seq_len = EN_SEQ_LEN, eos_token = char_to_id_en["EOS"]) 
print(samples.shape)

torch.Size([1, 2048])


## Image Captioning

In [None]:
import torch
from torch.nn import Sequential
from torchvision import models
from reformer_pytorch import Reformer, ReformerLM

resnet = models.resnet50(pretrained=True)
resnet = Sequential(*list(resnet.children())[:-4])

SEQ_LEN = 4096

encoder = Reformer(
    dim = 512,
    depth = 6,
    heads = 8,
    max_seq_len = 4096
)

decoder = ReformerLM(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    heads = 8,
    max_seq_len = SEQ_LEN,
    causal = True
)

x  = torch.randn(1, 3, 512, 512)
yi = torch.randint(0, 20000, (1, SEQ_LEN)).long()

visual_emb = resnet(x)
b, c, h, w = visual_emb.shape
visual_emb = visual_emb.view(1, c, h * w).transpose(1, 2) # nchw to nte

enc_keys = encoder(visual_emb)
yo = decoder(yi, keys = enc_keys) # (1, 4096, 20000)