In [1]:
import csv

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
from dataset import CVFConfigForTransformerDataset, CVFConfigForTransformerTestDataset

In [3]:
device = "cuda"

In [4]:
def generate_local_mask(seq_len):
    mask = torch.full((seq_len, seq_len), float("-inf"))
    for i in range(1, seq_len):
        mask[i, i - 1] = 0  # Only allow attending to the previous token
    mask[0, 0] = 0  # Optional: allow first token to attend to itself
    return mask

In [5]:
batch_size = 256

# dataset = CVFConfigForTransformerDataset(
#     device,
#     "implicit_graph_n5",
#     "implicit_graph_n5_pt_adj_list.txt",
#     "implicit_graph_n5_config_rank_dataset.csv",
#     D=5,
#     program="dijkstra",
# )

# dataset = CVFConfigForTransformerDataset(
#     device,
#     "implicit_graph_n7",
#     "implicit_graph_n7_pt_adj_list.txt",
#     "implicit_graph_n7_config_rank_dataset.csv",
#     D=7,
#     program="dijkstra",
# )

dataset = CVFConfigForTransformerDataset(
    device,
    "graph_random_regular_graph_n7_d4",
    "graph_random_regular_graph_n7_d4_pt_adj_list.txt",
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    D=7,
)


train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])


loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Total configs: 78,125.


In [6]:
print(f"Dataset size: {len(dataset):,}")

Dataset size: 2,604,798


In [7]:
class EmbeddingProjectionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(EmbeddingProjectionModel, self).__init__()
        self.projection = nn.Linear(input_dim, output_dim)  # Project Z to D

    def forward(self, x):
        # Apply the linear transformation to the input tensor
        return self.projection(x)  # Output shape: (B, S, D)

In [8]:
class CausalTransformer(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers):
        super().__init__()
        self.embedding = EmbeddingProjectionModel(vocab_size, hidden_dim)
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=4)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(hidden_dim, 1)

    def forward(self, x, padding_mask):
        # x: [B, T]
        x = self.embedding(x).transpose(0, 1)  # [T, B, D]
        mask = generate_local_mask(dataset.sequence_length).to(x.device)
        # print("mask", mask)
        out = self.transformer(
            x,
            memory=torch.zeros(1, x.size(1), x.size(2)).to(x.device),
            tgt_mask=mask,
            tgt_key_padding_mask=padding_mask,
        )
        out = self.output_head(out.transpose(0, 1)).squeeze(-1)  # [B, T]
        return out

    def fit(self, num_epochs, dataloader):
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.parameters(), lr=1e-3)

        for epoch in range(num_epochs):
            self.train()
            total_loss = 0
            count = 0
            for i, batch in enumerate(dataloader):
                x = batch[0][0]
                padding_mask = (~batch[0][1]).float()
                y = batch[1]
                out = self(x, padding_mask)
                # print(out.shape, y.shape)
                optimizer.zero_grad()
                # if (epoch + 1) % 5 == 0 and i == 0:
                #     print("Out", out)
                #     print("y", y)
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch %s | MSE Loss: %s"
                % (
                    epoch + 1,
                    round((total_loss / count).item(), 4),
                )
            )

In [11]:
# Parameters
vocab_size = dataset.D
seq_len = dataset.sequence_length
batch_size = 10240
hidden_dim = 16
num_layers = 2
num_epochs = 3


# 3. Training Loop
model = CausalTransformer(vocab_size, hidden_dim, num_layers).to(device)
print(model)
print()
print("Total parameters:", f"{sum(p.numel() for p in model.parameters()):,}")
print()
model.fit(num_epochs, loader)

CausalTransformer(
  (embedding): EmbeddingProjectionModel(
    (projection): Linear(in_features=7, out_features=16, bias=True)
  )
  (transformer): TransformerDecoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=16, out_features=16, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=16, out_features=16, bias=True)
        )
        (linear1): Linear(in_features=16, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=16, bias=True)
        (norm1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
       

KeyboardInterrupt: 

# Testing

In [10]:
model.eval()

criterion = torch.nn.MSELoss()

f = open(
    f"test_results/test_result_trans2.csv",
    "w",
    newline="",
)
csv_writer = csv.writer(f)
csv_writer.writerow(["Actual", "Predicted", "Correct"])


test_dataset = CVFConfigForTransformerTestDataset(
    device,
    "graph_random_regular_graph_n7_d4",
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    D=7,
)

with torch.no_grad():
    test_dataloader = DataLoader(test_dataset, batch_size=1024)

    total_loss = 0
    total_matched = 0
    count = 0
    total_seq_count = 0
    for batch in test_dataloader:
        x = batch[0][:, 0, :]
        padd = torch.full((dataset.sequence_length - 1, dataset.D), -1).to(device)
        padded_batches = [torch.cat([batch.unsqueeze(0), padd]) for batch in x]
        x = torch.stack(padded_batches)
        padding_mask = torch.full(
            (x.shape[0], dataset.sequence_length), 1, dtype=torch.bool
        ).to(device)
        padding_mask[:, 0] = False
        padding_mask = (~padding_mask).float()
        y = batch[1]
        out = model(x, padding_mask)
        out = out[:, 0].unsqueeze(-1)
        matched = torch.round(out) == y
        csv_writer.writerows(
            (j.item(), k.item(), z.item())
            for (j, k, z) in zip(
                y.detach().cpu().numpy(), out.detach().cpu().numpy(), matched
            )
        )
        # print(out.shape, y.shape)
        loss = criterion(out, y)
        total_loss += loss
        out = torch.round(out)
        matched = matched.sum().item()
        total_seq_count += out.numel()
        total_matched += matched
        count += 1

    print(
        f"Test set | MSE loss: {round((total_loss / count).item(), 4)} | Total matched: {total_matched:,} out of {total_seq_count:,} (Accuracy: {round(total_matched / total_seq_count * 100, 2):,}%)",
    )

f.close()

Total configs: 78,125.


Test set | MSE loss: 0.0949 | Total matched: 69,282 out of 78,125 (Accuracy: 88.68%)
