#  Predict AI Model Runtimes 🤖💨

## Introduction 🌟
Welcome to this Jupyter notebook developed for the Google - Fast or Slow? Predict AI Model Runtime! This notebook is designed to help you participate in the competition and to Detect sleep onset and wake from wrist-worn accelerometer data.

### Inspiration and Credits 🙌
This notebook is inspired by the work of Bhukya Satheesh
, available at [this Kaggle project](https://www.kaggle.com/code/satheeshbhukya1/google-fast-or-slow/notebook). I extend my gratitude to Bhukya Satheesh
 for sharing their insights and code.

🌟 Explore my profile and other public projects, and don't forget to share your feedback! 
👉 [Visit my Profile](https://www.kaggle.com/zulqarnainali) 👈

🙏 Thank you for taking the time to review my work, and please give it a thumbs-up if you found it valuable! 👍

## Purpose 🎯
The primary purpose of this notebook is to:
- Load and preprocess the competition data 📁
- Engineer relevant features for model training 🏋️‍♂️
- Train predictive models to make target variable predictions 🧠
- Submit predictions to the competition environment 📤

## Notebook Structure 📚
This notebook is structured as follows:
1. **Data Preparation**: In this section, we load and preprocess the competition data.
2. **Feature Engineering**: We generate and select relevant features for model training.
3. **Model Training**: We train machine learning models on the prepared data.
4. **Prediction and Submission**: We make predictions on the test data and submit them for evaluation.


## How to Use 🛠️
To use this notebook effectively, please follow these steps:
1. Ensure you have the competition data and environment set up.
2. Execute each cell sequentially to perform data preparation, feature engineering, model training, and prediction submission.
3. Customize and adapt the code as needed to improve model performance or experiment with different approaches.

**Note**: Make sure to replace any placeholder paths or configurations with your specific information.

## Acknowledgments 🙏
We acknowledge theChild Mind Institute
 organizers for providing the dataset and the competition platform.

Let's get started! Feel free to reach out if you have any questions or need assistance along the way.
👉 [Visit my Profile](https://www.kaggle.com/zulqarnainali) 👈

## 📦 Importing necessary libraries


In [1]:
# 📚 Importing  libraries
import os  # For interacting with the operating system
from pathlib import Path  # For working with file paths
from typing import Dict, Optional, List, Union, Tuple  # For defining data types
from dataclasses import dataclass  # For creating data classes
import math  # For mathematical operations
import numpy as np  # For numerical computations
import pandas as pd  # For data manipulation and analysis
from datasets import Dataset  # For handling datasets
from tqdm import tqdm  # For progress tracking
import torch  # For deep learning with PyTorch
from torch import nn  # For neural network modules
from torch.nn import functional as F  # For various functions used in neural networks
from torch.nn.utils.rnn import pad_sequence  # For padding sequences
from torch.utils.data import DataLoader  # For creating data loaders
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions  # For transformer model outputs
from transformers.pytorch_utils import apply_chunking_to_forward  # For handling chunking during model forward pass
from transformers.activations import ACT2FN  # For transformer activations
import pytorch_lightning as pl  # For PyTorch Lightning, a useful library for training
import torchmetrics as tm  # For additional metrics
# import bitsandbytes as bnb  


## 📊 Define constants and configuration values

**Explanation:**

```python
NODE_OP_CODES = 120
```

- This line defines a constant variable named `NODE_OP_CODES`.
- It assigns the value `120` to the `NODE_OP_CODES` variable.
- This variable likely represents the number of node operation codes used in your project.

```python
NODE_FEATS = 140
```

- This line defines another constant variable named `NODE_FEATS`.
- It assigns the value `140` to the `NODE_FEATS` variable.
- This variable likely represents the number of node features used in your project.

```python
CONFIG_FEATS = 24
```

- This line defines a constant variable named `CONFIG_FEATS`.
- It assigns the value `24` to the `CONFIG_FEATS` variable.
- This variable likely represents the number of configuration features used in your project.

```python
NODE_CONFIG_FEATS = 18
```

- This line defines a constant variable named `NODE_CONFIG_FEATS`.
- It assigns the value `18` to the `NODE_CONFIG_FEATS` variable.
- This variable likely represents the number of combined node and configuration features used in your project.

These lines of code are used to set specific numeric values that are used as constants or configuration parameters in your project. They provide clarity and allow for easy modification of these values throughout your code.

In [2]:
NODE_OP_CODES = 120  # Number of node operation codes
NODE_FEATS = 140     # Number of node features
CONFIG_FEATS = 24    # Number of configuration features
NODE_CONFIG_FEATS = 18  # Number of combined node and configuration features


 ##  📄  Function to Generate Tile DataFrame

**Explanation:**

- `DATA_DIR` is a constant that stores the directory path for the data files.
- `generate_tile_df` is a function used to generate a Pandas DataFrame containing information about tiles.
- Inside the function:
  - A DataFrame called `tile_df` is created with a 'paths' column containing file paths.
  - Additional columns are added to `tile_df` using the `.assign` method and lambda functions:
    - `split`: Extracts the name of the immediate parent directory.
    - `configuration`: Extracts the name of the parent's parent directory.
    - `extra`: Extracts the name of the parent's parent's parent directory.
    - `model_name`: Extracts the stem (base name without extension) of the file path.
    - `collection`: Combines 'extra' and 'configuration' with a ':' separator.
    - `ID`: Combines 'collection' and 'model_name' with a ':' separator.
    - `paths`: Converts file paths to strings.

This function processes file paths and extracts relevant information, organizing it into a DataFrame. It's a useful step for data preprocessing or exploration.


In [3]:
DATA_DIR = "../input/predict-ai-model-runtime/npz_all/npz"


def generate_tile_df() -> pd.DataFrame:
    tile_df = pd.DataFrame({'paths': [elem for elem in (Path(DATA_DIR) / 'tile').rglob("*") if elem.is_file()]}).assign(
        split=lambda df: df.paths.apply(lambda x: x.parent.name),
        configuration=lambda df: df.paths.apply(lambda x: x.parent.parent.name),
        extra=lambda df: df.paths.apply(lambda x: x.parent.parent.parent.name),
        model_name=lambda df: df.paths.apply(lambda x: x.stem),
        collection=lambda df: df.extra + ':' + df.configuration ,
        ID=lambda df: df.collection + ':' + df.model_name ,
        paths = lambda df: df.paths.apply(lambda x: str(x))
    )
    return tile_df

## 🧩 Generating and Displaying Tile DataFrame

**Explanation:**

- In this cell, you are executing the `generate_tile_df` function to create a DataFrame named `tile_df` containing information about tiles.
- `tile_df.head()` is used to display the first few rows of the DataFrame, allowing you to inspect the data and its structure.


In [4]:
# 🧩 Generate the tile DataFrame using the previously defined function
tile_df = generate_tile_df()

# 📋 Display the first few rows of the tile DataFrame
tile_df.head()


Unnamed: 0,paths,split,configuration,extra,model_name,collection,ID
0,../input/predict-ai-model-runtime/npz_all/npz/...,valid,xla,tile,resnet_v1_50_official_batch_128_bf16_2bea628b7...,tile:xla,tile:xla:resnet_v1_50_official_batch_128_bf16_...
1,../input/predict-ai-model-runtime/npz_all/npz/...,valid,xla,tile,inception_v3_batch_128_train_40fa8f86f121f00a,tile:xla,tile:xla:inception_v3_batch_128_train_40fa8f86...
2,../input/predict-ai-model-runtime/npz_all/npz/...,valid,xla,tile,inception_v3_batch_128_train_-23e94c034a65a177,tile:xla,tile:xla:inception_v3_batch_128_train_-23e94c0...
3,../input/predict-ai-model-runtime/npz_all/npz/...,valid,xla,tile,inception_v3_batch_128_train_171f4371caf28639,tile:xla,tile:xla:inception_v3_batch_128_train_171f4371...
4,../input/predict-ai-model-runtime/npz_all/npz/...,valid,xla,tile,mlperf_bert_batch_24_2x2_-25e30862c042a2b8,tile:xla,tile:xla:mlperf_bert_batch_24_2x2_-25e30862c04...


## 📊 Definition of Functions and a Custom Dataset Class

**Explanation:**

```python

def edges_adjacency(edges: torch.Tensor, add_diagonal=True) -> torch.Tensor:
   
```

- This line defines a Python function named `edges_adjacency`.
- The function takes two arguments: `edges`, which is expected to be a PyTorch tensor of shape `(num_edges, 2)` representing edges in a graph, and `add_diagonal`, a boolean parameter that indicates whether a diagonal should be added to the adjacency matrix.


```python
    adjacency_matrix = torch.zeros((edges.max() + 1, edges.max() + 1))
```

- This line initializes an adjacency matrix (`adjacency_matrix`) as a square matrix of zeros using `torch.zeros`. The matrix dimensions are determined by the maximum value in the `edges` tensor.

```python
    adjacency_matrix[edges[:, 0], edges[:, 1]] = 1
```

- Here, the edges are added to the adjacency matrix by setting the corresponding entries to `1`. This creates the connections between nodes in the graph.

```python
    if add_diagonal:
        diag_idx = torch.arange(adjacency_matrix.shape[0])
        adjacency_matrix[diag_idx, diag_idx] = 1
```

- This block of code checks if `add_diagonal` is `True`. If it is, it adds a diagonal to the adjacency matrix. The diagonal represents self-connections of nodes, which is common in some graph representations.

```python
    return adjacency_matrix
```

- Finally, the function returns the generated adjacency matrix.

In summary, this function takes a list of edges and constructs an adjacency matrix for a graph. It offers the option to include self-connections (diagonal) in the adjacency matrix, which can be useful in various graph-based tasks.


In [5]:
def edges_adjacency(edges: torch.Tensor, add_diagonal=True) -> torch.Tensor:
    """
    Generate an adjacency matrix from the edges
    Args:
        edges: Tensor of shape (num_edges, 2) with the edges
        add_diagonal: Boolean indicating if the diagonal should be added to the adjacency matrix
    Returns:
        adjacency_matrix: Tensor of shape (num_nodes, num_nodes) with the adjacency matrix
    """
    adjacency_matrix = torch.zeros((edges.max() + 1, edges.max() + 1))
    adjacency_matrix[edges[:, 0], edges[:, 1]] = 1
    if add_diagonal:
        diag_idx = torch.arange(adjacency_matrix.shape[0])
        adjacency_matrix[diag_idx, diag_idx] = 1
    return adjacency_matrix

def tile_loader(path):
    tile_dict =  dict(np.load(path))
    tile_dict = {k: torch.from_numpy(v) for k, v in tile_dict.items()}
    tile_dict['edges_adjecency'] = edges_adjacency(tile_dict['edge_index'])
    return tile_dict

def node_cls_token(elem_dict, shift_node_config_ids:bool=True):
    """
    Add a cls token to the node opcode, features, edges adjacency matrix, shift node_config_ids by 1 to account for the cls token
    Args:
        elem_dict: Dictionary with the elements of the tile
    Returns:
        elem_dict: Dictionary with the elements of the tile with the cls token
    """
    elem_dict['node_opcode'] = torch.cat([torch.tensor([0]), elem_dict['node_opcode']])
    elem_dict['node_feat'] = torch.cat([torch.zeros((1, elem_dict['node_feat'].shape[1])), elem_dict['node_feat']])
    elem_dict['edges_adjecency'] = F.pad(elem_dict['edges_adjecency'], (1,0,1,0), value=1)
    if 'node_config_ids' in elem_dict and shift_node_config_ids:
        elem_dict['node_config_ids'] = elem_dict['node_config_ids'] + 1
    return elem_dict


class TileDataset(torch.utils.data.Dataset):
    
    def __init__(self, df:pd.DataFrame ,add_cls_token:bool=True, num_configs:int=10,  max_configs:Optional[int]=None):
        self.df = df
        self.add_cls_token = add_cls_token
        self.num_configs = num_configs
        self.max_configs = max_configs  
        
    def __len__(self) -> int:
        return len(self.df)
    
    def select_configs(self, total_configs:int):
        if self.max_configs is not None:
            total_configs = min(total_configs, self.max_configs)
        if self.num_configs == -1:
            return np.arange(total_configs)
        if total_configs < self.num_configs:
            return np.random.choice(total_configs, self.num_configs, replace=True)
        return  np.random.choice(total_configs, self.num_configs, replace=False)
    
    def __getitem__(self, idx:int, selected_configs:List[int]=None):
        tile_dict = tile_loader(self.df.paths[idx])
        if selected_configs is None:
            selected_configs = self.select_configs(tile_dict['config_feat'].shape[0])
        tile_dict['node_config_feat'] = tile_dict.pop('config_feat')[selected_configs]
        tile_dict['node_config_feat'] = F.pad(tile_dict['node_config_feat'].unsqueeze(1), (0,NODE_CONFIG_FEATS))
        tile_dict['config_runtime'] = tile_dict['config_runtime'][selected_configs].float()
        tile_dict['config_runtime'] /= tile_dict['config_runtime_normalizers'][selected_configs].float()
        tile_dict['node_config_ids'] = torch.zeros((1,))
        tile_dict['selected_idxs'] = selected_configs
        if self.add_cls_token:
            tile_dict = node_cls_token(tile_dict, False)
        return tile_dict

## Creating a Tile Dataset



**Explaination**


```python
tile_dataset = TileDataset(tile_df)
```

- `TileDataset` is a custom dataset class that you defined earlier. It is designed to work with tile data for machine learning tasks.

- `tile_df` is a Pandas DataFrame containing information about tiles. This DataFrame likely includes file paths and associated metadata for the tiles.

- The code line creates an instance of the `TileDataset` class named `tile_dataset`. This instance will be used to work with the tile data for various machine learning tasks.

- When you create the `TileDataset` instance, it will use the `tile_df` DataFrame as its data source. This means that you can use `tile_dataset` to access and process the tile data stored in `tile_df`.

This code line prepares your data for training or other tasks by wrapping it in a custom dataset class, which can make it easier to work with the data in a structured and organized way.

In [6]:
tile_dataset = TileDataset(tile_df)

## Accessing Tile Data from the Dataset



**Explaination**

The provided code retrieves and inspects data from the `tile_dataset` created earlier. Let's go through it line by line:

```python
elem = tile_dataset[0]
```

- This line retrieves the first element from the `tile_dataset` by indexing it with `[0]`. In Python, when you index a dataset or list with `[0]`, you get the first item in the dataset.

- The retrieved element is stored in the variable `elem`. This element likely represents a tile or a sample from your dataset.

```python
for k, v in elem.items():
    print(k, v.shape)
```

- This block of code iterates through the key-value pairs in the `elem` dictionary, where keys represent different data components, and values are tensors.

- For each key-value pair, it prints the key (likely a data component name) and the shape of the corresponding tensor.

This code allows you to inspect the structure of the first tile or sample in your dataset by printing the names of different data components and their respective tensor shapes. It's a useful step for understanding the structure of your data and preparing it for further processing or model training.


In [7]:
elem = tile_dataset[0]
for k,v in elem.items():
    print(k, v.shape)

node_feat torch.Size([81, 140])
node_opcode torch.Size([81])
edge_index torch.Size([86, 2])
config_runtime torch.Size([10])
config_runtime_normalizers torch.Size([3246])
edges_adjecency torch.Size([81, 81])
node_config_feat torch.Size([10, 1, 42])
node_config_ids torch.Size([1])
selected_idxs (10,)


## Custom Functions for Data Preparation and Collation

**Explaination**

```python
def pad_edge_adjacency(edges_adjacency_list):
    max_len = max([elem.shape[0] for elem in edges_adjacency_list])
    return torch.stack([F.pad(elem, (0, max_len-elem.shape[0], 0, max_len-elem.shape[0]), value=0) for elem in edges_adjacency_list], dim=0)
```

- `pad_edge_adjacency` is a function that takes a list of edge adjacency matrices (`edges_adjacency_list`) and pads them to have the same size.
- It calculates the maximum length among the adjacency matrices.
- Then, it uses a list comprehension to iterate through the matrices, pad each matrix to the maximum size, and stack them into a tensor along a new dimension (dimension 0).

```python
@dataclass
class LayoutCollator:
    pad_to_multiple_of: int = 64
    targets: bool = True
    padding_idx: int = 120
    node_padding_idx: int = 0
```

- `LayoutCollator` is a dataclass that defines a collation strategy for preparing batches of data.
- It includes several parameters:
  - `pad_to_multiple_of`: Specifies the size to which data should be padded (default is 64).
  - `targets`: A boolean indicating whether targets are included in the data (default is True).
  - `padding_idx`: The index used for padding in the data (default is 120).
  - `node_padding_idx`: The index used for padding in node-related data (default is 0).

```python
def __call__(self, batch):
    # Implementation of collation logic
    # ...
    return output
```

- The `LayoutCollator` class defines a `__call__` method, which is called when an instance of this class is used as a function.
- Inside the `__call__` method, data collation and padding operations are performed based on the parameters specified when creating an instance of `LayoutCollator`.
- The processed data is returned as `output`.

The purpose of this code is to provide a data collation strategy for creating batches of data, ensuring that the data has consistent dimensions and padding. This is often necessary when working with neural network models that require inputs of the same size within a batch.


In [8]:
def pad_edge_adjacency(edges_adjacency_list):
    max_len = max([elem.shape[0] for elem in edges_adjacency_list])
    return torch.stack([F.pad(elem, (0, max_len-elem.shape[0], 0, max_len-elem.shape[0]), value=0) for elem in edges_adjacency_list], dim=0)

@dataclass
class LayoutCollator:
    pad_to_multiple_of: int = 64
    targets:bool = True
    padding_idx:int = 120
    node_padding_idx:int = 0
    
    def __call__(self, batch):
        output = {}
        max_node_len = max([elem['node_opcode'].shape[0] for elem in batch])
        node_pad_amount = self.pad_to_multiple_of - max_node_len % max(self.pad_to_multiple_of, 1)
        output['node_opcode'] = F.pad(pad_sequence([elem['node_opcode'] for elem in batch], batch_first=True, padding_value=self.padding_idx),
                                      (0, node_pad_amount), value=self.padding_idx).long()
        output['node_feat'] = F.pad(pad_sequence([elem['node_feat'] for elem in batch], batch_first=True),
                                    (0,0,0, node_pad_amount), value=0)
        output['edges_adjecency'] = F.pad(pad_edge_adjacency([elem['edges_adjecency'] for elem in batch]),
                                          (0, node_pad_amount, 0, node_pad_amount), value=0)
        output['node_attn_mask'] = F.pad(pad_sequence([torch.ones(len(elem['node_opcode'])) for elem in batch], batch_first=True),
                                         (0, node_pad_amount), value=0)

        max_node_config_len = max([elem['node_config_ids'].shape[0] for elem in batch])
        node_config_pad_amount = self.pad_to_multiple_of - max_node_config_len % max(self.pad_to_multiple_of, 1)
        output['node_config_ids'] = F.pad(pad_sequence([elem['node_config_ids'] for elem in batch], batch_first=True),
                                         (0, node_config_pad_amount), value=0).long()
        padded_node_config_feat = pad_sequence([elem['node_config_feat'].permute(1,0,2) for elem in batch], batch_first=True, padding_value=-1)
        padded_node_config_feat = F.pad(padded_node_config_feat.permute(0,2,1,3),
                                           (0,0,0, node_config_pad_amount,0,0), value=-1)
        
        output['node_config_feat'] = torch.where(padded_node_config_feat!=-1, padded_node_config_feat, self.node_padding_idx)
                                      
        output['config_idxs'] = torch.stack([torch.from_numpy(elem['selected_idxs']) for elem in batch])
        
        if self.targets:
            output['config_runtime'] = pad_sequence([elem['config_runtime'].float() for elem in batch], batch_first=True)
        return output

## Creating a Collate Function Using `LayoutCollator`

**Explaination**


```python
collate_fn = LayoutCollator(64)
```

- `LayoutCollator` is the dataclass defined earlier, which is used to define a collation strategy for preparing batches of data.

- The code line creates an instance of the `LayoutCollator` class and assigns it to the variable `collate_fn`.

- The value `64` is passed as an argument when creating the `LayoutCollator` instance. This value is used as the `pad_to_multiple_of` parameter, which specifies the size to which data should be padded.

By creating this instance, you have defined a collation function (`collate_fn`) that can be used with data loaders to prepare batches of data. The collation function ensures that data within a batch is padded and structured consistently, which is often a requirement when training neural network models.

You can later use `collate_fn` when creating data loaders to customize how data is processed and collated when loading batches. 


In [9]:
collate_fn = LayoutCollator(64)

## Collating a Batch of Data Using `collate_fn`

**Explaination**

```python
batch = collate_fn([tile_dataset[0], tile_dataset[1]])
```

- `collate_fn` is the collation function that you previously created using the `LayoutCollator` dataclass. It defines how data should be prepared and padded when creating batches.

- `[tile_dataset[0], tile_dataset[1]]` is a list containing two elements from the `tile_dataset`. These two elements represent two samples or tiles.

- The `collate_fn` is called with this list of elements, resulting in the collation of these samples into a batch. The `batch` variable holds the collated batch of data.

```python
for k, v in batch.items():
    print(k, v.shape)
```

- This block of code iterates through the key-value pairs in the `batch` dictionary.

- For each key-value pair, it prints the key (which likely represents a data component name) and the shape of the corresponding tensor (the data).

In summary, this code demonstrates how to use the `collate_fn` function to collate a batch of data from the `tile_dataset`. The collated batch contains data components such as node opcodes, features, adjacency matrices, and more. The code then inspects the shapes of these components within the batch.

This is a crucial step when working with deep learning models that require consistent batch input shapes, as it ensures that the data is appropriately padded and structured for model training.

In [10]:
batch = collate_fn([tile_dataset[0], tile_dataset[1]])
for k,v in batch.items():
    print(k,v.shape)

node_opcode torch.Size([2, 128])
node_feat torch.Size([2, 128, 140])
edges_adjecency torch.Size([2, 128, 128])
node_attn_mask torch.Size([2, 128])
node_config_ids torch.Size([2, 64])
node_config_feat torch.Size([2, 10, 64, 42])
config_idxs torch.Size([2, 10])
config_runtime torch.Size([2, 10])


## Defining`GraphConfig` Dataclass

**Explaination**

- `@dataclass` is a decorator that simplifies the creation of classes for storing data with default values.

- The `GraphConfig` class defines various hyperparameters as class attributes, each with a default value. These attributes include the number of hidden layers, hidden size, number of attention heads, intermediate size, dropout probabilities, and other hyperparameters commonly used in neural network models.

- The `__post_init__` method is defined to perform post-initialization tasks. In this case, it calculates the `embedding_size` attribute as equal to `hidden_size` by default.

- The `validate` method checks if the hidden size is a multiple of the number of attention heads and raises a `ValueError` if it's not. This is a common validation step in models using attention mechanisms.

- `save_config` and `load_config` methods are defined to save the configuration to a JSON file and load it from a JSON file, respectively. These methods allow you to save and load model configurations easily.

This `GraphConfig` dataclass serves as a convenient way to configure hyperparameters for graph-based models and provides methods for validation and serialization of configurations.


In [11]:
@dataclass
class GraphConfig:
    num_hidden_layers: int = 8
    hidden_size: int = 256
    num_attention_heads: int = 16
    intermediate_size: int = 64
    chunk_size_feed_forward: int = 64
    attention_probs_dropout_prob: float = 0.0
    max_position_embeddings: int = 512
    hidden_dropout_prob: float = 0.0
    layer_norm_eps: float = 1e-12
    hidden_act: str = 'gelu'
    initializer_range: float = 0.02
    output_hidden_states: bool = False
    output_attentions: bool = False
    gradient_checkpointing: bool = False
    margin: float = 0.1
    number_permutations: int = 10
    
    def __post_init__(self):
        self.embedding_size = self.hidden_size
    
    def validate(self):
        if self.hidden_size % self.num_attention_heads != 0 and not hasattr(self, "embedding_size"):
            raise ValueError(
                f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
                f"heads ({self.num_attention_heads})"
            )
            
    def save_config(self, path):
        config = asdict(self)
        with open(path, 'w') as f:
            json.dump(config, f)
            
    @classmethod
    def load_config(cls, path):
        with open(path, 'r') as f:
            config = json.load(f)
        return cls(**config)

## Definition of the `MultiElementRankLoss` Module

**Explaination**

This cell defines a custom PyTorch module named `MultiElementRankLoss`, which is a loss function used for comparing the model's output with the output of the model with permutations of elements.

```python
class MultiElementRankLoss(nn.Module):
    """
    Loss function that compares the output of the model with the output of the model with a permutation of the elements
    """
```

- This class definition begins with a docstring that briefly describes the purpose of the loss function.

```python
    def __init__(self, margin: float = 0.0, number_permutations: int = 1) -> None:
        super().__init__()
        self.loss_fn = torch.nn.MarginRankingLoss(margin=margin, reduction='none')
        self.number_permutations = number_permutations
```

- The `MultiElementRankLoss` class has an `__init__` method used for initializing instances of the class.
- It takes two arguments:
  - `margin`: A margin parameter for the margin ranking loss (default is 0.0).
  - `number_permutations`: The number of permutations to consider when calculating the loss (default is 1).
- Within the `__init__` method:
  - `super().__init__()` initializes the parent class (`nn.Module`).
  - `self.loss_fn` is created as an instance of `torch.nn.MarginRankingLoss` with the specified margin and a reduction mode of 'none'.
  - `self.number_permutations` is set based on the provided `number_permutations`.

```python
    def generate_permutation(self, config_attn_mask: torch.Tensor):
        # Implementation of permutation generation logic
        # ...
        return permutation
```

- The `generate_permutation` method generates a permutation of the elements in the batch based on a provided attention mask. It calculates the number of elements in each sequence, generates random permutations, and returns the resulting permutation tensor.

```python
    def permute_tensor(self, tensor: torch.Tensor, permutation: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
        # Implementation of tensor permutation logic
        # ...
        return permuted_tensor
```

- The `permute_tensor` method takes a tensor, a permutation tensor, and coordinates for elements to be permuted. It creates a new tensor with elements permuted according to the provided permutation.

```python
    def calculate_rank_loss(self, outputs: torch.Tensor, config_runtime: torch.Tensor, config_idxs: torch.Tensor):
        # Implementation of rank loss calculation logic
        # ...
        return loss
```

- The `calculate_rank_loss` method calculates a rank loss by generating a permutation of predictions and targets and comparing them using the margin ranking loss. The result is a loss tensor.

```python
    def forward(self, outputs: torch.Tensor, config_runtime: torch.Tensor, config_idxs: torch.Tensor):
        loss = 0
        for _ in range(self.number_permutations):
            loss += self.calculate_rank_loss(outputs, config_runtime, config_idxs)
        return loss / self.number_permutations
```

- The `forward` method is the main entry point of the module. It computes the loss by accumulating the rank loss for multiple permutations (controlled by `self.number_permutations`) and then returns the averaged loss.

This custom loss function, `MultiElementRankLoss`, is designed for tasks where the model's output needs to be compared with permutations of elements,  for ranking purposes. It provides flexibility in terms of the margin parameter and the number of permutations considered when computing the loss.

In [12]:
class MultiElementRankLoss(nn.Module):
    """
    Loss function that compares the output of the model with the output of the model with a permutation of the elements
    """
    
    def __init__(self, margin:float=0.0, number_permutations:int = 1) -> None:
        super().__init__()
        self.loss_fn = torch.nn.MarginRankingLoss(margin=margin, reduction = 'none')
        self.number_permutations = number_permutations
        
    def generate_permutation(self,
                             config_attn_mask: torch.Tensor
                             ):
        """
        Generate a permutation of the elements in the batch
        Args:
            config_attn_mask: Tensor of shape (bs, seq_len) with 1 in the positions of the elements
            and 0 in the positions of the padding
        Returns:
            permutation: Tensor of shape (2, bs*seq_len) with the permutation of the elements
        """
        num_elements = config_attn_mask.sum(1)
        permutation_list = [torch.randperm(int(elem)) for elem in  num_elements.cpu().numpy()]
        idxs_list = [num*torch.ones_like(elem) for num, elem in enumerate(permutation_list)]
        permutation = torch.stack([torch.cat(idxs_list), torch.cat(permutation_list)])
        return permutation
    
    def permute_tensor(self,
                       tensor:torch.Tensor,
                       permutation:torch.Tensor,
                       x:torch.Tensor,
                       y:torch.Tensor
                       ):
        """
        Permute the tensor according to the permutation
        Args:
            tensor: Tensor of shape (bs, seq_len) to be permuted
            permutation: Tensor of shape (2, bs*seq_len) with the permutation of the elements
            x: Tensor of shape (bs*seq_len) with the x coordinates of the elements to be permuted
            y: Tensor of shape (bs*seq_len) with the y coordinates of the elements to be permuted
        Returns:
            permuted_tensor: Tensor of shape (bs, seq_len) with the permuted elements
        """
        new_tensor = tensor.clone()
        new_tensor[x, y] = new_tensor[permutation[0, :], permutation[1, :]]
        return new_tensor
    
    def calculate_rank_loss(self,
                            outputs: torch.Tensor,
                            config_runtime: torch.Tensor,
                            config_idxs: torch.Tensor
                            ):
        """
        Generates a permutation of the predictions and targets and calculates the loss MarginRankingLoss against the permutation
        Args:
            outputs: Tensor of shape (bs, seq_len) with the outputs of the model
            config_runtime: Tensor of shape (bs, seq_len) with the runtime of the model
            config_mask: Tensor of shape (bs, seq_len) with 1 in the positions of the elements
            and 0 in the positions of the padding
        Returns:
            loss: Tensor of shape (bs, seq_len) with the loss for each element in the batch
        """
        bs, num_configs = outputs.shape
        permutation = torch.randperm(num_configs) 
        permuted_idxs = config_idxs[:, permutation]
        config_mask = torch.where(config_idxs != permuted_idxs, 1, 0)
        permuted_runtime = config_runtime[:, permutation]
        labels = 2*((config_runtime - permuted_runtime) > 0) -1
        permuted_output = outputs[:, permutation]
        loss = self.loss_fn(outputs.view(-1,1), permuted_output.view(-1,1), labels.view(-1,1))
        loss = loss.view(bs, num_configs) * config_mask
        return loss.mean()
                
    
    def forward(self,
                outputs: torch.Tensor,
                config_runtime: torch.Tensor,
                config_idxs: torch.Tensor
                ):
        loss = 0 
        for _ in range(self.number_permutations):
            loss += self.calculate_rank_loss(outputs, config_runtime, config_idxs)
        return loss/ self.number_permutations

## Definition of the `TileTopK` Metric

**Explaination**

This cell defines a custom metric named `TileTopK` that extends the `torchmetrics.Metric` class. This metric is designed to measure the performance of a model in a top-k ranking task for tile runtimes.

```python
class TileTopK(tm.Metric):
    
    higher_is_better = True
```

- `TileTopK` is defined as a subclass of `torchmetrics.Metric`. It's important to note that the `higher_is_better` attribute is set to `True`, indicating that higher values of this metric are considered better.

```python
    def __init__(self, k: int = 5) -> None:
        super().__init__()
        self.add_state("runtimes", default=[], dist_reduce_fx=None)
        self.k = k
```

- The `__init__` method initializes instances of the `TileTopK` metric. It takes one argument, `k`, which specifies the value of k for the top-k ranking.
- Inside the method:
  - `super().__init__()` initializes the parent class (`torchmetrics.Metric`).
  - `self.add_state("runtimes", default=[], dist_reduce_fx=None)` defines a state variable named "runtimes" to store metric values. This variable starts as an empty list and does not perform any distributed reduction.
  - `self.k` is set based on the provided `k` argument.

```python
    def update(self, preds: torch.Tensor, target: torch.Tensor, config_attn_mask: torch.Tensor) -> None:
        # Implementation of metric state update logic
        # ...
```

- The `update` method updates the metric state based on predicted values (`preds`), target values (`target`), and an attention mask (`config_attn_mask`).
- Within this method, the following logic is implemented:
  - It computes the best runtimes from the target values using the attention mask.
  - It masks the predicted values according to the attention mask and computes the indices of the bottom-k predictions.
  - It constructs `bottom_k_positions` to access the predicted runtimes corresponding to the bottom-k predictions.
  - It calculates the best predicted runtimes based on `bottom_k_positions`.
  - The results are appended to the "runtimes" state variable.

```python
    def compute(self) -> torch.Tensor:
        # Implementation of metric computation logic
        # ...
```

- The `compute` method calculates the final metric value based on the updated state.
- In this case, it computes the TileTopK metric value, which involves calculating the mean of a specific computation involving runtimes.

The `TileTopK` metric is designed to assess the model's performance in ranking tile runtimes using a top-k approach. It maintains a state variable to accumulate values during updates and computes the final metric value when requested.


In [13]:

class TileTopK(tm.Metric):
    
    higher_is_better = True
    
    def __init__(self, k:int=5) -> None:
        super().__init__()
        self.add_state("runtimes", default=[], dist_reduce_fx=None)
        self.k = k
        
    def update(self, preds: torch.Tensor, target: torch.Tensor, config_attn_mask:torch.Tensor) -> None:
        """
        Update the metric state
        Args:
            preds: Tensor of shape (bs, seq_len) with the predicted runtimes orders
            target: Tensor of shape (bs, seq_len) with the target runtimes
            config_attn_mask: Tensor of shape (bs, seq_len) with 1 in the positions of the elements
        """
        best_runtimes = torch.where(config_attn_mask==1, target, torch.tensor(float('inf'))).min(1).values
        masked_preds = torch.where(config_attn_mask==1, preds, torch.tensor(float('inf')))
        pred_bottomk_indices = torch.topk(masked_preds, k=self.k, largest=False).indices
        bs = preds.shape[0]
        bottom_k_positions = torch.stack([torch.arange(bs).repeat_interleave(self.k).to(config_attn_mask.device), pred_bottomk_indices.view(-1)])
        predicted_runtimes = target[bottom_k_positions[0], bottom_k_positions[1]].view(bs,self.k)
        best_predicted_runtimes = predicted_runtimes.min(1).values
        self.runtimes.append(best_predicted_runtimes/ best_runtimes)
        
    def compute(self) -> torch.Tensor:
        return (2-torch.cat(self.runtimes)).mean()

## Custom Model

**Explaination**:
The code defines a neural network model for graph encoding and ranking. 

1. **BertEncoder**: This class is a part of the BERT-based architecture used for graph encoding. It is responsible for stacking multiple BertLayer instances. It can handle optional output attentions and hidden states.

2. **BertLayer**: This class represents a single layer within the BertEncoder. It consists of BertAttention, BertIntermediate, and BertOutput components. Each layer processes the input sequentially through these components.

3. **BertIntermediate**: This class takes the output of the attention layer and applies a linear transformation followed by an activation function (e.g., GELU).

4. **BertOutput**: This class further processes the intermediate output, applying linear transformation, dropout, and layer normalization.

5. **BertAttention**: This class implements the self-attention mechanism. It includes the query, key, and value transformations, computes attention scores, applies dropout, and outputs context vectors.

6. **BertSelfAttention**: This class handles the core self-attention computation, including query, key, and value transformations. It also deals with positional embeddings for relative attention, if enabled.

7. **BertSelfOutput**: This class takes the attention output and applies linear transformation, dropout, and layer normalization.

8. **NodeEncoder**: This class is responsible for encoding node information. It takes node opcodes and features as input, processes them, and applies layer normalization.

9. **BertNodeEncoder**: This class combines node embeddings and node encoder layers for encoding node-level information in a BERT-like fashion. It also handles the masking of edges and self-attention.

10. **transform_node_positional_embeddings**: This function reshapes node embeddings based on node_config_ids and num_nodes.

11. **NodeFeatEmbeddings**: This class is responsible for embedding node configuration features, including positional embeddings.

12. **BertGraphEncoder**: This class combines node and node configuration embeddings, applies layer normalization, and passes the data through a BERT-like encoder.

13. **GraphEncoder**: This class represents the overall graph encoding model. It uses the BertGraphEncoder to encode graph information and generates output scores. Optionally, it computes a ranking loss based on runtime data.

This code defines a complex neural network model for encoding graph data, specifically tailored for the task of ranking elements. It combines BERT-like layers with custom components for encoding both nodes and node configurations, considering self-attention and positional embeddings.

In [14]:
# Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
class BertEncoder(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask #DONE: Same Head Mask for all layers

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs,  output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    all_hidden_states,
                    all_self_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=None,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=None,
        )
        
        
class BertLayer(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
        )
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights
        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        outputs = (layer_output,) + outputs


        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output
    
class BertIntermediate(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states
    
class BertOutput(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
    
class BertAttention(nn.Module):
    def __init__(self, config:GraphConfig, position_embedding_type=None):
        super().__init__()
        self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()


    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs
    
    
class BertSelfAttention(nn.Module):
    def __init__(self, config:GraphConfig, position_embedding_type=None):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
                f"heads ({config.num_attention_heads})"
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = position_embedding_type or getattr(
            config, "position_embedding_type", "absolute"
        )
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)


    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.

        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(mixed_query_layer)


        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            query_length, key_length = query_layer.shape[2], key_layer.shape[2]
            position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r

            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.functional.softmax(attention_scores, dim=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask #DONE: Same Head Mask for all Heads

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs


class BertSelfOutput(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states
    
    
class NodeEncoder(nn.Module):
    
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.node_opcode_embeddings = nn.Embedding(NODE_OP_CODES+1 , config.embedding_size, padding_idx=NODE_OP_CODES)
        self.linear = nn.Linear(NODE_FEATS, config.embedding_size, bias=False)
        self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        
        
    def forward(self,
                node_opcode: torch.Tensor,
                node_feat: torch.Tensor
                ) -> torch.Tensor:
        opcode_embeddings = self.node_opcode_embeddings(node_opcode) 
        node_feats =  self.linear(node_feat)
        features = opcode_embeddings + node_feats
        features = self.layer_norm(features)
        return features
    
    
class BertNodeEncoder(nn.Module):
    
    def __init__(self, config:GraphConfig) -> None:
        super().__init__()
        self.config = config
        self.node_embeddings = NodeEncoder(config)
        self.node_encoder = BertEncoder(config)
        
    def forward(self,
                node_opcode: torch.Tensor,
                node_feat: torch.Tensor,
                edges_adjecency: torch.Tensor,
                node_attn_mask: torch.Tensor
                ):
        node_embeddings = self.node_embeddings(node_opcode, node_feat)
        node_attn_mask = node_attn_mask.unsqueeze(1).unsqueeze(-1)
        node_encoder_outputs = self.node_encoder(node_embeddings,
                                                 attention_mask=node_attn_mask,
                                                 head_mask=edges_adjecency.unsqueeze(0).repeat(self.config.num_hidden_layers, 1, 1, 1).unsqueeze(2),
                                                 output_attentions=True)
        return node_encoder_outputs
    
def transform_node_positional_embeddings(embeddings_output:torch.Tensor,
                                         node_config_ids:torch.Tensor,
                                         num_nodes:int
                                         ) -> torch.Tensor:
    bs, num_configs, _, dim = embeddings_output.shape
    idxs = node_config_ids.unsqueeze(1).repeat(1,num_configs,1)
    zeros = torch.zeros(bs, num_configs, num_nodes, dim, device=embeddings_output.device, dtype=embeddings_output.dtype)
    idxs = idxs.unsqueeze(-1).repeat(1,1,1,dim)
    zeros.scatter_reduce_(2, idxs, embeddings_output, reduce='sum')
    return zeros

class NodeFeatEmbeddings(nn.Module):
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.config = config
        self.node_feat_embeddings = nn.Linear(NODE_CONFIG_FEATS + CONFIG_FEATS, config.embedding_size, bias=False)
        self.layer_norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
        
    def forward(self, node_config_feat: torch.Tensor, node_config_ids: torch.Tensor, num_nodes:int) -> torch.Tensor:
        node_config_feat_embeddings = self.node_feat_embeddings(node_config_feat)
        node_config_feat_embeddings = self.layer_norm(node_config_feat_embeddings)
        node_config_feat_embeddings = transform_node_positional_embeddings(node_config_feat_embeddings, node_config_ids, num_nodes)
        return node_config_feat_embeddings
        
    
class BertGraphEncoder(nn.Module):
    def __init__(self, config:GraphConfig) -> None:
        super().__init__()
        self.config = config
        self.node_embeddings = NodeEncoder(config)
        self.node_encoder = BertEncoder(config)
        self.node_feat_embeddings = NodeFeatEmbeddings(config)
        
    def forward(self,
                node_opcode: torch.Tensor, # (bs, num_nodes)
                node_feat: torch.Tensor, # (bs, num_nodes, num_node_feats)
                edges_adjecency: torch.Tensor, # (bs, num_nodes, num_nodes)
                node_attn_mask: torch.Tensor, # (bs, num_nodes)
                node_config_feat: torch.Tensor, # (bs, num_configs, num_config_nodes, num_node_feats)
                node_config_ids: torch.Tensor, # (bs, num_configs, num_config_nodes)
                ):
        bs, num_nodes = node_opcode.shape
        num_configs = node_config_feat.shape[1]
        node_embeddings = self.node_embeddings(node_opcode, node_feat)
        node_config_feat_embeddings = self.node_feat_embeddings(node_config_feat, node_config_ids, num_nodes)
        
        node_embeddings = node_embeddings.unsqueeze(1).repeat(1, num_configs, 1, 1)
        node_embeddings += node_config_feat_embeddings
        node_attn_mask = node_attn_mask.unsqueeze(1).repeat(1, num_configs, 1)
        node_embeddings = node_embeddings.reshape(bs *num_configs, num_nodes, -1)
        node_attn_mask = node_attn_mask.reshape(bs *num_configs, num_nodes)
        node_attn_mask = node_attn_mask.unsqueeze(1).unsqueeze(-1)
        edges_adjecency = edges_adjecency.unsqueeze(1).repeat(1, num_configs, 1, 1).reshape(bs *num_configs, num_nodes, num_nodes)
        edges_adjecency = edges_adjecency.unsqueeze(1)
        

        node_encoder_outputs = self.node_encoder(node_embeddings,
                                                 attention_mask=node_attn_mask,
                                                 # head_mask=edges_adjecency.unsqueeze(0).repeat(self.config.num_hidden_layers, 1, 1, 1).unsqueeze(2),
                                                 head_mask=edges_adjecency,
                                                 output_attentions=True)
        
        return node_encoder_outputs.last_hidden_state.reshape(bs, num_configs, num_nodes, -1)
    
    
class GraphEncoder(nn.Module):
    
    config_class = GraphConfig
    
    def __init__(self, config:GraphConfig):
        super().__init__()
        self.config = config
        self.node_encoder = BertGraphEncoder(config)
        self.head = nn.Linear(config.hidden_size, 1)
        self.loss_fn = MultiElementRankLoss(margin=config.margin, number_permutations=config.number_permutations)
        
        
    def forward(self,
                node_opcode: torch.Tensor, # (bs, num_nodes)
                node_feat: torch.Tensor, # (bs, num_nodes, num_node_feats)
                edges_adjecency: torch.Tensor, # (bs, num_nodes, num_nodes)
                node_attn_mask: torch.Tensor, # (bs, num_nodes)
                node_config_feat: torch.Tensor, # (bs, num_configs, num_config_nodes, num_node_feats)
                node_config_ids: torch.Tensor, # (bs, num_configs, num_config_nodes)
                config_idxs: Optional[torch.Tensor] = None, # (bs, num_configs)
                config_runtime: Optional[torch.Tensor] = None,):
        
        last_hidden_state = self.node_encoder(node_opcode,
                                    node_feat,
                                    edges_adjecency,
                                    node_attn_mask,
                                    node_config_feat,
                                    node_config_ids)
        
        output = self.head(last_hidden_state[:,:,0]).squeeze(-1)
        outputs = {'outputs': output, 'order': torch.argsort(output, dim=1)}
        if config_runtime is not None:
            loss = 0
            loss += self.loss_fn(output, config_runtime, config_idxs)
            outputs['loss'] = loss
        return outputs

##  `LightningWrapper` class

**Explaination**:

The `LightningWrapper` class is a PyTorch Lightning module that serves as a wrapper around your neural network model, allowing you to train, validate, and test the model using the PyTorch Lightning framework. :

1. **Initialization**:
   - The `__init__` method initializes the `LightningWrapper` object. It takes a neural network model (`model`) as a parameter.
   - It initializes a `TileTopK` instance called `topk`. This is used for computing top-k metrics during validation.

2. **Forward Pass**:
   - The `forward` method defines how input data should be passed through the model. It delegates the forward pass to the underlying `model`.

3. **Training Step**:
   - The `training_step` method is used for computing the loss during training. It takes a batch of data (`batch`) and a batch index (`batch_idx`) as parameters.
   - It calls the forward pass of the model with the input batch (`batch`) and retrieves the `loss` from the model's output.
   - It returns the computed loss, which will be used for optimization.

4. **Validation Step**:
   - The `validation_step` method is similar to `training_step` but is used during the validation phase.
   - It also calls the forward pass of the model with the input batch and retrieves the loss.
   - Additionally, it logs the validation loss using `self.log` and updates the `TileTopK` instance with top-k metrics using `self.topk.update`.

5. **Validation End**:
   - The `on_validation_end` method is called at the end of the validation phase. It computes the top-k metric from the collected data using `self.topk.compute()` and prints it.
   - After printing, it resets the `TileTopK` instance using `self.topk.reset()`.

6. **Test Step**:
   - The `test_step` method is used for computing the loss during testing. It takes a batch of data (`batch`) and a batch index (`batch_idx`) as parameters.
   - It computes the predicted values (`y_hat`) by calling the model's forward pass with the input data.
   - It also computes the loss between the predicted values and the ground truth (`y`).
   - The test loss is logged using `self.log`.

7. **Optimizer Configuration**:
   - The `configure_optimizers` method configures the optimizer to be used during training. In this case, it sets up an AdamW optimizer for the model's parameters with a learning rate of 1e-3.

The `LightningWrapper` class is designed to work seamlessly with PyTorch Lightning. It provides hooks for training, validation, and testing, and it allows you to log and track relevant metrics during these phases. This makes it easier to train and evaluate your neural network model using PyTorch Lightning's high-level abstractions.

In [15]:
class LightningWrapper(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.model = model
        self.topk = TileTopK()
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        return outputs['loss']

    def validation_step(self, batch, batch_idx):
        outputs = self.model(**batch)
        loss = outputs['loss']
        self.log("val_loss", loss, prog_bar=True)
        config_attn_mask = torch.ones_like(batch['config_runtime'], device=batch['config_runtime'].device)
        self.topk.update(outputs['outputs'], batch['config_runtime'], config_attn_mask)
        return loss
    
    def on_validation_end(self) -> None:
        topk = self.topk.compute()
        self.print(f"topk {topk:.3f}")
        self.topk.reset()
        return super().on_validation_end()

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.model.loss(y_hat, y)
        self.log("test_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=1e-3)
        return optimizer

## Training

**Explaination**:

- `hidden_size`: The size of the hidden layers in the model. In this case, it's set to 128.

- `num_attention_heads`: The number of attention heads in the model's attention mechanism. You've set it to 4, which means the model will use 4 parallel attention heads.

- `num_hidden_layers`: The total number of hidden layers in the model. You've set it to 2, which means the model will have 2 layers.

- `intermediate_size`: The size of the intermediate (feed-forward) layers in the model. It's set to 64.

- `gradient_checkpointing`: A boolean flag indicating whether gradient checkpointing should be used. When set to `True`, gradient checkpointing can help reduce memory usage during training by recomputing intermediate activations as needed. 

- `margin`: A margin value used in the loss function (e.g., MarginRankingLoss). It's set to 0.1, indicating a small margin for ranking losses.

- `number_permutations`: The number of permutations to be used in the loss function. This is related to how the model's outputs are compared to permutations of the inputs during training.


In [16]:
config_kwargs = dict(hidden_size= 128,
    num_attention_heads= 4,
    num_hidden_layers= 2,
    intermediate_size= 64,
    gradient_checkpointing= True,
    margin= 0.1,
    number_permutations= 4,
    )

 we use this config object to create an instance of your model with the desired architecture and hyperparameters.

In [17]:
config = GraphConfig(**config_kwargs)

In [18]:
model = GraphEncoder(config)
model = LightningWrapper(model)

In [19]:
train_df = tile_df.query("split == 'train'").reset_index(drop=True)
valid_df = tile_df.query("split == 'valid'").reset_index(drop=True)
train_dataset = TileDataset(train_df, num_configs=24)
valid_dataset = TileDataset(valid_df, num_configs=24)

In [20]:
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=8, num_workers=2, shuffle=True, persistent_workers=True)
valid_dataloader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=8, num_workers=2)

## `trainer_config`

**Explaination** :

- `max_epochs`: The maximum number of training epochs.
- `precision`: The numerical precision used for training (e.g., 32-bit floating-point precision).
- `gradient_clip_val`: The maximum gradient value allowed during training (gradient clipping).
- `accumulate_grad_batches`: The number of batches over which gradients are accumulated before performing an optimization step. This can be useful for simulating larger batch sizes when you have limited GPU memory.
- `check_val_every_n_epoch`: How often to perform validation during training (every `n` epochs).


In [21]:
trainer_config = dict(
    max_epochs= 50,
    precision= 32,
    gradient_clip_val= 1.0,
    accumulate_grad_batches= 4,
    check_val_every_n_epoch= 10)

**Explaination**:
Configuring and training a PyTorch Lightning model using the specified trainer configuration and your previously defined model and data loaders.:

1. `torch.set_float32_matmul_precision("medium")`: This line sets the float32 matrix multiplication precision to "medium," which means that PyTorch will use a medium-level precision for matrix multiplications during training. This can help control the trade-off between training speed and numerical stability.

2. `trainer = pl.Trainer(**trainer_config)`: You're creating a PyTorch Lightning Trainer instance with the configuration specified in the `trainer_config` dictionary. This trainer will be used to train your model.

3. `trainer.fit(model, train_dataloader, valid_dataloader)`: You're calling the `fit` method of the trainer to start the training process. This method takes the following arguments:
   - `model`: The PyTorch Lightning model you want to train.
   - `train_dataloader`: The data loader for training data.
   - `valid_dataloader`: The data loader for validation data.

This code will train your model for the specified number of epochs (as defined in `trainer_config`) while monitoring the validation loss and other metrics. The trainer will handle the training loop, logging, and other aspects of the training process using PyTorch Lightning's functionality.

In [22]:
torch.set_float32_matmul_precision("medium")
trainer = pl.Trainer(**trainer_config,)
trainer.fit(model, train_dataloader, valid_dataloader)

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

topk 0.983


Validation: 0it [00:00, ?it/s]

topk 0.980


Validation: 0it [00:00, ?it/s]

topk 0.989


Validation: 0it [00:00, ?it/s]

topk 0.990


Validation: 0it [00:00, ?it/s]

topk 0.992


In [23]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [24]:
split = 'test'
test_tile_df = tile_df.query("split == @split").reset_index(drop=True)
test_tile_ds = TileDataset(test_tile_df, num_configs=-1)
collate_fn = LayoutCollator(64, targets=split!="test")
test_dataloader = DataLoader(test_tile_ds, batch_size=1, shuffle=False, num_workers=0, collate_fn=collate_fn)

In [25]:
model.to(device)
model = model.eval()


 ## `chunk_batch` function

** Explanation** :

1. The function starts by creating an `output` dictionary, which will store the selected components of the batch. These components include 'node_opcode', 'node_feat', 'edges_adjecency', 'node_attn_mask', and 'node_config_ids'.

2. Next, it slices the 'node_config_feat' component of the batch using Python's slicing notation. The slice `[:, start_idx: end_idx]` selects a portion of the 'node_config_feat' tensor, where `start_idx` is the starting index and `end_idx` is the ending index (exclusive).

3. The sliced 'node_config_feat' is added to the `output` dictionary under the key 'node_config_feat'.

4. Finally, the function returns the `output` dictionary, which contains the selected components of the batch along with the sliced 'node_config_feat', creating a smaller chunk of the batch.

In [26]:
def chunk_batch(batch, start_idx, end_idx):
    # Create an output dictionary to store the selected batch components.
    output = {k: batch[k] for k in ['node_opcode', 'node_feat', 'edges_adjecency', 'node_attn_mask', 'node_config_ids']}
    
    # Slice the 'node_config_feat' component to create a smaller chunk.
    output['node_config_feat'] = batch['node_config_feat'][:, start_idx: end_idx]
    
    return output


**Explaination**:
1. `pred_order = []`: You initialize an empty list to store the predicted orders.

2. `for batch in tqdm(test_dataloader)`: You iterate through batches in the `test_dataloader`. `tqdm` is used to display a progress bar while iterating through the batches.

3. `batch.pop('config_idxs')`: You remove the 'config_idxs' key from the batch dictionary. This is likely because 'config_idxs' is not needed for making predictions on the test set.

4. `batch = {k: v.to(device) for k, v in batch.items()}`: You move the data in the batch to the specified device (e.g., GPU) for inference.

5. `num_configs = batch['node_config_feat'].shape[1]`: You determine the number of configurations in the batch.

6. `configs_cut_points = list(range(0, num_configs, 100)) + [num_configs]`: You create a list of cut points to divide the batch into smaller chunks. Each chunk will be processed separately.

7. `chunk_order = []`: You initialize an empty list to store the predicted orders for the current chunk.

8. Inside the loop over `configs_cut_points`, you split the batch into smaller chunks using the `chunk_batch` function (which is assumed to be defined elsewhere). For each chunk, you perform the following steps:
   - You make predictions (inference) using the trained model by passing the chunked batch to `model.model(**chunked_batch)`. This returns an output dictionary.
   - You extend the `chunk_order` list with the predicted outputs from the current chunk.

9. After processing all chunks, you concatenate the predicted orders from each chunk and use `np.argsort` to find the indices that would sort the concatenated array. Then, you select the first five indices with the smallest values and append them to `pred_order`. These are your top 5 predictions for each batch.

The `pred_order` list will contain the top 5 predicted orders for each batch in the test dataset.

In [27]:
pred_order = []
for batch in tqdm(test_dataloader):
    batch.pop('config_idxs')
    batch = {k: v.to(device) for k, v in batch.items()}
    num_configs = batch['node_config_feat'].shape[1]
    configs_cut_points = list(range(0,num_configs, 100)) + [num_configs]
    chunk_order = []
    for start, end in zip(configs_cut_points, configs_cut_points[1:]):
        chunked_batch = chunk_batch(batch, start, end)
        with torch.no_grad():
            output = model.model(**chunked_batch)
        chunk_order.extend(output['outputs'].cpu().numpy())
    pred_order.append(np.argsort(np.concatenate(chunk_order))[:5])

100%|██████████| 844/844 [19:58<00:00,  1.42s/it]


In [28]:
idxs_string = [";".join(map(str,elem)) for elem in pred_order]
test_tile_df['TopConfigs'] = idxs_string
test_tile_df = test_tile_df[['ID', 'TopConfigs']]
test_tile_df.head()

Unnamed: 0,ID,TopConfigs
0,tile:xla:04ae9238c653f8ae08f60f2c03615f0b,746;788;688;709;554
1,tile:xla:85d157d3b1848c6b6fff0c633876e2e6,5560;4526;5409;1066;910
2,tile:xla:862900d42397d03be2762e1bf7518bea,935;287;1409;1344;161
3,tile:xla:0afa527a7022415fda1dd69d11e908a4,210;212;158;234;176
4,tile:xla:2d09e3ab92e184c561abaf8d9efe7b87,170;147;24;89;6


## submission.csv

This code essentially filters out specific rows from 'submission_df' and appends the rows from 'test_tile_df' before saving the combined DataFrame to a new CSV file named 'submission.csv'.

In [29]:
submission_df = pd.read_csv('../input/predict-ai-model-runtime/sample_submission.csv')
submission_df = submission_df.query(f"ID not in {test_tile_df.ID.tolist()}")
submission_df = pd.concat([test_tile_df, submission_df])
submission_df.to_csv('submission.csv', index=False)
submission_df

Unnamed: 0,ID,TopConfigs
0,tile:xla:04ae9238c653f8ae08f60f2c03615f0b,746;788;688;709;554
1,tile:xla:85d157d3b1848c6b6fff0c633876e2e6,5560;4526;5409;1066;910
2,tile:xla:862900d42397d03be2762e1bf7518bea,935;287;1409;1344;161
3,tile:xla:0afa527a7022415fda1dd69d11e908a4,210;212;158;234;176
4,tile:xla:2d09e3ab92e184c561abaf8d9efe7b87,170;147;24;89;6
...,...,...
889,layout:nlp:random:60880ed76de53f4d7a1b960b24f2...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
890,layout:nlp:random:23559853d9702baaaacbb0c83fd3...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
891,layout:nlp:random:f6c146fc5cf10be4f3accbaca989...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...
892,layout:nlp:random:32531d07a084b319dce484f53a4c...,0;1;2;3;4;5;6;7;8;9;10;11;12;13;14;15;16;17;18...


## Explore More! 👀
Thank you for exploring this notebook! If you found this notebook insightful or if it helped you in any way, I invite you to explore more of my work on my profile.

👉 [Visit my Profile](https://www.kaggle.com/zulqarnainali) 👈

## Feedback and Gratitude 🙏
We value your feedback! Your insights and suggestions are essential for our continuous improvement. If you have any comments, questions, or ideas to share, please don't hesitate to reach out.

📬 Contact me via email: [zulqar445ali@gmail.com](mailto:zulqar445ali@gmail.com)

I would like to express our heartfelt gratitude for your time and engagement. Your support motivates us to create more valuable content.

Happy coding and best of luck in your data science endeavors! 🚀
