In [1]:
import csv
import datetime

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import ConcatDataset, DataLoader, random_split

In [2]:
from dataset import (
    CVFConfigForTransformerMDataset,
    CVFConfigForTransformerTestDatasetWName,
)

In [3]:
device = "cuda"

In [4]:
def generate_local_mask(seq_len):
    mask = torch.full((seq_len, seq_len), float("-inf"))
    for i in range(1, seq_len):
        mask[i, i - 1] = 0  # Only allow attending to the previous token
    mask[0, 0] = 0  # Optional: allow first token to attend to itself
    return mask

In [5]:

dataset_implicit_n5 = CVFConfigForTransformerMDataset(
    device,
    "implicit_graph_n5",
    "implicit_graph_n5_pt_adj_list.txt",
    "implicit_graph_n5_config_rank_dataset.csv",
    D=5,
    program="dijkstra",
)

# dataset = CVFConfigForTransformerDataset(
#     device,
#     "implicit_graph_n7",
#     "implicit_graph_n7_pt_adj_list.txt",
#     "implicit_graph_n7_config_rank_dataset.csv",
#     D=7,
#     program="dijkstra",
# )

dataset_s_n7 = CVFConfigForTransformerMDataset(
    device,
    "star_graph_n7",
    "star_graph_n7_pt_adj_list.txt",
    "star_graph_n7_config_rank_dataset.csv",
    D=7,
)

dataset_rr_n7 = CVFConfigForTransformerMDataset(
    device,
    "graph_random_regular_graph_n7_d4",
    "graph_random_regular_graph_n7_d4_pt_adj_list.txt",
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    D=7,
)


dataset_plc_n7 = CVFConfigForTransformerMDataset(
    device,
    "graph_powerlaw_cluster_graph_n7",
    "graph_powerlaw_cluster_graph_n7_pt_adj_list.txt",
    "graph_powerlaw_cluster_graph_n7_config_rank_dataset.csv",
    D=7,
)


# train_size = int(0.8 * len(dataset))
# test_size = len(dataset) - train_size
# train_dataset, test_dataset = random_split(dataset, [train_size, test_size])


# loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Total configs: 243.
Total configs: 448.
Total configs: 78,125.
Total configs: 27,000.


In [6]:
dataset_coll = [
    # dataset_implicit_n5,
    dataset_s_n7,
    dataset_rr_n7,
    dataset_plc_n7,
]

batch_size = 1024

train_sizes = [int(0.8 * len(ds)) for ds in dataset_coll]
test_sizes = [len(ds) - trs for ds, trs in zip(dataset_coll, train_sizes)]

train_test_datasets = [
    random_split(ds, [tr_s, ts])
    for ds, tr_s, ts in zip(dataset_coll, train_sizes, test_sizes)
]

train_datasets = [ds[0] for ds in train_test_datasets]
test_datasets = [ds[1] for ds in train_test_datasets]

In [7]:
datasets = ConcatDataset(train_datasets)
print(f"Train Dataset size: {len(datasets):,}")

loader = DataLoader(datasets, batch_size=batch_size)

sequence_length = max(d.sequence_length for d in dataset_coll)
print(f"Max sequence length: {sequence_length:,}")

Train Dataset size: 2,470,382
Max sequence length: 10


In [8]:
N = dataset_coll[0].D

In [9]:
class EmbeddingProjectionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(EmbeddingProjectionModel, self).__init__()
        self.projection = nn.Linear(input_dim, output_dim)  # Project Z to D

    def forward(self, x):
        # Apply the linear transformation to the input tensor
        return self.projection(x)  # Output shape: (B, S, D)

In [10]:
class CausalTransformer(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layers):
        super().__init__()
        self.embedding = EmbeddingProjectionModel(vocab_size, hidden_dim)
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=4)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(hidden_dim, 1)

    def forward(self, x, padding_mask):
        # x: [B, T]
        x = self.embedding(x).transpose(0, 1)  # [T, B, D]
        mask = generate_local_mask(sequence_length).to(x.device)
        # print("mask", mask)
        out = self.transformer(
            x,
            memory=torch.zeros(1, x.size(1), x.size(2)).to(x.device),
            tgt_mask=mask,
            tgt_key_padding_mask=padding_mask,
        )
        out = self.output_head(out.transpose(0, 1)).squeeze(-1)  # [B, T]
        return out

    def fit(self, num_epochs, dataloader):
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.parameters(), lr=1e-3)

        for epoch in range(num_epochs):
            self.train()
            total_loss = 0
            count = 0
            for i, batch in enumerate(dataloader):
                x = batch[0][0]
                # padding_mask = (~batch[0][1]).float()
                padding_mask = batch[0][1]  # 0 if not padded, 1 if padded
                y = batch[1]
                out = self(x, padding_mask)
                # print(out.shape, y.shape)
                optimizer.zero_grad()
                # if (epoch + 1) % 5 == 0 and i == 0:
                #     print("Out", out)
                #     print("y", y)
                loss = criterion(out, y)
                total_loss += loss
                count += 1
                loss.backward()
                optimizer.step()

            print(
                "Training set | Epoch %s | MSE Loss: %s"
                % (
                    epoch + 1,
                    round((total_loss / count).item(), 4),
                )
            )

