In [1]:
device = 'cuda'

# model parameter setting
batch_size = 128
max_len = 256
d_model = 512
n_layers = 6
n_heads = 8
ffn_hidden = 512
drop_prob = 0.1

# optimizer parameter setting
init_lr = 1e-5
factor = 0.9
adam_eps = 5e-9
patience = 10
warmup = 100
epoch = 1000
clip = 1.0
weight_decay = 5e-4
inf = float('inf')

import torch
from torch import nn, optim
from torch.optim import Adam
import torch.utils.data as data
import math
from collections import Counter
import collections
import numpy as np
import copy
import time
import spacy
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
import time

start_time = time.time()

class DataLoader:
    source: Field = None
    target: Field = None

    def __init__(self, ext, tokenize_en, tokenize_vi, init_token, eos_token):
        self.ext = ext
        self.tokenize_en = tokenize_en
        self.tokenize_vi = tokenize_vi
        self.init_token = init_token
        self.eos_token = eos_token
        print('dataset initializing start')

    def make_dataset(self):
        if self.ext == ('.vi', '.en'):
            self.source = Field(tokenize=self.tokenize_vi, init_token=self.init_token, eos_token=self.eos_token,
                                lower=True, batch_first=True)
            self.target = Field(tokenize=self.tokenize_en, init_token=self.init_token, eos_token=self.eos_token,
                                lower=True, batch_first=True)

        elif self.ext == ('.en', '.vi'):
            self.source = Field(tokenize=self.tokenize_en, init_token=self.init_token, eos_token=self.eos_token,
                                lower=True, batch_first=True)
            self.target = Field(tokenize=self.tokenize_vi, init_token=self.init_token, eos_token=self.eos_token,
                                lower=True, batch_first=True)

        train_data, valid_data, test_data = Multi30k.splits(exts=self.ext, fields=(self.source, self.target))
        return train_data, valid_data, test_data

    def build_vocab(self, train_data, min_freq):
        self.source.build_vocab(train_data, min_freq=min_freq)
        self.target.build_vocab(train_data, min_freq=min_freq)

    def make_iter(self, train, validate, test, batch_size):
        train_iterator, valid_iterator, test_iterator = BucketIterator.splits((train, validate, test),
                                                                              batch_size=batch_size)
        print('dataset initializing done')
        return train_iterator, valid_iterator, test_iterator

from pyvi import ViTokenizer
import spacy
from spacy.language import Language
from spacy.tokens import Doc

# Custom tokenizer function
def custom_tokenizer(nlp, text):
    words = ViTokenizer.tokenize(text).split()
    return Doc(nlp.vocab, words=words)

class Tokenizer:
    def __init__(self):
        self.spacy_vi = spacy.blank("en")
        self.spacy_vi.tokenizer = custom_tokenizer.__get__(self.spacy_vi)
        
        self.spacy_en = spacy.load('en_core_web_sm')

    def tokenize_vi(self, text):
        return [tok.text for tok in self.spacy_vi.tokenizer(text)]

    def tokenize_en(self, text):
        return [tok.text for tok in self.spacy_en.tokenizer(text)]


# In[4]:

tokenizer = Tokenizer()
loader = DataLoader(ext=('.en', '.vi'),
                    tokenize_en=tokenizer.tokenize_en,
                    tokenize_vi=tokenizer.tokenize_vi,
                    init_token='<sos>',
                    eos_token='<eos>')

train, valid, test = loader.make_dataset()
loader.build_vocab(train_data=train, min_freq=1)
train_iter, valid_iter, test_iter = loader.make_iter(train, valid, test,
                                                     batch_size=batch_size)

src_pad_idx = loader.source.vocab.stoi['<pad>']
trg_pad_idx = loader.target.vocab.stoi['<pad>']
trg_sos_idx = loader.target.vocab.stoi['<sos>']

enc_voc_size = len(loader.source.vocab)
dec_voc_size = len(loader.target.vocab)

device = 'cuda'

dataset initializing start
dataset initializing done


