In [None]:
import os
from pathlib import Path
from typing import Dict, Optional, List, Union, Tuple
from dataclasses import dataclass
import math
import numpy as np
import pandas as pd
from datasets import Dataset
from tqdm import tqdm
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import  DataLoader

from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.pytorch_utils import apply_chunking_to_forward
from transformers.activations import ACT2FN
import pytorch_lightning as pl
import torchmetrics as tm
# import bitsandbytes as bnb

In [None]:
NODE_OP_CODES = 120
NODE_FEATS = 140
CONFIG_FEATS = 24
NODE_CONFIG_FEATS = 18

In [None]:
DATA_DIR = "../input/predict-ai-model-runtime/npz_all/npz"


def generate_tile_df() -> pd.DataFrame:
    tile_df = pd.DataFrame({'paths': [elem for elem in (Path(DATA_DIR) / 'tile').rglob("*") if elem.is_file()]}).assign(
        split=lambda df: df.paths.apply(lambda x: x.parent.name),
        configuration=lambda df: df.paths.apply(lambda x: x.parent.parent.name),
        extra=lambda df: df.paths.apply(lambda x: x.parent.parent.parent.name),
        model_name=lambda df: df.paths.apply(lambda x: x.stem),
        collection=lambda df: df.extra + ':' + df.configuration ,
        ID=lambda df: df.collection + ':' + df.model_name ,
        paths = lambda df: df.paths.apply(lambda x: str(x))
    )
    return tile_df

In [None]:
tile_df = generate_tile_df()
tile_df.head()

# Dataset
* Create an Adjacency matrix for masking the attention
* Creates a virtual first node equivalent to the [CLS] token which contains the global config for tile cases, while layout node configuration goes to the corresponding node position

In [None]:
def edges_adjacency(edges: torch.Tensor, add_diagonal=True) -> torch.Tensor:
    """
    Generate an adjacency matrix from the edges
    Args:
        edges: Tensor of shape (num_edges, 2) with the edges
        add_diagonal: Boolean indicating if the diagonal should be added to the adjacency matrix
    Returns:
        adjacency_matrix: Tensor of shape (num_nodes, num_nodes) with the adjacency matrix
    """
    adjacency_matrix = torch.zeros((edges.max() + 1, edges.max() + 1))
    adjacency_matrix[edges[:, 0], edges[:, 1]] = 1
    if add_diagonal:
        diag_idx = torch.arange(adjacency_matrix.shape[0])
        adjacency_matrix[diag_idx, diag_idx] = 1
    return adjacency_matrix

def tile_loader(path):
    tile_dict =  dict(np.load(path))
    tile_dict = {k: torch.from_numpy(v) for k, v in tile_dict.items()}
    tile_dict['edges_adjecency'] = edges_adjacency(tile_dict['edge_index'])
    return tile_dict

def node_cls_token(elem_dict, shift_node_config_ids:bool=True):
    """
    Add a cls token to the node opcode, features, edges adjacency matrix, shift node_config_ids by 1 to account for the cls token
    Args:
        elem_dict: Dictionary with the elements of the tile
    Returns:
        elem_dict: Dictionary with the elements of the tile with the cls token
    """
    elem_dict['node_opcode'] = torch.cat([torch.tensor([0]), elem_dict['node_opcode']]) # Introduce [CLS] node
    elem_dict['node_feat'] = torch.cat([torch.zeros((1, elem_dict['node_feat'].shape[1])), elem_dict['node_feat']])
    elem_dict['edges_adjecency'] = F.pad(elem_dict['edges_adjecency'], (1,0,1,0), value=1)
    if 'node_config_ids' in elem_dict and shift_node_config_ids:
        elem_dict['node_config_ids'] = elem_dict['node_config_ids'] + 1 # Shift Node Config IDs to take in to account [CLS] node
    return elem_dict