In [11]:
# Parameters
vocab_size = N
hidden_dim = 8
num_layers = 2
num_epochs = 5


# 3. Training Loop
model = CausalTransformer(vocab_size, hidden_dim, num_layers).to(device)
print(model)
print()
print("Total parameters:", f"{sum(p.numel() for p in model.parameters()):,}")
print()
model.fit(num_epochs, loader)

CausalTransformer(
  (embedding): EmbeddingProjectionModel(
    (projection): Linear(in_features=7, out_features=8, bias=True)
  )
  (transformer): TransformerDecoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
        )
        (linear1): Linear(in_features=8, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=8, bias=True)
        (norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2

KeyboardInterrupt: 

# Testing

In [None]:
model.eval()

criterion = torch.nn.MSELoss()

f = open(
    f"test_results/test_result_transformer_same_node_seql_{datetime.datetime.now().strftime("%Y_%m_%d_%H_%M")}.csv",
    "w",
    newline="",
)
csv_writer = csv.writer(f)
csv_writer.writerow(["Actual", "Predicted", "Correct"])


dataset_s_n7_test = CVFConfigForTransformerTestDatasetWName(
    device,
    "star_graph_n7",
    "star_graph_n7_config_rank_dataset.csv",
    D=7,
)


dataset_rr_n7_test = CVFConfigForTransformerTestDatasetWName(
    device,
    "graph_random_regular_graph_n7_d4",
    "graph_random_regular_graph_n7_d4_config_rank_dataset.csv",
    D=7,
)

dataset_plc_n7_test = CVFConfigForTransformerTestDatasetWName(
    device,
    "graph_powerlaw_cluster_graph_n7",
    "graph_powerlaw_cluster_graph_n7_config_rank_dataset.csv",
    D=7,
)


dataset_implicit_n5 = CVFConfigForTransformerTestDatasetWName(
    device,
    "implicit_graph_n5",
    "implicit_graph_n5_config_rank_dataset.csv",
    D=5,
    program="dijkstra",
)

test_datasets = ConcatDataset([dataset_s_n7_test])
sp_emb_dim = test_datasets.datasets[0].sp_emb_dim

with torch.no_grad():
    test_dataloader = DataLoader(test_datasets, batch_size=1024)

    total_loss = 0
    total_matched = 0
    count = 0
    total_seq_count = 0
    for batch in test_dataloader:
        x = batch[0][:, 0 : sp_emb_dim + 1, :]
        padd = torch.full((sequence_length - (sp_emb_dim + 1), vocab_size), -1).to(
            device
        )
        padded_batches = [torch.cat([b, padd]) for b in x]
        x = torch.stack(padded_batches)
        padding_mask = torch.full(
            (x.shape[0], sequence_length), 1, dtype=torch.bool
        ).to(device)
        padding_mask[:, 0 : sp_emb_dim + 1] = False
        padding_mask = padding_mask.float()
        y = batch[1]
        out = model(x, padding_mask)
        out = out[:, sp_emb_dim].unsqueeze(-1)
        matched = torch.round(out) == y
        csv_writer.writerows(
            (j.item(), k.item(), z.item())
            for (j, k, z) in zip(
                y.detach().cpu().numpy(), out.detach().cpu().numpy(), matched
            )
        )
        loss = criterion(out, y)
        total_loss += loss
        out = torch.round(out)
        matched = matched.sum().item()
        total_seq_count += out.numel()
        total_matched += matched
        count += 1

    print(
        f"Test set | MSE loss: {round((total_loss / count).item(), 4)} | Total matched: {total_matched:,} out of {total_seq_count:,} (Accuracy: {round(total_matched / total_seq_count * 100, 2):,}%)",
    )

f.close()

Total configs: 448.
Total configs: 78,125.
Total configs: 27,000.
Total configs: 243.


Test set | MSE loss: 0.0304 | Total matched: 426 out of 448 (Accuracy: 95.09%)