In [2]:
print("Target Vocabulary:")
for idx, token in enumerate(loader.source.vocab.itos):
    print(f"{idx}: {token}")

print(f"Target Vocabulary Size: {len(loader.source.vocab)}")

Target Vocabulary:
0: <unk>
1: <pad>
2: <sos>
3: <eos>
4: .
5: the
6: ,
7: to
8: of
9: &
10: and
11: we
12: you
13: in
14: that
15: a
16: this
17: it
18: apos;s
19: brain
20: is
21: can
22: look
23: on
24: do
25: have
26: they
27: apos;re
28: as
29: i
30: one
31: pain
32: --
33: at
34: but
35: into
36: like
37: or
38: your
39: all
40: are
41: for
42: their
43: there
44: what
45: atmospheric
46: be
47: control
48: inside
49: when
50: fly
51: going
52: headlines
53: molecule
54: our
55: scientific
56: see
57: these
58: time
59: will
60: with
61: -
62: ;
63: able
64: apos;t
65: apos;ve
66: areas
67: climate
68: from
69: his
70: hundreds
71: make
72: peter
73: real
74: scientists
75: so
76: technology
77: which
78: about
79: activation
80: an
81: by
82: each
83: every
84: few
85: field
86: form
87: get
88: he
89: if
90: mind
91: molecules
92: now
93: order
94: own
95: people
96: research
97: soon
98: study
99: system
100: take
101: thousands
102: up
103: weight
104: :
105: aircraft
106: ar

In [3]:
print("Target Vocabulary:")
for idx, token in enumerate(loader.target.vocab.itos):
    print(f"{idx}: {token}")

print(f"Target Vocabulary Size: {len(loader.target.vocab)}")

Target Vocabulary:
0: <unk>
1: <pad>
2: <sos>
3: <eos>
4: .
5: ,
6: bạn
7: của
8: một
9: trong
10: là
11: và
12: -
13: những
14: này
15: não
16: chúng_tôi
17: các
18: có_thể
19: để
20: được
21: có
22: như
23: thấy
24: vào
25: đó
26: về
27: khi
28: nghiên_cứu
29: tôi
30: chúng_ta
31: sẽ
32: bộ
33: cho
34: khoa_học
35: nhìn
36: ra
37: ta
38: đây
39: đã
40: hàng
41: phân_tử
42: với
43: bay
44: họ
45: trên
46: từ
47: đau
48: đến
49: ở
50: không
51: làm
52: thế
53: điều_khiển
54: bên
55: nhà
56: điều
57: chính
58: cách
59: cơn
60: hay
61: khác
62: mà
63: nhưng
64: nó
65: phải
66: thời_gian
67: anh
68: khí_hậu
69: khí_quyển
70: mình
71: nhiều
72: nhỏ
73: phép
74: rất
75: ;
76: chiếc
77: còn
78: cả
79: gì
80: lại
81: lớn
82: năm
83: peter
84: theo
85: thực
86: trăm
87: viết
88: đề_tài
89: đều
90: &
91: bài
92: bản
93: chỉ
94: cái
95: cánh_tay
96: công_nghệ
97: cùng
98: cần
99: hơn
100: hệ_thống
101: isoprene
102: lượng
103: máy_bay
104: mỗi
105: ngàn
106: người
107: nói
108: nếu
109: thực_hiệ

In [4]:
class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim = -1)

    def forward(self, q, k ,v, mask = None, e = 1e-12):
        batch_size, head, length, d_tensor = k.size()
        k_t = k.transpose(2,3)
        score = (q @ k_t) / math.sqrt(d_tensor)

        if mask is not None:
            score = score.masked_fill(mask == 0, -100000)

        score = self.softmax(score)
        v = score @ v
        return v, score

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head # Number of attention heads
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model) # Query transformation
        self.w_k = nn.Linear(d_model, d_model) # Key
        self.w_v = nn.Linear(d_model, d_model) # Value
        self.w_concat = nn.Linear(d_model, d_model)

    def split(self, tensor):
        batch_size, length, d_model = tensor.size()
        d_tensor = d_model // self.n_head
        tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1,2).to(device)
        return tensor

    def concat(self, tensor):
        batch_size, head, length, d_tensor = tensor.size()
        d_model = head * d_tensor
        tensor = tensor.transpose(1,2).contiguous().view(batch_size, length, d_model).to(device)
        return tensor

    def forward(self, q, k, v, mask = None):
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q, k, v = self.split(q), self.split(k), self.split(v)

        out, attention = self.attention(q, k, v, mask = mask)

        out = self.concat(out)
        out = self.w_concat(out)
        return out


