In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import math
import sqlite3
import pandas as pd
from timm.models.layers import drop_path, trunc_normal_
import torch.utils.checkpoint as checkpoint
from torch.utils.data import ConcatDataset, DataLoader, Dataset, Sampler, SequentialSampler #random_split
from torch.nn.utils.rnn import pad_sequence
# from torch_geometric.utils import to_dense_batch
# from torch_geometric.nn.pool import knn_graph
# from torch_geometric.typing import Adj
# from torch_scatter import scatter_max, scatter_mean, scatter_min, scatter_sum
from typing import Any, Callable, List, Dict, Optional, Sequence, Tuple, Union, Iterator
from torch import Tensor, LongTensor

  warn(
  from .autonotebook import tqdm as notebook_tqdm


# Dataset, batch_sampler, collate_fn and dataloader

In [2]:
def combine_dom_types_and_rde_one_hot( dom_type, rde):
        # pDom dom with low efficiency
        pdom_low_qe = ((dom_type == 20) & (rde == 1)).float().unsqueeze(-1)
        # pDOM dom with high efficiency
        pdom_high_qe = ((dom_type == 20) & (rde == 1.35)).float().unsqueeze(-1)
        # pDOM upgrade == 110
        pdom_upgrade = (dom_type == 110).float().unsqueeze(-1)
        # D-EGG == 120
        d_egg = (dom_type == 120).float().unsqueeze(-1)
        # mDOM == 130
        mdom = (dom_type == 130).float().unsqueeze(-1)

        return torch.cat([pdom_low_qe, pdom_high_qe, pdom_upgrade, d_egg, mdom], dim=-1)

def combine_dom_types_and_rde(dom_type, rde):
    # pDom dom with low efficiency
    pdom_low_qe = ((dom_type == 20) & (rde == 1)).long() * 0
    # pDOM dom with high efficiency
    pdom_high_qe = ((dom_type == 20) & (rde == 1.35)).long() * 1
    # pDOM upgrade == 110
    pdom_upgrade = (dom_type == 110).long() * 2
    # D-EGG == 120
    d_egg = (dom_type == 120).long() * 3
    # mDOM == 130
    mdom = (dom_type == 130).long() * 4

    return pdom_low_qe + pdom_high_qe + pdom_upgrade + d_egg + mdom

class ChunkDataset(Dataset):
    """
    PyTorch dataset for loading chunked data from an SQLite database.
    This dataset retrieves pulsemap and truth data for each event from the database.

    Args:
        db_filename (str): Filename of the SQLite database.
        csv_filenames (list of str): List of CSV filenames containing event numbers.
        pulsemap_table (str): Name of the table containing pulsemap data.
        truth_table (str): Name of the table containing truth data.
        truth_variable (str): Name of the variable to query from the truth table.
        feature_variables (list of str): List of variable names to query from the pulsemap table.
    """

    def __init__(
        self,
        db_path: str,
        chunk_csvs: List[str],
        pulsemap: str,
        truth_table: str,
        target_cols: str,
        input_cols: List[str]
    ) -> None:
        self.conn = sqlite3.connect(db_path)  # Connect to the SQLite database
        self.c = self.conn.cursor()
        self.event_nos = []
        for csv_filename in chunk_csvs:
            df = pd.read_csv(csv_filename)
            self.event_nos.extend(df['event_no'].tolist())  # Collect event numbers from CSV files
        self.pulsemap = pulsemap  # Name of the table containing pulsemap data
        self.truth_table = truth_table  # Name of the table containing truth data
        self.target_cols = target_cols  # Name of the variable to query from the truth table
        self.input_cols = input_cols  # List of variable names to query from the pulsemap table


    def __len__(self) -> int:
        return len(self.event_nos)
    
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        event_no = idx # self.event_nos[idx]

        # Query the truth variable for the given event number
        self.c.execute(f"SELECT {self.target_cols} FROM {self.truth_table} WHERE event_no = ?", (event_no,))
        truth_value = self.c.fetchone()[0]
        

        pos_cols = ['dom_x', 'dom_y', 'dom_z']

        rde_index = self.input_cols.index('rde')
        dom_type_index = self.input_cols.index('dom_type')
        pos_indices = [self.input_cols.index(col) for col in pos_cols]
        rest_indices = [i for i in range(len(self.input_cols)) if i not in [rde_index, dom_type_index] + pos_indices]

        input_query = ', '.join(self.input_cols)
        # Query the feature variables from the pulsemap table for the given event number
        self.c.execute(f"SELECT {input_query} FROM {self.pulsemap} WHERE event_no = ?", (event_no,))
        pulsemap_data_rows = self.c.fetchall()
    
        # Convert pulsemap_data_rows into a dictionary of tensors
        
        pulsemap_data = {self.input_cols[i]: torch.tensor( [row[i] for row in pulsemap_data_rows], dtype=torch.float32)
                        for i in rest_indices}
        
        # Get the necessary data for combined_dom_type and pos
        dom_type_data = [row[dom_type_index] for row in pulsemap_data_rows] #[row[i] for row in pulsemap_data_rows for i in combined_dom_type_indices]
        rde_data = [row[rde_index] for row in pulsemap_data_rows]
        pos_data = [[row[i] for row in pulsemap_data_rows] for i in pos_indices]

        pulsemap_data["combined_dom_type"] = combine_dom_types_and_rde(torch.tensor(dom_type_data, dtype=torch.float32),
                                                                    torch.tensor(rde_data, dtype=torch.float32))
        pulsemap_data["pos"] = torch.stack([torch.tensor(col_data, dtype=torch.float32) for col_data in pos_data], dim=1)
        pulsemap_data["L0"] = torch.tensor(len(pulsemap_data_rows), dtype=torch.int32)
        return pulsemap_data, torch.tensor(truth_value, dtype=torch.float32)
    


class ChunkSampler(Sampler):
    """
    PyTorch sampler for creating chunks from event numbers.

    Args:
        csv_filenames (List[str]): List of CSV filenames containing event numbers.
        batch_sizes (List[int]): List of batch sizes for each CSV file.
    """

    def __init__(
        self, 
        chunk_csvs: List[str], 
        batch_sizes: List[int]
    ) -> None:
        self.event_nos = []
        for csv_filename, batch_size in zip(chunk_csvs, batch_sizes):
            event_nos = pd.read_csv(csv_filename)['event_no'].tolist()
            self.event_nos.extend([event_nos[i:i + batch_size] for i in range(0, len(event_nos), batch_size)])

    def __iter__(self) -> Iterator:
        return iter(self.event_nos)

    def __len__(self) -> int:
        return len(self.event_nos)
    

def collate_fn(batch):
    """
    This collate function is specifically designed for the dataset that
    returns a dictionary of tensors. It will pad the sequences to the same
    length and concatenate along the batch dimension given a list of such 
    dictionaries (i.e., a batch).
    """
    batch_keys = batch[0][0].keys()
    collated_batch = {}

    max_len = max(max(item[0][key].size(0) for item in batch) for key in batch_keys if key != 'L0')

    for key in batch_keys:
        if key != 'L0':
            # Pad the sequences to the same length and stack along a new batch dimension
            collated_batch[key] = torch.stack([torch.cat([item[0][key], item[0][key].new_zeros(max_len - item[0][key].size(0), item[0][key].size(1)) if len(item[0][key].shape) > 1 else item[0][key].new_zeros(max_len - item[0][key].size(0))]) for item in batch])
        else:
            # If the key is 'L0', simply collect the values into a list
            collated_batch[key] = torch.tensor([item[0][key] for item in batch])

    # Create a mask that indicates where the original sequence ends and the padding begins
    collated_batch['mask'] = collated_batch['L0'].new_ones((len(batch), max_len)).bool()
    for i, l0 in enumerate(collated_batch['L0']):
        collated_batch['mask'][i, l0:] = False
    # Stack all target tensors along a new batch dimension
    targets = torch.stack([item[1] for item in batch])
    
    return collated_batch, targets

## tests

In [3]:
chunk_csv_train = [
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_1.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_2.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_3.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_4.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_5.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_6.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_7.csv",
]
chunk_csv_test = [
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_1.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_2.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_3.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_4.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_5.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_6.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_7.csv",
]
chunk_csv_val = [
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_1.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_2.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_3.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_4.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_5.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_6.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_7.csv",
]
batch_sizes = [512, 256, 128, 64, 32, 16, 8]
truth_table = "truth"
db_path = "/groups/icecube/petersen/GraphNetDatabaseRepository/Upgrade_Data/sqlite3/dev_step4_upgrade_028_with_noise_dynedge_pulsemap_v3_merger_aftercrash.db"
pulsemap = "SplitInIcePulses_dynedge_v2_Pulses"
input_cols =  ["dom_x", "dom_y", "dom_z", "dom_time", "charge", "rde", "dom_type"]
target_cols = "inelasticity"

print()
print("Dataset test")
print()

dataset = ChunkDataset(
    db_path=db_path, 
    chunk_csvs=chunk_csv_train, 
    pulsemap=pulsemap, 
    truth_table=truth_table, 
    target_cols=target_cols, 
    input_cols=input_cols
    )
for i, (features, truth) in enumerate(dataset):
    print(i)
    # print(features, truth)
    print(features['L0'])
    if i >= 2:
        break
features, truth = dataset[6]

    # print(features, truth)
print(features['L0'])

dl = DataLoader(
    dataset=dataset,
    collate_fn=collate_fn,
    batch_sampler=ChunkSampler(chunk_csv_train, batch_sizes),
    num_workers=0,
    )
print()
print("DataLoader test")
print()
for i, (features, truth) in enumerate(dl):
    print(i)
    print(features['L0'].shape)
    print(features['pos'].shape)
    # print(features['L0'][0].item())
    # print(features['L0'].max(), features['L0'].min())

    if i == 2:
        break


Dataset test

0
tensor(0, dtype=torch.int32)
1
tensor(0, dtype=torch.int32)
2
tensor(0, dtype=torch.int32)
tensor(24, dtype=torch.int32)

DataLoader test

0
torch.Size([256])
torch.Size([256, 9, 3])
1
torch.Size([256])
torch.Size([256, 9, 3])
2
torch.Size([256])
torch.Size([256, 9, 3])


# Model

In [4]:
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

    def extra_repr(self) -> str:
        return "p={}".format(self.drop_prob)


class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement
        x = self.fc2(x)
        x = self.drop(x)
        return x


# BEiTv2 block
class Block(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        init_values=None,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        window_size=None,
        attn_head_dim=None,
        **kwargs,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = nn.MultiheadAttention(
            dim, num_heads, dropout=drop, batch_first=True,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

        if init_values is not None:
            self.gamma_1 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
            self.gamma_2 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        if self.gamma_1 is None:
            xn = self.norm1(x)
            x_attn, attn_weights = self.attn(
                    xn,
                    xn,
                    xn,
                    attn_mask=attn_mask,
                    key_padding_mask=key_padding_mask,
                    need_weights=True,
                    average_attn_weights=False
                )
            x = x + self.drop_path(x_attn)
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            xn = self.norm1(x)
            x_attn, attn_weights = self.attn(
                    xn,
                    xn,
                    xn,
                    attn_mask=attn_mask,
                    key_padding_mask=key_padding_mask,
                    need_weights=True,
                    average_attn_weights=False
                )
            x = x + self.drop_path(x_attn)
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))

        return x, attn_weights


class Attention_rel(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        attn_head_dim=None,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        if attn_head_dim is not None:
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.proj_q = nn.Linear(dim, all_head_dim, bias=False)
        self.proj_k = nn.Linear(dim, all_head_dim, bias=False)
        self.proj_v = nn.Linear(dim, all_head_dim, bias=False)
        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
        else:
            self.q_bias = None
            self.v_bias = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, q, k, v, rel_pos_bias=None, key_padding_mask=None):
        # rel_pos_bias: B L L C/h
        # key_padding_mask - float with -inf
        B, N, C = q.shape
        # qkv_bias = None
        # if self.q_bias is not None:
        #    qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        # qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        q = F.linear(input=q, weight=self.proj_q.weight, bias=self.q_bias)
        q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        k = F.linear(input=k, weight=self.proj_k.weight, bias=None)
        k = k.reshape(B, k.shape[1], self.num_heads, -1).permute(0, 2, 1, 3)
        v = F.linear(input=v, weight=self.proj_v.weight, bias=self.v_bias)
        v = v.reshape(B, v.shape[1], self.num_heads, -1).permute(0, 2, 1, 3)
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)

        if rel_pos_bias is not None:
            bias = torch.einsum("bhic,bijc->bhij", q, rel_pos_bias).type_as(attn)
            attn = attn + bias
        if key_padding_mask is not None:
            assert (
                key_padding_mask.dtype == torch.float32
                or key_padding_mask.dtype == torch.float16
            ), "incorrect mask dtype"
            bias = torch.min(key_padding_mask[:, None, :], key_padding_mask[:, :, None]).type_as(attn)
            bias[
                torch.max(key_padding_mask[:, None, :], key_padding_mask[:, :, None])
                < 0
            ] = 0
            attn = attn + bias.unsqueeze(1)

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2)
        if rel_pos_bias is not None:
            x = x + torch.einsum("bhij,bijc->bihc", attn, rel_pos_bias)
        x = x.reshape(B, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x, attn


# BEiTv2 block
class Block_rel(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        init_values=None,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        window_size=None,
        attn_head_dim=None,
        **kwargs,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_rel(
            dim, num_heads, attn_drop=attn_drop, qkv_bias=qkv_bias
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

        if init_values is not None:
            self.gamma_1 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
            self.gamma_2 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, key_padding_mask=None, rel_pos_bias=None, kv=None):
        if self.gamma_1 is None:
            xn = self.norm1(x)
            kv = xn if kv is None else self.norm1(kv)
            x_attn, attn_weights = self.attn(
                    xn,
                    kv,
                    kv,
                    rel_pos_bias=rel_pos_bias,
                    key_padding_mask=key_padding_mask,
                )
            x = x + self.drop_path(x_attn)
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            xn = self.norm1(x)
            kv = xn if kv is None else self.norm1(kv)
            x_attn, attn_weights =self.attn(
                        xn,
                        kv,
                        kv,
                        rel_pos_bias=rel_pos_bias,
                        key_padding_mask=key_padding_mask,
                    )
            x = x + self.drop_path( self.gamma_1 * self.drop_path(x_attn))
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x, attn_weights


class LocalBlock(nn.Module):
    def __init__(
        self,
        dim=192,
        num_heads=192 // 64,
        mlp_ratio=4,
        drop_path=0,
        init_values=1,
        **kwargs,
    ):
        super().__init__()
        self.proj_rel_bias = nn.Linear(dim // num_heads, dim // num_heads)
        self.block = Block_rel(
            dim=dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            drop_path=drop_path,
            init_values=init_values,
        )

    def forward(self, x, nbs, key_padding_mask=None, rel_pos_bias=None):
        B, Lmax, C = x.shape
        mask = (
            key_padding_mask
            if not (key_padding_mask is None)
            else torch.ones(B, Lmax, dtype=torch.bool)
        )

        m = torch.gather(mask.unsqueeze(1).expand(-1, Lmax, -1), 2, nbs)
        attn_mask = torch.zeros(m.shape)
        attn_mask[~mask] = -torch.inf
        attn_mask = attn_mask[mask]

        if rel_pos_bias is not None:
            rel_pos_bias = torch.gather(
                rel_pos_bias,
                2,
                nbs.unsqueeze(-1).expand(-1, -1, -1, rel_pos_bias.shape[-1]),
            )
            rel_pos_bias = rel_pos_bias[mask]
            rel_pos_bias = self.proj_rel_bias(rel_pos_bias).unsqueeze(1)

        xl = torch.gather(
            x.unsqueeze(1).expand(-1, Lmax, -1, -1),
            2,
            nbs.unsqueeze(-1).expand(-1, -1, -1, C),
        )
        xl = xl[mask]
        # modify only the node (0th element)
        xl = self.block(
            xl[:, :1],
            rel_pos_bias=rel_pos_bias,
            key_padding_mask=attn_mask[:, :1],
            kv=xl,
        )
        x = torch.zeros(x.shape, dtype=xl.dtype)
        x[mask] = xl.squeeze(1)
        return x

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim) * (-emb) ).type_as(x[..., None])
        emb = x[..., None] * emb[None, ...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Extractor(nn.Module):
    def __init__(self, dim_base=128, dim=384):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim_base)
        self.combined_dom_type_emb = nn.Embedding(5, dim_base // 2)
        self.emb2 = SinusoidalPosEmb(dim=dim_base // 2)
        self.proj = nn.Sequential(
            nn.Linear(6 * dim_base, 6 * dim_base),
            nn.LayerNorm(6 * dim_base),
            nn.GELU(),
            nn.Linear(6 * dim_base, dim),
        )

    def forward(self, x, Lmax=None):
        pos = x["pos"] if Lmax is None else x["pos"][:, :Lmax]
        charge = x["charge"] if Lmax is None else x["charge"][:, :Lmax]
        time = x["dom_time"] if Lmax is None else x["dom_time"][:, :Lmax]

        combined_dom_type = x["combined_dom_type"] if Lmax is None else x["combined_dom_type"][:, :Lmax]
        combined_dom_type.to(dtype=pos.dtype)
        length = torch.log10(x["L0"].to(dtype=pos.dtype))

        x = torch.cat(
            [
                self.emb(4096 * pos).flatten(-2),
                self.emb(1024 * charge),
                self.emb(4096 * time),
                self.combined_dom_type_emb(combined_dom_type),
                self.emb2(length).unsqueeze(1).expand(-1, pos.shape[1], -1),
            ],
            -1,
        )
        x = self.proj(x)
        return x

class Rel_ds(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, Lmax=None):
        pos = x["pos"] if Lmax is None else x["pos"][:, :Lmax]
        time = x["dom_time"] if Lmax is None else x["dom_time"][:, :Lmax]
        ds2 = (pos[:, :, None] - pos[:, None, :]).pow(2).sum(-1) - (
            (time[:, :, None] - time[:, None, :]) * (3e4 / 500 * 3e-1)
        ).pow(2)
        d = torch.sign(ds2) * torch.sqrt(torch.abs(ds2))
        emb = self.emb(1024 * d.clip(-4, 4))
        rel_attn = self.proj(emb)
        return rel_attn, emb

class DeepIceModel(nn.Module):
    def __init__(
        self,
        # dim=384,
        dim_base=128,
        depth=12,
        head_size=32,
        n_heads_rel=12,
        depth_rel=4,
        n_rel=1,
        dim_out=1,
        out_act = nn.Softplus(),
        use_checkpoint=False,
        **kwargs,
    ):
        super().__init__()
        dim = head_size * n_heads_rel

        self.extractor = Extractor(dim_base, dim)
        self.rel_pos = Rel_ds(head_size)
        self.sandwich = nn.ModuleList(
            [Block_rel(dim=dim, num_heads = n_heads_rel) for i in range(depth_rel)]
        )
        self.cls_token = nn.Linear(dim, 1, bias=False)
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=dim,
                    num_heads=dim // head_size,
                    mlp_ratio=4,
                    drop_path=0.0 * (i / (depth - 1)),
                    init_values=1,
                )
                for i in range(depth)
            ]
        )
        self.proj_out = nn.Linear(dim, dim_out)
        self.out_act = out_act
        self.use_checkpoint = use_checkpoint
        self.apply(self._init_weights)
        trunc_normal_(self.cls_token.weight, std=0.02)
        self.n_rel = n_rel

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def init_weights(self, pretrained=None):
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"cls_token"}

    def forward(self, x0):
        mask = x0["mask"]
        Lmax = mask.sum(-1).max()
        x = self.extractor(x0, Lmax)
        rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)

        mask = mask[:, :Lmax]
        B, _ = mask.shape

        attn_mask = torch.zeros(mask.shape).type_as(rel_pos_bias)
        attn_mask[~mask] = -torch.inf

        for i, blk in enumerate(self.sandwich):
            x, _ = blk(x, attn_mask, rel_pos_bias)
            if i + 1 == self.n_rel:
                rel_pos_bias = None

        mask = torch.cat(
            [
                torch.ones(
                    B,
                    1,
                ).type_as(mask),
                mask,
            ],
            1,
        )

        cls_token = self.cls_token.weight.unsqueeze(0).expand(B, -1, -1)

        attn_mask = torch.zeros(mask.shape, dtype=mask.dtype).type_as(cls_token)
        attn_mask[~mask] = -torch.inf

        x = torch.cat([cls_token, x], 1)

        for blk in self.blocks:
            if self.use_checkpoint:
                x, _ = checkpoint.checkpoint(blk, x, None, attn_mask)
            else:
                x, _ = blk(x, None, attn_mask)

        x = self.proj_out(x[:, 0])  # cls token
        x = self.out_act(x)
        return x.squeeze(1)

    def get_attn_weights(self, x0):
        mask = x0["mask"]
        Lmax = mask.sum(-1).max()
        x = self.extractor(x0, Lmax)
        rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)

        mask = mask[:, :Lmax]
        B, _ = mask.shape

        attn_mask = torch.zeros(mask.shape).type_as(rel_pos_bias)
        attn_mask[~mask] = -torch.inf

        attn_weight_rel_list = []
        for i, blk in enumerate(self.sandwich):
            x, attn_weight_rel = blk(x, attn_mask, rel_pos_bias)
            attn_weight_rel_list.append(attn_weight_rel)
            if i + 1 == self.n_rel:
                rel_pos_bias = None

        attn_weight_rel_tensor = torch.stack(attn_weight_rel_list, dim=1)

        mask = torch.cat(
            [
                torch.ones(
                    B,
                    1,
                ).type_as(mask),
                mask,
            ],
            1,
        )

        cls_token = self.cls_token.weight.unsqueeze(0).expand(B, -1, -1)

        attn_mask = torch.zeros(mask.shape, dtype=mask.dtype).type_as(cls_token)
        attn_mask[~mask] = -torch.inf

        x = torch.cat([cls_token, x], 1)

        attn_weight_list = []
        for blk in self.blocks:
            if self.use_checkpoint:
                x, attn_weight = checkpoint.checkpoint(blk, x, None, attn_mask)
                attn_weight_list.append(attn_weight)
            else:
                x, attn_weight = blk(x, None, attn_mask)
                attn_weight_list.append(attn_weight)

        attn_weight_tensor = torch.stack(attn_weight_list, dim=1)

        x = self.proj_out(x[:, 0])  # cls token
        x = self.out_act(x)
        return {
            "mask": attn_mask,
            "attn_weights": attn_weight_tensor,
            "rel_attn_weights": attn_weight_rel_tensor,
        }

    # def forward(self, x0, return_attention_weights=False):
    #     mask = x0["mask"]
    #     Lmax = mask.sum(-1).max()
    #     x = self.extractor(x0, Lmax)
    #     rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)

    #     mask = mask[:, :Lmax]
    #     B, _ = mask.shape
        
    #     attn_mask = torch.zeros(mask.shape).type_as(rel_pos_bias)
    #     attn_mask[~mask] = -torch.inf

    #     if return_attention_weights:
    #         attn_weight_rel_list = []
    #         for i, blk in enumerate(self.sandwich):
    #             x, attn_weight_rel = blk(x, attn_mask, rel_pos_bias)
    #             attn_weight_rel_list.append(attn_weight_rel)
    #             if i + 1 == self.n_rel:
    #                 rel_pos_bias = None
    #         attn_weight_rel_tensor = torch.stack(attn_weight_rel_list, dim=1)
    #     else:
    #         for i, blk in enumerate(self.sandwich):
    #             x, _ = blk(x, attn_mask, rel_pos_bias)
    #             if i + 1 == self.n_rel:
    #                 rel_pos_bias = None

        

    #     mask = torch.cat(
    #         [torch.ones(B, 1,).type_as(mask), mask], 1
    #     )

    #     cls_token = self.cls_token.weight.unsqueeze(0).expand(B, -1, -1)

    #     attn_mask = torch.zeros(mask.shape, dtype=mask.dtype).type_as(cls_token)
    #     attn_mask[~mask] = -torch.inf
        
    #     x = torch.cat([cls_token, x], 1)
        
    #     if return_attention_weights:
    #         attn_weight_list = []
    #         for blk in self.blocks:
    #             if self.use_checkpoint:
    #                 x,attn_weight = checkpoint.checkpoint(blk, x, None, attn_mask)
    #                 attn_weight_list.append(attn_weight)
    #             else:
    #                 x,attn_weight = blk(x, None, attn_mask)
    #                 attn_weight_list.append(attn_weight)
    #         attn_weight_tensor = torch.stack(attn_weight_list, dim=1)
    #     else:
    #         for blk in self.blocks:
    #             if self.use_checkpoint:
    #                 x,_ = checkpoint.checkpoint(blk, x, None, attn_mask)
    #             else:
    #                 x,_ = blk(x, None, attn_mask)
            

    #     x = self.proj_out(x[:, 0])  # cls token
    #     x = self.out_act(x)

    #     if return_attention_weights:
    #         return {
    #             "pred":x.squeeze(1), 
    #             "mask":attn_mask,
    #             "attn_weights": attn_weight_tensor, 
    #             "rel_attn_weights": attn_weight_rel_tensor
    #             }
    #     else: 
    #         return {
    #             "pred":x.squeeze(1),
    #             }

    # def forward(self, x0):
    #     mask = x0["mask"]
    #     Lmax = mask.sum(-1).max()
    #     x = self.extractor(x0, Lmax)
    #     rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)

    #     mask = mask[:, :Lmax]
    #     B, _ = mask.shape
        
    #     attn_mask = torch.zeros(mask.shape).type_as(rel_pos_bias)
    #     attn_mask[~mask] = -torch.inf

    #     attn_weight_rel_list = []
    #     for i, blk in enumerate(self.sandwich):

    #         x, attn_weight_rel = blk(x, attn_mask, rel_pos_bias)
    #         attn_weight_rel_list.append(attn_weight_rel)
    #         if i + 1 == self.n_rel:
    #             rel_pos_bias = None

    #     attn_weight_rel_tensor = torch.stack(attn_weight_rel_list, dim=1)

    #     mask = torch.cat(
    #         [torch.ones(B, 1,).type_as(mask), mask], 1
    #     )

    #     cls_token = self.cls_token.weight.unsqueeze(0).expand(B, -1, -1)

    #     attn_mask = torch.zeros(mask.shape, dtype=mask.dtype).type_as(cls_token)
    #     attn_mask[~mask] = -torch.inf
        
    #     x = torch.cat([cls_token, x], 1)

    #     attn_weight_list = []
    #     for blk in self.blocks:
    #         if self.use_checkpoint:
    #             x,attn_weight = checkpoint.checkpoint(blk, x, None, attn_mask)
    #             attn_weight_list.append(attn_weight)
    #         else:
    #             x,attn_weight = blk(x, None, attn_mask)
    #             attn_weight_list.append(attn_weight)

    #     attn_weight_tensor = torch.stack(attn_weight_list, dim=1)

    #     x = self.proj_out(x[:, 0])  # cls token
    #     x = self.out_act(x)
    #     return {
    #         "pred":x.squeeze(1), 
    #         "mask":attn_mask,
    #         "attn_weights": attn_weight_tensor, 
    #         "rel_attn_weights": attn_weight_rel_tensor
    #         }
    
    # def forward(self, x0):
    #     mask = x0["mask"]
    #     Lmax = mask.sum(-1).max()
    #     x = self.extractor(x0, Lmax)
    #     rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)
    #     # nbs = get_nbs(x0, Lmax)
    #     mask = mask[:, :Lmax]
    #     B, _ = mask.shape

    #     attn_mask = torch.zeros(mask.shape).type_as(
    #         rel_pos_bias
    #     )
    #     attn_mask[~mask] = -torch.inf

    #     for i, blk in enumerate(self.sandwich):
    #         # if isinstance(blk, LocalBlock):
    #         #     x = blk(x, nbs, mask, rel_enc)
    #         # else:
    #         x = blk(x, attn_mask, rel_pos_bias)
    #         if i + 1 == self.n_rel:
    #             rel_pos_bias = None

    #     mask = torch.cat(
    #         [
    #             torch.ones(
    #                 B,
    #                 1,
    #             ).type_as(mask),
    #             mask,
    #         ],
    #         1,
    #     )
    #     cls_token = self.cls_token.weight.unsqueeze(0).expand(B, -1, -1)

    #     attn_mask = torch.zeros(mask.shape, dtype=mask.dtype).type_as(cls_token)
    #     attn_mask[~mask] = -torch.inf

    #     x = torch.cat([cls_token, x], 1)

    #     for blk in self.blocks:
    #         if self.use_checkpoint:
    #             x = checkpoint.checkpoint(blk, x, None, attn_mask)
    #         else:
    #             x = blk(x, None, attn_mask)

    #     x = self.proj_out(x[:, 0])  # cls token
    #     x = self.out_act(x)
    #     return x.squeeze(1)

