In [None]:
from ZW_model import *
from ZW_utils import *
import torch.optim as optim
dataset_id = "v21D0_m1.npy"
batch_size = 10
block_size = 22
loss_function = std_loss
learning_rate = 1e-3
max_epochs = 20
n_embd = 32  # 32
n_head = 4  # 4
n_layer = 2  # 2
dropout = 0.1  # 0.1
classes = std_classes
vocab_size = len(classes)
model = GPTModel(vocab_size, n_embd, n_head, n_layer, block_size, dropout)


In [None]:
save_path = make_dir(
    model,
    batch_size,
    learning_rate,
)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
dataset = dataloading(dataset_id)

In [None]:
from torch.utils.data import DataLoader
from ZW_dataset import GPTDataset
data_split_ratio = 0.85
t_data = dataset[: int(data_split_ratio * len(dataset))]
v_data = dataset[int(data_split_ratio * len(dataset)) :]
training_set = GPTDataset(t_data, classes,block_size, training_type="augmented")
validation_set = GPTDataset(v_data, classes,block_size, training_type="standard")
train_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False)

In [None]:
best_loss = np.inf
best_model = None
train_losses = []
validation_losses = []
patience = 5
for epoch in range(max_epochs):
    model.train()
    epoch_loss = 0
    train_correct = 0
    train_total = 0
    for i, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        y_pred = model(x)
        y_pred = y_pred.view(-1, y_pred.size(-1))
        y = y.view(-1)
        loss = loss_function(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        _, predicted = torch.max(y_pred.data, 1)
        train_total += y.size(0)
        train_correct += (predicted == y).sum().item()
    epoch_loss = epoch_loss / len(train_loader)
    model.eval()
    correct = 0
    total = 0
    val_loss = 0
    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            y_pred = model(x)
            y_pred = y_pred.view(-1, y_pred.size(-1))
            y = y.view(-1)
            loss = loss_function(y_pred, y)
            _, predicted = torch.max(y_pred.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
            val_loss += loss.item()
        val_loss = val_loss / len(val_loader)
        if val_loss < best_loss:
            best_loss = val_loss
            best_model = copy.deepcopy(model.state_dict())
            patience = 5
        else:
            patience -= 1
        train_losses.append(epoch_loss)
        validation_losses.append(val_loss)
        print(
            "Epoch [%d/%d], T.Loss: %.4f, V.Loss: %.4f, T.Acc: %.2f%%, V.Acc: %d%%"
            % (
                epoch + 1,
                max_epochs,
                epoch_loss,
                val_loss,
                100 * train_correct / train_total,
                100 * correct / total,
            )
        )
    if patience == 0:
        break