In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
from dataset import CVFConfigForTransformerDataset, CVFConfigForTransformerTestDataset

In [3]:
device = "cuda"

In [4]:
def generate_local_mask(seq_len):
    mask = torch.full((seq_len, seq_len), float("-inf"))
    for i in range(1, seq_len):
        mask[i, i - 1] = 0  # Only allow attending to the previous token
    mask[0, 0] = 0  # Optional: allow first token to attend to itself
    return mask

In [5]:
batch_size = 256

dataset = CVFConfigForTransformerDataset(
    device,
    "implicit_graph_n5",
    "implicit_graph_n5_pt_adj_list.txt",
    "implicit_graph_n5_config_rank_dataset.csv",
    D=5,
    program="dijkstra",
)


train_size = int(1.0 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])


loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Total configs: 243.


In [6]:
class EmbeddingProjectionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(EmbeddingProjectionModel, self).__init__()
        self.projection = nn.Linear(input_dim, output_dim)  # Project Z to D

    def forward(self, x):
        # Apply the linear transformation to the input tensor
        return self.projection(x)  # Output shape: (B, S, D)

In [None]:
class CausalTransformer(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers):
        super().__init__()
        self.embedding = EmbeddingProjectionModel(vocab_size, hidden_dim)
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=4)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(hidden_dim, 1)

    def forward(self, x, padding_mask):
        # x: [B, T]
        x = self.embedding(x).transpose(0, 1)  # [T, B, D]
        mask = generate_local_mask(dataset.sequence_length).to(x.device)
        # print("mask", mask)
        out = self.transformer(
            x,
            memory=torch.zeros(1, x.size(1), x.size(2)).to(x.device),
            tgt_mask=mask,
            tgt_key_padding_mask=padding_mask,
        )
        out = self.output_head(out.transpose(0, 1)).squeeze(-1)  # [B, T]
        # return out, x
        return out

    def fit(self, num_epochs, dataloader):
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.parameters(), lr=1e-3)

        for epoch in range(num_epochs):
            self.train()
            total_loss = 0
            count = 0
            for i, batch in enumerate(dataloader):
                x = batch[0][0]
                padding_mask = (~batch[0][1]).float()
                y = batch[1]
                out = self(x, padding_mask)
                optimizer.zero_grad()
                if (epoch + 1) % 5 == 0 and i == 0:
                    print("Out", out)
                    print("y", y)
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch %s | MSE Loss: %s"
                % (
                    epoch + 1,
                    round((total_loss / count).item(), 4),
                )
            )

In [8]:
# Parameters
vocab_size = dataset.D
seq_len = dataset.sequence_length
batch_size = 128
hidden_dim = 32
num_layers = 2
num_epochs = 20


# 3. Training Loop
model = CausalTransformer(vocab_size, hidden_dim, num_layers).to(device)
model.fit(num_epochs, loader)

Training set | Epoch 1 | MSE Loss: 18.0051
Training set | Epoch 2 | MSE Loss: 9.7928
Training set | Epoch 3 | MSE Loss: 2.3927
Training set | Epoch 4 | MSE Loss: 0.7498
Out tensor([[11.3711, 12.0577, 12.0545,  ..., -0.9653, -0.9117, -1.1657],
        [13.1296, 12.8617, 12.7973,  ..., -1.1473, -0.6973, -1.1529],
        [12.0110, 12.0203, 10.4924,  ..., -0.9400, -1.0081, -1.1544],
        ...,
        [12.9280, 13.0854, 11.9656,  ..., -1.1800, -1.0282, -0.4209],
        [12.2427, 11.3849, 13.1736,  ..., -0.9184, -1.0417, -0.9462],
        [11.6279, 11.7566, 10.4550,  ..., -1.0684, -1.1555, -1.1640]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)
y tensor([[12., 12., 11.,  ..., -1., -1., -1.],
        [15., 13., 13.,  ..., -1., -1., -1.],
        [12., 11., 10.,  ..., -1., -1., -1.],
        ...,
        [15., 13., 13.,  ..., -1., -1., -1.],
        [12., 12., 12.,  ..., -1., -1., -1.],
        [12., 12., 10.,  ..., -1., -1., -1.]], device='cuda:0')
Training set | Epoch 5 | MSE Los