In [None]:
class DeepIceModel(nn.Module):
    def __init__(
        self,
        # dim=384,
        dim_base=128,
        depth=12,
        head_size=32,
        n_heads_rel=12,
        depth_rel=4,
        n_rel=1,
        dim_out=1,
        out_act = nn.Softplus(),
        use_checkpoint=False,
        **kwargs,
    ):
        super().__init__()
        dim = head_size * n_heads_rel

        self.extractor = Extractor(dim_base, dim)
        self.rel_pos = Rel_ds(head_size)
        self.sandwich = nn.ModuleList(
            [Block_rel(dim=dim, num_heads = n_heads_rel) for i in range(depth_rel)]
        )
        self.cls_token = nn.Linear(dim, 1, bias=False)
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=dim,
                    num_heads=dim // head_size,
                    mlp_ratio=4,
                    drop_path=0.0 * (i / (depth - 1)),
                    init_values=1,
                )
                for i in range(depth)
            ]
        )
        self.proj_out = nn.Linear(dim, dim_out)
        self.out_act = out_act
        self.use_checkpoint = use_checkpoint
        self.apply(self._init_weights)
        trunc_normal_(self.cls_token.weight, std=0.02)
        self.n_rel = n_rel

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def init_weights(self, pretrained=None):
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"cls_token"}

    def forward(self, x0):
        mask = x0["mask"]
        Lmax = mask.sum(-1).max()
        x = self.extractor(x0, Lmax)
        rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)

        mask = mask[:, :Lmax]
        B, _ = mask.shape

        attn_mask = torch.zeros(mask.shape).type_as(rel_pos_bias)
        attn_mask[~mask] = -torch.inf

        for i, blk in enumerate(self.sandwich):
            x, _ = blk(x, attn_mask, rel_pos_bias)
            if i + 1 == self.n_rel:
                rel_pos_bias = None

        mask = torch.cat(
            [
                torch.ones(
                    B,
                    1,
                ).type_as(mask),
                mask,
            ],
            1,
        )

        cls_token = self.cls_token.weight.unsqueeze(0).expand(B, -1, -1)

        attn_mask = torch.zeros(mask.shape, dtype=mask.dtype).type_as(cls_token)
        attn_mask[~mask] = -torch.inf

        x = torch.cat([cls_token, x], 1)

        for blk in self.blocks:
            if self.use_checkpoint:
                x, _ = checkpoint.checkpoint(blk, x, None, attn_mask)
            else:
                x, _ = blk(x, None, attn_mask)

        x = self.proj_out(x[:, 0])  # cls token
        x = self.out_act(x)
        return x.squeeze(1)

    def get_attn_weights(self, x0):
        mask = x0["mask"]
        Lmax = mask.sum(-1).max()
        x = self.extractor(x0, Lmax)
        rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)

        mask = mask[:, :Lmax]
        B, _ = mask.shape

        attn_mask = torch.zeros(mask.shape).type_as(rel_pos_bias)
        attn_mask[~mask] = -torch.inf

        attn_weight_rel_list = []
        for i, blk in enumerate(self.sandwich):
            x, attn_weight_rel = blk(x, attn_mask, rel_pos_bias)
            attn_weight_rel_list.append(attn_weight_rel)
            if i + 1 == self.n_rel:
                rel_pos_bias = None

        attn_weight_rel_tensor = torch.stack(attn_weight_rel_list, dim=1)

        mask = torch.cat(
            [
                torch.ones(
                    B,
                    1,
                ).type_as(mask),
                mask,
            ],
            1,
        )

        cls_token = self.cls_token.weight.unsqueeze(0).expand(B, -1, -1)

        attn_mask = torch.zeros(mask.shape, dtype=mask.dtype).type_as(cls_token)
        attn_mask[~mask] = -torch.inf

        x = torch.cat([cls_token, x], 1)

        attn_weight_list = []
        for blk in self.blocks:
            if self.use_checkpoint:
                x, attn_weight = checkpoint.checkpoint(blk, x, None, attn_mask)
                attn_weight_list.append(attn_weight)
            else:
                x, attn_weight = blk(x, None, attn_mask)
                attn_weight_list.append(attn_weight)

        attn_weight_tensor = torch.stack(attn_weight_list, dim=1)

        x = self.proj_out(x[:, 0])  # cls token
        x = self.out_act(x)
        return {
            "mask": attn_mask,
            "attn_weights": attn_weight_tensor,
            "rel_attn_weights": attn_weight_rel_tensor,
        }