class TileDataset(torch.utils.data.Dataset):
    
    def __init__(self, df:pd.DataFrame ,add_cls_token:bool=True, num_configs:int=10,  max_configs:Optional[int]=None):
        self.df = df
        self.add_cls_token = add_cls_token
        self.num_configs = num_configs
        self.max_configs = max_configs  
        
    def __len__(self) -> int:
        return len(self.df)
    
    def select_configs(self, total_configs:int):
        if self.max_configs is not None:
            total_configs = min(total_configs, self.max_configs)
        if self.num_configs == -1:
            return np.arange(total_configs)
        if total_configs < self.num_configs:
            return np.random.choice(total_configs, self.num_configs, replace=True)
        return  np.random.choice(total_configs, self.num_configs, replace=False)
    
    def __getitem__(self, idx:int, selected_configs:List[int]=None):
        tile_dict = tile_loader(self.df.paths[idx])
        if selected_configs is None:
            selected_configs = self.select_configs(tile_dict['config_feat'].shape[0])
        tile_dict['node_config_feat'] = tile_dict.pop('config_feat')[selected_configs]
        tile_dict['node_config_feat'] = F.pad(tile_dict['node_config_feat'].unsqueeze(1), (0,NODE_CONFIG_FEATS))
        tile_dict['config_runtime'] = tile_dict['config_runtime'][selected_configs].float()
        tile_dict['config_runtime'] /= tile_dict['config_runtime_normalizers'][selected_configs].float()
        tile_dict['node_config_ids'] = torch.zeros((1,))
        tile_dict['selected_idxs'] = selected_configs
        if self.add_cls_token:
            tile_dict = node_cls_token(tile_dict, False)
        return tile_dict

In [None]:
tile_dataset = TileDataset(tile_df)

In [None]:
elem = tile_dataset[0]
for k,v in elem.items():
    print(k, v.shape)

In [None]:
elem['edges_adjecency']

## Collator

In [None]:
def pad_edge_adjacency(edges_adjacency_list):
    max_len = max([elem.shape[0] for elem in edges_adjacency_list])
    return torch.stack([F.pad(elem, (0, max_len-elem.shape[0], 0, max_len-elem.shape[0]), value=0) for elem in edges_adjacency_list], dim=0)

@dataclass
class LayoutCollator:
    pad_to_multiple_of: int = 64
    targets:bool = True
    padding_idx:int = 120
    node_padding_idx:int = 0
    
    def __call__(self, batch):
        output = {}
        max_node_len = max([elem['node_opcode'].shape[0] for elem in batch])
        node_pad_amount = self.pad_to_multiple_of - max_node_len % max(self.pad_to_multiple_of, 1)
        output['node_opcode'] = F.pad(pad_sequence([elem['node_opcode'] for elem in batch], batch_first=True, padding_value=self.padding_idx),
                                      (0, node_pad_amount), value=self.padding_idx).long()
        output['node_feat'] = F.pad(pad_sequence([elem['node_feat'] for elem in batch], batch_first=True),
                                    (0,0,0, node_pad_amount), value=0)
        output['edges_adjecency'] = F.pad(pad_edge_adjacency([elem['edges_adjecency'] for elem in batch]),
                                          (0, node_pad_amount, 0, node_pad_amount), value=0)
        output['node_attn_mask'] = F.pad(pad_sequence([torch.ones(len(elem['node_opcode'])) for elem in batch], batch_first=True),
                                         (0, node_pad_amount), value=0)

        max_node_config_len = max([elem['node_config_ids'].shape[0] for elem in batch])
        node_config_pad_amount = self.pad_to_multiple_of - max_node_config_len % max(self.pad_to_multiple_of, 1)
        output['node_config_ids'] = F.pad(pad_sequence([elem['node_config_ids'] for elem in batch], batch_first=True),
                                         (0, node_config_pad_amount), value=0).long()
        padded_node_config_feat = pad_sequence([elem['node_config_feat'].permute(1,0,2) for elem in batch], batch_first=True, padding_value=-1)
        padded_node_config_feat = F.pad(padded_node_config_feat.permute(0,2,1,3),
                                           (0,0,0, node_config_pad_amount,0,0), value=-1)
        
        output['node_config_feat'] = torch.where(padded_node_config_feat!=-1, padded_node_config_feat, self.node_padding_idx)
                                      
        output['config_idxs'] = torch.stack([torch.from_numpy(elem['selected_idxs']) for elem in batch])
        
        if self.targets:
            output['config_runtime'] = pad_sequence([elem['config_runtime'].float() for elem in batch], batch_first=True)
        return output

In [None]:
collate_fn = LayoutCollator(64)

In [None]:
batch = collate_fn([tile_dataset[0], tile_dataset[1]])
for k,v in batch.items():
    print(k,v.shape)

# Model - Config

In [None]:
@dataclass
class GraphConfig:
    num_hidden_layers: int = 8
    hidden_size: int = 256
    num_attention_heads: int = 16
    intermediate_size: int = 64
    chunk_size_feed_forward: int = 64
    attention_probs_dropout_prob: float = 0.0
    max_position_embeddings: int = 512
    hidden_dropout_prob: float = 0.0
    layer_norm_eps: float = 1e-12
    hidden_act: str = 'gelu'
    initializer_range: float = 0.02
    output_hidden_states: bool = False
    output_attentions: bool = False
    gradient_checkpointing: bool = False
    margin: float = 0.1
    number_permutations: int = 10
    
    def __post_init__(self):
        self.embedding_size = self.hidden_size
    
    def validate(self):
        if self.hidden_size % self.num_attention_heads != 0 and not hasattr(self, "embedding_size"):
            raise ValueError(
                f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
                f"heads ({self.num_attention_heads})"
            )
            
    def save_config(self, path):
        config = asdict(self)
        with open(path, 'w') as f:
            json.dump(config, f)
            
    @classmethod
    def load_config(cls, path):
        with open(path, 'r') as f:
            config = json.load(f)
        return cls(**config)

