In [1]:
import time

import torch
import torch.nn as nn

from Data_load import (get_batch_indices, load_cn_vocab,
                                        load_en_vocab, load_train_data,
                                        maxlen)
from model import Transformer

# Config
batch_size = 64
lr = 0.0001
d_model = 512
d_ff = 2048
n_layers = 6
heads = 8
dropout_rate = 0.2
n_epochs = 60
PAD_ID = 0


def main():
    device = 'cuda'
    cn2idx, idx2cn = load_cn_vocab()
    en2idx, idx2en = load_en_vocab()
    # X: en
    # Y: cn
    Y, X = load_train_data()

    print_interval = 100

    model = Transformer(len(en2idx), len(cn2idx), PAD_ID, d_model, d_ff,
                        n_layers, heads, dropout_rate, maxlen)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr)

    citerion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
    tic = time.time()
    cnter = 0
    
    for epoch in range(n_epochs):
        for index, _ in get_batch_indices(len(X), batch_size):
            x_batch = torch.LongTensor(X[index]).to(device)
            y_batch = torch.LongTensor(Y[index]).to(device)
            y_input = y_batch[:, :-1]
            y_label = y_batch[:, 1:]
            y_hat = model(x_batch, y_input)

            y_label_mask = y_label != PAD_ID
            preds = torch.argmax(y_hat, -1)
            correct = preds == y_label
            acc = torch.sum(y_label_mask * correct) / torch.sum(y_label_mask)

            n, seq_len = y_label.shape
            y_hat = torch.reshape(y_hat, (n * seq_len, -1))
            y_label = torch.reshape(y_label, (n * seq_len, ))
            loss = citerion(y_hat, y_label)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()

            if cnter % print_interval == 0:
                toc = time.time()
                interval = toc - tic
                minutes = int(interval // 60)
                seconds = int(interval % 60)
                print(f'{cnter:08d} {minutes:02d}:{seconds:02d}'
                    f' loss: {loss.item()} acc: {acc.item()}')
            cnter += 1

    model_path = 'model.pth'
    torch.save(model.state_dict(), model_path)

    print(f'Model saved to {model_path}')


if __name__ == '__main__':
    main()


  from .autonotebook import tqdm as notebook_tqdm
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


00000000 00:03 loss: 9.62743091583252 acc: 0.0
00000100 00:33 loss: 7.061116695404053 acc: 0.0989096537232399
00000200 01:04 loss: 6.839704990386963 acc: 0.11538461595773697
00000300 01:35 loss: 6.816007137298584 acc: 0.12024825811386108
00000400 02:07 loss: 6.536980152130127 acc: 0.143968865275383
00000500 02:38 loss: 6.213496208190918 acc: 0.16900311410427094
00000600 03:10 loss: 6.049178600311279 acc: 0.18562401831150055
00000700 03:42 loss: 5.697995662689209 acc: 0.20108695328235626
00000800 04:15 loss: 5.7117600440979 acc: 0.18230357766151428
00000900 04:48 loss: 5.310678005218506 acc: 0.21101629734039307
00001000 05:21 loss: 5.2506256103515625 acc: 0.22102008759975433
00001100 05:54 loss: 5.166344165802002 acc: 0.2167721539735794
00001200 06:27 loss: 4.845065593719482 acc: 0.26514554023742676
00001300 07:01 loss: 4.974267482757568 acc: 0.23222748935222626
00001400 07:34 loss: 4.77531099319458 acc: 0.2562893033027649
00001500 08:08 loss: 4.630521297454834 acc: 0.26299455761909485
