In [None]:
import os
import torch
import numpy as np
from tqdm.autonotebook import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from model import Transformer, build_transformer

In [None]:
# DO YOUR OWN TRAIN/TEST SPLIT

In [None]:
model: Transformer = build_transformer(
	dropout = 0.2,
	source_vocab_size = 58, target_vocab_size = 49, context_length = 20,
	encoder_block_count = 8, decoder_block_count = 3,
	encoder_self_attention_head_count = 4, decoder_self_attention_head_count = 4,
	decoder_cross_attention_head_count = 4,
	encoder_self_attention_abstraction_coef = 0.25, decoder_self_attention_abstraction_coef = 0.25,
	decoder_cross_attention_abstraction_coef = 0.25,
	encoder_feed_forward_abstraction_coef = 2.0, decoder_feed_forward_abstraction_coef = 2.0,
	dim = 256, epsilon = 1e-9,
)

parameters = model.parameters()
parameter_count = sum(p.numel() for p in parameters)
print(f"Parameters : {parameter_count}")

model.to("cuda")

In [None]:
info = {
	"batch_size": 1024,
	"context_length": 20,
	"learning_rate": 1e-4,
	"weight_decay": 1e-3,
	"label_smoothing": 0.1,
	"num_epochs": 40,
	"model_folder": "save/",
	"model_file": "g2p_model_",
	"model_save_interval": 1,
}

In [None]:
# Prepare masks for attention
masks = torch.ones((info["context_length"], 1, info["context_length"], info["context_length"]), device="cuda")
for i in range(info["context_length"]-1):
	masks[i, :, :, i+1:] = 0

print(masks.shape)

In [None]:
class CustomDataset(Dataset):
	def __init__(self, x_enc_path, x_dec_path, l_enc_path, l_dec_path, y_dec_path):
		# Load in mmap mode in save vram
		self.X_enc = np.load(x_enc_path, mmap_mode='r')
		self.X_dec = np.load(x_dec_path, mmap_mode='r')
		self.L_enc = np.load(l_enc_path, mmap_mode='r')
		self.L_dec = np.load(l_dec_path, mmap_mode='r')
		self.Y_dec = np.load(y_dec_path, mmap_mode='r')

	def __len__(self):
		return self.X_enc.shape[0]

	def __getitem__(self, idx):
		x_enc = self.X_enc[idx]
		x_dec = self.X_dec[idx]
		l_enc = self.L_enc[idx]
		l_dec = self.L_dec[idx]
		y_dec = self.Y_dec[idx]

		# Convert to tensor
		x_enc = torch.tensor(x_enc, dtype=torch.int)
		x_dec = torch.tensor(x_dec, dtype=torch.int)
		l_enc = torch.tensor(l_enc, dtype=torch.int)
		l_dec = torch.tensor(l_dec, dtype=torch.int)
		y_dec = torch.tensor(y_dec, dtype=torch.float32)

		return x_enc, x_dec, l_enc, l_dec, y_dec

# Make the dataset
dataset = CustomDataset(
	x_enc_path="data/X_enc.npy",
	x_dec_path="data/X_dec.npy",
	l_enc_path="data/L_enc.npy",
	l_dec_path="data/L_dec.npy",
	y_dec_path="data/Y_dec.npy"
)

# Make the dataloader
dataloader = DataLoader(dataset, batch_size=info["batch_size"], shuffle=True, num_workers=16, pin_memory=True, pin_memory_device="cuda")

In [None]:
# Define the optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=info["learning_rate"], weight_decay=info["weight_decay"])
criterion = torch.nn.CrossEntropyLoss(label_smoothing=info["label_smoothing"])

In [None]:
# Test forward
x_enc, x_dec, l_enc, l_dec, y_dec = next(iter(dataloader))

# Send data to GPU
x_enc = x_enc.to("cuda")
x_dec = x_dec.to("cuda")
l_enc = l_enc.to("cuda")
l_dec = l_dec.to("cuda")
y_dec = y_dec.to("cuda")

# Select the masks for attention
mask_enc = masks[l_enc[:, 0] - 1]
mask_dec = masks[l_dec[:, 0] - 1]

