In [1]:
from model import build_gpt_model
from sklearn.model_selection import train_test_split
import torch
from dataset import GPTChatDataset
from tokenizer_utils import SubwordTokenizer
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEQ_LEN = 128

In [3]:


from convert import load
lines = load()


In [4]:
print(len(lines))

72603


In [5]:
train_lines, test_lines = train_test_split(lines, test_size=0.2, random_state = 42)


In [6]:
print(len(train_lines))

58082


In [7]:
tokenizer = SubwordTokenizer("saved/vi_bpe.model")

pad_id = tokenizer.word2idx["<pad>"]

train_dataset = GPTChatDataset(train_lines, tokenizer, seq_len= SEQ_LEN )
test_dataset = GPTChatDataset(test_lines, tokenizer, seq_len= SEQ_LEN)



In [8]:
# from statistics import mean
# 
# lengths = [len(tokenizer.tokenize(line)) for line in lines]
# print("Trung bình:", mean(lengths), "| Max:", max(lengths), "| Min:", min(lengths))

In [9]:
def collate_fn(batch):
    decoder_input = torch.stack([item["decoder_input"] for item in batch])
    decoder_mask = torch.stack([item["decoder_mask"] for item in batch])
    label = torch.stack([item["label"] for item in batch])

    return {
        "decoder_input": decoder_input,      # (batch_size, seq_len)
        "decoder_mask": decoder_mask,        # (batch_size, 1, seq_len, seq_len)
        "label": label                       # (batch_size, seq_len)
    }

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [10]:
# for word in tokenizer.word2idx:
#     print(word)

In [11]:
model = build_gpt_model(len(tokenizer.word2idx), seq_len= SEQ_LEN).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
n_epochs = 3
best_val_loss = float("inf")
save_path     = "saved/best_model.pth"
os.makedirs("saved", exist_ok=True)
for epoch in range(n_epochs):
    model.train()
    total_train_loss = 0

    for batch in tqdm(train_loader, desc=f"[Epoch {epoch}] Validating"):
        decoder_input = batch["decoder_input"].to(device)
        decoder_mask = batch["decoder_mask"].to(device)
        labels = batch["label"].to(device)

        # Forward
        output = model(decoder_input, decoder_mask)  

        # Loss
        output = output.view(-1, output.shape[-1])
        labels = labels.view(-1)
        loss = criterion(output, labels)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    
    
    # Evaluation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            decoder_input = batch["decoder_input"].to(device)
            decoder_mask = batch["decoder_mask"].to(device)
            labels = batch["label"].to(device)


            output = model(decoder_input, decoder_mask)  

            output = output.view(-1, output.shape[-1])
            labels = labels.view(-1)
            loss = criterion(output, labels)

            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(test_loader)
    
    print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    
    # Save model if best
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), save_path)
        print("Saved best model.")


[Epoch 0] Validating:   0%|          | 3/1816 [00:03<35:11,  1.16s/it]