In [1]:
import torch
import numpy as np

In [2]:
class Embedder(torch.nn.Module):
    
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab_size, d_model)
        
    def forward(self, x):
        # [123, 0, 23, 5] -> [[..512..], [...512...], ...]
        return self.embed(x)  

In [3]:
import math

class PositionalEncoder(torch.nn.Module):
    
    def __init__(self, d_model, max_seq_len=80):
        super().__init__()
        self.d_model = d_model
        
        # create constant positional encoding matrix
        pe_matrix = torch.zeros(max_seq_len, d_model)
        
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pe_matrix[pos, i] = math.sin(pos/10000**(2*i/d_model))
                pe_matrix[pos, i+1] = math.cos(pos/10000**(2*i/d_model))
        pe_matrix = pe_matrix.unsqueeze(0)     # Add one dimension for batch size
        self.register_buffer('pe', pe_matrix)  # Register as persistent buffer
        
    def forward(self, x):
        # x is a sentence after embedding with dim (batch, number of words, vector dimension)
        seq_len = x.size()[1]
        x = x + self.pe[:, :seq_len]
        return x

In [4]:
import math
import torch.nn.functional as F

# Given Query, Key, Value, calculate the final weighted value
def scaled_dot_product_attention(q, k, v, mask=None, dropout=None):
    # Shape of q and k are the same, both are (batch_size, seq_len, d_k)
    # Shape of v is (batch_size, seq_len, d_v)
    attention_scores = torch.matmul(q, k.transpose(-2, -1))/math.sqrt(q.shape[-1])  # size (batch_size, seq_len, seq_len)
    
    # Apply mask to scores
    # 
    if mask is not None:
        attention_scores = attention_scores.masked_fill(mask == 0, value=-1e9)
        
    # Softmax along the last dimension
    attention_weights = F.softmax(attention_scores, dim=-1)
    
    if dropout is not None:
        attention_weights = dropout(attention_weights)
        
    output = torch.matmul(attention_weights, v)
    return output

In [5]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, n_heads, d_model, dropout=0.1):
        super().__init__()
        
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = self.d_v = d_model//n_heads
        
        # self attention linear layers
        # Linear layers for q, k, v vectors generation in different heads
        self.q_linear_layers = []
        self.k_linear_layers = []
        self.v_linear_layers = []
        for i in range(n_heads):
            self.q_linear_layers.append(torch.nn.Linear(d_model, self.d_k))
            self.k_linear_layers.append(torch.nn.Linear(d_model, self.d_k))
            self.v_linear_layers.append(torch.nn.Linear(d_model, self.d_v))
        
        self.dropout = torch.nn.Dropout(dropout)
        self.out = torch.nn.Linear(n_heads*self.d_v, d_model)
        
    def forward(self, q, k, v, mask=None):
        multi_head_attention_outputs = []
        for q_linear, k_linear, v_linear in zip(self.q_linear_layers,
                                                self.k_linear_layers,
                                                self.v_linear_layers):
            new_q = q_linear(q)  # size: (batch_size, seq_len, d_k)
            new_k = k_linear(k)  # size: (batch_size, seq_len, d_k)
            new_v = v_linear(v)  # size (batch_size, seq_len, d_v)
            
            # Scaled Dot-Product attention
            head_v = scaled_dot_product_attention(new_q, new_k, new_v, mask, self.dropout)  # (batch_size, seq_len, d_v)
            multi_head_attention_outputs.append(head_v)
            
        # Concat
        #import pdb; pdb.set_trace()
        concat = torch.cat(multi_head_attention_outputs, -1)  # (batch_size, seq_len, n_heads*d_v)
        
        # Linear layer to recover to original shap
        output = self.out(concat)  # (batch_size, seq_len, d_model)
        
        return output

In [6]:
class FeedForward(torch.nn.Module):
    def __init__(self, d_model, d_ff=2048, dropout=0.1):
        super().__init__()
        
        self.linear_1 = torch.nn.Linear(d_model, d_ff)
        self.dropout = torch.nn.Dropout(dropout)
        self.linear_2 = torch.nn.Linear(d_ff, d_model)
        
    def forward(self, x):
        x = self.dropout(F.relu(self.linear_1(x)))
        x = self.linear_2(x)
        return x