# Forward pass
encoder_output = model.encode(x_enc, mask_enc)
decoder_output = model.decode(x_dec, encoder_output, mask_dec)
model_output = model.project(decoder_output)

In [None]:
# Load weights from a checkpoint (resume training)
pre_training_epoch = 0
if pre_training_epoch > 0:
	model.load_state_dict(torch.load(f"{info['model_folder']}{info['model_file']}{pre_training_epoch}.pth", weights_only=True))

In [None]:
# training loop
if os.path.exists(info["model_folder"]) == False:
	os.makedirs(info["model_folder"])

losses = []
accuracies = []

for epoch in range(info["num_epochs"]):
	model.train()
	epoch_loss = 0.0
	epoch_accuracy = 0.0

	batch_iterator = tqdm(dataloader, desc=f"Epoch {epoch+1}/{info['num_epochs']}")
	for x_enc, x_dec, l_enc, l_dec, y_dec in batch_iterator:
		# Send data to GPU
		x_enc = x_enc.to("cuda", non_blocking=True)
		x_dec = x_dec.to("cuda", non_blocking=True)
		l_enc = l_enc.to("cuda", non_blocking=True)
		l_dec = l_dec.to("cuda", non_blocking=True)
		y_dec = y_dec.to("cuda", non_blocking=True)

		# Select the masks for attention
		mask_enc = masks[l_enc[:, 0] - 1]
		mask_dec = masks[l_dec[:, 0] - 1]

		# Forward pass
		encoder_output = model.encode(x_enc, mask_enc)
		decoder_output = model.decode(x_dec, encoder_output, mask_dec)
		model_output = model.project(decoder_output)

		# Compute the loss
		loss = criterion(model_output, y_dec)
		epoch_loss += loss.item()

		# Compute the accuracy
		accuracy = (model_output.argmax(dim=-1) == y_dec.argmax(dim=-1)).sum().item() / y_dec.size(0)
		epoch_accuracy += accuracy

		# Optimisation
		torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		# Display the results
		batch_iterator.set_postfix({"loss": loss.item(), "accuracy": accuracy})

	# Display the results
	epoch_loss /= len(dataloader)
	epoch_accuracy /= len(dataloader)
	losses.append(epoch_loss)
	accuracies.append(epoch_accuracy)
	print(f"Epoch {epoch+1}/{info['num_epochs']} : loss = {epoch_loss:.4f}, accuracy = {epoch_accuracy:.4f}")

	# Save the model
	if (epoch+1) % info["model_save_interval"] == 0:
		torch.save(model.state_dict(), f"{info['model_folder']}{info['model_file']}{epoch+pre_training_epoch+1}.pth")

In [None]:
# Evaluation
model.eval()
epoch_loss = 0.0
epoch_accuracy = 0.0

criterion = torch.nn.CrossEntropyLoss()
batch_iterator = tqdm(dataloader, desc="Evaluation")
for x_enc, x_dec, l_enc, l_dec, y_dec in batch_iterator:
	# Send data to GPU
	x_enc = x_enc.to("cuda", non_blocking=True)
	x_dec = x_dec.to("cuda", non_blocking=True)
	l_enc = l_enc.to("cuda", non_blocking=True)
	l_dec = l_dec.to("cuda", non_blocking=True)
	y_dec = y_dec.to("cuda", non_blocking=True)

	# Select the masks for attention
	mask_enc = masks[l_enc[:, 0] - 1]
	mask_dec = masks[l_dec[:, 0] - 1]

	# Forward pass
	encoder_output = model.encode(x_enc, mask_enc)
	decoder_output = model.decode(x_dec, encoder_output, mask_dec)
	model_output = model.project(decoder_output)

	# Compute the loss
	loss = criterion(model_output, y_dec)
	epoch_loss += loss.item()

	# Compute the accuracy
	accuracy = (model_output.argmax(dim=-1) == y_dec.argmax(dim=-1)).sum().item() / y_dec.size(0)
	epoch_accuracy += accuracy

	# Display the results
	batch_iterator.set_postfix({"loss": loss.item(), "accuracy": accuracy})

# Display the results
epoch_loss /= len(dataloader)
epoch_accuracy /= len(dataloader)
print(f"Evaluation : loss = {epoch_loss:.4f}, accuracy = {epoch_accuracy:.4f}")