## tests

In [5]:
chunk_csv_train = [
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_1.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_2.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_3.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_4.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_5.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_6.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_7.csv",
]
chunk_csv_test = [
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_1.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_2.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_3.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_4.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_5.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_6.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_7.csv",
]
chunk_csv_val = [
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_1.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_2.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_3.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_4.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_5.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_6.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_7.csv",
]
batch_sizes = [512, 256, 128, 64, 32, 16, 8]
truth_table = "truth"
db_path = "/groups/icecube/petersen/GraphNetDatabaseRepository/Upgrade_Data/sqlite3/dev_step4_upgrade_028_with_noise_dynedge_pulsemap_v3_merger_aftercrash.db"
pulsemap = "SplitInIcePulses_dynedge_v2_Pulses"
input_cols =  ["dom_x", "dom_y", "dom_z", "dom_time", "charge", "rde", "dom_type"]
target_cols = "inelasticity"

dataset = ChunkDataset(
    db_path=db_path, 
    chunk_csvs=chunk_csv_train, 
    pulsemap=pulsemap, 
    truth_table=truth_table, 
    target_cols=target_cols, 
    input_cols=input_cols
    )

dl = DataLoader(
    dataset=dataset,
    collate_fn=collate_fn,
    batch_sampler=ChunkSampler(chunk_csv_train, batch_sizes),
    num_workers=0,
    )