# In[7]:


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p = drop_prob)

    def forward(self, x):

        return self.relu(self.linear2(self.dropout(self.relu(self.linear1(x)))))
        #return self.linear2(self.dropout(self.relu(self.linear1(x))))


# In[8]:


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super(PositionalEncoding, self).__init__()

        self.encoding = torch.zeros(max_len, d_model).to(device)
        self.encoding.requires_grad = False
        pos = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1).to(device)
        _2i = torch.arange(0, d_model, step = 2).float().to(device)
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i/d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i/d_model)))

    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :].to(device)


# In[9]:


class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        # '-1' means last dimension.

        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out

class TokenEmbedding(nn.Embedding):
   def __init__(self, vocab_size, d_model):
       super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)

class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob):
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len)
        self.drop_out = nn.Dropout(p=drop_prob)
    
    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.drop_out(tok_emb + pos_emb)


# In[10]:


class EncoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model, n_head)
        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p = drop_prob)
        self.dropout2 = nn.Dropout(p = drop_prob)

    def forward(self, x, src_mask):
        _x = x
        x = self.attention(x, x, x, src_mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        return x

class Encoder(nn.Module):
    def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob):
        super().__init__()
        self.emb = TransformerEmbedding(d_model=d_model,
                                        max_len=max_len,
                                        vocab_size=enc_voc_size,
                                        drop_prob=drop_prob)

        self.layers = nn.ModuleList([EncoderLayer(d_model=d_model,
                                                  ffn_hidden=ffn_hidden,
                                                  n_head=n_head,
                                                  drop_prob=drop_prob)
                                     for _ in range(n_layers)])

    def forward(self, x, src_mask):
        x = self.emb(x)

        for layer in self.layers:
            x = layer(x, src_mask)

        return x


# In[11]:


class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, n_head)
        self.enc_dec_attention = MultiHeadAttention(d_model, n_head)
        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm3 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p = drop_prob)
        self.dropout2 = nn.Dropout(p = drop_prob)
        self.dropout3 = nn.Dropout(p = drop_prob)

    def forward(self, dec, enc, trg_mask, src_mask):
        _x = dec
        x = self.self_attention(dec, dec, dec, trg_mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        if enc is not None:
            _x = x
            x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)

            x = self.dropout2(x)
            x = self.norm2(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x

class Decoder(nn.Module):
    def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob):
        super().__init__()
        self.emb = TransformerEmbedding(d_model=d_model,
                                        drop_prob=drop_prob,
                                        max_len=max_len,
                                        vocab_size=dec_voc_size)

        self.layers = nn.ModuleList([DecoderLayer(d_model=d_model,
                                                  ffn_hidden=ffn_hidden,
                                                  n_head=n_head,
                                                  drop_prob=drop_prob)
                                     for _ in range(n_layers)])

        self.linear = nn.Linear(d_model, dec_voc_size)

    def forward(self, trg, enc_src, trg_mask, src_mask):
        trg = self.emb(trg)

        for layer in self.layers:
            trg = layer(trg, enc_src, trg_mask, src_mask)
        output = self.linear(trg)
        return output


# In[12]:


