## Environment SetUP


In [5]:
%pip install torch_cluster-1.6.1-cp310-cp310-linux_x86_64.whl
%pip install torch_geometric-2.3.1-py3-none-any.whl
%pip install torch_scatter-2.1.1-cp310-cp310-linux_x86_64.whl
%pip install torch_sparse-0.6.17-cp310-cp310-linux_x86_64.whl
%pip install torch_spline_conv-1.2.2-cp310-cp310-linux_x86_64.whl

Processing ./torch_cluster-1.6.1-cp310-cp310-linux_x86_64.whl
Installing collected packages: torch-cluster
  Attempting uninstall: torch-cluster
    Found existing installation: torch-cluster 1.6.3+pt20cu118
    Uninstalling torch-cluster-1.6.3+pt20cu118:
      Successfully uninstalled torch-cluster-1.6.3+pt20cu118
Successfully installed torch-cluster-1.6.1
Note: you may need to restart the kernel to use updated packages.
Processing ./torch_geometric-2.3.1-py3-none-any.whl
Installing collected packages: torch-geometric
  Attempting uninstall: torch-geometric
    Found existing installation: torch_geometric 2.4.0
    Uninstalling torch_geometric-2.4.0:
      Successfully uninstalled torch_geometric-2.4.0
Successfully installed torch-geometric-2.3.1
Note: you may need to restart the kernel to use updated packages.
Processing ./torch_scatter-2.1.1-cp310-cp310-linux_x86_64.whl
Installing collected packages: torch-scatter
  Attempting uninstall: torch-scatter
    Found existing installation

In [1]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm

import sklearn, sklearn.model_selection
import torch
from torch import nn
from torch import Tensor
from torch_geometric.nn import GCNConv, SAGEConv, aggr
from torch_geometric.datasets import Planetoid
from torch.utils.data import DataLoader, Dataset
from timm.scheduler import CosineLRScheduler
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"



## Load Data


In [2]:
def load_df(directory):
    splits = ["train", "valid", "test"]
    dfs = dict()

    for split in splits:
        path = os.path.join(directory, split)
        files = os.listdir(path)
        list_df = []

        for file in files:
            d = dict(np.load(os.path.join(path, file)))
            d["file"] = file
            list_df.append(d)
        dfs[split] = pd.DataFrame.from_dict(list_df)
    return dfs


tile_xla = load_df("./data/tpugraphs/npz_all/npz/tile/xla/")

In [23]:
tile_xla["train"]