model = DeepIceModel(
        dim=384,
        dim_base=128, #128
        depth=6,
        use_checkpoint=False,
        head_size=32,
        depth_rel=4,
        n_rel=1,
        )

return_attention_weights = True
for i, (features, truth) in enumerate(dl):
    
    # Forward pass: Compute predicted y by passing x to the model
    pred = model(features,)

    print(f"Batch {i + 1}")
    print("Features:", features['charge'].shape)
    # print("Truth:", truth.shape)
    print("Predicted:", pred.shape)
    if return_attention_weights:
      output= model.get_attn_weights(features)
      print("Mask:", output["mask"].shape)
      print("Rel Attention:", output["rel_attn_weights"].shape)
      print("Attention:", output["attn_weights"].shape)
    print()
    if i == 3:
        break

Batch 1
Features: torch.Size([256, 9])
Predicted: torch.Size([256])
Mask: torch.Size([256, 10])
Rel Attention: torch.Size([256, 4, 12, 9, 9])
Attention: torch.Size([256, 6, 12, 10, 10])

Batch 2
Features: torch.Size([256, 9])
Predicted: torch.Size([256])
Mask: torch.Size([256, 10])
Rel Attention: torch.Size([256, 4, 12, 9, 9])
Attention: torch.Size([256, 6, 12, 10, 10])

Batch 3
Features: torch.Size([256, 9])
Predicted: torch.Size([256])
Mask: torch.Size([256, 10])
Rel Attention: torch.Size([256, 4, 12, 9, 9])
Attention: torch.Size([256, 6, 12, 10, 10])

Batch 4
Features: torch.Size([256, 9])
Predicted: torch.Size([256])
Mask: torch.Size([256, 10])
Rel Attention: torch.Size([256, 4, 12, 9, 9])
Attention: torch.Size([256, 6, 12, 10, 10])



In [10]:


# class Extractor(nn.Module):
#     def __init__(self, dim_base=128, dim=384):
#         super().__init__()
#         self.emb = SinusoidalPosEmb(dim=dim_base)
#         self.dom_type_rde_emb = nn.Embedding(5, dim_base // 2)
#         self.emb2 = SinusoidalPosEmb(dim=dim_base // 2)
#         self.proj = nn.Sequential(
#             nn.Linear(6 * dim_base, 6 * dim_base),
#             nn.LayerNorm(6 * dim_base),
#             nn.GELU(),
#             nn.Linear(6 * dim_base, dim),
#         )

#     def forward(self, x, L0=None):
#         Lmax = L0.max() if L0 is not None else max(len(item) for item in x['pos'])
#         pos = x["pos"][:, :Lmax]
#         charge = x["charge"][:, :Lmax]
#         time = x["time"][:, :Lmax]
#         dom_type_rde = x["dom_type_rde"][:, :Lmax]
#         length = torch.log10(L0.to(dtype=pos.dtype))