class Transformer(nn.Module):

    def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len,
                 ffn_hidden, n_layers, drop_prob):
        super().__init__()
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.trg_sos_idx = trg_sos_idx
        self.encoder = Encoder(d_model=d_model,
                               n_head=n_head,
                               max_len=max_len,
                               ffn_hidden=ffn_hidden,
                               enc_voc_size=enc_voc_size,
                               drop_prob=drop_prob,
                               n_layers=n_layers)

        self.decoder = Decoder(d_model=d_model,
                               n_head=n_head,
                               max_len=max_len,
                               ffn_hidden=ffn_hidden,
                               dec_voc_size=dec_voc_size,
                               drop_prob=drop_prob,
                               n_layers=n_layers)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        output = self.decoder(trg, enc_src, trg_mask, src_mask)
        return output

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(device)

    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(device)
        trg_mask = trg_pad_mask & trg_sub_mask
        return trg_mask.to(device)


# In[13]:


def get_bleu(pred_seq, label_seq, k = 4):
    """Compute the BLEU."""
    pred_tokens, label_tokens = pred_seq.split(' '), label_seq.split(' ')
    len_pred, len_label = len(pred_tokens), len(label_tokens)
    score = math.exp(min(0, 1 - len_label / len_pred))


    for n in range(1, min(k, len_pred) + 1):
        num_matches, label_subs = 0, collections.defaultdict(int)
        for i in range(len_label - n + 1):
            label_subs[' '.join(label_tokens[i: i + n])] += 1

        for i in range(len_pred - n + 1):
            if label_subs[' '.join(pred_tokens[i: i + n])] > 0:
                num_matches += 1
                label_subs[' '.join(pred_tokens[i: i + n])] -= 1
        
        #print(num_matches)

        score *= math.pow(num_matches / (len_pred - n + 1), math.pow(0.5, n))
    return score


def idx_to_word(x, vocab):
    words = []
    for i in x:
        word = vocab.itos[i]
        if '<' not in word:
            words.append(word)
    words = " ".join(words)
    return words


# In[14]:


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.kaiming_uniform(m.weight.data)



model = Transformer(src_pad_idx=src_pad_idx,
                    trg_pad_idx=trg_pad_idx,
                    trg_sos_idx=trg_sos_idx,
                    d_model=d_model,
                    enc_voc_size=enc_voc_size,
                    dec_voc_size=dec_voc_size,
                    max_len=max_len,
                    ffn_hidden=ffn_hidden,
                    n_head=n_heads,
                    n_layers=n_layers,
                    drop_prob=drop_prob).to(device)

print(f'The model has {count_parameters(model):,} trainable parameters')
model.apply(initialize_weights)
optimizer = Adam(params=model.parameters(),
                 lr=init_lr,
                 weight_decay=weight_decay,
                 eps=adam_eps)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                 verbose=True,
                                                 factor=factor,
                                                 patience=patience)

criterion = nn.CrossEntropyLoss(ignore_index=src_pad_idx)

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src.to(device)
        trg = batch.trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg[:, :-1])
        output_reshape = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:, 1:].contiguous().view(-1)

        loss = criterion(output_reshape, trg)
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
        print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())

    return epoch_loss / len(iterator)


def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    batch_bleu = []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            trg = batch.trg

            src = src.to(device)
            trg = trg.to(device)

            output = model(src, trg[:, :-1])
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output_reshape, trg)
            epoch_loss += loss.item()

            total_bleu = []
            for j in range(batch_size):
                try:
                    trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
                    output_words = output[j].max(dim=1)[1]
                    output_words = idx_to_word(output_words, loader.target.vocab)

                    #print(output_words + ' lmao')
                    #print(trg_words + ' bruh')
                    #print(get_bleu(output_words, trg_words))
                    #print('-' * 30)

                    bleu = get_bleu(output_words, trg_words)

                    total_bleu.append(bleu)
                except:
                    pass

            total_bleu = sum(total_bleu) / len(total_bleu)
            batch_bleu.append(total_bleu)

            #src = src.detach()
            #trg = trg.detach()
            torch.cuda.empty_cache()

    batch_bleu = sum(batch_bleu) / len(batch_bleu)
    return epoch_loss / len(iterator), batch_bleu

from pathlib import Path
def save_model(name):
    MODEL_PATH = Path('/home/dominhnhat/Pose_Estimation/MFSvi/debug/models')
    MODEL_NAME = Path(name + '.pth')
    MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME
    #print(f'Saving model to : {MODEL_SAVE_PATH}')
    torch.save(obj = model.state_dict(),
            f = MODEL_SAVE_PATH)