## Loss
* Uses Ranking loss to compare different configuration
* Compares does configurations with different indexes, masks those cases where the permutation returns the same element
* Compares multiple configurations in each run

In [None]:
class MultiElementRankLoss(nn.Module):
    """
    Loss function that compares the output of the model with the output of the model with a permutation of the elements
    """
    
    def __init__(self, margin:float=0.0, number_permutations:int = 1) -> None:
        super().__init__()
        self.loss_fn = torch.nn.MarginRankingLoss(margin=margin, reduction = 'none')
        self.number_permutations = number_permutations
    
    def calculate_rank_loss(self,
                            outputs: torch.Tensor,
                            config_runtime: torch.Tensor,
                            config_idxs: torch.Tensor
                            ):
        """
        Generates a permutation of the predictions and targets and calculates the loss MarginRankingLoss against the permutation
        Args:
            outputs: Tensor of shape (bs, seq_len) with the outputs of the model
            config_runtime: Tensor of shape (bs, seq_len) with the runtime of the model
            config_mask: Tensor of shape (bs, seq_len) with 1 in the positions of the elements
            and 0 in the positions of the padding
        Returns:
            loss: Tensor of shape (bs, seq_len) with the loss for each element in the batch
        """
        bs, num_configs = outputs.shape
        permutation = torch.randperm(num_configs) 
        permuted_idxs = config_idxs[:, permutation]
        # We mask those cases where we compare the same configuration
        config_mask = torch.where(config_idxs != permuted_idxs, 1, 0)
        permuted_runtime = config_runtime[:, permutation]
        labels = 2*((config_runtime - permuted_runtime) > 0) -1
        permuted_output = outputs[:, permutation]
        loss = self.loss_fn(outputs.view(-1,1), permuted_output.view(-1,1), labels.view(-1,1))
        loss = loss.view(bs, num_configs) * config_mask
        return loss.mean()
                
    
    def forward(self,
                outputs: torch.Tensor,
                config_runtime: torch.Tensor,
                config_idxs: torch.Tensor
                ):
        loss = 0 
        for _ in range(self.number_permutations):
            loss += self.calculate_rank_loss(outputs, config_runtime, config_idxs)
        return loss/ self.number_permutations

## Metric

In [None]:

class TileTopK(tm.Metric):
    
    higher_is_better = True
    
    def __init__(self, k:int=5) -> None:
        super().__init__()
        self.add_state("runtimes", default=[], dist_reduce_fx=None)
        self.k = k
        
    def update(self, preds: torch.Tensor, target: torch.Tensor, config_attn_mask:torch.Tensor) -> None:
        """
        Update the metric state
        Args:
            preds: Tensor of shape (bs, seq_len) with the predicted runtimes orders
            target: Tensor of shape (bs, seq_len) with the target runtimes
            config_attn_mask: Tensor of shape (bs, seq_len) with 1 in the positions of the elements
        """
        best_runtimes = torch.where(config_attn_mask==1, target, torch.tensor(float('inf'))).min(1).values
        masked_preds = torch.where(config_attn_mask==1, preds, torch.tensor(float('inf')))
        pred_bottomk_indices = torch.topk(masked_preds, k=self.k, largest=False).indices
        bs = preds.shape[0]
        bottom_k_positions = torch.stack([torch.arange(bs).repeat_interleave(self.k).to(config_attn_mask.device), pred_bottomk_indices.view(-1)])
        predicted_runtimes = target[bottom_k_positions[0], bottom_k_positions[1]].view(bs,self.k)
        best_predicted_runtimes = predicted_runtimes.min(1).values
        self.runtimes.append(best_predicted_runtimes/ best_runtimes)
        
    def compute(self) -> torch.Tensor:
        return (2-torch.cat(self.runtimes)).mean()