#         x = torch.cat(
#             [
#                 self.emb(4096 * pos).flatten(-2),
#                 self.emb(1024 * charge),
#                 self.emb(4096 * time),
#                 self.dom_type_rde_emb(dom_type_rde),
#                 self.emb2(length).unsqueeze(1).expand(-1, pos.shape[1], -1),
#             ],
#             -1,
#         )
#         x = self.proj(x)
#         return x



# class Extractor(nn.Module):
#     def __init__(self, dim_base=128, dim=384):
#         super().__init__()
#         self.emb = SinusoidalPosEmb(dim=dim_base)
#         self.aux_emb = nn.Embedding(2, dim_base // 2)
#         self.emb2 = SinusoidalPosEmb(dim=dim_base // 2)
#         self.proj = nn.Sequential(
#             nn.Linear(6 * dim_base, 6 * dim_base),
#             nn.LayerNorm(6 * dim_base),
#             nn.GELU(),
#             nn.Linear(6 * dim_base, dim),
#         )

#     def forward(self, x, Lmax=None):
#         pos = x["pos"] if Lmax is None else x["pos"][:, :Lmax]
#         charge = x["charge"] if Lmax is None else x["charge"][:, :Lmax]
#         time = x["time"] if Lmax is None else x["time"][:, :Lmax]
#         auxiliary = x["auxiliary"] if Lmax is None else x["auxiliary"][:, :Lmax]
#         qe = x["qe"] if Lmax is None else x["qe"][:, :Lmax]
#         ice_properties = (
#             x["ice_properties"] if Lmax is None else x["ice_properties"][:, :Lmax]
#         )
#         length = torch.log10(x["L0"].to(dtype=pos.dtype))

#         x = torch.cat(
#             [
#                 self.emb(4096 * pos).flatten(-2),
#                 self.emb(1024 * charge),
#                 self.emb(4096 * time),
#                 self.aux_emb(auxiliary),
#                 self.emb2(length).unsqueeze(1).expand(-1, pos.shape[1], -1),
#             ],
#             -1,
#         )
#         x = self.proj(x)
#         return x



# class LocalBlock(nn.Module):
#     def __init__(
#         self,
#         dim=192,
#         num_heads=192 // 64,
#         mlp_ratio=4,
#         drop_path=0,
#         init_values=1,
#         **kwargs,
#     ):
#         super().__init__()
#         self.proj_rel_bias = nn.Linear(dim // num_heads, dim // num_heads)
#         self.block = Block_rel(
#             dim=dim,
#             num_heads=num_heads,
#             mlp_ratio=mlp_ratio,
#             drop_path=drop_path,
#             init_values=init_values,
#         )

#     def forward(self, x, nbs, key_padding_mask=None, rel_pos_bias=None):
#         B, Lmax, C = x.shape
#         mask = (
#             key_padding_mask
#             if not (key_padding_mask is None)
#             else torch.ones(B, Lmax, dtype=torch.bool, device=x.deice)
#         )

#         m = torch.gather(mask.unsqueeze(1).expand(-1, Lmax, -1), 2, nbs)
#         attn_mask = torch.zeros(m.shape, device=m.device)
#         attn_mask[~mask] = -torch.inf
#         attn_mask = attn_mask[mask]

#         if rel_pos_bias is not None:
#             rel_pos_bias = torch.gather(
#                 rel_pos_bias,
#                 2,
#                 nbs.unsqueeze(-1).expand(-1, -1, -1, rel_pos_bias.shape[-1]),
#             )
#             rel_pos_bias = rel_pos_bias[mask]
#             rel_pos_bias = self.proj_rel_bias(rel_pos_bias).unsqueeze(1)

#         xl = torch.gather(
#             x.unsqueeze(1).expand(-1, Lmax, -1, -1),
#             2,
#             nbs.unsqueeze(-1).expand(-1, -1, -1, C),
#         )
#         xl = xl[mask]
#         # modify only the node (0th element)
#         # print(xl[:,:1].shape,rel_pos_bias.shape,attn_mask[:,:1].shape,xl.shape)
#         xl = self.block(
#             xl[:, :1],
#             rel_pos_bias=rel_pos_bias,
#             key_padding_mask=attn_mask[:, :1],
#             kv=xl,
#         )
#         x = torch.zeros(x.shape, device=x.device, dtype=xl.dtype)
#         x[mask] = xl.squeeze(1)
#         return x
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

    def extra_repr(self) -> str:
        return "p={}".format(self.drop_prob)


class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        # x = self.drop(x)
        # commit this for the orignal BERT implement
        x = self.fc2(x)
        x = self.drop(x)
        return x


# BEiTv2 block
class Block(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        init_values=None,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        window_size=None,
        attn_head_dim=None,
        **kwargs,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = nn.MultiheadAttention(
            dim, num_heads, dropout=drop, batch_first=True
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

        if init_values is not None:
            self.gamma_1 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
            self.gamma_2 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, attn_mask=None, key_padding_mask=None):
        if self.gamma_1 is None:
            xn = self.norm1(x)
            x_attn, attn_weights = self.attn(
                    xn,
                    xn,
                    xn,
                    attn_mask=attn_mask,
                    key_padding_mask=key_padding_mask,
                    need_weights=True,
                )
            x = x + self.drop_path(x_attn)
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            xn = self.norm1(x)
            x_attn, attn_weights = self.attn(
                    xn,
                    xn,
                    xn,
                    attn_mask=attn_mask,
                    key_padding_mask=key_padding_mask,
                    need_weights=True,
                )
            x = x + self.drop_path(x_attn)
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
            print(attn_weights)
        return x


class Attention_rel(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
        attn_head_dim=None,
    ):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        if attn_head_dim is not None:
            head_dim = attn_head_dim
        all_head_dim = head_dim * self.num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.proj_q = nn.Linear(dim, all_head_dim, bias=False)
        self.proj_k = nn.Linear(dim, all_head_dim, bias=False)
        self.proj_v = nn.Linear(dim, all_head_dim, bias=False)
        if qkv_bias:
            self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
            self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
        else:
            self.q_bias = None
            self.v_bias = None

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(all_head_dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, q, k, v, rel_pos_bias=None, key_padding_mask=None):
        # rel_pos_bias: B L L C/h
        # key_padding_mask - float with -inf
        B, N, C = q.shape
        # qkv_bias = None
        # if self.q_bias is not None:
        #    qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
        # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
        # qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        q = F.linear(input=q, weight=self.proj_q.weight, bias=self.q_bias)
        q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        k = F.linear(input=k, weight=self.proj_k.weight, bias=None)
        k = k.reshape(B, k.shape[1], self.num_heads, -1).permute(0, 2, 1, 3)
        v = F.linear(input=v, weight=self.proj_v.weight, bias=self.v_bias)
        v = v.reshape(B, v.shape[1], self.num_heads, -1).permute(0, 2, 1, 3)

        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        if rel_pos_bias is not None:
            bias = torch.einsum("bhic,bijc->bhij", q, rel_pos_bias)
            attn = attn + bias
        if key_padding_mask is not None:
            assert (
                key_padding_mask.dtype == torch.float32
                or key_padding_mask.dtype == torch.float16
            ), "incorrect mask dtype"
            bias = torch.min(key_padding_mask[:, None, :], key_padding_mask[:, :, None])
            bias[
                torch.max(key_padding_mask[:, None, :], key_padding_mask[:, :, None])
                < 0
            ] = 0
            # print(bias.shape,bias.min(),bias.max())
            attn = attn + bias.unsqueeze(1)

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2)
        if rel_pos_bias is not None:
            x = x + torch.einsum("bhij,bijc->bihc", attn, rel_pos_bias)
        x = x.reshape(B, N, -1)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