In [7]:
class LayerNorm(torch.nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.d_model = d_model
        self.alpha = torch.nn.Parameter(torch.ones(self.d_model))
        self.beta = torch.nn.Parameter(torch.zeros(self.d_model))
        self.eps = eps
        
    def forward(self, x):
        # x size: (batch_size, seq_len, d_model)
        x_hat = (x - x.mean(dim=-1, keepdim=True))/(x.std(dim=-1, keepdim=True) + self.eps)
        x_tilde = self.alpha*x_hat + self.beta
        return x_tilde

In [8]:
class EncoderLayer(torch.nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.norm_1 = LayerNorm(d_model)
        self.norm_2 = LayerNorm(d_model)
        self.multi_head_attention = MultiHeadAttention(n_heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout_1 = torch.nn.Dropout(dropout)
        self.dropout_2 = torch.nn.Dropout(dropout)
        
    def forward(self, x, mask):
        #import pdb; pdb.set_trace()
        x = x + self.dropout_1(self.multi_head_attention(x, x, x, mask))
        x = self.norm_1(x)
        
        x = x + self.dropout_2(self.feed_forward(x))
        x = self.norm_2(x)
        return x

In [9]:
class DecoderLayer(torch.nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.norm_1 = LayerNorm(d_model)
        self.norm_2 = LayerNorm(d_model)
        self.norm_3 = LayerNorm(d_model)
        
        self.dropout_1 = torch.nn.Dropout(dropout)
        self.dropout_2 = torch.nn.Dropout(dropout)
        self.dropout_3 = torch.nn.Dropout(dropout)
        
        self.multi_head_attention_1 = MultiHeadAttention(n_heads, d_model)
        self.multi_head_attention_2 = MultiHeadAttention(n_heads, d_model)
        
        self.feed_forward = FeedForward(d_model)
        
    def forward(self, x, encoder_output, src_mask, trg_mask):
        x = self.dropout_1(self.multi_head_attention_1(x, x, x, trg_mask))
        x = x + self.norm_1(x)
        
        x = self.dropout_2(self.multi_head_attention_2(x, encoder_output, encoder_output, src_mask))
        x = x + self.norm_2(x)
        
        x = self.dropout_3(self.feed_forward(x))
        x = x + self.norm_3(x)
        
        return x

In [10]:
import copy

def clone_layer(module, N):
    return torch.nn.ModuleList([copy.deepcopy(module) for i in range(N)])

In [11]:
class Encoder(torch.nn.Module):
    def __init__(self, vocab_size, d_model, N, n_heads):
        super().__init__()
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model)
        self.encoder_layers = clone_layer(EncoderLayer(d_model, n_heads), N)
        self.norm = LayerNorm(d_model)
        
    def forward(self, src, mask):
        x = self.embed(src)
        x = self.pe(x)
        for encoder in self.encoder_layers:
            x = encoder(x, mask)
        return self.norm(x)

In [12]:
class Decoder(torch.nn.Module):
    def __init__(self, vocab_size, d_model, N, n_heads):
        super().__init__()
        self.embed = Embedder(vocab_size, d_model)
        self.pe = PositionalEncoder(d_model)
        self.decoder_layers = clone_layer(DecoderLayer(d_model, n_heads), N)
        self.norm = LayerNorm(d_model)
        
    def forward(self, trg, encoder_output, src_mask, trg_mask):
        x = self.embed(trg)
        x = self.pe(x)
        for decoder in self.decoder_layers:
            x = decoder(x, encoder_output, src_mask, trg_mask)
        return self.norm(x)

In [13]:
class Transformer(torch.nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, d_model, N, n_heads):
        super().__init__()
        self.encoder = Encoder(src_vocab_size, d_model, N, n_heads)
        self.decoder = Decoder(trg_vocab_size, d_model, N, n_heads)
        self.linear = torch.nn.Linear(d_model, trg_vocab_size)
        
    def forward(self, src, trg, src_mask, trg_mask):
        encoder_output = self.encoder(src, src_mask)
        decoder_output = self.decoder(trg, encoder_output, src_mask, trg_mask)
        output = self.linear(decoder_output)
        return output

In [2]:
import spacy
from tqdm.auto import tqdm
from torchtext import data
nlp = spacy.load('pt_core_news_sm')

tokenizer = lambda sentence: [tok.text for tok in nlp.tokenizer(sentence) if tok.text != " "]

eng = open('data/english.txt', 'r', encoding = 'utf8').readlines()
fr = open('data/french.txt', 'r', encoding = 'utf8').readlines()

engnew = open('data/english1.txt', 'w', encoding = 'utf8')
frnew = open('data/french1.txt', 'w', encoding = 'utf8')

for e, f in tqdm(zip(eng, fr)):
    e, f = tokenizer(e)[:-1], ['@'] + tokenizer(f)[:-1] + ['#']
    for i in range(1, len(f)):
#         print(e, f[:i], [f[i]])
        engnew.write(' '.join(e))
        engnew.write('\n')
        frnew.write(' '.join(f[:i+1]))
        frnew.write('\n')

engnew.close()
frnew.close()

0it [00:00, ?it/s]

In [22]:
engnew.close()
frnew.close()
ys.close()

In [5]:
import spacy
from torchtext import data
nlp = spacy.load('pt_core_news_sm')

tokenizer = lambda sentence: [tok.text for tok in nlp.tokenizer(sentence) if tok.text != " "]
SRC = data.Field(lower=True, tokenize=tokenizer)
TRG = data.Field(lower=True, tokenize=tokenizer)
Y = data.Field(lower=True, tokenize=tokenizer)
src_data = open('data/english1.txt', 'r', encoding = 'utf8')
trg_data = open('data/french1.txt', 'r', encoding = 'utf8')
raw_data = {'src': [line for line in src_data], 'trg': [line for line in trg_data]}
import pandas as pd
df = pd.DataFrame(raw_data, columns=['src', 'trg'])
df.tail(50)

Unnamed: 0,src,trg
1343731,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...
1343732,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...
1343733,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...
1343734,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...
1343735,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...
1343736,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...
1343737,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...
1343738,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...
1343739,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...
1343740,If someone who doesn't know your background sa...,@ Si quelqu'un qui ne connaît pas vos antécéde...


In [18]:
import spacy
from torchtext import data
nlp = spacy.load('pt_core_news_sm')

tokenizer = lambda sentence: [tok.text for tok in nlp.tokenizer(sentence) if tok.text != " "]
SRC = data.Field(lower=True, tokenize=tokenizer)
TRG = data.Field(lower=True, tokenize=tokenizer, init_token = '<sos>', eos_token = '<eos>')

src_data = open('data/english.txt', 'r', encoding = 'utf8')
trg_data = open('data/french.txt', 'r', encoding = 'utf8')
raw_data = {'src': [line for line in src_data], 'trg': [line for line in trg_data]}

import pandas as pd
df = pd.DataFrame(raw_data, columns=['src', 'trg'])
df.head(50)

Unnamed: 0,src,trg
0,Go.\n,Va !\n
1,Run!\n,Cours !\n
2,Run!\n,Courez !\n
3,Fire!\n,Au feu !\n
4,Help!\n,À l'aide !\n
5,Jump.\n,Saute.\n
6,Stop!\n,Ça suffit !\n
7,Stop!\n,Stop !\n
8,Stop!\n,Arrête-toi !\n
9,Wait!\n,Attends !\n


In [19]:
df.to_csv('en_to_fr.csv', index=False)

data_fields = [('src', SRC), ('trg', TRG)]
train_set = data.TabularDataset('./en_to_fr.csv', format='csv', fields=data_fields)
SRC.build_vocab(train_set)
print(len(SRC.vocab))

TRG.build_vocab(train_set)
print(len(TRG.vocab))

train_set

14115
28354


<torchtext.data.dataset.TabularDataset at 0x1d8ba9dfa00>

In [20]:
[SRC.vocab.stoi[tok] for tok in tokenizer("A Run carbon footprint is the amount of carbon")]

[0, 0, 3513, 5974, 9, 6, 2025, 15, 3513]

In [21]:
# set some parameters
d_model = 512
n_heads = 8
N = 6
src_vocab_size = len(SRC.vocab)
trg_vocab_size = len(TRG.vocab)
model = Transformer(src_vocab_size, trg_vocab_size, d_model, N, n_heads)
for p in model.parameters():
    if p.dim() > 1:
        torch.nn.init.xavier_uniform_(p)

In [22]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
train_iter = data.Iterator(train_set, batch_size=1, sort_key=lambda x: (len(x.src), len(x.trg)), shuffle=True, train=True)
import numpy as np
from tqdm.auto import tqdm

def create_mask(src_input, trg_input):
    # Source input mask
    pad = SRC.vocab.stoi['<pad>']
    src_mask = (src_input != pad).unsqueeze(1)
    
    # Target input mask
    trg_mask = (trg_input != pad).unsqueeze(1)
    
    seq_len = trg_input.size(1)
    nopeak_mask = np.tril(np.ones((1, seq_len, seq_len)), k=0).astype('uint8')
    nopeak_mask = torch.from_numpy(nopeak_mask) != 0
    trg_mask = trg_mask & nopeak_mask
    
    return src_mask, trg_mask

import time

def train_model(n_epochs, output_interval=100):
    model.train()
    start = time.time()
    
    for epoch in range(n_epochs):
        
        total_loss = 0
        for i, batch in enumerate(train_iter):
            
            src_input = batch.src.transpose(0, 1)  # size (batch_size, seq_len)
            trg = batch.trg.transpose(0, 1)  # size (batch_size, seq_len)
            
            trg_input = trg[:, :-1]
            ys = trg[:, 1:].contiguous().view(-1)
            # create src & trg masks
            src_mask, trg_mask = create_mask(src_input, trg_input)
            preds = model(src_input, trg_input, src_mask, trg_mask)
            
            optimizer.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=1)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.data

            if (i + 1) % output_interval == 0:
                avg_loss = total_loss/output_interval
                print('time = {}, epoch = {}, iter = {}, loss = {}'.format((time.time() - start)/60,
                                                                           epoch + 1,
                                                                           i + 1,
                                                                           avg_loss))
                total_loss = 0
                start = time.time()

In [27]:
for i, batch in enumerate(train_iter):
    src_input = batch.src.transpose(0,1)
    trg = batch.trg.transpose(0,1) # size (batch_size, seq_len)
    
    trg_input = trg[:, :-1]
    ys = trg[:, 1:].contiguous().view(-1)
    for s, t in zip(src_input, trg_input):
        print([SRC.vocab.itos[tok] for tok in s])
        print([TRG.vocab.itos[tok] for tok in t])
        print([TRG.vocab.itos[tok] for tok in ys])
        print('-'*100)
    print('='*100)
    
    src_mask, trg_mask = create_mask(src_input, trg_input)
    print(trg_mask)
    preds = model(src_input, trg_input, src_mask, trg_mask)
    preds = F.softmax(preds, dim = -1)
    _, out = preds.data.topk(1)
    print([TRG.vocab.itos[tok] for tok in out[0]])
    break

NameError: name 'train_iter' is not defined

In [25]:
def t(src, trg):
    src_input = torch.tensor([[SRC.vocab.stoi[tok] for tok in src]])
    trg = torch.tensor([[TRG.vocab.stoi[tok] for tok in trg]])
    
    trg_input = trg[:, :-1]
    
    src_mask, trg_mask = create_mask(src_input, trg_input)
    print(src_mask, trg_mask)
    preds = model(src_input, trg_input, src_mask, trg_mask)
    print(preds)
    preds = F.softmax(preds, dim = -1)
    _, out = preds.data.topk(1)
    print([TRG.vocab.itos[tok] for tok in out[0]])

src = ['have', 'you', 'ever', 'seen', 'a', 'tiger', 'around', 'here', '?'] 
trg = ['<sos>', 'avez-vous', 'déjà', 'vu', 'un', 'tigre', 'dans', 'les', 'environs', '?', '<eos>']

def translate(src, max_len = 50):
    src_input = torch.tensor([[SRC.vocab.stoi[tok] for tok in src]])
    trg_input = torch.tensor([[TRG.vocab.stoi['<sos>']]])
    
    for _ in range(max_len):
        src_mask, trg_mask = create_mask(src_input, trg_input)
        preds = model(src_input, trg_input, src_mask, trg_mask)
        preds = F.softmax(preds, dim = -1)
        _, out = preds.data.topk(1)
        trg_input = torch.cat((trg_input, torch.tensor([[out[0][-1]]])))
        if out[0][-1][0] == torch.tensor(TRG.vocab.stoi['<eos>']):
            break
    return ' '.join([TRG.vocab.itos[tok] for tok in trg_input])

translate(src)

'<sos> que pourquoi pourquoi comment comment pourquoi pourquoi pourquoi comment pourquoi que pourquoi pourquoi pourquoi comment pourquoi comment comment pourquoi comment que que comment pourquoi comment pourquoi pourquoi pourquoi pourquoi comment pourquoi comment que que pourquoi comment comment pourquoi pourquoi comment pourquoi comment comment pourquoi pourquoi comment pourquoi pourquoi que pourquoi'

In [None]:
train_model(3, output_interval=1)

time = 0.016568076610565186, epoch = 1, iter = 1, loss = 6.364777565002441
time = 0.01610981225967407, epoch = 1, iter = 2, loss = 7.785104274749756
time = 0.01474459966023763, epoch = 1, iter = 3, loss = 6.641254425048828
time = 0.015484038988749187, epoch = 1, iter = 4, loss = 8.166705131530762
time = 0.015028854211171469, epoch = 1, iter = 5, loss = 6.028402805328369
time = 0.015239071846008301, epoch = 1, iter = 6, loss = 5.850545883178711
time = 0.01653244098027547, epoch = 1, iter = 7, loss = 7.139897346496582
time = 0.020486164093017577, epoch = 1, iter = 8, loss = 6.179225921630859
time = 0.021252556641896566, epoch = 1, iter = 9, loss = 7.808310508728027
time = 0.01576830546061198, epoch = 1, iter = 10, loss = 7.755741119384766
time = 0.015081632137298583, epoch = 1, iter = 11, loss = 7.039430618286133
time = 0.014724949995676676, epoch = 1, iter = 12, loss = 6.282425880432129
time = 0.015323448181152343, epoch = 1, iter = 13, loss = 6.6631879806518555
time = 0.016498037179311

time = 0.014906509717305502, epoch = 1, iter = 110, loss = 7.603344917297363
time = 0.01490335464477539, epoch = 1, iter = 111, loss = 7.069977760314941
time = 0.015535569190979004, epoch = 1, iter = 112, loss = 6.010895252227783
time = 0.015240983168284098, epoch = 1, iter = 113, loss = 6.997994899749756
time = 0.015251576900482178, epoch = 1, iter = 114, loss = 5.298720359802246
time = 0.014977351824442545, epoch = 1, iter = 115, loss = 7.077021598815918
time = 0.015481801827748616, epoch = 1, iter = 116, loss = 6.601924896240234
time = 0.0149750550587972, epoch = 1, iter = 117, loss = 6.3649797439575195
time = 0.014788695176442464, epoch = 1, iter = 118, loss = 7.241326808929443
time = 0.015041780471801759, epoch = 1, iter = 119, loss = 6.903527736663818
time = 0.0151214599609375, epoch = 1, iter = 120, loss = 7.8203277587890625
time = 0.015495137373606364, epoch = 1, iter = 121, loss = 5.683116912841797
time = 0.014774616559346516, epoch = 1, iter = 122, loss = 6.963935852050781
ti

time = 0.016067111492156984, epoch = 1, iter = 217, loss = 7.981557369232178
time = 0.015085152784983317, epoch = 1, iter = 218, loss = 5.71095609664917
time = 0.015858805179595946, epoch = 1, iter = 219, loss = 5.640440464019775
time = 0.015534520149230957, epoch = 1, iter = 220, loss = 5.59948205947876
time = 0.015269529819488526, epoch = 1, iter = 221, loss = 6.7716827392578125
time = 0.015584909915924072, epoch = 1, iter = 222, loss = 7.336301326751709
time = 0.014829460779825847, epoch = 1, iter = 223, loss = 4.580948352813721
time = 0.014569528897603353, epoch = 1, iter = 224, loss = 7.896090030670166
time = 0.01504906415939331, epoch = 1, iter = 225, loss = 6.481283187866211
time = 0.015131346384684245, epoch = 1, iter = 226, loss = 5.767949104309082
time = 0.015105827649434408, epoch = 1, iter = 227, loss = 5.8119330406188965
time = 0.015232284863789877, epoch = 1, iter = 228, loss = 7.1722846031188965
time = 0.015360653400421143, epoch = 1, iter = 229, loss = 6.303107738494873

time = 0.01586843729019165, epoch = 1, iter = 324, loss = 6.362663745880127
time = 0.015098853905995687, epoch = 1, iter = 325, loss = 5.912585258483887
time = 0.015066742897033691, epoch = 1, iter = 326, loss = 6.430394649505615
time = 0.014976723988850912, epoch = 1, iter = 327, loss = 6.621278285980225
time = 0.015549282232920328, epoch = 1, iter = 328, loss = 5.325039386749268
time = 0.01523664395014445, epoch = 1, iter = 329, loss = 6.074916839599609
time = 0.015516308943430583, epoch = 1, iter = 330, loss = 5.36292028427124
time = 0.015260740121205648, epoch = 1, iter = 331, loss = 6.035889625549316
time = 0.01561439037322998, epoch = 1, iter = 332, loss = 6.732856273651123
time = 0.016059676806132, epoch = 1, iter = 333, loss = 6.188155174255371
time = 0.016783646742502847, epoch = 1, iter = 334, loss = 6.16618013381958
time = 0.017112255096435547, epoch = 1, iter = 335, loss = 7.148543357849121
time = 0.01581275463104248, epoch = 1, iter = 336, loss = 7.1462931632995605
time = 

time = 0.01578883727391561, epoch = 1, iter = 431, loss = 6.868432521820068
time = 0.015361074606577556, epoch = 1, iter = 432, loss = 6.299765110015869
time = 0.015064589182535807, epoch = 1, iter = 433, loss = 6.608158111572266
time = 0.015561517079671223, epoch = 1, iter = 434, loss = 6.641746520996094
time = 0.015254473686218262, epoch = 1, iter = 435, loss = 6.348840713500977
time = 0.014995201428731283, epoch = 1, iter = 436, loss = 5.750905990600586
time = 0.014969058831532796, epoch = 1, iter = 437, loss = 6.064032077789307
time = 0.015250396728515626, epoch = 1, iter = 438, loss = 5.910512447357178
time = 0.014989511171976725, epoch = 1, iter = 439, loss = 5.739227294921875
time = 0.01481003761291504, epoch = 1, iter = 440, loss = 6.02882194519043
time = 0.015041184425354005, epoch = 1, iter = 441, loss = 6.1530351638793945
time = 0.015590973695119222, epoch = 1, iter = 442, loss = 7.773530960083008
time = 0.015501093864440919, epoch = 1, iter = 443, loss = 7.6615400314331055


time = 0.015893320242563885, epoch = 1, iter = 538, loss = 6.259914398193359
time = 0.015069488684336345, epoch = 1, iter = 539, loss = 5.512271881103516
time = 0.015566802024841309, epoch = 1, iter = 540, loss = 6.708102703094482
time = 0.015528508027394613, epoch = 1, iter = 541, loss = 6.240634918212891
time = 0.015548340479532878, epoch = 1, iter = 542, loss = 6.054017543792725
time = 0.015286978085835774, epoch = 1, iter = 543, loss = 5.823151588439941
time = 0.015598909060160319, epoch = 1, iter = 544, loss = 6.878537178039551
time = 0.015277334054311116, epoch = 1, iter = 545, loss = 7.16290807723999
time = 0.015237808227539062, epoch = 1, iter = 546, loss = 6.123591899871826
time = 0.015049835046132406, epoch = 1, iter = 547, loss = 6.234869480133057
time = 0.015226014455159505, epoch = 1, iter = 548, loss = 6.190341949462891
time = 0.015098253885904947, epoch = 1, iter = 549, loss = 5.914502143859863
time = 0.015345486005147298, epoch = 1, iter = 550, loss = 5.773598670959473


time = 0.01605399449666341, epoch = 1, iter = 645, loss = 5.872409343719482
time = 0.015392263730367025, epoch = 1, iter = 646, loss = 7.319266319274902
time = 0.01531689167022705, epoch = 1, iter = 647, loss = 5.887308120727539
time = 0.01471861203511556, epoch = 1, iter = 648, loss = 7.161724090576172
time = 0.01654677391052246, epoch = 1, iter = 649, loss = 5.909853458404541
time = 0.015339076519012451, epoch = 1, iter = 650, loss = 5.13730525970459
time = 0.015055688222249348, epoch = 1, iter = 651, loss = 5.689321041107178
time = 0.015063528219858806, epoch = 1, iter = 652, loss = 5.093904495239258
time = 0.014971125125885009, epoch = 1, iter = 653, loss = 5.3283491134643555
time = 0.015591712792714436, epoch = 1, iter = 654, loss = 5.526270866394043
time = 0.015495193004608155, epoch = 1, iter = 655, loss = 5.972203254699707
time = 0.01525131861368815, epoch = 1, iter = 656, loss = 6.308418273925781
time = 0.015232503414154053, epoch = 1, iter = 657, loss = 5.789233684539795
time

time = 0.01578455368677775, epoch = 1, iter = 752, loss = 7.455382347106934
time = 0.015577876567840576, epoch = 1, iter = 753, loss = 5.971994400024414
time = 0.015235026677449545, epoch = 1, iter = 754, loss = 6.43003511428833
time = 0.015148325761159261, epoch = 1, iter = 755, loss = 5.789048671722412
time = 0.015433279673258464, epoch = 1, iter = 756, loss = 6.373017311096191
time = 0.01606483062108358, epoch = 1, iter = 757, loss = 6.5079522132873535
time = 0.01568281650543213, epoch = 1, iter = 758, loss = 5.5148138999938965
time = 0.015044021606445312, epoch = 1, iter = 759, loss = 5.774095058441162
time = 0.015488942464192709, epoch = 1, iter = 760, loss = 6.416515350341797
time = 0.015242457389831543, epoch = 1, iter = 761, loss = 5.407937526702881
time = 0.015761935710906984, epoch = 1, iter = 762, loss = 7.021275520324707
time = 0.015280914306640626, epoch = 1, iter = 763, loss = 7.027584075927734
time = 0.015509466330210367, epoch = 1, iter = 764, loss = 7.360857009887695
t

time = 0.01563899517059326, epoch = 1, iter = 859, loss = 6.415809631347656
time = 0.01575092077255249, epoch = 1, iter = 860, loss = 7.069057941436768
time = 0.015227830410003662, epoch = 1, iter = 861, loss = 5.422759056091309
time = 0.0158323605855306, epoch = 1, iter = 862, loss = 5.766216278076172
time = 0.015257561206817627, epoch = 1, iter = 863, loss = 6.620869159698486
time = 0.01550530195236206, epoch = 1, iter = 864, loss = 6.389482021331787
time = 0.016031217575073243, epoch = 1, iter = 865, loss = 6.847357749938965
time = 0.015586435794830322, epoch = 1, iter = 866, loss = 4.876928806304932
time = 0.015750586986541748, epoch = 1, iter = 867, loss = 6.16153621673584
time = 0.015027213096618652, epoch = 1, iter = 868, loss = 5.8123345375061035
time = 0.01549235184987386, epoch = 1, iter = 869, loss = 6.58230447769165
time = 0.015283397833506266, epoch = 1, iter = 870, loss = 6.3552398681640625
time = 0.015766735871632895, epoch = 1, iter = 871, loss = 7.44182014465332
time =

time = 0.01567550500233968, epoch = 1, iter = 966, loss = 6.994413375854492
time = 0.015614227453867594, epoch = 1, iter = 967, loss = 5.416916370391846
time = 0.015351208051045735, epoch = 1, iter = 968, loss = 5.403323173522949
time = 0.015519289175669353, epoch = 1, iter = 969, loss = 6.976091384887695
time = 0.01585277318954468, epoch = 1, iter = 970, loss = 6.545538425445557
time = 0.01594024101893107, epoch = 1, iter = 971, loss = 5.775811195373535
time = 0.017653115590413413, epoch = 1, iter = 972, loss = 4.958970069885254
time = 0.016574621200561523, epoch = 1, iter = 973, loss = 6.423995494842529
time = 0.016844316323598226, epoch = 1, iter = 974, loss = 6.933032989501953
time = 0.01556705633799235, epoch = 1, iter = 975, loss = 6.15789270401001
time = 0.01584074099858602, epoch = 1, iter = 976, loss = 6.569124698638916
time = 0.015783385435740153, epoch = 1, iter = 977, loss = 4.8014912605285645
time = 0.01531047821044922, epoch = 1, iter = 978, loss = 5.363208770751953
time 

time = 0.016027653217315675, epoch = 1, iter = 1072, loss = 6.197028636932373
time = 0.015292716026306153, epoch = 1, iter = 1073, loss = 5.18513298034668
time = 0.015461480617523194, epoch = 1, iter = 1074, loss = 5.9671630859375
time = 0.015548038482666015, epoch = 1, iter = 1075, loss = 5.298318386077881
time = 0.0157747745513916, epoch = 1, iter = 1076, loss = 5.815690994262695
time = 0.0158174196879069, epoch = 1, iter = 1077, loss = 4.744574069976807
time = 0.015554749965667724, epoch = 1, iter = 1078, loss = 5.889654159545898
time = 0.016187663873036703, epoch = 1, iter = 1079, loss = 6.628118515014648
time = 0.015532600879669189, epoch = 1, iter = 1080, loss = 6.067887306213379
time = 0.015784462292989094, epoch = 1, iter = 1081, loss = 6.733179092407227
time = 0.015839346249898276, epoch = 1, iter = 1082, loss = 7.246123790740967
time = 0.016041505336761474, epoch = 1, iter = 1083, loss = 6.342406749725342
time = 0.015796796480814616, epoch = 1, iter = 1084, loss = 6.364531040

time = 0.01579041083653768, epoch = 1, iter = 1178, loss = 6.7661871910095215
time = 0.01586079200108846, epoch = 1, iter = 1179, loss = 6.952496528625488
time = 0.015857863426208495, epoch = 1, iter = 1180, loss = 6.538609027862549
time = 0.016568430264790854, epoch = 1, iter = 1181, loss = 6.482394695281982
time = 0.016088855266571046, epoch = 1, iter = 1182, loss = 6.052360534667969
time = 0.016253976027170818, epoch = 1, iter = 1183, loss = 5.781850337982178
time = 0.015279849370320639, epoch = 1, iter = 1184, loss = 6.8523149490356445
time = 0.01586062510808309, epoch = 1, iter = 1185, loss = 5.337226867675781
time = 0.015766517321268717, epoch = 1, iter = 1186, loss = 6.612987041473389
time = 0.015782097975413006, epoch = 1, iter = 1187, loss = 5.1540608406066895
time = 0.016070838769276938, epoch = 1, iter = 1188, loss = 5.094840049743652
time = 0.015556283791859945, epoch = 1, iter = 1189, loss = 5.6255412101745605
time = 0.015552576382954915, epoch = 1, iter = 1190, loss = 5.2

time = 0.01580911080042521, epoch = 1, iter = 1284, loss = 7.1618170738220215
time = 0.015866323312123617, epoch = 1, iter = 1285, loss = 4.968854904174805
time = 0.01601260503133138, epoch = 1, iter = 1286, loss = 5.742913246154785
time = 0.01606459617614746, epoch = 1, iter = 1287, loss = 6.967636585235596
time = 0.01782956123352051, epoch = 1, iter = 1288, loss = 6.827091217041016
time = 0.017608853181203206, epoch = 1, iter = 1289, loss = 5.344006061553955
time = 0.017104156812032065, epoch = 1, iter = 1290, loss = 5.3916707038879395
time = 0.016095304489135744, epoch = 1, iter = 1291, loss = 7.349853515625
time = 0.016250904401143393, epoch = 1, iter = 1292, loss = 6.28725528717041
time = 0.016049714883168538, epoch = 1, iter = 1293, loss = 5.412097454071045
time = 0.01583108107248942, epoch = 1, iter = 1294, loss = 6.87129545211792
time = 0.015554551283518474, epoch = 1, iter = 1295, loss = 5.567989349365234
time = 0.015266652901967366, epoch = 1, iter = 1296, loss = 6.4248170852

time = 0.01569583813349406, epoch = 1, iter = 1390, loss = 5.121145725250244
time = 0.016556493441263833, epoch = 1, iter = 1391, loss = 6.796705722808838
time = 0.016577800114949543, epoch = 1, iter = 1392, loss = 6.730971813201904
time = 0.015809039274851482, epoch = 1, iter = 1393, loss = 5.661753177642822
time = 0.015789389610290527, epoch = 1, iter = 1394, loss = 7.307091236114502
time = 0.015828462441762288, epoch = 1, iter = 1395, loss = 6.665531158447266
time = 0.016064719359079997, epoch = 1, iter = 1396, loss = 5.2948760986328125
time = 0.01607560316721598, epoch = 1, iter = 1397, loss = 5.195597171783447
time = 0.016092952092488608, epoch = 1, iter = 1398, loss = 5.447049617767334
time = 0.016211076577504476, epoch = 1, iter = 1399, loss = 7.47901725769043
time = 0.01634089946746826, epoch = 1, iter = 1400, loss = 6.256782054901123
time = 0.015576120217641194, epoch = 1, iter = 1401, loss = 6.270019054412842
time = 0.016039605935414633, epoch = 1, iter = 1402, loss = 6.88553

time = 0.01609420379002889, epoch = 1, iter = 1496, loss = 6.49908971786499
time = 0.016433719793955484, epoch = 1, iter = 1497, loss = 6.5367326736450195
time = 0.015851465861002605, epoch = 1, iter = 1498, loss = 7.1840105056762695
time = 0.01595388650894165, epoch = 1, iter = 1499, loss = 7.8274335861206055
time = 0.016067532698313396, epoch = 1, iter = 1500, loss = 5.485763072967529
time = 0.016332876682281495, epoch = 1, iter = 1501, loss = 6.173087120056152
time = 0.016097056865692138, epoch = 1, iter = 1502, loss = 5.488361835479736
time = 0.01567512353261312, epoch = 1, iter = 1503, loss = 6.761080265045166
time = 0.0163664182027181, epoch = 1, iter = 1504, loss = 6.852267265319824
time = 0.01633371909459432, epoch = 1, iter = 1505, loss = 6.949765682220459
time = 0.016066718101501464, epoch = 1, iter = 1506, loss = 5.841769695281982
time = 0.015992538134257, epoch = 1, iter = 1507, loss = 7.536272048950195
time = 0.016050942738850913, epoch = 1, iter = 1508, loss = 6.030697822

time = 0.01584979295730591, epoch = 1, iter = 1602, loss = 6.348926067352295
time = 0.016120413939158123, epoch = 1, iter = 1603, loss = 4.697298526763916
time = 0.01630430221557617, epoch = 1, iter = 1604, loss = 5.534346103668213
time = 0.016416247685750326, epoch = 1, iter = 1605, loss = 6.687004089355469
time = 0.0158518115679423, epoch = 1, iter = 1606, loss = 7.13908052444458
time = 0.016069245338439942, epoch = 1, iter = 1607, loss = 6.844205379486084
time = 0.01622534195582072, epoch = 1, iter = 1608, loss = 5.928696155548096
time = 0.0163389523824056, epoch = 1, iter = 1609, loss = 5.843334674835205
time = 0.01607765754063924, epoch = 1, iter = 1610, loss = 6.020759105682373
time = 0.016083137194315592, epoch = 1, iter = 1611, loss = 5.704959869384766
time = 0.01648809512456258, epoch = 1, iter = 1612, loss = 5.703071117401123
time = 0.01607590119043986, epoch = 1, iter = 1613, loss = 5.280885219573975
time = 0.01633050839106242, epoch = 1, iter = 1614, loss = 6.78375816345214

time = 0.015839536984761555, epoch = 1, iter = 1708, loss = 6.698559284210205
time = 0.016087337334950765, epoch = 1, iter = 1709, loss = 7.0549492835998535
time = 0.016203991572062173, epoch = 1, iter = 1710, loss = 5.801385402679443
time = 0.016074140866597492, epoch = 1, iter = 1711, loss = 5.208427429199219
time = 0.016577486197153726, epoch = 1, iter = 1712, loss = 6.361332416534424
time = 0.016090905666351317, epoch = 1, iter = 1713, loss = 5.8738250732421875
time = 0.01622328758239746, epoch = 1, iter = 1714, loss = 6.75409460067749
time = 0.01608246167500814, epoch = 1, iter = 1715, loss = 6.0345563888549805
time = 0.016333397229512533, epoch = 1, iter = 1716, loss = 6.558958053588867
time = 0.01633443832397461, epoch = 1, iter = 1717, loss = 7.131885051727295
time = 0.016275636355082192, epoch = 1, iter = 1718, loss = 6.591344833374023
time = 0.015810823440551756, epoch = 1, iter = 1719, loss = 6.886975288391113
time = 0.015818838278452554, epoch = 1, iter = 1720, loss = 5.581

time = 0.01685742139816284, epoch = 1, iter = 1814, loss = 6.581652641296387
time = 0.01636651357014974, epoch = 1, iter = 1815, loss = 6.172017574310303
time = 0.016277488072713217, epoch = 1, iter = 1816, loss = 6.104529857635498
time = 0.01610183318456014, epoch = 1, iter = 1817, loss = 5.680546760559082
time = 0.016840132077534993, epoch = 1, iter = 1818, loss = 7.272225856781006
time = 0.01659482717514038, epoch = 1, iter = 1819, loss = 5.646300792694092
time = 0.01606004238128662, epoch = 1, iter = 1820, loss = 5.075171947479248
time = 0.01632365385691325, epoch = 1, iter = 1821, loss = 5.343080043792725
time = 0.01599882443745931, epoch = 1, iter = 1822, loss = 5.2662811279296875
time = 0.01607016324996948, epoch = 1, iter = 1823, loss = 5.8146443367004395
time = 0.016592363516489666, epoch = 1, iter = 1824, loss = 6.881935119628906
time = 0.01609810988108317, epoch = 1, iter = 1825, loss = 5.394584655761719
time = 0.016326630115509035, epoch = 1, iter = 1826, loss = 7.730834960

time = 0.016530911127726238, epoch = 1, iter = 1920, loss = 7.412380695343018
time = 0.01668868859608968, epoch = 1, iter = 1921, loss = 6.473843574523926
time = 0.016630903879801432, epoch = 1, iter = 1922, loss = 6.671797275543213
time = 0.01611518065134684, epoch = 1, iter = 1923, loss = 4.862601280212402
time = 0.01710484822591146, epoch = 1, iter = 1924, loss = 7.014828205108643
time = 0.016331223646799724, epoch = 1, iter = 1925, loss = 7.107028484344482
time = 0.016805346806844076, epoch = 1, iter = 1926, loss = 6.701931953430176
time = 0.01714940071105957, epoch = 1, iter = 1927, loss = 7.085399627685547
time = 0.016122682889302572, epoch = 1, iter = 1928, loss = 5.890432357788086
time = 0.01624565919240316, epoch = 1, iter = 1929, loss = 7.479166507720947
time = 0.01634820302327474, epoch = 1, iter = 1930, loss = 6.173525333404541
time = 0.01634134848912557, epoch = 1, iter = 1931, loss = 5.976771354675293
time = 0.016577557722727457, epoch = 1, iter = 1932, loss = 6.144996166

time = 0.016602869828542074, epoch = 1, iter = 2026, loss = 6.22150182723999
time = 0.016710742314656576, epoch = 1, iter = 2027, loss = 6.688034534454346
time = 0.016317311922709146, epoch = 1, iter = 2028, loss = 5.642022132873535
time = 0.016136868794759115, epoch = 1, iter = 2029, loss = 7.172186374664307
time = 0.016607511043548583, epoch = 1, iter = 2030, loss = 5.68389892578125
time = 0.015984527269999185, epoch = 1, iter = 2031, loss = 6.208333969116211
time = 0.01633521318435669, epoch = 1, iter = 2032, loss = 7.284924030303955
time = 0.016614445050557456, epoch = 1, iter = 2033, loss = 6.634978294372559
time = 0.016375251611073813, epoch = 1, iter = 2034, loss = 6.692601680755615
time = 0.01626344919204712, epoch = 1, iter = 2035, loss = 6.360152244567871
time = 0.016567687193552654, epoch = 1, iter = 2036, loss = 5.99299430847168
time = 0.016631941000620525, epoch = 1, iter = 2037, loss = 6.770416259765625
time = 0.01661480267842611, epoch = 1, iter = 2038, loss = 6.37726545

time = 0.016795289516448975, epoch = 1, iter = 2132, loss = 6.588396072387695
time = 0.016689149538675944, epoch = 1, iter = 2133, loss = 6.832013130187988
time = 0.017126703262329103, epoch = 1, iter = 2134, loss = 5.581646919250488
time = 0.016376280784606935, epoch = 1, iter = 2135, loss = 6.329041957855225
time = 0.01710265874862671, epoch = 1, iter = 2136, loss = 7.538498401641846
time = 0.016581114133199057, epoch = 1, iter = 2137, loss = 5.481276035308838
time = 0.016775735219319663, epoch = 1, iter = 2138, loss = 6.337430000305176
time = 0.01657569408416748, epoch = 1, iter = 2139, loss = 5.477395534515381
time = 0.01635056734085083, epoch = 1, iter = 2140, loss = 7.1345534324646
time = 0.01661368211110433, epoch = 1, iter = 2141, loss = 6.207897663116455
time = 0.0165718674659729, epoch = 1, iter = 2142, loss = 4.781488418579102
time = 0.016896132628122965, epoch = 1, iter = 2143, loss = 5.873712062835693
time = 0.01651674509048462, epoch = 1, iter = 2144, loss = 6.46436929702

time = 0.016926912466684978, epoch = 1, iter = 2238, loss = 5.678627014160156
time = 0.01660709778467814, epoch = 1, iter = 2239, loss = 6.701933860778809
time = 0.0166214386622111, epoch = 1, iter = 2240, loss = 6.798056602478027
time = 0.016394007205963134, epoch = 1, iter = 2241, loss = 6.416085720062256
time = 0.017119077841440837, epoch = 1, iter = 2242, loss = 5.668754577636719
time = 0.016683574517567953, epoch = 1, iter = 2243, loss = 6.249934673309326
time = 0.01680541435877482, epoch = 1, iter = 2244, loss = 5.6748809814453125
time = 0.017201101779937743, epoch = 1, iter = 2245, loss = 5.207412242889404
time = 0.017296862602233887, epoch = 1, iter = 2246, loss = 5.823445796966553
time = 0.01713098684946696, epoch = 1, iter = 2247, loss = 7.4591851234436035
time = 0.016908395290374755, epoch = 1, iter = 2248, loss = 6.74569845199585
time = 0.0165921688079834, epoch = 1, iter = 2249, loss = 6.229964733123779
time = 0.016618279616038005, epoch = 1, iter = 2250, loss = 5.66660213

time = 0.01685102383295695, epoch = 1, iter = 2344, loss = 7.416339874267578
time = 0.0169046680132548, epoch = 1, iter = 2345, loss = 6.514337539672852
time = 0.016579699516296387, epoch = 1, iter = 2346, loss = 5.551295757293701
time = 0.017571051915486652, epoch = 1, iter = 2347, loss = 5.904786586761475
time = 0.0185272216796875, epoch = 1, iter = 2348, loss = 8.340009689331055
time = 0.016853936513264976, epoch = 1, iter = 2349, loss = 5.163972854614258
time = 0.016845905780792238, epoch = 1, iter = 2350, loss = 6.059669017791748
time = 0.01690833568572998, epoch = 1, iter = 2351, loss = 5.521485805511475
time = 0.017150533199310303, epoch = 1, iter = 2352, loss = 5.470271110534668
time = 0.017108007272084554, epoch = 1, iter = 2353, loss = 6.666635990142822
time = 0.01638219753901164, epoch = 1, iter = 2354, loss = 5.7173590660095215
time = 0.01685646375020345, epoch = 1, iter = 2355, loss = 5.936824321746826
time = 0.017068056265513103, epoch = 1, iter = 2356, loss = 5.737821102

time = 0.01691158612569173, epoch = 1, iter = 2450, loss = 6.182827472686768
time = 0.01662473678588867, epoch = 1, iter = 2451, loss = 5.722742557525635
time = 0.01678064266840617, epoch = 1, iter = 2452, loss = 6.646022319793701
time = 0.01693191130956014, epoch = 1, iter = 2453, loss = 5.6137542724609375
time = 0.016843104362487794, epoch = 1, iter = 2454, loss = 5.217098712921143
time = 0.016803228855133058, epoch = 1, iter = 2455, loss = 6.3598313331604
time = 0.01665197213490804, epoch = 1, iter = 2456, loss = 6.544353008270264
time = 0.016863505045572918, epoch = 1, iter = 2457, loss = 5.3011064529418945
time = 0.01688900391260783, epoch = 1, iter = 2458, loss = 6.178002834320068
time = 0.016891165574391683, epoch = 1, iter = 2459, loss = 6.83354377746582
time = 0.016920944054921467, epoch = 1, iter = 2460, loss = 5.93489408493042
time = 0.016593726476033528, epoch = 1, iter = 2461, loss = 5.210522174835205
time = 0.016866334279378257, epoch = 1, iter = 2462, loss = 5.0762715339

time = 0.016636327902475993, epoch = 1, iter = 2556, loss = 6.253343105316162
time = 0.016646472613016765, epoch = 1, iter = 2557, loss = 5.549475193023682
time = 0.017097500960032146, epoch = 1, iter = 2558, loss = 5.240899085998535
time = 0.016904815038045248, epoch = 1, iter = 2559, loss = 6.155246257781982
time = 0.016597092151641846, epoch = 1, iter = 2560, loss = 5.721031188964844
time = 0.016884334882100425, epoch = 1, iter = 2561, loss = 7.856756687164307
time = 0.016886377334594728, epoch = 1, iter = 2562, loss = 7.078680515289307
time = 0.01737910509109497, epoch = 1, iter = 2563, loss = 5.85453987121582
time = 0.017116649945576986, epoch = 1, iter = 2564, loss = 5.925579071044922
time = 0.017133756478627523, epoch = 1, iter = 2565, loss = 5.837743759155273
time = 0.017131455739339194, epoch = 1, iter = 2566, loss = 8.427056312561035
time = 0.017042251427968343, epoch = 1, iter = 2567, loss = 7.425813674926758
time = 0.01686776081720988, epoch = 1, iter = 2568, loss = 6.12330

time = 0.017155889670054117, epoch = 1, iter = 2662, loss = 6.234306812286377
time = 0.01666329304377238, epoch = 1, iter = 2663, loss = 5.411689758300781
time = 0.016808112462361652, epoch = 1, iter = 2664, loss = 6.505671977996826
time = 0.01717069149017334, epoch = 1, iter = 2665, loss = 4.643852233886719
time = 0.017174283663431805, epoch = 1, iter = 2666, loss = 8.015874862670898
time = 0.01731397310892741, epoch = 1, iter = 2667, loss = 5.583611965179443
time = 0.017427881558736164, epoch = 1, iter = 2668, loss = 5.726089000701904
time = 0.016625046730041504, epoch = 1, iter = 2669, loss = 6.509227752685547
time = 0.01704560915629069, epoch = 1, iter = 2670, loss = 5.981228351593018
time = 0.01719288428624471, epoch = 1, iter = 2671, loss = 5.200624465942383
time = 0.017119014263153078, epoch = 1, iter = 2672, loss = 6.052367687225342
time = 0.016611178716023762, epoch = 1, iter = 2673, loss = 5.294187545776367
time = 0.016879900296529134, epoch = 1, iter = 2674, loss = 4.4682884

time = 0.016360286871592203, epoch = 1, iter = 2768, loss = 5.384061813354492
time = 0.016698797543843586, epoch = 1, iter = 2769, loss = 5.306075572967529
time = 0.016857012112935384, epoch = 1, iter = 2770, loss = 5.341803550720215
time = 0.01687831481297811, epoch = 1, iter = 2771, loss = 6.39037561416626
time = 0.01656860907872518, epoch = 1, iter = 2772, loss = 5.902225494384766
time = 0.016505873203277587, epoch = 1, iter = 2773, loss = 7.320228099822998
time = 0.016640313466389976, epoch = 1, iter = 2774, loss = 6.51936149597168
time = 0.016893037160237632, epoch = 1, iter = 2775, loss = 5.813525199890137
time = 0.016665824254353843, epoch = 1, iter = 2776, loss = 6.150733470916748
time = 0.016585310300191242, epoch = 1, iter = 2777, loss = 5.532759189605713
time = 0.016571541627248127, epoch = 1, iter = 2778, loss = 6.9021220207214355
time = 0.016472649574279786, epoch = 1, iter = 2779, loss = 6.051064968109131
time = 0.016838630040486652, epoch = 1, iter = 2780, loss = 6.14287

time = 0.016927282015482586, epoch = 1, iter = 2874, loss = 6.247829437255859
time = 0.016710551579793294, epoch = 1, iter = 2875, loss = 7.310619831085205
time = 0.01659337282180786, epoch = 1, iter = 2876, loss = 6.164979457855225
time = 0.01635837952295939, epoch = 1, iter = 2877, loss = 4.226750373840332
time = 0.016970574855804443, epoch = 1, iter = 2878, loss = 5.32991361618042
time = 0.016590829690297446, epoch = 1, iter = 2879, loss = 6.251145362854004
time = 0.016643182436625163, epoch = 1, iter = 2880, loss = 5.895031452178955
time = 0.016730920473734538, epoch = 1, iter = 2881, loss = 6.149298667907715
time = 0.016636598110198974, epoch = 1, iter = 2882, loss = 6.278608798980713
time = 0.0166152556737264, epoch = 1, iter = 2883, loss = 4.756671905517578
time = 0.01659911870956421, epoch = 1, iter = 2884, loss = 5.176300525665283
time = 0.016570969422658285, epoch = 1, iter = 2885, loss = 5.341194152832031
time = 0.016345874468485514, epoch = 1, iter = 2886, loss = 5.69591283

time = 0.01688077449798584, epoch = 1, iter = 2980, loss = 6.418293476104736
time = 0.016840040683746338, epoch = 1, iter = 2981, loss = 6.241025447845459
time = 0.01633330583572388, epoch = 1, iter = 2982, loss = 5.495951175689697
time = 0.01633478800455729, epoch = 1, iter = 2983, loss = 5.4083781242370605
time = 0.016593488057454427, epoch = 1, iter = 2984, loss = 5.530014991760254
time = 0.0168430765469869, epoch = 1, iter = 2985, loss = 5.797458171844482
time = 0.01684940258661906, epoch = 1, iter = 2986, loss = 6.331662654876709
time = 0.01676829655965169, epoch = 1, iter = 2987, loss = 5.435097694396973
time = 0.01635171175003052, epoch = 1, iter = 2988, loss = 6.375453472137451
time = 0.01688779592514038, epoch = 1, iter = 2989, loss = 6.859935760498047
time = 0.017102356751759848, epoch = 1, iter = 2990, loss = 7.020007610321045
time = 0.01714797814687093, epoch = 1, iter = 2991, loss = 4.74886417388916
time = 0.016902486483256023, epoch = 1, iter = 2992, loss = 5.611989974975

time = 0.01651770273844401, epoch = 1, iter = 3086, loss = 6.434435844421387
time = 0.016403273741404215, epoch = 1, iter = 3087, loss = 7.020534515380859
time = 0.016548657417297365, epoch = 1, iter = 3088, loss = 5.789434432983398
time = 0.016864633560180663, epoch = 1, iter = 3089, loss = 4.648015022277832
time = 0.01658575137456258, epoch = 1, iter = 3090, loss = 5.701669216156006
time = 0.01662361224492391, epoch = 1, iter = 3091, loss = 6.431480407714844
time = 0.016757456461588542, epoch = 1, iter = 3092, loss = 6.901547908782959
time = 0.016591743628184, epoch = 1, iter = 3093, loss = 6.873919486999512
time = 0.016857147216796875, epoch = 1, iter = 3094, loss = 5.614671230316162
time = 0.017060903708140056, epoch = 1, iter = 3095, loss = 5.630941867828369
time = 0.01686462163925171, epoch = 1, iter = 3096, loss = 6.288124084472656
time = 0.0163520614306132, epoch = 1, iter = 3097, loss = 5.813190937042236
time = 0.01661605437596639, epoch = 1, iter = 3098, loss = 4.911685466766

time = 0.016365679105122884, epoch = 1, iter = 3192, loss = 5.723267555236816
time = 0.01666028102238973, epoch = 1, iter = 3193, loss = 6.30318546295166
time = 0.016609299182891845, epoch = 1, iter = 3194, loss = 6.415153503417969
time = 0.016601022084554037, epoch = 1, iter = 3195, loss = 5.8011908531188965
time = 0.017583290735880535, epoch = 1, iter = 3196, loss = 7.176741600036621
time = 0.01701959768931071, epoch = 1, iter = 3197, loss = 6.340893745422363
time = 0.017142446835835774, epoch = 1, iter = 3198, loss = 5.657044410705566
time = 0.01691639026006063, epoch = 1, iter = 3199, loss = 5.991219520568848
time = 0.01755577325820923, epoch = 1, iter = 3200, loss = 5.132474899291992
time = 0.01662025849024455, epoch = 1, iter = 3201, loss = 5.789314270019531
time = 0.016891713937123617, epoch = 1, iter = 3202, loss = 6.13978385925293
time = 0.01662178834279378, epoch = 1, iter = 3203, loss = 6.261290550231934
time = 0.017093420028686523, epoch = 1, iter = 3204, loss = 5.533209800

time = 0.01670209566752116, epoch = 1, iter = 3298, loss = 5.495948791503906
time = 0.01665855646133423, epoch = 1, iter = 3299, loss = 4.947539329528809
time = 0.01677188475926717, epoch = 1, iter = 3300, loss = 6.396264553070068
time = 0.01689092715581258, epoch = 1, iter = 3301, loss = 7.497420787811279
time = 0.016614739100138345, epoch = 1, iter = 3302, loss = 5.8433637619018555
time = 0.016875672340393066, epoch = 1, iter = 3303, loss = 6.630441188812256
time = 0.017102277278900145, epoch = 1, iter = 3304, loss = 5.3698930740356445
time = 0.01664073864618937, epoch = 1, iter = 3305, loss = 6.597094535827637
time = 0.016869076093037925, epoch = 1, iter = 3306, loss = 5.85808801651001
time = 0.017164901892344157, epoch = 1, iter = 3307, loss = 7.159008502960205
time = 0.01689035495122274, epoch = 1, iter = 3308, loss = 7.285573482513428
time = 0.017131789525349935, epoch = 1, iter = 3309, loss = 6.49959659576416
time = 0.016791756947835287, epoch = 1, iter = 3310, loss = 5.23353385

time = 0.017071247100830078, epoch = 1, iter = 3404, loss = 6.611480712890625
time = 0.017729910214742024, epoch = 1, iter = 3405, loss = 6.208871841430664
time = 0.017163288593292237, epoch = 1, iter = 3406, loss = 5.815808296203613
time = 0.017334612210591634, epoch = 1, iter = 3407, loss = 6.984644412994385
time = 0.016899089018503826, epoch = 1, iter = 3408, loss = 5.485477924346924
time = 0.017678995927174885, epoch = 1, iter = 3409, loss = 6.037585258483887
time = 0.017348261674245198, epoch = 1, iter = 3410, loss = 5.94259786605835
time = 0.017158230145772297, epoch = 1, iter = 3411, loss = 7.067567825317383
time = 0.017391916116078696, epoch = 1, iter = 3412, loss = 7.208606719970703
time = 0.017354051272074383, epoch = 1, iter = 3413, loss = 6.450766563415527
time = 0.017426502704620362, epoch = 1, iter = 3414, loss = 6.029238700866699
time = 0.01766820748647054, epoch = 1, iter = 3415, loss = 6.048384189605713
time = 0.01984673341115316, epoch = 1, iter = 3416, loss = 5.77599

time = 0.017423780759175618, epoch = 1, iter = 3510, loss = 5.556254863739014
time = 0.017163499196370443, epoch = 1, iter = 3511, loss = 6.59650182723999
time = 0.017639350891113282, epoch = 1, iter = 3512, loss = 7.220539569854736
time = 0.017659127712249756, epoch = 1, iter = 3513, loss = 7.452730178833008
time = 0.01729419231414795, epoch = 1, iter = 3514, loss = 5.736181735992432
time = 0.017442703247070312, epoch = 1, iter = 3515, loss = 6.343694686889648
time = 0.017204443613688152, epoch = 1, iter = 3516, loss = 6.327396392822266
time = 0.018105522791544596, epoch = 1, iter = 3517, loss = 6.818541049957275
time = 0.017139307657877603, epoch = 1, iter = 3518, loss = 5.75188684463501
time = 0.017194743951161703, epoch = 1, iter = 3519, loss = 7.065069675445557
time = 0.017592350641886394, epoch = 1, iter = 3520, loss = 5.0008368492126465
time = 0.01718087593714396, epoch = 1, iter = 3521, loss = 6.620203971862793
time = 0.01740415096282959, epoch = 1, iter = 3522, loss = 6.505846

time = 0.017392675081888836, epoch = 1, iter = 3616, loss = 6.1977314949035645
time = 0.017165203889211018, epoch = 1, iter = 3617, loss = 5.04533052444458
time = 0.017412964502970377, epoch = 1, iter = 3618, loss = 5.977675437927246
time = 0.017347502708435058, epoch = 1, iter = 3619, loss = 5.380112171173096
time = 0.017682822545369466, epoch = 1, iter = 3620, loss = 6.043785572052002
time = 0.01739433209101359, epoch = 1, iter = 3621, loss = 5.202593803405762
time = 0.01704622507095337, epoch = 1, iter = 3622, loss = 5.181764125823975
time = 0.01744681199391683, epoch = 1, iter = 3623, loss = 5.394243240356445
time = 0.017962396144866943, epoch = 1, iter = 3624, loss = 6.689873695373535
time = 0.01738885243733724, epoch = 1, iter = 3625, loss = 6.527567386627197
time = 0.017394403616587322, epoch = 1, iter = 3626, loss = 5.030267715454102
time = 0.017404663562774658, epoch = 1, iter = 3627, loss = 4.501185894012451
time = 0.017711917559305828, epoch = 1, iter = 3628, loss = 5.541145

time = 0.019540786743164062, epoch = 1, iter = 3722, loss = 6.125877857208252
time = 0.018256521224975585, epoch = 1, iter = 3723, loss = 6.465667247772217
time = 0.017397181193033854, epoch = 1, iter = 3724, loss = 5.7624616622924805
time = 0.01794694662094116, epoch = 1, iter = 3725, loss = 6.260304927825928
time = 0.017326295375823975, epoch = 1, iter = 3726, loss = 7.118714809417725
time = 0.017923760414123534, epoch = 1, iter = 3727, loss = 5.532364368438721
time = 0.017920068899790444, epoch = 1, iter = 3728, loss = 5.516927242279053
time = 0.017340493202209473, epoch = 1, iter = 3729, loss = 6.022172451019287
time = 0.017966564496358237, epoch = 1, iter = 3730, loss = 6.108994483947754
time = 0.017413302262624105, epoch = 1, iter = 3731, loss = 6.992203235626221
time = 0.017863746484120688, epoch = 1, iter = 3732, loss = 6.663990020751953
time = 0.017896461486816406, epoch = 1, iter = 3733, loss = 6.340667724609375
time = 0.017897661526997885, epoch = 1, iter = 3734, loss = 7.88

time = 0.017994999885559082, epoch = 1, iter = 3828, loss = 7.852565288543701
time = 0.017639307181040446, epoch = 1, iter = 3829, loss = 6.51274299621582
time = 0.017918145656585692, epoch = 1, iter = 3830, loss = 7.572894096374512
time = 0.018161861101786296, epoch = 1, iter = 3831, loss = 5.251627445220947
time = 0.017875981330871583, epoch = 1, iter = 3832, loss = 5.853228569030762
time = 0.017709954579671224, epoch = 1, iter = 3833, loss = 5.271236896514893
time = 0.01770333449045817, epoch = 1, iter = 3834, loss = 5.390597343444824
time = 0.018104310830434164, epoch = 1, iter = 3835, loss = 7.429394245147705
time = 0.017704566319783527, epoch = 1, iter = 3836, loss = 7.374758243560791
time = 0.018439829349517822, epoch = 1, iter = 3837, loss = 6.714697360992432
time = 0.017694687843322753, epoch = 1, iter = 3838, loss = 6.491124629974365
time = 0.0179482897122701, epoch = 1, iter = 3839, loss = 5.55996036529541
time = 0.017619967460632324, epoch = 1, iter = 3840, loss = 5.0814638

time = 0.018229846159617105, epoch = 1, iter = 3934, loss = 5.171271324157715
time = 0.01773087978363037, epoch = 1, iter = 3935, loss = 7.748638153076172
time = 0.01783454418182373, epoch = 1, iter = 3936, loss = 6.59873628616333
time = 0.018166104952494305, epoch = 1, iter = 3937, loss = 6.829709053039551
time = 0.017700032393137614, epoch = 1, iter = 3938, loss = 4.5057053565979
time = 0.018075899283091227, epoch = 1, iter = 3939, loss = 5.9138569831848145
time = 0.017754840850830077, epoch = 1, iter = 3940, loss = 6.44536018371582
time = 0.0181752880414327, epoch = 1, iter = 3941, loss = 5.432521343231201
time = 0.018106762568155924, epoch = 1, iter = 3942, loss = 6.073098182678223
time = 0.017907615502675375, epoch = 1, iter = 3943, loss = 6.361081123352051
time = 0.01789497137069702, epoch = 1, iter = 3944, loss = 6.267362117767334
time = 0.018133715788523356, epoch = 1, iter = 3945, loss = 5.781929016113281
time = 0.017968432108561198, epoch = 1, iter = 3946, loss = 6.5528483390

time = 0.018216880162556966, epoch = 1, iter = 4040, loss = 5.879962921142578
time = 0.017965126037597656, epoch = 1, iter = 4041, loss = 6.761811256408691
time = 0.018205785751342775, epoch = 1, iter = 4042, loss = 4.909966945648193
time = 0.018093395233154296, epoch = 1, iter = 4043, loss = 6.80560302734375
time = 0.01765897274017334, epoch = 1, iter = 4044, loss = 5.163455009460449
time = 0.01792608102162679, epoch = 1, iter = 4045, loss = 6.3885345458984375
time = 0.01835795243581136, epoch = 1, iter = 4046, loss = 8.101806640625
time = 0.01819608211517334, epoch = 1, iter = 4047, loss = 6.678640365600586
time = 0.01842950185139974, epoch = 1, iter = 4048, loss = 5.888946533203125
time = 0.017902863025665284, epoch = 1, iter = 4049, loss = 6.641042232513428
time = 0.018268016974131267, epoch = 1, iter = 4050, loss = 5.31168270111084
time = 0.018159870306650797, epoch = 1, iter = 4051, loss = 6.913180351257324
time = 0.01815632184346517, epoch = 1, iter = 4052, loss = 6.456026554107

time = 0.019922180970509847, epoch = 1, iter = 4146, loss = 6.323391437530518
time = 0.020954608917236328, epoch = 1, iter = 4147, loss = 5.260069847106934
time = 0.018714559078216553, epoch = 1, iter = 4148, loss = 6.190357208251953
time = 0.018378170331319173, epoch = 1, iter = 4149, loss = 5.725307464599609
time = 0.01788225571314494, epoch = 1, iter = 4150, loss = 7.157690525054932
time = 0.01794639031092326, epoch = 1, iter = 4151, loss = 6.3947296142578125
time = 0.018736358483632407, epoch = 1, iter = 4152, loss = 5.629854679107666
time = 0.017957464853922526, epoch = 1, iter = 4153, loss = 5.401864528656006
time = 0.018422508239746095, epoch = 1, iter = 4154, loss = 6.431536674499512
time = 0.01762666702270508, epoch = 1, iter = 4155, loss = 5.739795207977295
time = 0.01818012793858846, epoch = 1, iter = 4156, loss = 5.713539123535156
time = 0.018160812060038247, epoch = 1, iter = 4157, loss = 6.129301071166992
time = 0.018560985724131267, epoch = 1, iter = 4158, loss = 5.56935

In [40]:
PATH = "Trained_model_NEW"
torch.save(model, PATH)

In [24]:
model = torch.load("Trained_model_NEW")

In [52]:
def translate(model, src, max_len = 80, custom_string=False):    
    model.eval()
    if custom_string == True:
            src = tokenizer(src)
            src = torch.LongTensor([[SRC.vocab.stoi[tok] for tok in src]])
    
#     pad = SRC.vocab.stoi['<pad>']
#     src_mask = (src != pad).unsqueeze(1)
#     e_outputs = model.encoder(src, src_mask)
    
    outputs = torch.zeros(max_len).type_as(src.data)
    outputs[0] = torch.LongTensor([TRG.vocab.stoi['<sos>']])
    for i in range(1, max_len):    

#             trg_mask = np.tril(np.ones((1, i, i)), k = 0).astype('uint8')
#             trg_mask = torch.from_numpy(trg_mask) != 0
#             print(outputs[:i], e_outputs)

            src_mask, trg_mask = create_mask(src, outputs[:i].unsqueeze(0))
            print(outputs[:i].unsqueeze(0))
            out = model(src, outputs[:i].unsqueeze(0), src_mask, trg_mask)
            out = F.softmax(out, dim=-1)
            val, ix = out[:, -1].data.topk(1)

            outputs[i] = ix[0][0]
            if ix[0][0] == TRG.vocab.stoi['<eos>']:
                break
    return ' '.join([TRG.vocab.itos[ix] for ix in outputs[:i]])
                               
translate(model, "A carbon footprint is the amount of carbon", custom_string = True)

tensor([[2]])
tensor([[ 2, 16]])
tensor([[ 2, 16, 16]])
tensor([[ 2, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16]])
tensor([[ 2, 16, 16, 16

tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16, 16, 16]])
tensor([[ 2, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
         16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 

'<sos> tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom tom'