## Model
Modified version of 🤗 Bert implementation to take in to account [Graph Attention](https://arxiv.org/abs/1710.10903)
* Removed the parts corresponding to Cross-attention
* Made layer_head_mask the same for all layers, heads
* The Head mask corresponds to the edge adjacency 


In [None]:
# Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
class BertEncoder(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask #DONE: Same Head Mask for all layers

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs,  output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    all_hidden_states,
                    all_self_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=None,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=None,
        )
        
        
class BertLayer(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        outputs = (layer_output,) + outputs


        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output
    
class BertIntermediate(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states
    
class BertOutput(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
    
class BertAttention(nn.Module):
    def __init__(self, config:GraphConfig, position_embedding_type=None):
        super().__init__()
        self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()


    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs
    
    
class BertSelfAttention(nn.Module):
    def __init__(self, config:GraphConfig, position_embedding_type=None):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)


    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        
        mixed_query_layer = self.query(hidden_states)
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)


        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask #DONE: Same Head Mask for all Heads

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs


class BertSelfOutput(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
    
    
class NodeEncoder(nn.Module):
    
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.node_opcode_embeddings = nn.Embedding(NODE_OP_CODES+1 , config.embedding_size, padding_idx=NODE_OP_CODES)
        self.linear = nn.Linear(NODE_FEATS, config.embedding_size, bias=False)
        self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        
        
    def forward(self,
                node_opcode: torch.Tensor,
                node_feat: torch.Tensor
                ) -> torch.Tensor:
        opcode_embeddings = self.node_opcode_embeddings(node_opcode) 
        node_feats =  self.linear(node_feat)
        features = opcode_embeddings + node_feats
        features = self.layer_norm(features)
        return features
    
    
class BertNodeEncoder(nn.Module):
    
    def __init__(self, config:GraphConfig) -> None:
        super().__init__()
        self.config = config
        self.node_embeddings = NodeEncoder(config)
        self.node_encoder = BertEncoder(config)
        
    def forward(self,
                node_opcode: torch.Tensor,
                node_feat: torch.Tensor,
                edges_adjecency: torch.Tensor,
                node_attn_mask: torch.Tensor
                ):
        node_embeddings = self.node_embeddings(node_opcode, node_feat)
        node_attn_mask = node_attn_mask.unsqueeze(1).unsqueeze(-1)
        node_encoder_outputs = self.node_encoder(node_embeddings,
                                                 attention_mask=node_attn_mask,
                                                 head_mask=edges_adjecency.unsqueeze(0).repeat(self.config.num_hidden_layers, 1, 1, 1).unsqueeze(2),
                                                 output_attentions=True)
        return node_encoder_outputs
    
def transform_node_positional_embeddings(embeddings_output:torch.Tensor,
                                         node_config_ids:torch.Tensor,
                                         num_nodes:int
                                         ) -> torch.Tensor:
    bs, num_configs, _, dim = embeddings_output.shape
    idxs = node_config_ids.unsqueeze(1).repeat(1,num_configs,1)
    zeros = torch.zeros(bs, num_configs, num_nodes, dim, device=embeddings_output.device, dtype=embeddings_output.dtype)
    idxs = idxs.unsqueeze(-1).repeat(1,1,1,dim)
    zeros.scatter_reduce_(2, idxs, embeddings_output, reduce='sum')
    return zeros

class NodeFeatEmbeddings(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.config = config
        self.node_feat_embeddings = nn.Linear(NODE_CONFIG_FEATS + CONFIG_FEATS, config.embedding_size, bias=False)
        self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        
    def forward(self, node_config_feat: torch.Tensor, node_config_ids: torch.Tensor, num_nodes:int) -> torch.Tensor:
        node_config_feat_embeddings = self.node_feat_embeddings(node_config_feat)
        node_config_feat_embeddings = self.layer_norm(node_config_feat_embeddings)
        node_config_feat_embeddings = transform_node_positional_embeddings(node_config_feat_embeddings, node_config_ids, num_nodes)
        return node_config_feat_embeddings
        
    
class BertGraphEncoder(nn.Module):
    def __init__(self, config:GraphConfig) -> None:
        super().__init__()
        self.config = config
        self.node_embeddings = NodeEncoder(config)
        self.node_encoder = BertEncoder(config)
        self.node_feat_embeddings = NodeFeatEmbeddings(config)
        
    def forward(self,
                node_opcode: torch.Tensor, # (bs, num_nodes)
                node_feat: torch.Tensor, # (bs, num_nodes, num_node_feats)
                edges_adjecency: torch.Tensor, # (bs, num_nodes, num_nodes)
                node_attn_mask: torch.Tensor, # (bs, num_nodes)
                node_config_feat: torch.Tensor, # (bs, num_configs, num_config_nodes, num_node_feats)
                node_config_ids: torch.Tensor, # (bs, num_configs, num_config_nodes)
                ):
        bs, num_nodes = node_opcode.shape
        num_configs = node_config_feat.shape[1]
        node_embeddings = self.node_embeddings(node_opcode, node_feat)
        node_config_feat_embeddings = self.node_feat_embeddings(node_config_feat, node_config_ids, num_nodes)
        
        node_embeddings = node_embeddings.unsqueeze(1).repeat(1, num_configs, 1, 1)
        node_embeddings += node_config_feat_embeddings
        node_attn_mask = node_attn_mask.unsqueeze(1).repeat(1, num_configs, 1)
        node_embeddings = node_embeddings.reshape(bs *num_configs, num_nodes, -1)
        node_attn_mask = node_attn_mask.reshape(bs *num_configs, num_nodes)
        node_attn_mask = node_attn_mask.unsqueeze(1).unsqueeze(-1)
        edges_adjecency = edges_adjecency.unsqueeze(1).repeat(1, num_configs, 1, 1).reshape(bs *num_configs, num_nodes, num_nodes)
        edges_adjecency = edges_adjecency.unsqueeze(1)
        

        node_encoder_outputs = self.node_encoder(node_embeddings,
                                                 attention_mask=node_attn_mask,
                                                 head_mask=edges_adjecency,
                                                 output_attentions=True)
        
        return node_encoder_outputs.last_hidden_state.reshape(bs, num_configs, num_nodes, -1)
    
    
class GraphEncoder(nn.Module):
    
    config_class = GraphConfig
    
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.config = config
        self.node_encoder = BertGraphEncoder(config)
        self.head = nn.Linear(config.hidden_size, 1)
        self.loss_fn = MultiElementRankLoss(margin=config.margin, number_permutations=config.number_permutations)
        
        
    def forward(self,
                node_opcode: torch.Tensor, # (bs, num_nodes)
                node_feat: torch.Tensor, # (bs, num_nodes, num_node_feats)
                edges_adjecency: torch.Tensor, # (bs, num_nodes, num_nodes)
                node_attn_mask: torch.Tensor, # (bs, num_nodes)
                node_config_feat: torch.Tensor, # (bs, num_configs, num_config_nodes, num_node_feats)
                node_config_ids: torch.Tensor, # (bs, num_configs, num_config_nodes)
                config_idxs: Optional[torch.Tensor] = None, # (bs, num_configs)
                config_runtime: Optional[torch.Tensor] = None,):
        
        last_hidden_state = self.node_encoder(node_opcode,
                                    node_feat,
                                    edges_adjecency,
                                    node_attn_mask,
                                    node_config_feat,
                                    node_config_ids)
        
        output = self.head(last_hidden_state[:,:,0]).squeeze(-1)
        outputs = {'outputs': output, 'order': torch.argsort(output, dim=1)}
        if config_runtime is not None:
            loss = 0
            loss += self.loss_fn(output, config_runtime, config_idxs)
            outputs['loss'] = loss
        return outputs

In [None]:
class LightningWrapper(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.model = model
        self.topk = TileTopK()
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        return outputs['loss']

    def validation_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        loss = outputs['loss']
        self.log("val_loss", loss, prog_bar=True)
        config_attn_mask = torch.ones_like(batch['config_runtime'], device=batch['config_runtime'].device)
        self.topk.update(outputs['outputs'], batch['config_runtime'], config_attn_mask)
        return loss
    
    def on_validation_end(self) -> None:
        topk = self.topk.compute()
        self.print(f"topk {topk:.3f}")
        self.topk.reset()
        return super().on_validation_end()

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.model.loss(y_hat, y)
        self.log("test_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=1e-3)
        return optimizer

# Training

In [None]:
config_kwargs = dict(hidden_size= 128,
    num_attention_heads= 4,
    num_hidden_layers= 2,
    intermediate_size= 64,
    gradient_checkpointing= True,
    margin= 0.1,
    number_permutations= 4,
    )

In [None]:
config = GraphConfig(**config_kwargs)

In [None]:
model = GraphEncoder(config)
model = LightningWrapper(model)

In [None]:
tile_df

In [None]:
train_df = tile_df.query("split == 'train'").reset_index(drop=True)
valid_df = tile_df.query("split == 'valid'").reset_index(drop=True)
train_dataset = TileDataset(train_df, num_configs=24)
valid_dataset = TileDataset(valid_df, num_configs=24)

In [None]:
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=8, num_workers=2, shuffle=True, persistent_workers=True)
valid_dataloader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=8, num_workers=2)

In [None]:
trainer_config = dict(
    max_epochs= 50,
    precision= 32,
    gradient_clip_val= 1.0,
    accumulate_grad_batches= 4,
    check_val_every_n_epoch= 10)

In [None]:
torch.set_float32_matmul_precision("medium")
trainer = pl.Trainer(**trainer_config,)
trainer.fit(model, train_dataloader, valid_dataloader)

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
split = 'test'
test_tile_df = tile_df.query("split == @split").reset_index(drop=True)
test_tile_ds = TileDataset(test_tile_df, num_configs=-1)
collate_fn = LayoutCollator(64, targets=split!="test")
test_dataloader = DataLoader(test_tile_ds, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)

In [None]:
model.to(device)
model = model.eval()


In [None]:
def chunk_batch(batch, start_idx, end_idx):
    output = {k:batch[k] for k in ['node_opcode', 'node_feat', 'edges_adjecency', 'node_attn_mask', 'node_config_ids']}
    output['node_config_feat'] = batch['node_config_feat'][:, start_idx: end_idx]
    return output
    

In [None]:
pred_order = []
for batch in tqdm(test_dataloader):
    batch.pop('config_idxs')
    batch = {k: v.to(device) for k, v in batch.items()}
    num_configs = batch['node_config_feat'].shape[1]
    # Chunk the configs to avoid OOM errors
    configs_cut_points = list(range(0,num_configs, 100)) + [num_configs]
    chunk_order = []
    for start, end in zip(configs_cut_points, configs_cut_points[1:]):
        chunked_batch = chunk_batch(batch, start, end)
        with torch.no_grad():
            output = model.model(**chunked_batch)
        chunk_order.extend(output['outputs'].cpu().numpy())
    pred_order.append(np.argsort(np.concatenate(chunk_order))[:5])

In [None]:
idxs_string = [";".join(map(str,elem)) for elem in pred_order]
test_tile_df['TopConfigs'] = idxs_string
test_tile_df = test_tile_df[['ID', 'TopConfigs']]
test_tile_df.head()

In [None]:
submission_df = pd.read_csv('../input/predict-ai-model-runtime/sample_submission.csv')
submission_df = submission_df.query(f"ID not in {test_tile_df.ID.tolist()}")
submission_df = pd.concat([test_tile_df, submission_df])
submission_df.to_csv('submission.csv', index=False)
submission_df

In [None]:
!pip install /kaggle/input/fast-slow-4-dataset-train/torch_geometric-2.3.1-py3-none-any.whl
!pip install /kaggle/input/fast-slow-4-dataset-train/torch_scatter-2.1.1-cp310-cp310-linux_x86_64.whl

In [None]:
!pip install timm


In [None]:
import timm
from timm.scheduler import  CosineLRScheduler
import numpy as np
import pandas as pd
import os
from tqdm import tqdm 

import sklearn,sklearn.model_selection
import torch
from torch import nn
from torch import Tensor
from torch_geometric.nn import GCNConv,SAGEConv
from torch_geometric.datasets import Planetoid
from torch.utils.data import DataLoader, Dataset
#from timm.scheduler import CosineLRScheduler
import matplotlib.pyplot as plt
device = 'cpu'

In [None]:
def load_df(directory):
    splits = ["test"]
    dfs = dict()
    
    for split in splits:
        path = os.path.join(directory, split)
        files = os.listdir(path)
        list_df = []
        
        for file in files:
            d = dict(np.load(os.path.join(path,file)))
            d['file'] = file
            list_df.append(d)
        dfs[split] = pd.DataFrame.from_dict(list_df)
    return dfs
layout_xla_random = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/layout/xla/random/")
layout_xla_default = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/layout/xla/default/")
layout_nlp_default = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/layout/nlp/default/")
layout_nlp_random = load_df("/kaggle/input/predict-ai-model-runtime/npz_all/npz/layout/nlp/random/")

In [None]:
class TileDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        config_feat = torch.tensor(row['node_config_feat'].astype(np.float32))
        node_feat = torch.tensor(row['node_feat'].astype(np.float32))
        node_opcode = torch.tensor(row['node_opcode'].astype(np.int64))
        edge_index = torch.tensor(np.swapaxes(row['edge_index'],0,1).astype(np.int64))
        target = (row['config_runtime']).astype(np.float32)
        # minmax scale the target, we only care about order
        target = (target-min(target))/(max(target) -min(target))
        target = torch.tensor(target)
        return config_feat,node_feat,node_opcode,edge_index,target
    
class SimpleModel(torch.nn.Module):
    def __init__(self, hidden_channels, graph_feats, hidden_dim):
        super().__init__()
        op_embedding_dim = 4 # I choose 4-dimensional embedding
        self.embedding = torch.nn.Embedding(120, #120 different op-codes
                                            op_embedding_dim,
                                           )
        assert len(hidden_channels)>0
        in_channels = op_embedding_dim+140
        self.convs = torch.nn.ModuleList()
        last_dim = hidden_channels[0]
        self.convs.append(GCNConv(in_channels, hidden_channels[0]))
        for i in range(len(hidden_channels)-1):
            self.convs.append(GCNConv(hidden_channels[i], hidden_channels[i+1]))
            last_dim = hidden_channels[i+1]
        self.convs.append(GCNConv(last_dim, graph_feats))
        
        self.dense = torch.nn.Sequential(nn.Linear(82, 64),
                                         nn.ReLU(),
                                         nn.Linear(64, 64),
                                         nn.ReLU(),
                                         nn.Linear(64, 1),
                                        )
    
    def forward(self, x_cfg: Tensor,x_feat: Tensor, x_op: Tensor, edge_index: Tensor) -> Tensor:
        
        #get graph features
        x_cfg = x_cfg.mean(dim=1)
        #print(x_cfg.shape)
        x = torch.concat([x_feat,self.embedding(x_op)],dim = 1)
        #pass though conv layers
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        # get 1d graph embedding using average pooling
        x_graph = torch.mean(x,0)
        
        
        #put graph data into config data
        x = torch.concat([x_cfg,x_graph.repeat((len(x_cfg),1))],axis=1) #torch.Size([10528, 225])
        #put into dense nn
        #print(x.shape)
        x = torch.flatten(self.dense(x))
        return x

model = SimpleModel(hidden_channels = [16,32,16,48],graph_feats = 64,hidden_dim=64).to(device)
class SimpleModel2(torch.nn.Module):
    def __init__(self, hidden_channels, graph_feats, hidden_dim):
        super().__init__()
        op_embedding_dim = 4 # I choose 4-dimensional embedding
        self.embedding = torch.nn.Embedding(120, #120 different op-codes
                                            op_embedding_dim,
                                           )
        assert len(hidden_channels)>0
        in_channels = op_embedding_dim+140
        self.convs = torch.nn.ModuleList()
        last_dim = hidden_channels[0]
        self.convs.append(SAGEConv(in_channels, hidden_channels[0]))
        for i in range(len(hidden_channels)-1):
            self.convs.append(SAGEConv(hidden_channels[i], hidden_channels[i+1]))
            last_dim = hidden_channels[i+1]
        self.convs.append(SAGEConv(last_dim, graph_feats))
        
        self.dense = torch.nn.Sequential(nn.Linear(82, 64),
                                         nn.ReLU(),
                                         nn.Linear(64, 64),
                                         nn.ReLU(),
                                         nn.Linear(64, 1),
                                        )
    
    def forward(self, x_cfg: Tensor,x_feat: Tensor, x_op: Tensor, edge_index: Tensor) -> Tensor:
        
        #get graph features
        x_cfg = x_cfg.mean(dim=1)
        #print(x_cfg.shape)
        x = torch.concat([x_feat,self.embedding(x_op)],dim = 1)
        #pass though conv layers
        for conv in self.convs:
            x = conv(x, edge_index).relu()
        # get 1d graph embedding using average pooling
        x_graph = torch.mean(x,0)
        
        
        #put graph data into config data
        x = torch.concat([x_cfg,x_graph.repeat((len(x_cfg),1))],axis=1) #torch.Size([10528, 225])
        #put into dense nn
        #print(x.shape)
        x = torch.flatten(self.dense(x))
        return x

model2 = SimpleModel2(hidden_channels = [16,32,16,48],graph_feats = 64,hidden_dim=64).to(device)

In [None]:
dataset = TileDataset(layout_xla_default["test"])
tile_xla_predictions = [[] for i in range(len(dataset))]
for fold in range(5):
    model.load_state_dict(torch.load(f'/kaggle/input/fast-slow-sep/xla_defalut/layout_xla_default_best_model_{fold}.pth',map_location=torch.device('cpu') ))
    model.eval()
    pbar = tqdm(range(len(dataset)))
    for i in pbar:
        cfg_ft,nd_ft,nd_op,ind,target = dataset[i]
        cfg_ft,nd_ft,nd_op,ind,target = cfg_ft.to(device),nd_ft.to(device),nd_op.to(device),ind.to(device),target.to(device)

        out = model(cfg_ft,nd_ft,nd_op,ind)
        tile_xla_predictions[i].append(out.cpu().detach().numpy())
tile_xla_predictions = [np.argsort(np.mean(pred,axis=0))[:-1] for pred in tile_xla_predictions]
tile_xla_predictions[0]
#sub = submission_df
#sub = pd.read_csv('/kaggle/input/predict-ai-model-runtime/sample_submission.csv')
for i,filename in enumerate(layout_xla_random["test"]['file'].values):
    id = 'layout:xla:default:' +filename[:-4]
    print(id)
    sub.loc[sub.ID == id,'TopConfigs'] = ';'.join(tile_xla_predictions[i].astype(str))
sub.to_csv('submission.csv',index=False)
sub

In [None]:
dataset = TileDataset(layout_xla_random["test"])
tile_xla_predictions = [[] for i in range(len(dataset))]
for fold in range(5):
    model.load_state_dict(torch.load(f'/kaggle/input/fast-slow-sep/xla_random/layout_xla_default_best_model_{fold}.pth',map_location=torch.device('cpu') ))
    model.eval()
    pbar = tqdm(range(len(dataset)))
    for i in pbar:
        cfg_ft,nd_ft,nd_op,ind,target = dataset[i]
        cfg_ft,nd_ft,nd_op,ind,target = cfg_ft.to(device),nd_ft.to(device),nd_op.to(device),ind.to(device),target.to(device)

        out = model(cfg_ft,nd_ft,nd_op,ind)
        tile_xla_predictions[i].append(out.cpu().detach().numpy())
tile_xla_predictions = [np.argsort(np.mean(pred,axis=0))[:-1] for pred in tile_xla_predictions]
tile_xla_predictions[0]

#sub = pd.read_csv('/kaggle/input/predict-ai-model-runtime/sample_submission.csv')
for i,filename in enumerate(layout_xla_random["test"]['file'].values):
    id = 'layout:xla:random:' +filename[:-4]
    print(id)
    sub.loc[sub.ID == id,'TopConfigs'] = ';'.join(tile_xla_predictions[i].astype(str))
sub.to_csv('submission.csv',index=False)
sub

In [None]:
dataset = TileDataset(layout_nlp_default["test"])
tile_xla_predictions = [[] for i in range(len(dataset))]
for fold in range(5):
    model.load_state_dict(torch.load(f'/kaggle/input/fast-slow-nlp-v3/nlp_default/layout_xla_default_best_model_{fold}.pth',map_location=torch.device('cpu') ))
    model.eval()
    pbar = tqdm(range(len(dataset)))
    for i in pbar:
        cfg_ft,nd_ft,nd_op,ind,target = dataset[i]
        cfg_ft,nd_ft,nd_op,ind,target = cfg_ft.to(device),nd_ft.to(device),nd_op.to(device),ind.to(device),target.to(device)
        out = model(cfg_ft,nd_ft,nd_op,ind) 
        tile_xla_predictions[i].append(out.cpu().detach().numpy())
tile_xla_predictions = [np.argsort(np.mean(pred,axis=0))[:-1] for pred in tile_xla_predictions]
tile_xla_predictions[0]

#sub = pd.read_csv('/kaggle/input/predict-ai-model-runtime/sample_submission.csv')
for i,filename in enumerate(layout_nlp_default["test"]['file'].values):
    id = 'layout:nlp:default:' +filename[:-4]
    print(id)
    sub.loc[sub.ID == id,'TopConfigs'] = ';'.join(tile_xla_predictions[i].astype(str))
sub.to_csv('submission.csv',index=False)
sub

In [None]:
dataset = TileDataset(layout_nlp_random["test"])
tile_xla_predictions = [[] for i in range(len(dataset))]
for fold in range(2):
    model.load_state_dict(torch.load(f'/kaggle/input/fast-slow-sep/nlp_random/layout_xla_default_best_model_{fold}.pth',map_location=torch.device('cpu') ))
    model.eval()
    
    pbar = tqdm(range(len(dataset)))
    for i in pbar:
        cfg_ft,nd_ft,nd_op,ind,target = dataset[i]
        cfg_ft,nd_ft,nd_op,ind,target = cfg_ft.to(device),nd_ft.to(device),nd_op.to(device),ind.to(device),target.to(device)

        out = model(cfg_ft,nd_ft,nd_op,ind) 
        tile_xla_predictions[i].append(out.cpu().detach().numpy())
tile_xla_predictions = [np.argsort(np.mean(pred,axis=0))[:-1] for pred in tile_xla_predictions]
tile_xla_predictions[0]

#sub = pd.read_csv('/kaggle/input/predict-ai-model-runtime/sample_submission.csv')
for i,filename in enumerate(layout_nlp_random["test"]['file'].values):
    id = 'layout:nlp:random:' +filename[:-4]
    print(id)
    sub.loc[sub.ID == id,'TopConfigs'] = ';'.join(tile_xla_predictions[i].astype(str))
sub.to_csv('submission.csv',index=False)
sub