# BEiTv2 block
class Block_rel(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        init_values=None,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        window_size=None,
        attn_head_dim=None,
        **kwargs,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_rel(
            dim, num_heads, attn_drop=attn_drop, qkv_bias=qkv_bias
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

        if init_values is not None:
            self.gamma_1 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
            self.gamma_2 = nn.Parameter(
                init_values * torch.ones((dim)), requires_grad=True
            )
        else:
            self.gamma_1, self.gamma_2 = None, None

    def forward(self, x, key_padding_mask=None, rel_pos_bias=None, kv=None):
        if self.gamma_1 is None:
            xn = self.norm1(x)
            kv = xn if kv is None else self.norm1(kv)
            x = x + self.drop_path(
                self.attn(
                    xn,
                    kv,
                    kv,
                    rel_pos_bias=rel_pos_bias,
                    key_padding_mask=key_padding_mask,
                )
            )
            x = x + self.drop_path(self.mlp(self.norm2(x)))
        else:
            xn = self.norm1(x)
            kv = xn if kv is None else self.norm1(kv)
            x = x + self.drop_path(
                self.gamma_1
                * self.drop_path(
                    self.attn(
                        xn,
                        kv,
                        kv,
                        rel_pos_bias=rel_pos_bias,
                        key_padding_mask=key_padding_mask,
                    )
                )
            )
            x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim=16, M=10000):
        super().__init__()
        self.dim = dim
        self.M = M

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(self.M) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * (-emb))
        emb = x[..., None] * emb[None, ...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class Extractor(nn.Module):
    def __init__(self, dim_base=128, dim=384):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim_base)
        self.combined_dom_type_linear = nn.Linear(5, dim_base // 2, bias=False)
        self.emb2 = SinusoidalPosEmb(dim=dim_base // 2)
        self.proj = nn.Sequential(
            nn.Linear(6 * dim_base, 6 * dim_base),
            nn.LayerNorm(6 * dim_base),
            nn.GELU(),
            nn.Linear(6 * dim_base, dim),
        )

    def forward(self, x, L0=None):
        Lmax = L0.max().item() if L0 is not None else max(len(item) for item in x['pos'])
        pos = x["pos"] if Lmax is None else x["pos"][:, :Lmax]
        charge = x["charge"] if Lmax is None else x["charge"][:, :Lmax]
        time = x["time"] if Lmax is None else x["time"][:, :Lmax]

        combined_dom_type = x["combined_dom_type"] if Lmax is None else x["combined_dom_type"][:, :Lmax]
        length = torch.log10(x["L0"].to(dtype=pos.dtype))

        x = torch.cat(
            [
                self.emb(4096 * pos).flatten(-2),
                self.emb(1024 * charge),
                self.emb(4096 * time),
                self.combined_dom_type_linear(combined_dom_type),
                self.emb2(length).unsqueeze(1).expand(-1, pos.shape[1], -1),
            ],
            -1,
        )
        x = self.proj(x)
        return x

# class Extractor(nn.Module):
#     def __init__(self, dim_base=128, dim=384):
#         super().__init__()
#         self.emb = SinusoidalPosEmb(dim=dim_base)
#         self.dom_type_rde_linear = nn.Linear(5, dim_base // 2, bias=False)
#         self.emb2 = SinusoidalPosEmb(dim=dim_base // 2)
#         self.proj = nn.Sequential(
#             nn.Linear(6 * dim_base, 6 * dim_base),
#             nn.LayerNorm(6 * dim_base),
#             nn.GELU(),
#             nn.Linear(6 * dim_base, dim),
#         )

#     def forward(self, x, Lmax=None):
#         pos = x["pos"] if Lmax is None else x["pos"][:, :Lmax]
#         charge = x["charge"] if Lmax is None else x["charge"][:, :Lmax]
#         time = x["dom_time"] if Lmax is None else x["dom_time"][:, :Lmax]
#         combined_dom_type = x["combined_dom_type"] if Lmax is None else x["combined_dom_type"][:, :Lmax]
#         length = torch.log10(x["L0"].to(dtype=pos.dtype))

#         x = torch.cat(
#             [
#                 self.emb(4096 * pos).flatten(-2),
#                 self.emb(1024 * charge),
#                 self.emb(4096 * time),
#                 self.emb2(length).unsqueeze(1).expand(-1, pos.shape[1], -1),
#             ],
#             -1,
#         )
#         x = self.proj(x)
#         return x

# class Extractor(nn.Module):
#     def __init__(self, dim_base=128, dim=384):
#         super().__init__()
#         self.emb = SinusoidalPosEmb(dim=dim_base)
#         self.aux_emb = nn.Embedding(2, dim_base // 2)
#         self.emb2 = SinusoidalPosEmb(dim=dim_base // 2)
#         self.proj = nn.Sequential(
#             nn.Linear(6 * dim_base, 6 * dim_base),
#             nn.LayerNorm(6 * dim_base),
#             nn.GELU(),
#             nn.Linear(6 * dim_base, dim),
#         )

#     def forward(self, x, Lmax=None):
#         pos = x["pos"] if Lmax is None else x["pos"][:, :Lmax]
#         charge = x["charge"] if Lmax is None else x["charge"][:, :Lmax]
#         time = x["time"] if Lmax is None else x["time"][:, :Lmax]
#         auxiliary = x["auxiliary"] if Lmax is None else x["auxiliary"][:, :Lmax]
#         qe = x["qe"] if Lmax is None else x["qe"][:, :Lmax]
#         ice_properties = (
#             x["ice_properties"] if Lmax is None else x["ice_properties"][:, :Lmax]
#         )
#         length = torch.log10(x["L0"].to(dtype=pos.dtype))

#         x = torch.cat(
#             [
#                 self.emb(4096 * pos).flatten(-2),
#                 self.emb(1024 * charge),
#                 self.emb(4096 * time),
#                 self.aux_emb(auxiliary),
#                 self.emb2(length).unsqueeze(1).expand(-1, pos.shape[1], -1),
#             ],
#             -1,
#         )
#         x = self.proj(x)
#         return x


class Rel_ds(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        self.emb = SinusoidalPosEmb(dim=dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, Lmax=None):
        pos = x["pos"] if Lmax is None else x["pos"][:, :Lmax]
        time = x["time"] if Lmax is None else x["time"][:, :Lmax]
        ds2 = (pos[:, :, None] - pos[:, None, :]).pow(2).sum(-1) - (
            (time[:, :, None] - time[:, None, :]) * (3e4 / 500 * 3e-1)
        ).pow(2)
        d = torch.sign(ds2) * torch.sqrt(torch.abs(ds2))
        emb = self.emb(1024 * d.clip(-4, 4))
        rel_attn = self.proj(emb)
        return rel_attn, emb


def get_nbs(x, Lmax=None, K=8):
    pos = x["pos"] if Lmax is None else x["pos"][:, :Lmax]
    mask = x["mask"][:, :Lmax]
    B = pos.shape[0]

    d = -torch.cdist(pos, pos, p=2)
    d -= 100 * (~torch.min(mask[:, None, :], mask[:, :, None]))
    d -= 200 * torch.eye(Lmax, dtype=pos.dtype, device=pos.device).unsqueeze(0)
    nbs = d.topk(K - 1, dim=-1)[1]
    nbs = torch.cat(
        [
            torch.arange(Lmax, dtype=nbs.dtype, device=nbs.device)
            .unsqueeze(0)
            .unsqueeze(-1)
            .expand(B, -1, -1),
            nbs,
        ],
        -1,
    )
    return nbs


class LocalBlock(nn.Module):
    def __init__(
        self,
        dim=192,
        num_heads=192 // 64,
        mlp_ratio=4,
        drop_path=0,
        init_values=1,
        **kwargs,
    ):
        super().__init__()
        self.proj_rel_bias = nn.Linear(dim // num_heads, dim // num_heads)
        self.block = Block_rel(
            dim=dim,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            drop_path=drop_path,
            init_values=init_values,
        )

    def forward(self, x, nbs, key_padding_mask=None, rel_pos_bias=None):
        B, Lmax, C = x.shape
        mask = (
            key_padding_mask
            if not (key_padding_mask is None)
            else torch.ones(B, Lmax, dtype=torch.bool, device=x.deice)
        )

        m = torch.gather(mask.unsqueeze(1).expand(-1, Lmax, -1), 2, nbs)
        attn_mask = torch.zeros(m.shape, device=m.device)
        attn_mask[~mask] = -torch.inf
        attn_mask = attn_mask[mask]

        if rel_pos_bias is not None:
            rel_pos_bias = torch.gather(
                rel_pos_bias,
                2,
                nbs.unsqueeze(-1).expand(-1, -1, -1, rel_pos_bias.shape[-1]),
            )
            rel_pos_bias = rel_pos_bias[mask]
            rel_pos_bias = self.proj_rel_bias(rel_pos_bias).unsqueeze(1)

        xl = torch.gather(
            x.unsqueeze(1).expand(-1, Lmax, -1, -1),
            2,
            nbs.unsqueeze(-1).expand(-1, -1, -1, C),
        )
        xl = xl[mask]
        # modify only the node (0th element)
        # print(xl[:,:1].shape,rel_pos_bias.shape,attn_mask[:,:1].shape,xl.shape)
        xl = self.block(
            xl[:, :1],
            rel_pos_bias=rel_pos_bias,
            key_padding_mask=attn_mask[:, :1],
            kv=xl,
        )
        x = torch.zeros(x.shape, device=x.device, dtype=xl.dtype)
        x[mask] = xl.squeeze(1)
        return x


class DeepIceModel(nn.Module):
    def __init__(
        self,
        dim=384,
        dim_base=128,
        depth=12,
        use_checkpoint=False,
        head_size=32,
        depth_rel=4,
        n_rel=1,
        **kwargs,
    ):
        super().__init__()
        self.extractor = Extractor(dim_base, dim)
        self.rel_pos = Rel_ds(head_size)
        self.sandwich = nn.ModuleList(
            [Block_rel(dim=dim, num_heads=dim // head_size) for i in range(depth_rel)]
        )
        self.cls_token = nn.Linear(dim, 1, bias=False)
        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=dim,
                    num_heads=dim // head_size,
                    mlp_ratio=4,
                    drop_path=0.0 * (i / (depth - 1)),
                    init_values=1,
                )
                for i in range(depth)
            ]
        )
        self.proj_out = nn.Linear(dim, 3)
        self.use_checkpoint = use_checkpoint
        self.apply(self._init_weights)
        trunc_normal_(self.cls_token.weight, std=0.02)
        self.n_rel = n_rel

    def fix_init_weight(self):
        def rescale(param, layer_id):
            param.div_(math.sqrt(2.0 * layer_id))

        for layer_id, layer in enumerate(self.blocks):
            rescale(layer.attn.proj.weight.data, layer_id + 1)
            rescale(layer.mlp.fc2.weight.data, layer_id + 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def init_weights(self, pretrained=None):
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                trunc_normal_(m.weight, std=0.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"cls_token"}

    def forward(self, x0):
        mask = x0["mask"]
        Lmax = mask.sum(-1).max()
        x = self.extractor(x0, Lmax)
        rel_pos_bias, rel_enc = self.rel_pos(x0, Lmax)
        # nbs = get_nbs(x0, Lmax)
        mask = mask[:, :Lmax]
        B, _ = mask.shape
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf

        for i, blk in enumerate(self.sandwich):
            if isinstance(blk, LocalBlock):
                x = blk(x, nbs, mask, rel_enc)
            else:
                x = blk(x, attn_mask, rel_pos_bias)
                if i + 1 == self.n_rel:
                    rel_pos_bias = None

        mask = torch.cat(
            [torch.ones(B, 1, dtype=mask.dtype, device=mask.device), mask], 1
        )
        attn_mask = torch.zeros(mask.shape, device=mask.device)
        attn_mask[~mask] = -torch.inf
        cls_token = self.cls_token.weight.unsqueeze(0).expand(B, -1, -1)
        x = torch.cat([cls_token, x], 1)

        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x, None, attn_mask)
            else:
                x = blk(x, None, attn_mask)

        x = self.proj_out(x[:, 0])  # cls token
        return x

In [13]:
chunk_csv_train = [
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_1.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_2.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_3.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_4.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_5.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_6.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/train/output_7.csv",
]
chunk_csv_test = [
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_1.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_2.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_3.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_4.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_5.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_6.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/test/output_7.csv",
]
chunk_csv_val = [
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_1.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_2.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_3.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_4.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_5.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_6.csv",
  "/groups/icecube/moust/storage/cached_event_no/upgrade_numu/val/output_7.csv",
]

batch_sizes = [512, 256, 128, 64, 32, 16, 8]
truth_table = "truth"
db_path = "/groups/icecube/petersen/GraphNetDatabaseRepository/Upgrade_Data/sqlite3/dev_step4_upgrade_028_with_noise_dynedge_pulsemap_v3_merger_aftercrash.db"
pulsemap = "SplitInIcePulses_dynedge_v2_Pulses"
input_cols =  ["dom_x", "dom_y", "dom_z", "dom_time", "charge"]
target_cols = "inelasticity"

# class ChunkDataset(Dataset):
#     """
#     PyTorch dataset for loading chunked data from an SQLite database.
#     This dataset retrieves pulsemap and truth data for each event from the database.

#     Args:
#         db_filename (str): Filename of the SQLite database.
#         csv_filenames (list of str): List of CSV filenames containing event numbers.
#         pulsemap_table (str): Name of the table containing pulsemap data.
#         truth_table (str): Name of the table containing truth data.
#         truth_variable (str): Name of the variable to query from the truth table.
#         feature_variables (list of str): List of variable names to query from the pulsemap table.
#     """

#     def __init__(
#         self,
#         db_path: str,
#         chunk_csvs: List[str],
#         pulsemap: str,
#         truth_table: str,
#         target_cols: str,
#         input_cols: List[str]
#     ) -> None:
#         self.conn = sqlite3.connect(db_path)  # Connect to the SQLite database
#         self.c = self.conn.cursor()
#         self.event_nos = []
#         for csv_filename in chunk_csvs:
#             df = pd.read_csv(csv_filename)
#             self.event_nos.extend(df['event_no'].tolist())  # Collect event numbers from CSV files
#         self.pulsemap = pulsemap  # Name of the table containing pulsemap data
#         self.truth_table = truth_table  # Name of the table containing truth data
#         self.target_cols = target_cols  # Name of the variable to query from the truth table
#         self.input_cols = input_cols  # List of variable names to query from the pulsemap table


#     def __len__(self) -> int:
#         return len(self.event_nos)

#     def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
#         # print(idx)
#         event_no = idx # self.event_nos[idx]
#         # print(event_no)
#         # Query the truth variable for the given event number
#         self.c.execute(f"SELECT {self.target_cols} FROM {self.truth_table} WHERE event_no = ?", (event_no,))
#         truth_value = self.c.fetchone()[0]
#         input_query = ', '.join(self.input_cols)
#         # Query the feature variables from the pulsemap table for the given event number
#         self.c.execute(f"SELECT {input_query} FROM {self.pulsemap} WHERE event_no = ?", (event_no,))
#         pulsemap_data = self.c.fetchall()
#         return torch.tensor(truth_value, dtype=torch.float32), torch.tensor(pulsemap_data, dtype=torch.float32)
    
#     def close_connection(self) -> None:
#         self.conn.close()


# def get_nbs(x, Lmax=None, K=8):
#     pos = x["pos"] if Lmax is None else x["pos"][:, :Lmax]
#     mask = x["mask"][:, :Lmax]
#     B = pos.shape[0]

#     d = -torch.cdist(pos, pos, p=2)
#     d -= 100 * (~torch.min(mask[:, None, :], mask[:, :, None]))
#     d -= 200 * torch.eye(Lmax, dtype=pos.dtype, device=pos.device).unsqueeze(0)
#     nbs = d.topk(K - 1, dim=-1)[1]
#     nbs = torch.cat(
#         [
#             torch.arange(Lmax, dtype=nbs.dtype, device=nbs.device)
#             .unsqueeze(0)
#             .unsqueeze(-1)
#             .expand(B, -1, -1),
#             nbs,
#         ],
#         -1,
#     )
#     return nbs













# print()
# print("Model test")
# print()
model = DeepIceModel(
        dim=384,
        dim_base=128, #128
        depth=12,
        use_checkpoint=False,
        head_size=32,
        depth_rel=4,
        n_rel=1,
        )


for i, (features, truth) in enumerate(dl):
    
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(features)

    print(f"Batch {i + 1}")
    print("Features:", features['charge'].shape)
    print("Truth:", truth.shape)
    print("Predicted:", y_pred.shape)

    if i == 60:
        break
    

KeyError: 'time'

In [19]:
input_cols =  ["dom_x", "dom_y", "dom_z", "dom_time", "charge", "rde", "dom_type"]
combine_dom_types_and_rde = ["rde", "dom_type"]
posistion_cols = ["dom_x", "dom_y", "dom_z"]
list1 = [1, 2, 3, 4, 5, 6, 7]


In [33]:
class ChunkDataset(Dataset):
    """
    PyTorch dataset for loading chunked data from an SQLite database.
    This dataset retrieves pulsemap and truth data for each event from the database.

    Args:
        db_filename (str): Filename of the SQLite database.
        csv_filenames (list of str): List of CSV filenames containing event numbers.
        pulsemap_table (str): Name of the table containing pulsemap data.
        truth_table (str): Name of the table containing truth data.
        truth_variable (str): Name of the variable to query from the truth table.
        feature_variables (list of str): List of variable names to query from the pulsemap table.
    """

    def __init__(
        self,
        db_path: str,
        chunk_csvs: List[str],
        pulsemap: str,
        truth_table: str,
        target_cols: str,
        input_cols: List[str]
    ) -> None:
        self.conn = sqlite3.connect(db_path)  # Connect to the SQLite database
        self.c = self.conn.cursor()
        self.event_nos = []
        for csv_filename in chunk_csvs:
            df = pd.read_csv(csv_filename)
            self.event_nos.extend(df['event_no'].tolist())  # Collect event numbers from CSV files
        self.pulsemap = pulsemap  # Name of the table containing pulsemap data
        self.truth_table = truth_table  # Name of the table containing truth data
        self.target_cols = target_cols  # Name of the variable to query from the truth table
        self.input_cols = input_cols  # List of variable names to query from the pulsemap table


    def __len__(self) -> int:
        return len(self.event_nos)
    
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        event_no = idx # self.event_nos[idx]

        # Query the truth variable for the given event number
        self.c.execute(f"SELECT {self.target_cols} FROM {self.truth_table} WHERE event_no = ?", (event_no,))
        truth_value = self.c.fetchone()[0]
        

        pos_cols = ['dom_x', 'dom_y', 'dom_z']

        rde_index = self.input_cols.index('rde')
        dom_type_index = self.input_cols.index('dom_type')
        pos_indices = [self.input_cols.index(col) for col in pos_cols]
        rest_indices = [i for i in range(len(self.input_cols)) if i not in [rde_index, dom_type_index] + pos_indices]

        input_query = ', '.join(self.input_cols)
        # print(input_query)
        # Query the feature variables from the pulsemap table for the given event number
        self.c.execute(f"SELECT {input_query} FROM {self.pulsemap} WHERE event_no = ?", (event_no,))
        pulsemap_data_rows = self.c.fetchall()
    
        # Convert pulsemap_data_rows into a dictionary of tensors
        
        pulsemap_data = {self.input_cols[i]: torch.tensor( [row[i] for row in pulsemap_data_rows], dtype=torch.float32)
                        for i in rest_indices}
        
        # Get the necessary data for combined_dom_type and pos
        dom_type_data = [row[dom_type_index] for row in pulsemap_data_rows] #[row[i] for row in pulsemap_data_rows for i in combined_dom_type_indices]
        rde_data = [row[rde_index] for row in pulsemap_data_rows]  #[row[i] for row in pulsemap_data_rows for i in combined_dom_type_indices]
        pos_data = [[row[i] for row in pulsemap_data_rows] for i in pos_indices]

        pulsemap_data["combined_dom_type"] = combine_dom_types_and_rde(torch.tensor(dom_type_data, dtype=torch.float32),
                                                                    torch.tensor(rde_data, dtype=torch.float32))
        pulsemap_data["pos"] = torch.stack([torch.tensor(col_data, dtype=torch.float32) for col_data in pos_data], dim=1)
        pulsemap_data["L0"] = torch.tensor(len(pulsemap_data_rows), dtype=torch.float32)

        return pulsemap_data, torch.tensor(truth_value, dtype=torch.float32)
    

    # def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
    #     # print(idx)
    #     event_no = idx # self.event_nos[idx]
    #     # print(event_no)
    #     # Query the truth variable for the given event number
    #     self.c.execute(f"SELECT {self.target_cols} FROM {self.truth_table} WHERE event_no = ?", (event_no,))
    #     truth_value = self.c.fetchone()[0]
    #     input_query = ', '.join(self.input_cols)
    #     # Query the feature variables from the pulsemap table for the given event number
    #     self.c.execute(f"SELECT {input_query} FROM {self.pulsemap} WHERE event_no = ?", (event_no,))
    #     pulsemap_data = self.c.fetchall()
    #     # Convert pulsemap_data into a dictionary of tensors
    #     pulsemap_data = {col: torch.tensor([row[i] for row in pulsemap_data], dtype=torch.float32)
    #                      for i, col in enumerate(self.input_cols)}
    #     pulsemap_data["combined_dom_type"] = combine_dom_types_and_rde(pulsemap_data["dom_type"], pulsemap_data["rde"])
    #     pulsemap_data["pos"] = torch.stack([pulsemap_data["dom_x"], pulsemap_data["dom_y"], pulsemap_data["dom_z"]], dim=1) 
    #     del pulsemap_data['dom_x']
    #     del pulsemap_data['dom_y']
    #     del pulsemap_data['dom_z']
    #     del pulsemap_data['dom_type']
    #     del pulsemap_data['rde']
    #     return pulsemap_data, torch.tensor(truth_value, dtype=torch.float32)
    # {'truth': torch.tensor(truth_value, dtype=torch.float32), 'pulsemap': pulsemap_data}
        
    def close_connection(self) -> None:
        self.conn.close()

def combine_dom_types_and_rde( dom_type, rde):
        # pDom dom with low efficiency
        pdom_low_qe = ((dom_type == 20) & (rde == 1)).float().unsqueeze(-1)
        # pDOM dom with high efficiency
        pdom_high_qe = ((dom_type == 20) & (rde == 1.35)).float().unsqueeze(-1)
        # pDOM upgrade == 110
        pdom_upgrade = (dom_type == 110).float().unsqueeze(-1)
        # D-EGG == 120
        d_egg = (dom_type == 120).float().unsqueeze(-1)
        # mDOM == 130
        mdom = (dom_type == 130).float().unsqueeze(-1)

        return torch.cat([pdom_low_qe, pdom_high_qe, pdom_upgrade, d_egg, mdom], dim=-1)


class ChunkSampler(Sampler):
    """
    PyTorch sampler for creating chunks from event numbers.

    Args:
        csv_filenames (List[str]): List of CSV filenames containing event numbers.
        batch_sizes (List[int]): List of batch sizes for each CSV file.
    """

    def __init__(
        self, 
        chunk_csvs: List[str], 
        batch_sizes: List[int]
    ) -> None:
        self.event_nos = []
        for csv_filename, batch_size in zip(chunk_csvs, batch_sizes):
            event_nos = pd.read_csv(csv_filename)['event_no'].tolist()
            self.event_nos.extend([event_nos[i:i + batch_size] for i in range(0, len(event_nos), batch_size)])

    def __iter__(self) -> Iterator:
        return iter(self.event_nos)

    def __len__(self) -> int:
        return len(self.event_nos)
    
# def collate_fn(data: List[Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
#     truths = [item['truth'] for item in data]
#     pulsemap_data = [item['pulsemap'] for item in data]
#     # Pad each feature separately
#     max_length = max(len(item['pos']) for item in pulsemap_data)
#     for item in pulsemap_data:
#         for feature, tensor in item.items():
#             pad_length = max_length - len(tensor)
#             item[feature] = F.pad(tensor, (0, pad_length))
#     # Create the mask
#     mask = torch.stack([torch.cat([torch.ones(len(item['pos']), dtype=torch.bool),
#                                    torch.zeros(max_length - len(item['pos']), dtype=torch.bool)]) for item in pulsemap_data])
#     return {'truth': torch.stack(truths), 'pulsemap': pulsemap_data, 'mask': mask}
def collate_fn(data: List[Tuple[torch.Tensor, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    truths, pulsemap_data = zip(*data)

    for item in pulsemap_data:
        item['pos'] = torch.stack([item['dom_x'], item['dom_y'], item['dom_z']], dim=-1)
        item['dom_type_rde'] = combine_dom_types_and_efficiency(item['dom_type'], item['rde'])

    max_length = max(len(item['pos']) for item in pulsemap_data)

    pos_data = torch.zeros(len(pulsemap_data), max_length, 3)
    charge_data = torch.zeros(len(pulsemap_data), max_length)
    time_data = torch.zeros(len(pulsemap_data), max_length)
    dom_type_rde_data = torch.zeros(len(pulsemap_data), max_length, 5)
    mask = torch.zeros(len(pulsemap_data), max_length, dtype=torch.bool)

    for i, item in enumerate(pulsemap_data):
        pos_data[i, :len(item['pos'])] = item['pos']
        charge_data[i, :len(item['charge'])] = item['charge']
        time_data[i, :len(item['time'])] = item['time']
        dom_type_rde_data[i, :len(item['dom_type_rde'])] = item['dom_type_rde']
        mask[i, :len(item['pos'])] = 1

    return {
        'pulsemap': {
            'pos': pos_data,
            'charge': charge_data,
            'time': time_data,
            'dom_type_rde': dom_type_rde_data
        },
        'mask': mask,
        'truth': torch.stack(truths)
    }

# def combine_dom_types_and_efficiency(dom_type, rde):
#     print(f"Dom_type tensor: {dom_type}, RDE tensor: {rde}") 
#     # Normal dom with low efficiency
#     type1 = ((dom_type == 20) & (rde == 0)).float().unsqueeze(-1)
#     # Normal dom with high efficiency
#     type2 = ((dom_type == 20) & (rde == 1)).float().unsqueeze(-1)
#     # dom_type == 110
#     type3 = (dom_type == 110).float().unsqueeze(-1)
#     # dom_type == 120
#     type4 = (dom_type == 120).float().unsqueeze(-1)
#     # dom_type == 130
#     type5 = (dom_type == 130).float().unsqueeze(-1)

#     return torch.cat([type1, type2, type3, type4, type5], dim=-1)

input_cols =  ["dom_x", "dom_y", "dom_z", "dom_time", "charge", "rde", "dom_type"]

dataset = ChunkDataset(
    db_path=db_path, 
    chunk_csvs=chunk_csv_train, 
    pulsemap=pulsemap, 
    truth_table=truth_table, 
    target_cols=target_cols, 
    input_cols=input_cols
    )

# dl = DataLoader(
#     dataset=dataset,
#     collate_fn=collate_fn,
#     batch_sampler=ChunkSampler(chunk_csv_train, batch_sizes),
#     num_workers=12,
# )
# dl = DataLoader(
#     dataset=dataset,
#     collate_fn=collate_fn,
#     batch_sampler=ChunkSampler(chunk_csv_train, batch_sizes),
#     num_workers=12,
# )

# for i, (features, truth) in enumerate(dataset):
#     print(i)
#     print(features, truth)
#     print(features['dom_x'])
#     if i >= 10:
#         break

# # Iterate over the first 10 batches
# for i, batch in enumerate(dl):
#     if i >= 10:
#         break
#     features = batch['pulsemap']
#     mask = batch['mask']
#     truth = batch['truth']
#     print(i, features, mask, truth)

features, truth = dataset[6]
# print(features, truth)
print(features['L0'].max().item() if features['L0'] is not None else max(len(item) for item in x['pos']))
# print(features['pos'].shape)
print()

24.0



In [18]:


# Instantiate the extractor
extractor = Extractor()

# Create a batch of data
batch_size = 16
max_length = 100
batch = {
    'pulsemap': {
        'pos': torch.randn(batch_size, max_length, 3),
        'charge': torch.randn(batch_size, max_length),
        'time': torch.randn(batch_size, max_length),
        'dom_type_rde': torch.randint(0, 2, (batch_size, max_length, 5)).float()
    },
    'mask': torch.ones(batch_size, max_length, dtype=torch.bool),
    'L0': torch.tensor([max_length]*batch_size),  # adding the sequence length
    'truth': torch.randn(batch_size, 3)
}

# Feed the batch to the extractor
output = extractor(batch['pulsemap'], batch['L0'])

# Print the shape of the output
print("Output shape:", output.shape)





Output shape: torch.Size([16, 100, 384])


In [29]:
batch_size = 32
max_length = 10
batch = {
    'pulsemap': {
        'pos': torch.randn(batch_size, max_length, 3),
        'charge': torch.randn(batch_size, max_length),
        'time': torch.randn(batch_size, max_length),
        'dom_type_rde': torch.randn(batch_size, max_length, 5)
    },
    'L0': torch.tensor([max_length] * batch_size),
    'truth': torch.randn(batch_size, 3)
}

# Create the extractor
extractor = Extractor()

# Feed the batch to the extractor
output = extractor(batch['pulsemap'], batch['L0'])

# Print the shape of the output
print("Output shape:", output.shape)


Output shape: torch.Size([32, 10, 384])