def run(total_epoch, best_loss, best_bleu):
    #train_losses, test_losses, bleus = [], [], []
    for step in range(total_epoch):
        start_time = time.time()
        train_loss = train(model, train_iter, optimizer, criterion, clip)
        valid_loss, bleu = evaluate(model, valid_iter, criterion)
        #end_time = time.time()

        if step > warmup:
            scheduler.step(valid_loss)

        #train_losses.append(train_loss)
        #test_losses.append(valid_loss)
        #bleus.append(bleu)
        #epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_loss:
            best_loss = valid_loss
            #save_model('best_loss')
        
        if (bleu > best_bleu):
            best_bleu = bleu
            #save_model('best_bleu')

        print(f'Epoch: {step + 1}')
        print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
        print(f'\tVal Loss: {valid_loss:.3f} |  Val PPL: {math.exp(valid_loss):7.3f}')
        print(f'\tBLEU Score: {bleu:.3f}')

The model has 19,278,387 trainable parameters


  nn.init.kaiming_uniform(m.weight.data)
  from .autonotebook import tqdm as notebook_tqdm


In [5]:
run(total_epoch = 1000, best_loss = inf, best_bleu = 0.0)

step : 0.0 % , loss : 7.507517337799072
Epoch: 1
	Train Loss: 7.508 | Train PPL: 1821.685
	Val Loss: 7.369 |  Val PPL: 1585.333
	BLEU Score: 0.000
step : 0.0 % , loss : 7.363132476806641
Epoch: 2
	Train Loss: 7.363 | Train PPL: 1576.768
	Val Loss: 7.187 |  Val PPL: 1321.495
	BLEU Score: 0.000
step : 0.0 % , loss : 7.274956226348877
Epoch: 3
	Train Loss: 7.275 | Train PPL: 1443.688
	Val Loss: 7.036 |  Val PPL: 1137.138
	BLEU Score: 0.000
step : 0.0 % , loss : 7.147759437561035
Epoch: 4
	Train Loss: 7.148 | Train PPL: 1271.254
	Val Loss: 6.919 |  Val PPL: 1011.336
	BLEU Score: 0.000
step : 0.0 % , loss : 7.054183483123779
Epoch: 5
	Train Loss: 7.054 | Train PPL: 1157.692
	Val Loss: 6.830 |  Val PPL: 925.233
	BLEU Score: 0.000
step : 0.0 % , loss : 6.925076961517334
Epoch: 6
	Train Loss: 6.925 | Train PPL: 1017.473
	Val Loss: 6.764 |  Val PPL: 866.521
	BLEU Score: 0.000
step : 0.0 % , loss : 6.8362250328063965
Epoch: 7
	Train Loss: 6.836 | Train PPL: 930.968
	Val Loss: 6.711 |  Val PPL: 8

In [6]:
def evaluate_and_print(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    batch_bleu = []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src.to(device)
            trg = batch.trg.to(device)

            output = model(src, trg[:, :-1])
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output_reshape, trg)
            epoch_loss += loss.item()

            total_bleu = []
            for j in range(batch_size):
                try:
                    trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
                    output_words = output[j].max(dim=1)[1]

                    output_words = idx_to_word(output_words, loader.target.vocab)
                    bleu = get_bleu(output_words, trg_words)

                    print('-' * 30)
                    print(output_words.split())
                    print(trg_words.split())

                    total_bleu.append(bleu)
                except:
                    pass

            total_bleu = sum(total_bleu) / len(total_bleu)
            batch_bleu.append(total_bleu)

            #src = src.detach()
            #trg = trg.detach()
            torch.cuda.empty_cache()

    #batch_bleu = sum(batch_bleu) / len(batch_bleu)
    print('Final result :')
    print(epoch_loss / len(iterator),' ', batch_bleu) 

evaluate_and_print(model, valid_iter, criterion)