Unnamed: 0,node_feat,node_opcode,edge_index,config_feat,config_runtime,config_runtime_normalizers,file
0,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[63, 11, 63, 11, 63, 41, 63, 41, 26, 63, 63, 41]","[[1, 0], [3, 2], [5, 1], [5, 4], [7, 3], [7, 6...","[[32.0, 32.0, 0.0, 0.0, 0.0, 0.0, 64.0, 1024.0...","[263238, 2029255, 1192602, 1027600, 1962135, 5...","[263238, 263238, 263238, 263238, 263238, 26323...",alexnet_train_batch_32_-1bae27a41d70f4dc.npz
1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[24, 13, 48, 87, 63, 13, 25, 52, 25, 63, 24, 1...","[[1, 0], [3, 1], [3, 2], [5, 4], [6, 5], [7, 3...","[[6.0, 12.0, 2.0, 2.0, 0.0, 0.0, 22.0, 288.0, ...","[155012, 3950817, 2048285, 1528077, 682642, 77...","[155012, 155012, 155012, 155012, 155012, 15501...",alexnet_train_batch_32_-21d9f3b8c41eb3e3.npz
2,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[63, 11, 63, 11, 63, 63, 13, 63, 41, 63, 41, 2...","[[1, 0], [3, 2], [6, 5], [8, 1], [8, 7], [10, ...","[[3.0, 12.0, 4.0, 3.0, 0.0, 0.0, 22.0, 432.0, ...","[113020, 667977, 966760, 5897798, 1554171, 308...","[113020, 113020, 113020, 113020, 113020, 11302...",alexnet_train_batch_32_-282ddd3271de7d28.npz
3,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[63, 63, 2, 24, 13, 48, 87, 63, 13, 25, 52, 25...","[[2, 0], [2, 1], [4, 3], [6, 4], [6, 5], [8, 7...","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0,...","[13580, 35675, 63934, 62597, 40362, 27707, 319...","[13580, 13580, 13580, 13580, 13580, 13580, 135...",alexnet_train_batch_32_-3545610a073feea6.npz
4,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[63, 63, 2, 63, 11, 63, 11, 24, 13, 48, 87, 63...","[[2, 0], [2, 1], [4, 3], [6, 5], [8, 7], [10, ...","[[3.0, 3.0, 16.0, 3.0, 0.0, 0.0, 25.0, 432.0, ...","[216908, 5999505, 14326342, 861357, 1297804, 2...","[216908, 216908, 216908, 216908, 216908, 21690...",alexnet_train_batch_32_-444744203bcb5069.npz
...,...,...,...,...,...,...,...
5704,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[63, 63, 13, 95, 63, 13, 59, 63, 13, 59, 63, 1...","[[2, 1], [3, 0], [3, 2], [5, 4], [6, 3], [6, 5...","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 11.0...","[284464, 270720, 892307, 324908, 294581, 59365...","[284464, 284464, 284464, 284464, 284464, 28446...",xception_imagenet_754825b353c7974f.npz
5705,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[63, 13, 63, 24, 13, 59, 63, 13, 95, 63, 13, 6...","[[1, 0], [4, 3], [5, 2], [5, 4], [7, 6], [8, 5...","[[8.0, 37.0, 2.0, 1.0, 0.0, 0.0, 48.0, 592.0, ...","[638608, 11962432, 6114872, 7372895, 5149337, ...","[638608, 638608, 638608, 638608, 638608, 63860...",xception_imagenet_7560f673e5820c82.npz
5706,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[63, 13, 63, 63, 13, 95, 63, 13, 59, 63, 13, 5...","[[1, 0], [4, 3], [5, 2], [5, 4], [7, 6], [8, 5...","[[7.0, 7.0, 2.0, 4.0, 0.0, 0.0, 20.0, 392.0, 1...","[244198, 7123302, 4481747, 1745201, 1627652, 8...","[244198, 244198, 244198, 244198, 244198, 24419...",xception_imagenet_7720f6dabe293cfe.npz
5707,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...","[63, 13, 63, 24, 13, 59, 63, 13, 95, 63, 13, 6...","[[1, 0], [4, 3], [5, 2], [5, 4], [7, 6], [8, 5...","[[1.0, 28.0, 2.0, 2.0, 0.0, 0.0, 33.0, 112.0, ...","[95785, 807268, 437115, 220154, 264825, 179818...","[95785, 95785, 95785, 95785, 95785, 95785, 957...",xception_imagenet_7eaa46ca4812dfb2.npz


In [12]:
def write_edge_list_to_file(edges, filename):
    with open(filename, "w") as file:
        for edge in edges:
            file.write(f"{edge[0]} {edge[1]}\n")


write_edge_list_to_file(edge_list, "edge_list.txt")

In [22]:
import os

# Create the folder if it doesn't exist
if not os.path.exists("all_edge_list"):
    os.makedirs("all_edge_list")

for i, edge in enumerate(edge_list):
    write_edge_list_to_file(edge, f"./all_edge_list/edge_list_{i}.txt")

In [14]:
# tile_xla['valid'].loc[0, 'config_runtime']/(tile_xla['valid'].loc[0, 'config_runtime_normalizers'] + 1e-5)

In [3]:
NODE_OP_CODES = 120  # Number of node operation codes
# opcode padding value 121
NODE_FEATS = 140  # Number of node features
CONFIG_FEATS = 24  # Number of configuration features
NODE_CONFIG_FEATS = 18  # Number of combined node and configuration featueres
STRECHING_CONSTANT = 1

In [24]:
!pip install networkx



In [4]:
import networkx as nx


def topological_sort(graph_edges):
    G = nx.DiGraph(graph_edges)
    return np.array(list(nx.topological_sort(G)))


topological_sort([tuple(edge) for edge in tile_xla["train"]["edge_index"][0]])

array([11,  8,  9, 10,  5,  7,  1,  4,  3,  6,  0,  2])

In [5]:
class TileDataset(Dataset):
    def __init__(self, df, n=STRECHING_CONSTANT, reverse=False):
        self.df = df
        self.streching_constant = n
        self.reverse = reverse

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        config_feat = torch.tensor(row["config_feat"].astype(np.float32))
        node_feat = torch.tensor(row["node_feat"].astype(np.float32))
        node_opcode = torch.tensor(row["node_opcode"].astype(np.int64))
        if self.reverse:
            node_sequence = torch.tensor(
                np.flip(
                    topological_sort([tuple(edge) for edge in row["edge_index"]])
                ).astype(np.int64)
            )
        else:
            node_sequence = torch.tensor(
                topological_sort([tuple(edge) for edge in row["edge_index"]]).astype(
                    np.int64
                )
            )
        # edge_index = torch.tensor(np.swapaxes(row["edge_index"], 0, 1).astype(np.int64))
        target = (
            row["config_runtime"] / (row["config_runtime_normalizers"] + 1e-5)
        ).astype(np.float32)
        # minmax scale the target, we only care about order
        # target = (
        #     (target - np.min(target))
        #     / (np.max(target) - np.min(target) + 1e-5)
        #     * self.streching_constant
        # )
        target = torch.tensor(target)
        return config_feat, node_feat, node_opcode, node_sequence, target

In [6]:
def custom_collate(batch):
    config_feats, node_feats, node_opcodes, node_sequences, targets = zip(*batch)

    bs = len(config_feats)  # batch_size
    max_len_config = max([config_feat.shape[0] for config_feat in config_feats])
    max_len_node = max([len(node_sequence) for node_sequence in node_sequences])

    padded_config = torch.zeros(bs, max_len_config, config_feats[0].shape[-1])
    padded_target = torch.zeros(bs, max_len_config)
    padded_feat = torch.zeros(bs, max_len_node, node_feats[0].shape[-1])
    padded_opcode = torch.ones(bs, max_len_node) * 121  # opcode padding value
    padded_sequence = torch.ones(bs, max_len_node) * -1  # not a valid node index
    config_mask = torch.zeros(bs, max_len_config)

    for idx, (config_feat, node_feat, node_opcode, node_sequence, target) in enumerate(
        batch
    ):
        padded_config[idx, : config_feat.shape[0]] = config_feat
        padded_target[idx, : target.shape[0]] = target
        config_mask[idx, : target.shape[0]] = 1
        padded_feat[idx, : node_feat.shape[0]] = node_feat
        padded_opcode[idx, : node_opcode.shape[0]] = node_opcode
        padded_sequence[idx, : node_sequence.shape[0]] = node_sequence

    return (
        padded_sequence.to(torch.int64),
        padded_opcode.to(torch.int64),
        padded_feat,
        padded_config,
        config_mask,
        padded_target,
    )

In [64]:
for (
    padded_sequence,
    padded_opcode,
    padded_feat,
    padded_config,
    config_mask,
    padded_target,
) in train_loader:
    gather_indices = torch.where(
        padded_sequence == -1, torch.zeros_like(padded_sequence), padded_sequence
    )
    print(padded_target)
    # print(gather_indices.unsqueeze(-1).expand(-1, -1, padded_feat.shape[-1]))
    break

tensor([[1.3587e-04, 6.5726e-01, 3.4595e-01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [3.0426e-04, 2.3086e-01, 1.1530e-01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.1099e-05, 3.6438e-02, 5.6050e-02,  ..., 2.3217e-03, 5.2476e-03,
         4.8365e-04],
        [4.3841e-03, 1.1808e-01, 2.6349e-01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


In [None]:
class SequenceModel(torch.nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.op_embedding_dim = embedding_dim
        self.embedding = torch.nn.Embedding(
            NODE_OP_CODES, self.op_embedding_dim, padding_idx=121
        )

    def forward(self, node_sequence, node_opcode, node_feat, configs, config_mask):
        node_features = torch.concat(
            [node_feat, self.embedding(node_opcode)], dim=-1
        )  # [bs, # of nodes, feat_dim]
        # node_features = self.pre_net(node_features)
        gather_indices = torch.where(
            node_sequence == -1, torch.zeros_like(node_sequence), node_sequence
        )
        sequence = torch.gather(
            node_features,
            1,
            gather_indices.unsqueeze(-1).expand(-1, -1, node_features.shape[-1]),
        )
        mask = (node_sequence != -1).float()

In [285]:
train_data = TileDataset(tile_xla["train"])

In [286]:
train_loader = DataLoader(
    train_data, batch_size=8, shuffle=True, collate_fn=custom_collate
)
# train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=0)

In [79]:
class SimpleModel(torch.nn.Module):
    def __init__(
        self,
        hidden_channels,
        embedding_dim,
        graph_in,
        graph_out,
        hidden_dim,
        activation_fn,
        dropout=0.0,
        n=STRECHING_CONSTANT,
        aggregator=aggr.MultiAggregation,
    ):
        assert len(hidden_channels) > 0
        super().__init__()
        self.op_embedding_dim = embedding_dim  # I choose 4-dimensional embedding
        self.node_dim = graph_in
        self.hidden_dim = hidden_dim
        self.activation_fn = activation_fn
        self.dropout = dropout
        self.streching_constant = n
        self.agg = aggregator
        self.gnn_conv = SAGEConv
        self.embedding = torch.nn.Embedding(
            NODE_OP_CODES,
            self.op_embedding_dim,
        )
        self.pre_net = torch.nn.Sequential(
            nn.Linear(self.op_embedding_dim + NODE_FEATS, self.hidden_dim),
            nn.Dropout(p=self.dropout),
            self.activation_fn(),
            nn.Linear(self.hidden_dim, self.node_dim),
        )

        self.convs = torch.nn.ModuleList()
        aggr1 = self.agg(
            aggrs=["mean", "std"],
            mode="attn",
            mode_kwargs=dict(
                in_channels=self.node_dim, out_channels=hidden_channels[0], num_heads=4
            ),
        )
        self.convs.append(self.gnn_conv(self.node_dim, hidden_channels[0]))
        for i in range(len(hidden_channels) - 1):
            aggr2 = self.agg(
                aggrs=["mean", "std"],
                mode="attn",
                mode_kwargs=dict(
                    in_channels=hidden_channels[i],
                    out_channels=hidden_channels[i + 1],
                    num_heads=4,
                ),
            )
            self.convs.append(self.gnn_conv(hidden_channels[i], hidden_channels[i + 1]))
        aggr3 = self.agg(
            aggrs=["mean", "std"],
            mode="attn",
            mode_kwargs=dict(
                in_channels=hidden_channels[-1], out_channels=graph_out, num_heads=4
            ),
        )
        self.convs.append(self.gnn_conv(hidden_channels[-1], graph_out))

        self.post_net = torch.nn.Sequential(
            nn.Linear(graph_out * 2 + 24, self.hidden_dim),
            nn.Dropout(p=self.dropout),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Dropout(p=self.dropout),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1),
        )

    #         self.dropout = nn.Dropout(p=dropout)

    def forward(
        self, x_cfg: Tensor, x_feat: Tensor, x_op: Tensor, edge_index: Tensor
    ) -> Tensor:
        # get graph features
        x = torch.concat([x_feat, self.embedding(x_op)], dim=1)
        x = self.pre_net(x)
        # pass though conv layers
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        # get 1d graph embedding using average pooling
        x_mean = x.mean(0)
        x_max = x.max(0).values

        # put graph data into config data
        x = torch.concat(
            [x_cfg, x_max.repeat((len(x_cfg), 1)), x_mean.repeat((len(x_cfg), 1))],
            axis=1,
        )
        # put into dense nn
        x = torch.flatten(self.post_net(x))
        x = (
            (x - torch.min(x))
            / (torch.max(x) - torch.min(x) + 1e-5)
            * self.streching_constant
        )
        return x

In [12]:
def score_tile_mean(predictions, df):
    score = 0
    for i in range(len(df)):
        predbest = np.mean(df.iloc[i]["config_runtime"][predictions[i]])
        best = np.mean(np.sort(df.iloc[i]["config_runtime"])[:50])
        score += 2 - predbest / best
    score /= len(df)
    return score


def score_tile_max(predictions, df):
    score = 0
    for i in range(len(df)):
        predbest = np.min(df.iloc[i]["config_runtime"][predictions[i][:5]])
        best = np.min(df.iloc[i]["config_runtime"])
        #         print(best,predbest)
        score += 2 - predbest / best
    score /= len(df)
    return score

In [13]:
class CustomWeightedMSELoss(nn.Module):
    def __init__(self, epsilon=1e-6):
        super(CustomWeightedMSELoss, self).__init__()
        self.epsilon = epsilon
        self.mse_loss = nn.MSELoss(
            reduction="none"
        )  # We will handle the reduction ourselves

    def forward(self, y_pred, y_true):
        # Calculate the weights based on the true values
        weights = 1 / (y_true + self.epsilon)  # Add epsilon to avoid division by zero
        # Calculate the per-element squared error
        per_element_loss = self.mse_loss(y_pred, y_true)
        # Apply the weights and calculate the mean loss
        weighted_loss = weights * per_element_loss
        loss = torch.mean(weighted_loss)
        return loss

In [67]:
class CustomMarginRankingLoss(nn.Module):
    def __init__(self, margin=0.0, reduction="none"):
        super(CustomMarginRankingLoss, self).__init__()
        self.loss = nn.MarginRankingLoss(margin=margin, reduction=reduction)

    def forward(self, y_pred, y_true):
        true = y_true.view(y_true.shape[0], -1).half()
        pred = y_pred.view(y_pred.shape[0], -1).half()
        true_diffs = true - true.T
        pred_diffs = pred - pred.T
        s_ij = torch.sign(true_diffs).to(device)
        s_ij.fill_diagonal_(0)
        pred_diffs.fill_diagonal_(0)
        cost = self.loss(pred_diffs, torch.zeros_like(pred_diffs), s_ij)
        cost = cost[s_ij != 0].mean()
        return cost

In [23]:
model = SimpleModel(
    hidden_channels=[32, 48, 64, 84],
    embedding_dim=32,
    graph_in=64,
    graph_out=64,
    hidden_dim=128,
    activation_fn=nn.ReLU,
    dropout=0.2,
).to(device)
train_dataset = TileDataset(tile_xla["train"])
val_dataset = TileDataset(tile_xla["valid"])
criterion = CustomWeightedMSELoss()
epochs = 20
steps = len(train_dataset) * epochs
warmup_steps = int(steps * 0.2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = CosineLRScheduler(
    optimizer,
    t_initial=steps,
    warmup_t=warmup_steps,
    warmup_lr_init=1e-6,
    lr_min=2e-8,
)

In [31]:
df = pd.concat((tile_xla["train"], tile_xla["valid"]), axis=0).reset_index(drop=True)

In [80]:
kfold = sklearn.model_selection.KFold(n_splits=10, shuffle=True, random_state=0)
score_means = []
score_maxs = []
for fold, (tr_idx, va_idx) in enumerate(kfold.split(df)):
    train_dataset = TileDataset(df.iloc[tr_idx])
    val_dataset = TileDataset(df.iloc[va_idx])
    model = SimpleModel(
        hidden_channels=[32, 48, 64, 84],
        embedding_dim=32,
        graph_in=64,
        graph_out=64,
        hidden_dim=128,
        activation_fn=nn.ReLU,
        dropout=0.2,
    ).to(device)
    criterion = CustomWeightedMSELoss()
    epochs = 15
    steps = len(train_dataset) * epochs
    warmup_steps = int(steps * 0.15)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = CosineLRScheduler(
        optimizer,
        t_initial=steps,
        warmup_t=warmup_steps,
        warmup_lr_init=1e-6,
        lr_min=2e-8,
    )

    best_score = 0
    best_score_max = 0
    for epoch in range(15):
        model.train()
        pbar = tqdm(range(len(train_dataset)), leave=False)
        loss_sum = 0
        n = 0
        for i in pbar:
            cfg_ft, nd_ft, nd_op, ind, target = train_dataset[i]
            cfg_ft, nd_ft, nd_op, ind, target = (
                cfg_ft.to(device),
                nd_ft.to(device),
                nd_op.to(device),
                ind.to(device),
                target.to(device),
            )

            out = model(cfg_ft, nd_ft, nd_op, ind)
            loss = criterion(out, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
            scheduler.step(i + len(train_dataset) * epoch)
            optimizer.step()
            loss_sum += loss.item()
            n += 1
            pbar.set_description(
                f"running loss: {(loss_sum/n):.2f},current loss: {(loss.item()):.2f}"
            )
        pbar.close()
        model.eval()

        tile_xla_predictions = []
        pbar = tqdm(range(len(val_dataset)), leave=False)
        for i in pbar:
            cfg_ft, nd_ft, nd_op, ind, target = val_dataset[i]
            cfg_ft, nd_ft, nd_op, ind, target = (
                cfg_ft.to(device),
                nd_ft.to(device),
                nd_op.to(device),
                ind.to(device),
                target.to(device),
            )
            with torch.no_grad():
                out = model(cfg_ft, nd_ft, nd_op, ind)
            tile_xla_predictions.append(np.argsort(out.detach().cpu().numpy())[:50])
        pbar.close()
        score_mean = score_tile_mean(tile_xla_predictions, val_dataset.df)
        score_max = score_tile_max(tile_xla_predictions, val_dataset.df)
        print(
            f"fold {fold} epoch {epoch}, comp_score = {score_max:.3f}, mean_score = {score_mean:.3f},"
        )
        if score_mean > best_score:
            best_score = score_mean
            best_score_max = score_max
            torch.save(model.state_dict(), f"best_model_{fold}.pth")
    score_means.append(best_score)
    score_maxs.append(best_score_max)
print(f"comp_score = {np.mean(score_maxs)}, mean_score = {np.mean(score_means)},")

                                                                                                 

fold 0 epoch 0, comp_score = 0.518, mean_score = 0.147,


                                                                                                 

fold 0 epoch 1, comp_score = -0.933, mean_score = -2.074,


                                                                                                 

fold 0 epoch 2, comp_score = 0.971, mean_score = 0.227,


                                                                                                 

fold 0 epoch 3, comp_score = 0.969, mean_score = 0.200,


                                                                                                 

fold 0 epoch 4, comp_score = 0.969, mean_score = 0.324,


                                                                                                 

fold 0 epoch 5, comp_score = 0.961, mean_score = 0.178,


                                                                                                 

fold 0 epoch 6, comp_score = 0.969, mean_score = 0.323,


                                                                                                 

fold 0 epoch 7, comp_score = 0.936, mean_score = 0.474,


                                                                                                 

fold 0 epoch 8, comp_score = 0.970, mean_score = 0.369,


                                                                                                 

fold 0 epoch 9, comp_score = 0.919, mean_score = 0.468,


                                                                                                 

fold 0 epoch 10, comp_score = 0.950, mean_score = 0.446,


                                                                                                 

fold 0 epoch 11, comp_score = 0.494, mean_score = 0.331,


                                                                                                 

fold 0 epoch 12, comp_score = 0.910, mean_score = 0.576,


                                                                                                 

fold 0 epoch 13, comp_score = 0.931, mean_score = 0.525,


                                                                                                 

fold 0 epoch 14, comp_score = 0.934, mean_score = 0.521,


                                                                                                 

fold 1 epoch 0, comp_score = 0.421, mean_score = 0.104,


                                                                                                 

fold 1 epoch 1, comp_score = 0.970, mean_score = -1.187,


                                                                                                 

fold 1 epoch 2, comp_score = 0.966, mean_score = 0.147,


                                                                                                 

fold 1 epoch 3, comp_score = 0.971, mean_score = 0.252,


                                                                                                 

fold 1 epoch 4, comp_score = 0.967, mean_score = 0.171,


                                                                                                 

fold 1 epoch 5, comp_score = 0.963, mean_score = 0.227,


                                                                                                 

fold 1 epoch 6, comp_score = 0.961, mean_score = 0.309,


                                                                                                 

fold 1 epoch 7, comp_score = 0.960, mean_score = 0.298,


                                                                                                 

fold 1 epoch 8, comp_score = 0.974, mean_score = 0.288,


                                                                                                 

fold 1 epoch 9, comp_score = 0.921, mean_score = 0.474,


                                                                                                 

fold 1 epoch 10, comp_score = 0.929, mean_score = 0.491,


                                                                                                 

fold 1 epoch 11, comp_score = 0.932, mean_score = 0.628,


                                                                                                 

fold 1 epoch 12, comp_score = 0.945, mean_score = 0.520,


                                                                                                 

fold 1 epoch 13, comp_score = 0.950, mean_score = 0.547,


                                                                                                 

fold 1 epoch 14, comp_score = 0.937, mean_score = 0.559,


                                                                                                 

fold 2 epoch 0, comp_score = 0.438, mean_score = 0.005,


                                                                                                 

fold 2 epoch 1, comp_score = 0.971, mean_score = -1.216,


                                                                                                 

fold 2 epoch 2, comp_score = 0.962, mean_score = 0.242,


                                                                                                 

fold 2 epoch 3, comp_score = 0.952, mean_score = 0.278,


                                                                                                 

fold 2 epoch 4, comp_score = 0.963, mean_score = 0.193,


                                                                                                 

fold 2 epoch 5, comp_score = 0.968, mean_score = 0.111,


                                                                                                 

fold 2 epoch 6, comp_score = 0.957, mean_score = 0.407,


                                                                                                 

fold 2 epoch 7, comp_score = 0.963, mean_score = 0.323,


                                                                                                 

fold 2 epoch 8, comp_score = 0.906, mean_score = 0.674,


                                                                                                 

fold 2 epoch 9, comp_score = 0.943, mean_score = 0.476,


                                                                                                 

fold 2 epoch 10, comp_score = 0.958, mean_score = 0.383,


                                                                                                 

fold 2 epoch 11, comp_score = 0.957, mean_score = 0.414,


                                                                                                 

fold 2 epoch 12, comp_score = 0.899, mean_score = 0.520,


                                                                                                 

fold 2 epoch 13, comp_score = 0.917, mean_score = 0.573,


                                                                                                 

fold 2 epoch 14, comp_score = 0.911, mean_score = 0.539,


                                                                                                 

fold 3 epoch 0, comp_score = -0.328, mean_score = -0.308,


                                                                                                 

fold 3 epoch 1, comp_score = 0.969, mean_score = -1.325,


                                                                                                 

fold 3 epoch 2, comp_score = 0.961, mean_score = 0.180,


                                                                                                 

fold 3 epoch 3, comp_score = 0.956, mean_score = 0.159,


                                                                                                 

fold 3 epoch 4, comp_score = 0.961, mean_score = 0.157,


                                                                                                 

fold 3 epoch 5, comp_score = 0.963, mean_score = 0.049,


                                                                                                 

fold 3 epoch 6, comp_score = 0.959, mean_score = 0.097,


                                                                                                 

fold 3 epoch 7, comp_score = 0.902, mean_score = 0.395,


                                                                                                 

fold 3 epoch 8, comp_score = 0.958, mean_score = 0.304,


                                                                                                 

fold 3 epoch 9, comp_score = 0.963, mean_score = 0.256,


                                                                                                 

fold 3 epoch 10, comp_score = 0.948, mean_score = 0.392,


                                                                                                 

fold 3 epoch 11, comp_score = 0.955, mean_score = 0.387,


                                                                                                 

fold 3 epoch 12, comp_score = 0.914, mean_score = 0.509,


                                                                                                 

fold 3 epoch 13, comp_score = 0.932, mean_score = 0.525,


                                                                                                 

fold 3 epoch 14, comp_score = 0.875, mean_score = 0.621,


                                                                                                 

fold 4 epoch 0, comp_score = -1.500, mean_score = -1.318,


                                                                                                 

fold 4 epoch 1, comp_score = 0.968, mean_score = -0.504,


                                                                                                 

fold 4 epoch 2, comp_score = 0.960, mean_score = 0.251,


                                                                                                 

fold 4 epoch 3, comp_score = 0.967, mean_score = 0.179,


                                                                                                 

fold 4 epoch 4, comp_score = 0.963, mean_score = 0.308,


                                                                                                 

fold 4 epoch 5, comp_score = 0.920, mean_score = 0.554,


                                                                                                 

fold 4 epoch 6, comp_score = 0.954, mean_score = 0.388,


                                                                                                 

fold 4 epoch 7, comp_score = 0.956, mean_score = 0.441,


                                                                                                 

fold 4 epoch 8, comp_score = 0.929, mean_score = 0.585,


                                                                                                 

fold 4 epoch 9, comp_score = 0.923, mean_score = 0.631,


                                                                                                 

fold 4 epoch 10, comp_score = 0.887, mean_score = 0.524,


                                                                                                 

fold 4 epoch 11, comp_score = 0.943, mean_score = 0.579,


                                                                                                 

fold 4 epoch 12, comp_score = 0.902, mean_score = 0.578,


                                                                                                 

fold 4 epoch 13, comp_score = 0.937, mean_score = 0.578,


                                                                                                 

fold 4 epoch 14, comp_score = 0.939, mean_score = 0.550,


                                                                                                 

fold 5 epoch 0, comp_score = 0.536, mean_score = 0.158,


                                                                                                 

fold 5 epoch 1, comp_score = 0.941, mean_score = 0.564,


                                                                                                 

fold 5 epoch 2, comp_score = 0.967, mean_score = 0.157,


                                                                                                 

fold 5 epoch 3, comp_score = 0.974, mean_score = 0.028,


                                                                                                 

fold 5 epoch 4, comp_score = 0.963, mean_score = 0.168,


                                                                                                 

fold 5 epoch 5, comp_score = 0.970, mean_score = 0.132,


                                                                                                 

fold 5 epoch 6, comp_score = 0.693, mean_score = 0.448,


                                                                                                 

fold 5 epoch 7, comp_score = 0.961, mean_score = 0.371,


                                                                                                 

fold 5 epoch 8, comp_score = 0.923, mean_score = 0.482,


                                                                                                 

fold 5 epoch 9, comp_score = 0.963, mean_score = 0.392,


                                                                                                 

fold 5 epoch 10, comp_score = 0.950, mean_score = 0.525,


                                                                                                 

fold 5 epoch 11, comp_score = 0.924, mean_score = 0.554,


                                                                                                 

fold 5 epoch 12, comp_score = 0.946, mean_score = 0.541,


                                                                                                 

fold 5 epoch 13, comp_score = 0.947, mean_score = 0.602,


                                                                                                 

fold 5 epoch 14, comp_score = 0.929, mean_score = 0.600,


                                                                                                 

fold 6 epoch 0, comp_score = 0.545, mean_score = 0.159,


                                                                                                 

fold 6 epoch 1, comp_score = 0.971, mean_score = -0.419,


                                                                                                 

fold 6 epoch 2, comp_score = 0.965, mean_score = 0.126,


                                                                                                 

fold 6 epoch 3, comp_score = 0.925, mean_score = 0.445,


                                                                                                 

fold 6 epoch 4, comp_score = 0.964, mean_score = 0.164,


                                                                                                 

fold 6 epoch 5, comp_score = 0.954, mean_score = 0.338,


                                                                                                 

fold 6 epoch 6, comp_score = 0.949, mean_score = 0.487,


                                                                                                 

fold 6 epoch 7, comp_score = 0.963, mean_score = 0.224,


                                                                                                 

fold 6 epoch 8, comp_score = 0.960, mean_score = 0.271,


                                                                                                 

fold 6 epoch 9, comp_score = 0.923, mean_score = 0.524,


                                                                                                 

fold 6 epoch 10, comp_score = 0.938, mean_score = 0.496,


                                                                                                 

fold 6 epoch 11, comp_score = 0.921, mean_score = 0.589,


                                                                                                 

fold 6 epoch 12, comp_score = 0.939, mean_score = 0.561,


                                                                                                 

fold 6 epoch 13, comp_score = 0.946, mean_score = 0.588,


                                                                                                 

fold 6 epoch 14, comp_score = 0.932, mean_score = 0.576,


                                                                                                 

fold 7 epoch 0, comp_score = 0.544, mean_score = 0.237,


                                                                                                 

fold 7 epoch 1, comp_score = 0.973, mean_score = -0.954,


                                                                                                 

fold 7 epoch 2, comp_score = 0.965, mean_score = 0.202,


                                                                                                 

fold 7 epoch 3, comp_score = 0.931, mean_score = 0.489,


                                                                                                 

fold 7 epoch 4, comp_score = 0.964, mean_score = 0.317,


                                                                                                 

fold 7 epoch 5, comp_score = 0.623, mean_score = 0.350,


                                                                                                 

fold 7 epoch 6, comp_score = 0.911, mean_score = 0.556,


                                                                                                 

fold 7 epoch 7, comp_score = 0.917, mean_score = 0.585,


                                                                                                 

fold 7 epoch 8, comp_score = 0.953, mean_score = 0.418,


                                                                                                 

fold 7 epoch 9, comp_score = 0.903, mean_score = 0.497,


                                                                                                 

fold 7 epoch 10, comp_score = 0.948, mean_score = 0.492,


                                                                                                 

fold 7 epoch 11, comp_score = 0.938, mean_score = 0.523,


                                                                                                 

fold 7 epoch 12, comp_score = 0.938, mean_score = 0.544,


                                                                                                 

fold 7 epoch 13, comp_score = 0.931, mean_score = 0.567,


                                                                                                 

fold 7 epoch 14, comp_score = 0.899, mean_score = 0.557,


                                                                                                 

fold 8 epoch 0, comp_score = 0.453, mean_score = 0.172,


                                                                                                 

fold 8 epoch 1, comp_score = 0.973, mean_score = -1.206,


                                                                                                 

fold 8 epoch 2, comp_score = 0.969, mean_score = -0.089,


                                                                                                 

fold 8 epoch 3, comp_score = 0.951, mean_score = 0.202,


                                                                                                 

fold 8 epoch 4, comp_score = 0.961, mean_score = 0.137,


                                                                                                 

fold 8 epoch 5, comp_score = 0.907, mean_score = 0.508,


                                                                                                 

fold 8 epoch 6, comp_score = 0.965, mean_score = 0.088,


                                                                                                 

fold 8 epoch 7, comp_score = 0.937, mean_score = 0.421,


                                                                                                 

fold 8 epoch 8, comp_score = 0.957, mean_score = 0.341,


                                                                                                 

fold 8 epoch 9, comp_score = 0.956, mean_score = 0.444,


                                                                                                 

fold 8 epoch 10, comp_score = 0.939, mean_score = 0.420,


                                                                                                 

fold 8 epoch 11, comp_score = 0.724, mean_score = 0.561,


                                                                                                 

fold 8 epoch 12, comp_score = 0.922, mean_score = 0.547,


                                                                                                 

fold 8 epoch 13, comp_score = 0.926, mean_score = 0.586,


                                                                                                 

fold 8 epoch 14, comp_score = 0.796, mean_score = 0.547,


                                                                                                 

fold 9 epoch 0, comp_score = 0.630, mean_score = 0.313,


                                                                                                 

fold 9 epoch 1, comp_score = 0.971, mean_score = -0.459,


                                                                                                 

fold 9 epoch 2, comp_score = 0.964, mean_score = 0.171,


                                                                                                 

fold 9 epoch 3, comp_score = 0.961, mean_score = 0.466,


                                                                                                 

fold 9 epoch 4, comp_score = 0.969, mean_score = 0.283,


                                                                                                 

fold 9 epoch 5, comp_score = 0.968, mean_score = 0.346,


                                                                                                 

fold 9 epoch 6, comp_score = 0.967, mean_score = 0.179,


                                                                                                 

fold 9 epoch 7, comp_score = 0.953, mean_score = 0.506,


                                                                                                 

fold 9 epoch 8, comp_score = 0.962, mean_score = 0.427,


                                                                                                 

fold 9 epoch 9, comp_score = 0.967, mean_score = 0.397,


                                                                                                 

fold 9 epoch 10, comp_score = 0.938, mean_score = 0.514,


                                                                                                 

fold 9 epoch 11, comp_score = 0.936, mean_score = 0.566,


                                                                                                 

fold 9 epoch 12, comp_score = 0.700, mean_score = 0.540,


                                                                                                 

fold 9 epoch 13, comp_score = 0.934, mean_score = 0.613,


                                                                                                 

fold 9 epoch 14, comp_score = 0.940, mean_score = 0.615,
comp_score = 0.9196846112576488, mean_score = 0.6107592581974552,


In [81]:
dataset = TileDataset(tile_xla["test"])
tile_xla_predictions = [[] for i in range(len(dataset))]
for fold in range(10):
    model.load_state_dict(torch.load(f"best_model_{fold}.pth"))
    model.eval()
    pbar = tqdm(range(len(dataset)))
    for i in pbar:
        cfg_ft, nd_ft, nd_op, ind, target = dataset[i]
        cfg_ft, nd_ft, nd_op, ind, target = (
            cfg_ft.to(device),
            nd_ft.to(device),
            nd_op.to(device),
            ind.to(device),
            target.to(device),
        )

        out = model(cfg_ft, nd_ft, nd_op, ind)
        tile_xla_predictions[i].append(out.detach().cpu().numpy())
tile_xla_predictions = [
    np.argsort(np.mean(pred, axis=0))[:5] for pred in tile_xla_predictions
]

100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 438.74it/s]
100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 524.75it/s]
100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 535.08it/s]
100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 522.98it/s]
100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 507.70it/s]
100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 514.68it/s]
100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 526.96it/s]
100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 517.41it/s]
100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 514.42it/s]
100%|█████████████████████████████████████████████████████████| 844/844 [00:01<00:00, 505.30it/s]


In [82]:
sub = pd.read_csv("./data/tpugraphs/sample_submission.csv")
for i, filename in enumerate(tile_xla["test"]["file"].values):
    id = "tile:xla:" + filename[:-4]
    sub.loc[sub.ID == id, "TopConfigs"] = ";".join(tile_xla_predictions[i].astype(str))
sub.to_csv("submission.csv", index=False)
sub

Unnamed: 0,ID,TopConfigs
0,tile:xla:d6f5f54247bd1e58a10b9e7062c636ab,0;22;21;20;19
1,tile:xla:e3a655daa38e34ec240df959b650ac16,513;1290;1282;866;697
2,tile:xla:f8c2c1a1098b2a361c26df668b286c87,41;116;101;202;166
3,tile:xla:4dd1716853ed46ee4e7d09ede1732de8,6939;3321;1910;4644;7374
4,tile:xla:d0a69155b6340748c36724e4bfc34be3,171;554;810;576;229
...,...,...
889,layout:nlp:random:60880ed76de53f4d7a1b960b24f2...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
890,layout:nlp:random:23559853d9702baaaacbb0c83fd3...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
891,layout:nlp:random:f6c146fc5cf10be4f3accbaca989...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
892,layout:nlp:random:32531d07a084b319dce484f53a4c...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...


In [24]:
for epoch in range(epochs):
    model.train()
    pbar = tqdm(range(len(train_dataset)), leave=False)
    loss_sum = 0
    n = 0
    for i in pbar:
        cfg_ft, nd_ft, nd_op, ind, target = train_dataset[i]
        cfg_ft, nd_ft, nd_op, ind, target = (
            cfg_ft.to(device),
            nd_ft.to(device),
            nd_op.to(device),
            ind.to(device),
            target.to(device),
        )
        out = model(cfg_ft, nd_ft, nd_op, ind)
        loss = criterion(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
        scheduler.step(i + len(train_dataset) * epoch)
        optimizer.step()
        loss_sum += loss.item()
        n += 1
        pbar.set_description(
            f"running loss: {(loss_sum/n):.2f},current loss: {(loss.item()):.2f}"
        )
    pbar.close()
    model.eval()
    tile_xla_predictions = []
    pbar = tqdm(range(len(val_dataset)), leave=False)
    for i in pbar:
        cfg_ft, nd_ft, nd_op, ind, target = val_dataset[i]
        cfg_ft, nd_ft, nd_op, ind, target = (
            cfg_ft.to(device),
            nd_ft.to(device),
            nd_op.to(device),
            ind.to(device),
            target.to(device),
        )
        with torch.no_grad():
            out = model(cfg_ft, nd_ft, nd_op, ind)
        tile_xla_predictions.append(np.argsort(out.detach().cpu().numpy())[:50])
    pbar.close()
    score_mean = score_tile_mean(tile_xla_predictions, val_dataset.df)
    score_max = score_tile_max(tile_xla_predictions, val_dataset.df)
    print(
        f"epoch {epoch}, comp_score = {score_max:.3f}, mean_score = {score_mean:.3f},"
    )
    if score_mean > best_score:
        best_score = score_mean
        best_score_max = score_max
        torch.save(model.state_dict(), f"best_model.pth")
        print(" * [@%i] Validation (NEW BEST): %s" % (epoch + 1, str(best_score)))

running loss: 231.48,current loss: 52.64:   0%|                 | 3/5709 [00:00<03:11, 29.82it/s]

                                                                                                 

epoch 0, comp_score = 0.448, mean_score = 0.192,


                                                                                                 

epoch 1, comp_score = 0.979, mean_score = -0.967,


                                                                                                 

epoch 2, comp_score = 0.974, mean_score = 0.334,


                                                                                                 

epoch 3, comp_score = 0.972, mean_score = 0.289,


                                                                                                 

epoch 4, comp_score = 0.908, mean_score = 0.268,


                                                                                                 

epoch 5, comp_score = 0.943, mean_score = 0.329,


                                                                                                 

epoch 6, comp_score = 0.960, mean_score = 0.382,


                                                                                                 

epoch 7, comp_score = 0.967, mean_score = 0.385,


                                                                                                 

epoch 8, comp_score = 0.958, mean_score = 0.350,


                                                                                                 

epoch 9, comp_score = 0.927, mean_score = 0.448,


                                                                                                 

epoch 10, comp_score = 0.938, mean_score = 0.326,


                                                                                                 

epoch 11, comp_score = 0.948, mean_score = 0.508,


                                                                                                 

epoch 12, comp_score = 0.969, mean_score = 0.399,


                                                                                                 

epoch 13, comp_score = 0.965, mean_score = 0.455,


                                                                                                 

epoch 14, comp_score = 0.945, mean_score = 0.485,


                                                                                                 

epoch 15, comp_score = 0.949, mean_score = 0.467,


                                                                                                 

epoch 16, comp_score = 0.948, mean_score = 0.476,


                                                                                                 

epoch 17, comp_score = 0.937, mean_score = 0.532,


                                                                                                 

epoch 18, comp_score = 0.959, mean_score = 0.549,


                                                                                                 

epoch 19, comp_score = 0.946, mean_score = 0.565,


In [28]:
dataset = TileDataset(tile_xla["test"])
tile_xla_predictions = []
model.load_state_dict(torch.load(f"best_model.pth"))
model.eval()
pbar = tqdm(range(len(dataset)))
for i in pbar:
    cfg_ft, nd_ft, nd_op, ind, target = dataset[i]
    cfg_ft, nd_ft, nd_op, ind, target = (
        cfg_ft.to(device),
        nd_ft.to(device),
        nd_op.to(device),
        ind.to(device),
        target.to(device),
    )
    with torch.no_grad():
        out = model(cfg_ft, nd_ft, nd_op, ind)
    tile_xla_predictions.append(np.argsort(out.detach().cpu().numpy())[:5])
tile_xla_predictions

100%|█████████████████████████████████████████████████████████| 844/844 [00:03<00:00, 256.91it/s]


[array([   0, 5234, 4093, 3570, 3560]),
 array([291, 328, 521, 134, 327]),
 array([ 216,  958,  268,  425, 1007]),
 array([ 672,  198,   31, 1099,  978]),
 array([532, 435, 436, 456, 458]),
 array([66, 61, 18, 58,  7]),
 array([2631, 1621, 1620, 7916, 4079]),
 array([1251, 3192, 2150, 2801,  865]),
 array([ 393, 1209,  228,  234,  901]),
 array([241,  60, 219,  93,  72]),
 array([2306, 3286, 4163, 7846, 2855]),
 array([ 968,  226, 1293,  224,  567]),
 array([5340, 5659, 5663, 8188, 9199]),
 array([7764, 8034, 1619, 2609, 5195]),
 array([115, 128,  36,   1,  98]),
 array([749, 117, 333, 502, 725]),
 array([145, 355, 233, 352, 240]),
 array([ 657, 2192, 1836,  183, 1157]),
 array([   0, 3410, 3403, 3400, 3395]),
 array([ 805,  205,  332,  547, 1207]),
 array([428,  79, 794, 545, 968]),
 array([51, 43, 38, 35, 28]),
 array([ 245,  271,  689,  269, 1063]),
 array([209, 293, 621, 284, 283]),
 array([451, 112, 268, 538, 632]),
 array([ 0, 98, 97, 95, 94]),
 array([2286, 9697, 6940, 5381, 118

In [29]:
sub = pd.read_csv("./data/tpugraphs/sample_submission.csv")
for i, filename in enumerate(tile_xla["test"]["file"].values):
    id = "tile:xla:" + filename[:-4]
    sub.loc[sub.ID == id, "TopConfigs"] = ";".join(tile_xla_predictions[i].astype(str))
sub.to_csv("submission.csv", index=False)
sub

Unnamed: 0,ID,TopConfigs
0,tile:xla:d6f5f54247bd1e58a10b9e7062c636ab,0;22;21;20;19
1,tile:xla:e3a655daa38e34ec240df959b650ac16,1250;1060;963;292;776
2,tile:xla:f8c2c1a1098b2a361c26df668b286c87,41;101;116;202;166
3,tile:xla:4dd1716853ed46ee4e7d09ede1732de8,7172;964;2688;8696;6651
4,tile:xla:d0a69155b6340748c36724e4bfc34be3,0;264;655;262;261
...,...,...
889,layout:nlp:random:60880ed76de53f4d7a1b960b24f2...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
890,layout:nlp:random:23559853d9702baaaacbb0c83fd3...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
891,layout:nlp:random:f6c146fc5cf10be4f3accbaca989...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
892,layout:nlp:random:32531d07a084b319dce484f53a4c...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...


In [7]:
import os
from pathlib import Path
from typing import Dict, Optional, List, Union, Tuple
from dataclasses import dataclass
import math
import numpy as np
import pandas as pd
from datasets import Dataset
from tqdm import tqdm
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.pytorch_utils import apply_chunking_to_forward
from transformers.activations import ACT2FN
import pytorch_lightning as pl
import torchmetrics as tm

In [8]:
@dataclass
class GraphConfig:
    num_hidden_layers: int = 8
    embedding_size: int = 256
    num_attention_heads: int = 16
    intermediate_size: int = 64
    chunk_size_feed_forward: int = 64
    attention_probs_dropout_prob: float = 0.0
    max_position_embeddings: int = 512
    hidden_dropout_prob: float = 0.0
    layer_norm_eps: float = 1e-12
    hidden_act: torch.nn = torch.nn.GELU
    initializer_range: float = 0.02
    output_hidden_states: bool = False
    output_attentions: bool = False
    gradient_checkpointing: bool = False
    margin: float = 0.1
    number_permutations: int = 10

    def __post_init__(self):
        self.hidden_size = self.embedding_size + 140

    def validate(self):
        if self.hidden_size % self.num_attention_heads != 0 and not hasattr(
            self, "embedding_size"
        ):
            raise ValueError(
                f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
                f"heads ({self.num_attention_heads})"
            )

    def save_config(self, path):
        config = asdict(self)
        with open(path, "w") as f:
            json.dump(config, f)

    @classmethod
    def load_config(cls, path):
        with open(path, "r") as f:
            config = json.load(f)
        return cls(**config)

In [9]:
config_kwargs = dict(
    embedding_size=128,
    num_attention_heads=4,
    num_hidden_layers=2,
    intermediate_size=64,
    gradient_checkpointing=True,
    margin=0.1,
    number_permutations=4,
)

config = GraphConfig(**config_kwargs)

In [10]:
import math


class BertSelfAttention(nn.Module):
    def __init__(self, config: GraphConfig, position_embedding_type=None):
        super().__init__()
        config.validate()

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if (
            self.position_embedding_type == "relative_key"
            or self.position_embedding_type == "relative_key_query"
        ):
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(
                2 * config.max_position_embeddings - 1, self.attention_head_size
            )

    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (
            self.num_attention_heads,
            self.attention_head_size,
        )
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(self.query(hidden_states))

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        if (
            self.position_embedding_type == "relative_key"
            or self.position_embedding_type == "relative_key_query"
        ):
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            position_ids_l = torch.arange(
                query_length, dtype=torch.long, device=hidden_states.device
            ).view(-1, 1)
            position_ids_r = torch.arange(
                key_length, dtype=torch.long, device=hidden_states.device
            ).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(
                distance + self.max_position_embeddings - 1
            )
            positional_embedding = positional_embedding.to(
                dtype=query_layer.dtype
            )  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding
                )
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum(
                    "bhld,lrd->bhlr", query_layer, positional_embedding
                )
                relative_position_scores_key = torch.einsum(
                    "bhrd,lrd->bhlr", key_layer, positional_embedding
                )
                attention_scores = (
                    attention_scores
                    + relative_position_scores_query
                    + relative_position_scores_key
                )
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(-1)
            attention_mask = attention_mask.expand(-1, self.num_attention_heads, -1, -1)
            attention_scores = attention_scores + attention_mask
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)
        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = (
                attention_probs * head_mask
            )  # DONE: Same Head Mask for all Heads

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (
            (context_layer, attention_probs) if output_attentions else (context_layer,)
        )

        return outputs

In [11]:
class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

In [12]:
class BertAttention(nn.Module):
    def __init__(self, config: GraphConfig, position_embedding_type=None):
        super().__init__()
        self.self = BertSelfAttention(
            config, position_embedding_type=position_embedding_type
        )
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[
            1:
        ]  # add attention_probs if we output them
        return outputs

In [13]:
class BertIntermediate(nn.Module):
    def __init__(self, config: GraphConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn()(hidden_states)
        return hidden_states

In [14]:
class BertOutput(nn.Module):
    def __init__(self, config: GraphConfig):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(
        self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
    ) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(
            hidden_states + input_tensor
        )  # Residual Connection
        return hidden_states

In [15]:
from transformers.pytorch_utils import apply_chunking_to_forward


class BertLayer(nn.Module):
    def __init__(self, config: GraphConfig):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[
            1:
        ]  # add self attentions if we output attention weights
        layer_output = self.feed_forward_chunk(attention_output)
        outputs = (layer_output,) + outputs
        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

In [16]:
class BertEncoder(nn.Module):
    def __init__(self, config: GraphConfig):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList(
            [BertLayer(config) for _ in range(config.num_hidden_layers)]
        )
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[torch.Tensor, BaseModelOutputWithPastAndCrossAttentions]:
        for layer_module in self.layer:
            if self.gradient_checkpointing and self.training:

                def custom_forward(module):
                    return module(hidden_states, attention_mask, output_attentions)

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    custom_forward, layer_module
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]

        pooled_output = hidden_states.mean(dim=1)

        if return_dict:
            return BaseModelOutputWithPastAndCrossAttentions(
                last_hidden_state=pooled_output,
                past_key_values=None,
                cross_attentions=None,
            )

        return pooled_output

In [17]:
class GraphToSequence(nn.Module):
    def __init__(self, config: GraphConfig):
        super().__init__()
        self.op_embedding_dim = config.embedding_size
        self.embedding = torch.nn.Embedding(
            NODE_OP_CODES + 2, self.op_embedding_dim, padding_idx=121
        )

    def forward(self, node_sequence, node_opcode, node_feat):
        node_features = torch.concat(
            [node_feat, self.embedding(node_opcode)], dim=-1
        )  # [bs, # of nodes, feat_dim]
        # node_features = self.pre_net(node_features)
        gather_indices = torch.where(
            node_sequence == -1, torch.zeros_like(node_sequence), node_sequence
        ).to(torch.int64)
        sequence = torch.gather(
            node_features,
            1,
            gather_indices.unsqueeze(-1).expand(-1, -1, node_features.shape[-1]),
        )
        return sequence

In [18]:
class PostMLP(nn.Module):
    def __init__(self, config: GraphConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.embedding_size + 140 + 24, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x = nn.GELU()(self.fc1(x))
        x = self.dropout(x)
        x = nn.GELU()(self.fc2(x))
        x = self.dropout(x)
        return self.fc3(x)

In [19]:
class BertSequenceEncoder(nn.Module):
    def __init__(self, config: GraphConfig):
        super().__init__()
        self.config = config
        self.graph_sequence = GraphToSequence(config)
        self.encoder = BertEncoder(config)
        self.postnet = PostMLP(config)

    def forward(self, node_sequence, node_opcode, node_feat, configs, targets):
        sequence = self.graph_sequence(node_sequence, node_opcode, node_feat)
        attention_mask = (node_sequence != -1).float()
        attention_mask = attention_mask.masked_fill(attention_mask == 0, -1e9)
        attention_mask = attention_mask.masked_fill(attention_mask == 1, 0.0)
        outputs = self.encoder(sequence, attention_mask)
        x = outputs.last_hidden_state
        x_replicated = torch.repeat_interleave(x.unsqueeze(1), configs.shape[1], dim=1)
        x = torch.cat([x_replicated, configs], dim=-1)
        x = self.postnet(x)
        x = x.squeeze(-1)
        mask = (targets != 0).float()
        x = mask * x
        return x

In [279]:
model = BertSequenceEncoder(config)
loss = MultiElementRankLoss()

In [358]:
model.to("cpu")
for (
    padded_sequence,
    padded_opcode,
    padded_feat,
    padded_config,
    config_mask,
    padded_target,
) in train_loader:
    outputs = model(
        padded_sequence, padded_opcode, padded_feat, padded_config, padded_target
    )
    print(outputs)
    mask = (outputs != 0.0).int()
    print(mask)
    error = loss(outputs, padded_target, mask)
    print(error)
    break

tensor([[2.3022, 0.3571, 0.8268,  ..., 0.0000, 0.0000, 0.0000],
        [4.5267, 0.4191, 0.6969,  ..., 0.0000, 0.0000, 0.0000],
        [0.8229, 0.1834, 0.1814,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.9943, 0.3271, 0.2575,  ..., 0.0000, 0.0000, 0.0000],
        [1.1317, 0.4084, 0.5718,  ..., 0.0000, 0.0000, 0.0000],
        [0.8673, 0.3081, 0.3596,  ..., 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)
tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], dtype=torch.int32)
tensor(0., grad_fn=<DivBackward0>)


In [356]:
class MultiElementRankLoss(nn.Module):
    """
    Loss function that compares the output of the model with the output of the model with a permutation of the elements
    """

    def __init__(self, margin: float = 0.0, number_permutations: int = 1) -> None:
        super().__init__()
        self.loss_fn = torch.nn.MarginRankingLoss(margin=margin, reduction="none")
        self.number_permutations = number_permutations

    def calculate_rank_loss(
        self,
        outputs: torch.Tensor,
        config_runtime: torch.Tensor,
        config_idxs: torch.Tensor,
    ):
        """
        Generates a permutation of the predictions and targets and calculates the loss MarginRankingLoss against the permutation
        Args:
            outputs: Tensor of shape (bs, seq_len) with the outputs of the model
            config_runtime: Tensor of shape (bs, seq_len) with the runtime of the model
            config_mask: Tensor of shape (bs, seq_len) with 1 in the positions of the elements
            and 0 in the positions of the padding
        Returns:
            loss: Tensor of shape (bs, seq_len) with the loss for each element in the batch
        """
        bs, num_configs = outputs.shape
        permutation = torch.randperm(num_configs)
        permuted_idxs = config_idxs[:, permutation]
        # We mask those cases where we compare the same configuration
        config_mask = torch.where(config_idxs != permuted_idxs, 1, 0)
        permuted_runtime = config_runtime[:, permutation]
        labels = 2 * ((config_runtime - permuted_runtime) > 0) - 1
        print(labels)
        permuted_output = outputs[:, permutation]
        loss = self.loss_fn(
            outputs.view(-1, 1), permuted_output.view(-1, 1), labels.view(-1, 1)
        )
        print(loss)
        loss = loss.view(bs, num_configs) * config_mask
        return loss.mean()

    def forward(
        self,
        outputs: torch.Tensor,
        config_runtime: torch.Tensor,
        config_idxs: torch.Tensor,
    ):
        loss = 0
        for _ in range(self.number_permutations):
            loss += self.calculate_rank_loss(outputs, config_runtime, config_idxs)
        return loss / self.number_permutations

In [20]:
class TileTopK(tm.Metric):
    higher_is_better = True

    def __init__(self, k: int = 5) -> None:
        super().__init__()
        self.add_state("runtimes", default=[], dist_reduce_fx=None)
        self.k = k

    def update(
        self, preds: torch.Tensor, target: torch.Tensor, config_attn_mask: torch.Tensor
    ) -> None:
        """
        Update the metric state
        Args:
            preds: Tensor of shape (bs, seq_len) with the predicted runtimes orders
            target: Tensor of shape (bs, seq_len) with the target runtimes
            config_attn_mask: Tensor of shape (bs, seq_len) with 1 in the positions of the elements
        """
        best_runtimes = (
            torch.where(config_attn_mask == 1, target, torch.tensor(float("inf")))
            .min(1)
            .values
        )
        masked_preds = torch.where(
            config_attn_mask == 1, preds, torch.tensor(float("inf"))
        )
        pred_bottomk_indices = torch.topk(masked_preds, k=self.k, largest=False).indices
        bs = preds.shape[0]
        bottom_k_positions = torch.stack(
            [
                torch.arange(bs).repeat_interleave(self.k).to(config_attn_mask.device),
                pred_bottomk_indices.view(-1),
            ]
        )
        predicted_runtimes = target[bottom_k_positions[0], bottom_k_positions[1]].view(
            bs, self.k
        )
        best_predicted_runtimes = predicted_runtimes.min(1).values
        self.runtimes.append(best_predicted_runtimes / best_runtimes)

    def compute(self) -> torch.Tensor:
        return (2 - torch.cat(self.runtimes)).mean()

In [27]:
class PairwiseRankingLoss(nn.Module):
    def __init__(self, margin: float = 0.0):
        super().__init__()
        self.criterion = nn.MarginRankingLoss(margin=margin)

    def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        mask = (target != 0.0).float()
        mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
        diff = preds.unsqueeze(-1) - preds.unsqueeze(-2)
        target_diff = target.unsqueeze(-1) - target.unsqueeze(-2)
        target_diff = torch.clamp(target_diff, min=0.0)
        loss = mask * torch.relu(self.margin + diff * target_diff - self.margin)
        loss = loss.sum() / mask.sum()
        return loss

In [21]:
class CustomRankingLoss(nn.Module):
    def __init__(self):
        super(CustomRankingLoss, self).__init__()
        self.margin_ranking_loss = nn.MarginRankingLoss(
            reduction="none"
        )  # none to handle averaging manually

    def forward(self, pred, target):
        mask = target > 0
        target_diffs = target.unsqueeze(2) - target.unsqueeze(1)
        pred_diffs = pred.unsqueeze(2) - pred.unsqueeze(1)
        labels = torch.sign(target_diffs)
        mask_diffs = mask.unsqueeze(2) & mask.unsqueeze(1)
        labels = labels * mask_diffs.float()
        losses = self.margin_ranking_loss(
            pred_diffs, torch.zeros_like(pred_diffs), labels
        )
        avg_loss = losses.sum() / mask_diffs.float().sum()
        return avg_loss

In [27]:
class CustomRankingLoss(nn.Module):
    def __init__(self, chunk_size=1000):
        super(CustomRankingLoss, self).__init__()
        self.margin_ranking_loss = nn.MarginRankingLoss(
            reduction="none"
        )  # none to handle averaging manually
        self.chunk_size = chunk_size

    def compute_chunked_loss(self, pred, target, mask):
        # Break the computation into chunks to save memory
        total_loss = 0.0
        total_count = 0

        num_configs = pred.size(1)
        for i in range(0, num_configs, self.chunk_size):
            for j in range(0, num_configs, self.chunk_size):
                target_diffs = target[:, i : i + self.chunk_size].unsqueeze(2) - target[
                    :, j : j + self.chunk_size
                ].unsqueeze(1)
                pred_diffs = pred[:, i : i + self.chunk_size].unsqueeze(2) - pred[
                    :, j : j + self.chunk_size
                ].unsqueeze(1)

                labels = torch.sign(target_diffs)

                mask_diffs = mask[:, i : i + self.chunk_size].unsqueeze(2) & mask[
                    :, j : j + self.chunk_size
                ].unsqueeze(1)
                labels = labels * mask_diffs.float()

                losses = self.margin_ranking_loss(
                    pred_diffs, torch.zeros_like(pred_diffs), labels
                )

                total_loss += losses.sum()
                total_count += mask_diffs.float().sum()

        avg_loss = total_loss / total_count
        return avg_loss

    def forward(self, pred, target):
        mask = target > 0
        return self.compute_chunked_loss(pred, target, mask)

In [22]:
class LightningWrapper(pl.LightningModule):
    def __init__(self, model: nn.Module, loss: nn.Module):
        super().__init__()
        self.model = model
        self.loss = loss
        self.topk = TileTopK()

    def forward(self, *inputs):
        return self.model(*inputs)

    def compute_loss(self, outputs, target, mask=None):
        if mask is None:
            mask = (outputs != 0.0).int()
        return self.loss(outputs, target)

    def training_step(self, batch, batch_idx):
        inputs = batch[:-1]
        target = batch[-1]
        outputs = self.model(*inputs)
        error = self.compute_loss(outputs, target)
        self.log("loss", error, prog_bar=True)
        return {"loss": error}

    def validation_step(self, batch, batch_idx):
        inputs = batch[:-1]
        target = batch[-1]
        outputs = self.model(*inputs)
        error = self.compute_loss(outputs, target)
        self.log("val_loss", error, prog_bar=True)

        mask = (outputs != 0.0).int()
        self.topk.update(outputs, target, mask)
        return {"val_loss": error}

    def on_validation_end(self) -> None:
        topk = self.topk.compute()
        self.print(f"topk {topk:.3f}")
        self.topk.reset()
        return super().on_validation_end()

    def test_step(self, batch, batch_idx):
        inputs = batch[:-1]
        target = batch[-1]
        outputs = self.model(*inputs)
        error = self.compute_loss(outputs, target)
        self.log("test_loss", error, prog_bar=True)
        return {"test_loss": error}

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer

In [28]:
model = LightningWrapper(BertSequenceEncoder(config), CustomRankingLoss())

In [24]:
train_dataset = TileDataset(tile_xla["train"])
valid_dataset = TileDataset(tile_xla["valid"])
train_dataloader = DataLoader(
    train_dataset,
    collate_fn=custom_collate,
    batch_size=2,
    num_workers=2,
    shuffle=True,
    persistent_workers=True,
)
valid_dataloader = DataLoader(
    valid_dataset, collate_fn=custom_collate, batch_size=2, num_workers=2
)

In [25]:
pl.seed_everything(42)
trainer_config = dict(
    max_epochs=50,
    # precision=32,
    gradient_clip_val=1.0,
    accumulate_grad_batches=2,
    check_val_every_n_epoch=10,
)

In [29]:
torch.cuda.empty_cache()
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# device = "cpu"
model.to(device)
model = model.train()
# torch.set_float32_matmul_precision("medium")
trainer = pl.Trainer(
    **trainer_config,
)
trainer.fit(model, train_dataloader, valid_dataloader)

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

topk 0.333


Validation: 0it [00:00, ?it/s]

topk 0.956


Validation: 0it [00:00, ?it/s]

topk 0.980


Validation: 0it [00:00, ?it/s]

topk 0.979


Validation: 0it [00:00, ?it/s]

topk 0.979


In [31]:
torch.cuda.empty_cache()
model = model.cpu()
!nvidia-smi

Mon Oct 30 05:04:24 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.47                 Driver Version: 531.68       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 3090 Ti      On | 00000000:01:00.0 Off |                  Off |
| 45%   57C    P8               16W / 450W|   1445MiB / 24564MiB |      3%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    