------------------------------
['chuyên_gia', 'chuyên_gia', 'chuyên_gia', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'về', 'biến_đổi', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-']
['trong', '4', 'phút', ',', 'chuyên_gia', 'hoá_học', 'khí_quyển', 'rachel', 'pike', 'giới_thiệu', 'sơ_lược', 'về', 'những', 'nỗ_lực', 'khoa_học', 'miệt_mài', 'đằng', 'sau', 'những', 'tiêu_đề', 'táo_bạo', 'về', 'biến_đổi', 'khí_hậu', ',', 'cùng', 'với', 'đoàn', 'nghiên_cứu', 'của', 'mình', '-', '-', 'hàng', 'ngàn', 'người', 'đã', 'cống_hiến', 'cho', 'dự_án', 'này', '-', '-', 'một', 'chuyến', 'bay', 'mạo_hiểm', 'qua', 'rừng_già', 'để', 'tìm_kiếm', 'thông_tin', 'về', 'một', 'phân_tử', 'then_chốt', '.']
------------------------------
['chúng_tôi', 'chúng_tôi', 'chúng_tôi', 'tìm', 'tìm', 'tìm', 'tìm', 'tìm', 'tìm', '

In [7]:
def test_model(model_path):
    loaded_model = Transformer(src_pad_idx=src_pad_idx,
                        trg_pad_idx=trg_pad_idx,
                        trg_sos_idx=trg_sos_idx,
                        d_model=d_model,
                        enc_voc_size=enc_voc_size,
                        dec_voc_size=dec_voc_size,
                        max_len=max_len,
                        ffn_hidden=ffn_hidden,
                        n_head=n_heads,
                        n_layers=n_layers,
                        drop_prob=drop_prob).to(device)

    path = Path(model_path)
    loaded_model.load_state_dict(torch.load(f = path))

    evaluate_and_print(loaded_model, valid_iter, criterion)

test_model('/home/dominhnhat/Pose_Estimation/MFSvi/debug/models/best_bleu.pth')
print('\n' * 3)
test_model('/home/dominhnhat/Pose_Estimation/MFSvi/debug/models/best_loss.pth')

RuntimeError: Error(s) in loading state_dict for Transformer:
	Missing key(s) in state_dict: "encoder.layers.2.attention.w_q.weight", "encoder.layers.2.attention.w_q.bias", "encoder.layers.2.attention.w_k.weight", "encoder.layers.2.attention.w_k.bias", "encoder.layers.2.attention.w_v.weight", "encoder.layers.2.attention.w_v.bias", "encoder.layers.2.attention.w_concat.weight", "encoder.layers.2.attention.w_concat.bias", "encoder.layers.2.ffn.linear1.weight", "encoder.layers.2.ffn.linear1.bias", "encoder.layers.2.ffn.linear2.weight", "encoder.layers.2.ffn.linear2.bias", "encoder.layers.2.norm1.gamma", "encoder.layers.2.norm1.beta", "encoder.layers.2.norm2.gamma", "encoder.layers.2.norm2.beta", "encoder.layers.3.attention.w_q.weight", "encoder.layers.3.attention.w_q.bias", "encoder.layers.3.attention.w_k.weight", "encoder.layers.3.attention.w_k.bias", "encoder.layers.3.attention.w_v.weight", "encoder.layers.3.attention.w_v.bias", "encoder.layers.3.attention.w_concat.weight", "encoder.layers.3.attention.w_concat.bias", "encoder.layers.3.ffn.linear1.weight", "encoder.layers.3.ffn.linear1.bias", "encoder.layers.3.ffn.linear2.weight", "encoder.layers.3.ffn.linear2.bias", "encoder.layers.3.norm1.gamma", "encoder.layers.3.norm1.beta", "encoder.layers.3.norm2.gamma", "encoder.layers.3.norm2.beta", "encoder.layers.4.attention.w_q.weight", "encoder.layers.4.attention.w_q.bias", "encoder.layers.4.attention.w_k.weight", "encoder.layers.4.attention.w_k.bias", "encoder.layers.4.attention.w_v.weight", "encoder.layers.4.attention.w_v.bias", "encoder.layers.4.attention.w_concat.weight", "encoder.layers.4.attention.w_concat.bias", "encoder.layers.4.ffn.linear1.weight", "encoder.layers.4.ffn.linear1.bias", "encoder.layers.4.ffn.linear2.weight", "encoder.layers.4.ffn.linear2.bias", "encoder.layers.4.norm1.gamma", "encoder.layers.4.norm1.beta", "encoder.layers.4.norm2.gamma", "encoder.layers.4.norm2.beta", "decoder.layers.2.self_attention.w_q.weight", "decoder.layers.2.self_attention.w_q.bias", "decoder.layers.2.self_attention.w_k.weight", "decoder.layers.2.self_attention.w_k.bias", "decoder.layers.2.self_attention.w_v.weight", "decoder.layers.2.self_attention.w_v.bias", "decoder.layers.2.self_attention.w_concat.weight", "decoder.layers.2.self_attention.w_concat.bias", "decoder.layers.2.enc_dec_attention.w_q.weight", "decoder.layers.2.enc_dec_attention.w_q.bias", "decoder.layers.2.enc_dec_attention.w_k.weight", "decoder.layers.2.enc_dec_attention.w_k.bias", "decoder.layers.2.enc_dec_attention.w_v.weight", "decoder.layers.2.enc_dec_attention.w_v.bias", "decoder.layers.2.enc_dec_attention.w_concat.weight", "decoder.layers.2.enc_dec_attention.w_concat.bias", "decoder.layers.2.ffn.linear1.weight", "decoder.layers.2.ffn.linear1.bias", "decoder.layers.2.ffn.linear2.weight", "decoder.layers.2.ffn.linear2.bias", "decoder.layers.2.norm1.gamma", "decoder.layers.2.norm1.beta", "decoder.layers.2.norm2.gamma", "decoder.layers.2.norm2.beta", "decoder.layers.2.norm3.gamma", "decoder.layers.2.norm3.beta", "decoder.layers.3.self_attention.w_q.weight", "decoder.layers.3.self_attention.w_q.bias", "decoder.layers.3.self_attention.w_k.weight", "decoder.layers.3.self_attention.w_k.bias", "decoder.layers.3.self_attention.w_v.weight", "decoder.layers.3.self_attention.w_v.bias", "decoder.layers.3.self_attention.w_concat.weight", "decoder.layers.3.self_attention.w_concat.bias", "decoder.layers.3.enc_dec_attention.w_q.weight", "decoder.layers.3.enc_dec_attention.w_q.bias", "decoder.layers.3.enc_dec_attention.w_k.weight", "decoder.layers.3.enc_dec_attention.w_k.bias", "decoder.layers.3.enc_dec_attention.w_v.weight", "decoder.layers.3.enc_dec_attention.w_v.bias", "decoder.layers.3.enc_dec_attention.w_concat.weight", "decoder.layers.3.enc_dec_attention.w_concat.bias", "decoder.layers.3.ffn.linear1.weight", "decoder.layers.3.ffn.linear1.bias", "decoder.layers.3.ffn.linear2.weight", "decoder.layers.3.ffn.linear2.bias", "decoder.layers.3.norm1.gamma", "decoder.layers.3.norm1.beta", "decoder.layers.3.norm2.gamma", "decoder.layers.3.norm2.beta", "decoder.layers.3.norm3.gamma", "decoder.layers.3.norm3.beta", "decoder.layers.4.self_attention.w_q.weight", "decoder.layers.4.self_attention.w_q.bias", "decoder.layers.4.self_attention.w_k.weight", "decoder.layers.4.self_attention.w_k.bias", "decoder.layers.4.self_attention.w_v.weight", "decoder.layers.4.self_attention.w_v.bias", "decoder.layers.4.self_attention.w_concat.weight", "decoder.layers.4.self_attention.w_concat.bias", "decoder.layers.4.enc_dec_attention.w_q.weight", "decoder.layers.4.enc_dec_attention.w_q.bias", "decoder.layers.4.enc_dec_attention.w_k.weight", "decoder.layers.4.enc_dec_attention.w_k.bias", "decoder.layers.4.enc_dec_attention.w_v.weight", "decoder.layers.4.enc_dec_attention.w_v.bias", "decoder.layers.4.enc_dec_attention.w_concat.weight", "decoder.layers.4.enc_dec_attention.w_concat.bias", "decoder.layers.4.ffn.linear1.weight", "decoder.layers.4.ffn.linear1.bias", "decoder.layers.4.ffn.linear2.weight", "decoder.layers.4.ffn.linear2.bias", "decoder.layers.4.norm1.gamma", "decoder.layers.4.norm1.beta", "decoder.layers.4.norm2.gamma", "decoder.layers.4.norm2.beta", "decoder.layers.4.norm3.gamma", "decoder.layers.4.norm3.beta". 
	size mismatch for encoder.emb.tok_emb.weight: copying a param with shape torch.Size([561, 128]) from checkpoint, the shape in current model is torch.Size([561, 512]).
	size mismatch for encoder.layers.0.attention.w_q.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.layers.0.attention.w_q.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.0.attention.w_k.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.layers.0.attention.w_k.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.0.attention.w_v.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.layers.0.attention.w_v.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.0.attention.w_concat.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.layers.0.attention.w_concat.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.0.ffn.linear1.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for encoder.layers.0.ffn.linear1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.layers.0.ffn.linear2.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for encoder.layers.0.ffn.linear2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.0.norm1.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.0.norm1.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.0.norm2.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.0.norm2.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.1.attention.w_q.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.layers.1.attention.w_q.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.1.attention.w_k.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.layers.1.attention.w_k.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.1.attention.w_v.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.layers.1.attention.w_v.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.1.attention.w_concat.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for encoder.layers.1.attention.w_concat.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.1.ffn.linear1.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for encoder.layers.1.ffn.linear1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.layers.1.ffn.linear2.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for encoder.layers.1.ffn.linear2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.1.norm1.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.1.norm1.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.1.norm2.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.layers.1.norm2.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.emb.tok_emb.weight: copying a param with shape torch.Size([563, 128]) from checkpoint, the shape in current model is torch.Size([563, 512]).
	size mismatch for decoder.layers.0.self_attention.w_q.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.0.self_attention.w_q.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.self_attention.w_k.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.0.self_attention.w_k.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.self_attention.w_v.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.0.self_attention.w_v.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.self_attention.w_concat.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.0.self_attention.w_concat.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.enc_dec_attention.w_q.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.0.enc_dec_attention.w_q.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.enc_dec_attention.w_k.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.0.enc_dec_attention.w_k.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.enc_dec_attention.w_v.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.0.enc_dec_attention.w_v.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.enc_dec_attention.w_concat.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.0.enc_dec_attention.w_concat.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.ffn.linear1.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for decoder.layers.0.ffn.linear1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.layers.0.ffn.linear2.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for decoder.layers.0.ffn.linear2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.norm1.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.norm1.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.norm2.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.norm2.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.norm3.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.0.norm3.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.self_attention.w_q.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.1.self_attention.w_q.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.self_attention.w_k.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.1.self_attention.w_k.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.self_attention.w_v.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.1.self_attention.w_v.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.self_attention.w_concat.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.1.self_attention.w_concat.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.enc_dec_attention.w_q.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.1.enc_dec_attention.w_q.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.enc_dec_attention.w_k.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.1.enc_dec_attention.w_k.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.enc_dec_attention.w_v.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.1.enc_dec_attention.w_v.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.enc_dec_attention.w_concat.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 512]).
	size mismatch for decoder.layers.1.enc_dec_attention.w_concat.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.ffn.linear1.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for decoder.layers.1.ffn.linear1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.layers.1.ffn.linear2.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([512, 256]).
	size mismatch for decoder.layers.1.ffn.linear2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.norm1.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.norm1.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.norm2.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.norm2.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.norm3.gamma: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.layers.1.norm3.beta: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.linear.weight: copying a param with shape torch.Size([563, 128]) from checkpoint, the shape in current model is torch.Size([563, 512]).