In [28]:
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from typing import List, Dict
from enum import Enum
from datetime import datetime
import logging
# 20 sec

In [29]:
# import this

The Zen of Python, by Tim Peters  
  
Beautiful is better than ugly.  
Explicit is better than implicit.  
Simple is better than complex.  
Complex is better than complicated.  
Flat is better than nested.  
Sparse is better than dense.  
Readability counts.  
Special cases aren't special enough to break the rules.  
Although practicality beats purity.  
Errors should never pass silently.  
Unless explicitly silenced.  
In the face of ambiguity, refuse the temptation to guess.  
There should be one-- and preferably only one --obvious way to do it.  
Although that way may not be obvious at first unless you're Dutch.  
Now is better than never.  
Although never is often better than *right* now.  
If the implementation is hard to explain, it's a bad idea.  
If the implementation is easy to explain, it may be a good idea.  
Namespaces are one honking great idea -- let's do more of those!

In [30]:
from concurrent.futures import ThreadPoolExecutor

In [31]:
import wandb

In [32]:
import sqlite3 as sql
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
import pyarrow.compute as pc

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
# import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningDataModule, LightningModule
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# fails after 1min 10 sec and reruns within 10 sec

In [34]:
sys.path.append('/groups/icecube/cyan/Utils')
from PlotUtils import setMplParam, getColour, getHistoParam 
# getHistoParam:
# Nbins, binwidth, bins, counts, bin_centers  = 
from DB_lister import list_content, list_tables
from ExternalFunctions import nice_string_output, add_text_to_ax
setMplParam()

In [35]:
class EnergyRange(Enum):
    ER_100_GEV_10_TEV = (0, ["22010", "22013", "22016"])
    ER_10_TEV_1_PEV   = (1, ["22011", "22014", "22017"])
    ER_1_PEV_100_PEV  = (2, ["22012", "22015", "22018"])

    def __init__(self, value, subdirs):
        self._value_ = value
        self._subdirs = subdirs

    def get_subdirs(self):
        return self._subdirs

In [36]:
root_dir = "/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied/Snowstorm/"
NuE_PeV_root = root_dir + "22015/"
NuMu_PeV_root = root_dir + "22012/"
NuTau_PeV_root = root_dir + "22018/"

truth_NuE_PeV_1 = NuE_PeV_root + "truth_1.parquet"
truth_NuMu_PeV_1 = NuMu_PeV_root + "truth_1.parquet"
truth_NuTau_PeV_1 = NuTau_PeV_root + "truth_1.parquet"

PMTfied_NuE_PeV_1 = NuE_PeV_root + "1/"
PMTfied_NuE_PeV_1_1 = PMTfied_NuE_PeV_1 + "PMTfied_1.parquet"

PMTfied_NuMu_PeV_1 = NuMu_PeV_root + "1/"
PMTfied_NuMu_PeV_1_1 = PMTfied_NuMu_PeV_1 + "PMTfied_1.parquet"

PMTfied_NuTau_PeV_1 = NuTau_PeV_root + "1/"
PMTfied_NuTau_PeV_1_1 = PMTfied_NuTau_PeV_1 + "PMTfied_1.parquet"

In [37]:
PMTfied_NuE_PeV_1_1

'/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied/Snowstorm/22015/1/PMTfied_1.parquet'

In [38]:
def get_files_in_dir(directory, extension='.parquet'):
    return [f for f in os.listdir(directory) if f.endswith(extension)]
def get_subdir_in_dir(directory):
    return [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]

In [39]:
get_subdir_in_dir(NuE_PeV_root)

['4', '11', '9', '3', '7', '12', '2', '8', '5', '10', '1', '6']

In [40]:
def convertParquetToDF(file:str) -> pd.DataFrame:
    table = pq.read_table(file)
    df = table.to_pandas()
    return df

In [41]:
df_PMTfied_NuE_PeV_1_1 = convertParquetToDF(PMTfied_NuE_PeV_1_1)
# 16 sec
# 695403 rows × 24 columns

In [42]:
def isClean(df:pd.DataFrame) -> bool:
    return not df.isna().any().any()

In [43]:
print(isClean(df_PMTfied_NuE_PeV_1_1))

True


In [44]:
class PseudoNormaliser:
    def __init__(self):
        self.position_scaler = 2e-3  # 1/500
        self.t_scaler = 3e-4         # 1/30000
        self.t_shifter = 1e4         # (-) 10000
        self.Q_shifter = 2           # (-) 2 in log10

    def __call__(self, table: pa.Table) -> pa.Table:
        """
        Apply the normalisation steps to the given PyArrow table.
        """
        table = self._log10_charge(table)
        table = self._pseudo_normalise_dom_pos(table)
        table = self._pseudo_normalise_time(table)
        return table

    def _log10_charge(self, table: pa.Table) -> pa.Table:
        """
        Apply log10 transformation and shift on charge-related columns.
        """
        q_columns = ['q1', 'q2', 'q3', 'q4', 'q5', 'Q25', 'Q75', 'Qtotal']
        for col in q_columns:
            if col in table.column_names:
                col_array = table[col].to_pandas()
                new_col = np.where(col_array > 0, np.log10(col_array), 0) - self.Q_shifter
                idx = table.column_names.index(col)
                table = table.set_column(idx, col, pa.array(new_col))
        return table

    def _pseudo_normalise_dom_pos(self, table: pa.Table) -> pa.Table:
        """
        Apply scaling to DOM position columns.
        """
        pos_columns = ['dom_x', 'dom_y', 'dom_z', 'dom_x_rel', 'dom_y_rel', 'dom_z_rel']
        for col in pos_columns:
            if col in table.column_names:
                new_col = table[col].to_pandas() * self.position_scaler
                idx = table.column_names.index(col)
                table = table.set_column(idx, col, pa.array(new_col))
        return table

    def _pseudo_normalise_time(self, table: pa.Table) -> pa.Table:
        """
        Apply shifting and scaling to time-related columns.
        """
        t_columns = ['t1', 't2', 't3', 'T10', 'T50', 'sigmaT']
        t_columns_shift = ['t1', 't2', 't3']

        # Time shifting
        for col in t_columns_shift:
            if col in table.column_names:
                shifted = table[col].to_pandas() - self.t_shifter
                idx = table.column_names.index(col)
                table = table.set_column(idx, col, pa.array(shifted))

        # Time scaling
        for col in t_columns:
            if col in table.column_names:
                scaled = table[col].to_pandas() * self.t_scaler
                idx = table.column_names.index(col)
                table = table.set_column(idx, col, pa.array(scaled))

        return table


In [45]:
os.path.basename(os.path.normpath(NuMu_PeV_root))

'22012'

In [46]:
class MaxNDOMFinder:
    def __init__(self, root_dir: str, energy_band: EnergyRange, part: int = None, shard: int = None, verbosity: int = 0):
        self.root_dir = root_dir
        self.energy_band = energy_band
        self.part = part
        self.shard = shard
        self.verbosity = verbosity
        self.subdirectories = self.energy_band.get_subdirs()

    def __call__(self) -> int:
        max_n_doms_list = [self._get_max_n_doms_for_subdirectory(subdir) for subdir in self.subdirectories]
        max_n_doms_list = [value for value in max_n_doms_list if value is not None]

        global_max_n_doms = max(max_n_doms_list, default=0)
        if self.verbosity > 0:
            print(f"Global max_n_doms across all data: {global_max_n_doms}")
        return global_max_n_doms

    def _get_max_n_doms_for_subdirectory(self, subdirectory: str) -> int:
        if self.part is not None: # only specific parts
            part_path = os.path.join(self.root_dir, subdirectory, str(self.part))
            truth_path = os.path.join(self.root_dir, subdirectory, f"truth_{self.part}.parquet")

            return self._get_max_n_doms_for_part(part_path, truth_path)
        else: # across all parts in the subdirectory
            return self._get_max_n_doms_for_entire_subdirectory(subdirectory)

    def _get_max_n_doms_for_part(self, part_path: str, truth_path: str) -> int:
        # Get the maximum `n_doms` across all shards in a part.
        if not os.path.exists(truth_path):
            if self.verbosity > 0:
                print(f"Truth file missing for {truth_path}. Skipping.")
            return None

        truth_data = pq.read_table(truth_path)

        if self.shard is not None: # only specific shards
            
            shard_filter = self._filter_shard_data(truth_data)
            return self._compute_max_n_doms(shard_filter)
        else: # across all shards in the part
            shard_files = [
                f for f in os.listdir(part_path) if f.startswith("PMTfied_") and f.endswith(".parquet")
            ]
            max_n_doms_list = []
            for shard_file in shard_files:
                shard_no = int(shard_file.split("_")[1].split(".")[0])
                self.shard = shard_no
                max_n_doms_list.append(self._get_max_n_doms_for_shard(truth_data))
            return max(max_n_doms_list, default=None)

    def _get_max_n_doms_for_shard(self, truth_data: pa.Table) -> int:
        shard_filter = self._filter_shard_data(truth_data)
        return self._compute_max_n_doms(shard_filter)

    def _get_max_n_doms_for_entire_subdirectory(self, subdirectory: str) -> int:
        # Get the maximum `n_doms` across all parts and shards in a subdirectory.
        part_dirs = [
            d for d in os.listdir(os.path.join(self.root_dir, subdirectory)) 
            if os.path.isdir(os.path.join(self.root_dir, subdirectory, d)) and d.isdigit()
        ]

        max_n_doms_list = []
        for part in part_dirs:
            self.part = int(part)
            part_path = os.path.join(self.root_dir, subdirectory, part)
            truth_path = os.path.join(self.root_dir, subdirectory, f"truth_{part}.parquet")
            max_n_doms_list.append(self._get_max_n_doms_for_part(part_path, truth_path))

        return max(max_n_doms_list, default=None)

    def _filter_shard_data(self, truth_data: pa.Table) -> pa.Table:
        shard_mask = pc.equal(truth_data.column("shard_no"), self.shard)
        return truth_data.filter(shard_mask)

    def _compute_max_n_doms(self, shard_filter: pa.Table) -> int:
        return shard_filter.column("N_doms").combine_chunks().to_numpy().max()


In [47]:
maxNDOMFinder_PeV_1 = MaxNDOMFinder(
    root_dir=root_dir,
    energy_band=EnergyRange.ER_1_PEV_100_PEV,
    part=1,
    verbosity=1
)

global_max_n_doms = maxNDOMFinder_PeV_1()
print(f"Global max_n_doms: {global_max_n_doms}")


Global max_n_doms across all data: 2674
Global max_n_doms: 2674


In [53]:
class DatasetMonoFlavourShard(Dataset):
    def __init__(self, root_dir: str, 
                 subdirectory_no: int,
                 part: int, 
                 shard: int, 
                 max_n_doms: int,
                 verbosity: int = 0,
                 ):
        """
        Args:
            root_dir (str): The root directory of the flavour.
            part (int): The part of the dataset.
            shard (int): The shard number.
            verbosity (int): The verbosity level.
        """
        self.root_dir = root_dir
        self.subdirectory_no = subdirectory_no
        self.part = part
        self.shard = shard
        self.verbosity = verbosity
        self.max_n_doms = max_n_doms
        self.transform = PseudoNormaliser()
 
        self.feature_file = os.path.join(self.root_dir, f"{self.subdirectory_no}", f"{self.part}", f"PMTfied_{self.shard}.parquet")
        self.truth_file = os.path.join(self.root_dir, f"{self.subdirectory_no}", f"truth_{self.part}.parquet")

        self.truth_data = self._load_truth_data()
        self.feature_data = self._load_feature_data()

        if verbosity > 0:
            self._show_info()

    def __len__(self):
        return len(self.truth_data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Retrieve truth data
        truth_row = self.truth_data.slice(idx, 1)
        event_no = truth_row.column("event_no").to_pylist()[0]
        original_event_no = truth_row.column("original_event_no").to_pylist()[0]
        offset = truth_row.column("offset").to_pylist()[0]
        n_doms = truth_row.column("N_doms").to_pylist()[0]
        flavour = truth_row.column("flavour").to_pylist()[0]

        # Extract and pad features
        features = self._extract_features(offset, n_doms)
        features_padded = np.zeros((self.max_n_doms, features.shape[1]), dtype=np.float32)
        features_padded[:features.shape[0], :] = features

        # Create the mask
        mask = np.zeros((self.max_n_doms,), dtype=np.float32)
        mask[:features.shape[0]] = 1.0

        # Convert to tensors
        event_no_tensor = torch.tensor([event_no, original_event_no], dtype=torch.int64)
        features_tensor = torch.tensor(features_padded, dtype=torch.float32)
        target_tensor = torch.tensor(flavour, dtype=torch.int64)
        mask_tensor = torch.tensor(mask, dtype=torch.float32)

        return {
            "event_no": event_no_tensor,
            "features": features_tensor,
            "target": target_tensor,
            "mask": mask_tensor,
        }

    def _load_feature_data(self):
        table = pq.read_table(self.feature_file)
        table = self.transform(table)
        return table
    
    def _load_truth_data(self):
        """
        Load and filter the truth data for the specific shard.
        Dynamically create the 'flavour' column if it is missing.
        """
        # Read the truth data
        truth_table = pq.read_table(self.truth_file)

        # Filter rows matching the shard number
        shard_mask = pc.equal(truth_table.column("shard_no").combine_chunks(), self.shard)
        shard_filter = truth_table.filter(shard_mask)

        # Check if 'flavour' column exists; if not, create it
        if 'flavour' not in shard_filter.column_names:
            if 'pid' not in shard_filter.column_names:
                raise ValueError("The truth data is missing both 'flavour' and 'pid' columns. Cannot determine flavours.")

            # Define PID to flavour mapping
            UNKNOWN_FLAVOUR = -1
            pid_to_class = {
                12: 0,   # NuE
                -12: 0,  # NuE
                14: 1,   # NuMu
                -14: 1,  # NuMu
                16: 2,   # NuTau
                -16: 2,  # NuTau
            }

            # Create 'flavour' column based on 'pid'
            pid_column = shard_filter.column("pid").combine_chunks().to_numpy()
            flavour_array = [
                pid_to_class.get(pid, UNKNOWN_FLAVOUR) 
                for pid in pid_column
            ]

            # Convert to PyArrow Array and append as 'flavour'
            flavour_arrow_array = pa.array(flavour_array, type=pa.int64())
            shard_filter = shard_filter.append_column("flavour", flavour_arrow_array)

        # Validate 'flavour' values
        flavour_column = shard_filter.column("flavour").combine_chunks().to_numpy()
        if not np.all(np.isin(flavour_column, [0, 1, 2])):
            raise ValueError("The 'flavour' column contains invalid values. Expected 0, 1, or 2.")

        return shard_filter

    def _extract_features(self, offset, n_rows):
        """
        Extract a specific slice of features based on offset and number of rows using PyArrow.
        """
        features_slice = self.feature_data.slice(offset, n_rows)

        # Drop columns "event_no" and "original_event_no" if present
        columns_to_keep = [
            col for col in features_slice.column_names if col not in ["event_no", "original_event_no"]
        ]
        features_slice = features_slice.select(columns_to_keep)

        # Convert PyArrow table to a NumPy array efficiently
        features_slice = np.stack([features_slice.column(col).to_numpy() for col in columns_to_keep], axis=1)
        return features_slice

    def _show_info(self):
        print(f"------------- Statistics (subdirectory {self.subdirectory_no}, part {self.part}, shard {self.shard}) -------------")
        num_events = len(self.truth_data)
        print(f"Total {num_events} events from shard {self.shard}")


In [49]:
dataset_NuMu_PeV_1_1 = DatasetMonoFlavourShard(root_dir=root_dir,
                                            subdirectory_no=22012,
                                            part=1, 
                                            shard=1, 
                                            max_n_doms=3000,
                                            verbosity=0)

dataset_NuE_PeV_1_1 = DatasetMonoFlavourShard(root_dir=root_dir,
                                            subdirectory_no=22015,
                                            part=1, 
                                            shard=1, 
                                            max_n_doms=3000,
                                            verbosity=1)

dataset_NuTau_PeV_1_1 = DatasetMonoFlavourShard(root_dir=root_dir,
                                            subdirectory_no=22018,
                                            part=1, 
                                            shard=1, 
                                            max_n_doms=3000,
                                            verbosity=1)


  result = getattr(ufunc, method)(*inputs, **kwargs)


------------- Statistics (subdirectory 22015, part 1, shard 1) -------------
Total 2000 events from shard 1
------------- Statistics (subdirectory 22018, part 1, shard 1) -------------
Total 2000 events from shard 1


In [54]:
maxNDOMFinder_PeV_1_1 = MaxNDOMFinder(root_dir, EnergyRange.PEV_1_TO_PEV_100, 1, 1, verbosity=1)

In [55]:
maxNdomFinder_PeV_1_1()

Global max_n_doms across all data: 2421


2421

In [None]:
class DatasetMultiFlavourShard(Dataset):
    def __init__(self, root_dir: str, 
                 energy_band: EnergyRange, 
                 part: int, 
                 shard: int = None,
                 max_n_doms: int = None,
                 verbosity: int = 0):
        """
        Args:
            root_dir (str): The root directory of the dataset.
            energy_band (EnergyRange): The energy band (enum) defining the subdirectories.
            part (int): The part number to collect.
            shard (int): The shard number to collect.
            max_n_doms (int): The maximum number of DOMs in the dataset.
            verbosity (int): The verbosity level.
        """
        self.root_dir = root_dir
        self.energy_band = energy_band
        self.part = part
        self.shard = shard
        self.verbosity = verbosity
        
        if max_n_doms is None:
            max_n_doms_finder = MaxNDOMFinder(root_dir, energy_band, part, shard, verbosity=self.verbosity)
            self.max_n_doms = max_n_doms_finder()
        else:
            self.max_n_doms = max_n_doms
        
        self.datasets = self._collect_shards()
        self.cumulative_lengths = self._compute_cumulative_lengths()
        
        if verbosity > 0:
            self._show_info()
    
    def __len__(self):
        return sum(len(dataset) for dataset in self.datasets)
    
    def __getitem__(self, idx):
        dataset_idx, local_idx = self._global_to_local_index(idx)
        return self.datasets[dataset_idx][local_idx]
    
    def _collect_shards(self):
        datasets = []
        
        for subdir in self.energy_band.get_subdirs():
            dataset = DatasetMonoFlavourShard(
                root_dir = self.root_dir, 
                subdirectory_no = int(subdir),
                part = self.part,
                shard = self.shard,
                max_n_doms = self.max_n_doms,
                verbosity = self.verbosity - 1
            )
            datasets.append(dataset)
            
        return datasets
    
    def _global_to_local_index(self, idx):
        for dataset_idx, start in enumerate(self.cumulative_lengths[:-1]):
            if start <= idx < self.cumulative_lengths[dataset_idx + 1]:
                local_idx = idx - start
                return dataset_idx, local_idx
        raise IndexError(f"Index {idx} is out of range.")
    
    def _compute_cumulative_lengths(self):
        lengths = [len(dataset) for dataset in self.datasets]
        return [0] + list(np.cumsum(lengths))
    
    def _show_info(self):
        print(f"------------- Multi-Flavour Shard (Energy Band: {self.energy_band.name}, Part: {self.part}, Shard: {self.shard}) -------------")

In [56]:
dataset_PeV_1_1 = DatasetMultiFlavourShard(root_dir=root_dir,
                                            subdirectory_no=22012,
                                            part=1, 
                                            shard=1, 
                                            max_n_doms=maxNDOMFinder_PeV_1_1(),
                                            verbosity=1)

Global max_n_doms across all data: 2421
------------- Statistics (subdirectory 22012, part 1, shard 1) -------------
Total 2000 events from shard 1


In [57]:
dataset_PeV_1_1.__getitem__(0)['target'].item()

1

In [58]:
class DatasetMultiFlavourPart(Dataset):
    def __init__(self, root_dir: str, 
                 energy_band: EnergyRange, 
                 part: int, 
                 max_n_doms: int = None,
                 verbosity: int = 0):
        """
        Args:
            root_dir (str): The root directory of the dataset.
            energy_band (EnergyRange): The energy band (enum) defining the subdirectories.
            part (int): The part number to collect.
            verbosity (int): The verbosity level.
        """
        self.root_dir = root_dir
        self.energy_band = energy_band
        self.part = part
        self.verbosity = verbosity
        
        if max_n_doms is None:
            max_n_doms_finder = MaxNDOMFinder(root_dir, energy_band, part, verbosity=verbosity)
            self.max_n_doms = max_n_doms_finder()
        else:
            self.max_n_doms = max_n_doms

        # Collect all shards for the part from each flavour
        self.datasets = self._collect_shards()

        # Compute cumulative lengths for indexing
        self.cumulative_lengths = self._compute_cumulative_lengths()

        if verbosity > 0:
            self._show_info()

    def _collect_shards(self):
        datasets = []
        subdirectories = self.energy_band.get_subdirs()
        common_shards, unique_shards = self._get_common_and_unique_shard_numbers(subdirectories)
        
        for subdir in subdirectories:
            for shard in common_shards:
                datasets.append(
                    DatasetMultiFlavourShard(
                        root_dir=self.root_dir,
                        energy_band=self.energy_band,
                        part=self.part,
                        shard=shard,
                        max_n_doms=self.max_n_doms,
                        verbosity=self.verbosity - 1
                    )
                )
            for shard in unique_shards[subdir]:
                datasets.append(
                    DatasetMonoFlavourShard(
                        root_dir=self.root_dir,
                        subdirectory_no=int(subdir),
                        part=self.part,
                        shard=shard,
                        max_n_doms=self.max_n_doms,
                        verbosity=self.verbosity - 1
                    )
                )
        return datasets
        
    def _get_common_and_unique_shard_numbers(self, subdirectories):
        shard_sets = []
        all_shard_numbers = {}
        
        for subdir in subdirectories:
            shard_dir = os.path.join(self.root_dir, subdir, str(self.part))
            shard_numbers = {
                int(f.split('_')[1].split('.')[0])
                for f in os.listdir(shard_dir) if f.startswith("PMTfied_") and f.endswith(".parquet")
            }
            shard_sets.append(shard_numbers)
            all_shard_numbers[subdir] = shard_numbers

        common_shards = sorted(set.intersection(*shard_sets))

        unique_shards = {
            subdir: sorted(shard_numbers - set(common_shards))
            for subdir, shard_numbers in all_shard_numbers.items()
        }

        return common_shards, unique_shards

    def __len__(self):
        return sum(len(dataset) for dataset in self.datasets)

    def __getitem__(self, idx):
        dataset_idx, local_idx = self._global_to_local_index(idx)
        return self.datasets[dataset_idx][local_idx]

    def _compute_cumulative_lengths(self):
        lengths = [len(dataset) for dataset in self.datasets]
        return [0] + list(np.cumsum(lengths))

    def _global_to_local_index(self, idx):
        for dataset_idx, start in enumerate(self.cumulative_lengths[:-1]):
            if start <= idx < self.cumulative_lengths[dataset_idx + 1]:
                local_idx = idx - start
                return dataset_idx, local_idx
        raise IndexError(f"Index {idx} is out of range.")

    def _show_info(self):
        print(f"------------- Multi-Flavour Part (Energy Band: {self.energy_band.name}, Part: {self.part}) -------------")
        for dataset in self.datasets:
            dataset._show_info()


In [59]:
maxNDOMFinder_PeV_1 = MaxNDOMFinder(
    root_dir = root_dir,
    energy_band = EnergyRange.PEV_1_TO_PEV_100,
    part=1,
    verbosity=1
)

In [60]:
dataset_PeV_1 = DatasetMultiFlavourPart(root_dir=root_dir,
                                            energy_band=EnergyRange.PEV_1_TO_PEV_100,
                                            part=1, 
                                            max_n_doms=maxNDOMFinder_PeV_1(),
                                            verbosity=1)
# 30 sec -> 15 sec

Global max_n_doms across all data: 2674
------------- Multi-Flavour Part Statistics (Energy Band: PEV_1_TO_PEV_100, Part: 1) -------------
------------- Statistics (subdirectory 22012, part 1, shard 1) -------------
Total 2000 events from shard 1
------------- Statistics (subdirectory 22012, part 1, shard 2) -------------
Total 2000 events from shard 2
------------- Statistics (subdirectory 22012, part 1, shard 3) -------------
Total 2000 events from shard 3
------------- Statistics (subdirectory 22012, part 1, shard 4) -------------
Total 2000 events from shard 4
------------- Statistics (subdirectory 22012, part 1, shard 5) -------------
Total 2000 events from shard 5
------------- Statistics (subdirectory 22012, part 1, shard 6) -------------
Total 2000 events from shard 6
------------- Statistics (subdirectory 22012, part 1, shard 7) -------------
Total 2000 events from shard 7
------------- Statistics (subdirectory 22012, part 1, shard 8) -------------
Total 2000 events from shard

In [61]:
%%time
for _ in dataset_PeV_1:
    pass
# ~ 2 min

CPU times: user 23min 46s, sys: 16.7 s, total: 24min 3s
Wall time: 1min 56s


In [62]:
len(dataset_PeV_1)

92146

In [63]:
len(dataset_PeV_1)/74

1245.2162162162163

In [64]:
class PMTfiedDataModule(LightningDataModule):
    def __init__(self, root_dir: str, 
                 energy_band: EnergyRange, 
                 dataset: Dataset, 
                 batch_size: int = 32, 
                 num_workers: int = 4, 
                 split_ratios=(0.8, 0.1, 0.1), verbosity=0):
        super().__init__()
        self.root_dir = root_dir
        self.energy_band = energy_band
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.split_ratios = split_ratios
        self.verbosity = verbosity

    def setup(self, stage=None):
        total_len = len(self.dataset)
        train_len = int(total_len * self.split_ratios[0])
        val_len = int(total_len * self.split_ratios[1])
        test_len = total_len - train_len - val_len

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            self.dataset,
            [train_len, val_len, test_len],
            generator=torch.Generator().manual_seed(42)
        )

        # Compute class weights based on targets in train dataset
        targets = [sample["target"].item() for sample in self.train_dataset]
        class_counts = torch.bincount(torch.tensor(targets), minlength=4)  # Assuming 4 classes (0, 1, 2, 3)
        self.class_weights = 1.0 / class_counts.float()

        if self.verbosity > 0:
            print(f"Dataset split into train ({train_len}), val ({val_len}), and test ({test_len})")
            print(f"Class weights: {self.class_weights}")

    def _collate_fn(self, batch):
        features = [item["features"] for item in batch]
        targets = [item["target"] for item in batch]
        masks = [item["mask"] for item in batch]

        max_seq_length = max(f.shape[0] for f in features)
        padded_features = torch.zeros((len(features), max_seq_length, features[0].shape[1]), dtype=torch.float32)
        padded_masks = torch.zeros((len(masks), max_seq_length), dtype=torch.float32)

        for i, (feature, mask) in enumerate(zip(features, masks)):
            seq_length = feature.shape[0]
            padded_features[i, :seq_length, :] = feature
            padded_masks[i, :seq_length] = mask

        targets = torch.stack(targets)
        return {
            "features": padded_features,
            "target": targets,
            "mask": padded_masks,
        }

    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers, 
                          shuffle=True, 
                        #   collate_fn=self._collate_fn, 
                          pin_memory=True,
                          # skip the last batch if it is not full
                          drop_last=False)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers, 
                          shuffle=True, 
                        #   collate_fn=self._collate_fn, 
                          pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers, 
                          shuffle=True, 
                        #   collate_fn=self._collate_fn, 
                          pin_memory=True)


In [65]:
datamodule_PeV_1_1 = PMTfiedDataModule(root_dir=root_dir,
                                        energy_band=EnergyRange.ER_1_PEV_100_PEV,
                                        dataset=dataset_PeV_1_1,
                                        batch_size=32,
                                        num_workers=4)

In [66]:
datamodule_PeV_1_1.setup()
dl  = datamodule_PeV_1_1.train_dataloader()

In [67]:
%%time 
for batch in dl:
    pass
# ~ 2.5 secs

CPU times: user 111 ms, sys: 510 ms, total: 620 ms
Wall time: 2.14 s


In [68]:
len(dl)

50

In [69]:
len(dl)/2.5

20.0

$$
x = [\text{batch size}, \text{sequence length (or } N_{\text{doms}}\text{)}, \text{input size (or } N_{\text{features}}\text{)}] \\

\frac{1}{N_{\text{DOM}}}  \sum_{i} W_{\gamma\beta}^{\text{output}} \text{ReLU}\left( W_{\beta\alpha}^{\text{input}} x_{\beta i \alpha} \right) \\
\gamma \text{ is the number of classes: 3,} \\
x \text{ is normalised.}
$$


I will tell more about the data. the PMTfied data has the features that I will feed the model. and one of the columns is 'event_no' which the truth data also has. I may want to do these: regression on columns: 'azimuth', 'zenith', or 'energy' or multiclass classification for 'pid'.

'azimuth', 'zenith', 'energy' and 'pid' are all columns of truth data

## Multi-head Attention with ALiBi 
$\boxed{\text{input layer}} \rightarrow \boxed{\text{query, key, value projections}} \rightarrow \boxed{\text{scaled dot-product attention}} \rightarrow \boxed{\text{ALiBi bias addition}} \rightarrow \boxed{\text{attention scores}} \rightarrow \boxed{\text{weighted output}} \rightarrow \boxed{\text{output layer}}$

$$
\text{batch\_size, } d_{\text{model}} \text{, and } N_{\text{head}} \text{ are pre-defined by model builder, thus, so is } d_{\text{head}} = \frac{d_{\text{model}}}{N_{\text{head}}}
$$

$$
(\text{embed\_dim  is } d_{\text{model}}) 
$$

$$
\text{similarity measure } s_{ij} = q_i \cdot k_j^T 
$$

$$
\text{attention score } S_{ij} = \frac{s_{ij}}{\sqrt{d_k}} + m_l \cdot (j-i) + \text{mask}_{ij} \in \mathbb{R}^{\text{batch\_size} \times N_{\text{head}} \times N_{\text{dom}} \times N_{\text{dom}}}
$$

$$
\text{softmax }(S_{ij}) = \frac{\exp(S_{ij})}{\sum_{j=1}^{N_{\text{dom}}} \exp(S_{ij})} 
$$

$$
\text{attention}_i = \text{softmax}(S_{ij}) \cdot v_j \in \mathbb{R}^{\text{batch\_size} \times N_{\text{dom}} \times N_{\text{head}} \times d_{\text{head}}}
$$

* `nn.Linear` works for input tensor $x$, already containing both the weight matrix and the bias term: $y = x\cdot W^T + b$

In [70]:
class ALiBiAttention(nn.Module):
    def __init__(self, 
                 d_model: int, # the size of an embedded vector
                #  d_qk: int,
                #  d_v: int,
                 n_heads: int,
                 dropout: float = 0.1,
                nan_logger=None):
        super().__init__()
        self.d_model = d_model # d_model
        self.d_qk = self.d_model
        self.d_v = self.d_model
        self.n_heads = n_heads
        self.head_dim_qk = self.d_qk // self.n_heads
        self.head_dim_v = self.d_v // self.n_heads
        self.scale = self.head_dim_qk ** -0.5
        self.dropout = nn.Dropout(dropout)

        # a layer is a linear transformation
        self.q_proj = nn.Linear(self.d_model, self.d_qk) # Projects input vectors of size d_model into query vectors of size d_qk using a weight matrix W_q of size (d_model x d_qk)
        self.k_proj = nn.Linear(self.d_model, self.d_qk)
        self.v_proj = nn.Linear(self.d_model, self.d_v)
        self.out_proj = nn.Linear(self.d_v, self.d_model)

    # forward is invoked when calling the model
    # x is the input tensor
    # batch_size is the number of data samples in the batch
    # seq_length is the number of elements in the sequence(N_dom_max)
    # embed_dim is the dimension of the embedding
    def forward(self, x, mask=None):
        batch_size, seq_length, embed_dim = x.size()
        # print(f"batch_size: {batch_size}, seq_length: {seq_length}, embed_dim: {embed_dim}")
        # print(f"embed_dim: {embed_dim}, d_model: {self.d_model}")
        # assert embed_dim == self.d_model
        
        V = self.v_proj(x).view(batch_size, seq_length, self.n_heads, self.head_dim_v)
        attention_scores = self._get_attention_pure_score(x, batch_size, seq_length)
        alibi_bias = self._get_ALiBi_bias(x, seq_length)
        
        attention_scores += alibi_bias

        # Mask attention scores
        # masked_fill() fills elements of the tensor with a value where mask is True
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask[:, None, None, :] == 0, float("-inf"))

        attention = F.softmax(attention_scores, dim=-1)
        attention_output = torch.einsum("bhqk,bkhd->bqhd", attention, V).reshape(batch_size, seq_length, embed_dim)
        
        self.nan_logger.info(f"---------- attention(ALiBi) ---------- ")
        self.nan_logger.info(f"attn_output hasn't nan: {not torch.isnan(attention_output).any()}")
        
        return self.out_proj(attention_output)
    
    def _get_attention_pure_score(self, x, batch_size, seq_length):
        Q = self.q_proj(x).view(batch_size, seq_length, self.n_heads, self.head_dim_qk)
        K = self.k_proj(x).view(batch_size, seq_length, self.n_heads, self.head_dim_qk)
        attention_scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) * self.scale
        return attention_scores
    
    # HACK consider movig this outside this class. That would only be possible N_dom_max is constant for all parts...?
    def _get_ALiBi_bias(self, x, seq_length):
        # arange(n) returns a 1-D tensor of size n with values from 0 to n - 1
        slopes = 1.0 / (2 ** (torch.arange(self.n_heads).float() / self.n_heads))
        # to() moves the tensor to the device
        slopes = slopes.to(x.device)

        # view() reshapes the tensor
        relative_positions = torch.arange(seq_length).view(1, 1, seq_length) - torch.arange(seq_length).view(1, seq_length, 1)
        relative_positions = relative_positions.to(x.device)
        
        alibi_bias = slopes.view(self.n_heads, 1, 1) * relative_positions
        # unsqueeze() adds a dimension to the tensor
        alibi_bias = alibi_bias.unsqueeze(0)
        return alibi_bias

In [71]:
class InnocentAttention(nn.Module):
    def __init__(self, 
                 d_model: int, # the size of an embedded vector
                 # d_qk: int,
                    # d_v: int,
                    n_heads: int,
                    dropout: float = 0.1,
                    nan_logger=None):
        super().__init__()
        self.d_model = d_model
        self.d_qk = self.d_model
        self.d_v = self.d_model
        self.n_heads = n_heads
        self.head_dim_qk = self.d_qk // self.n_heads
        self.head_dim_v = self.d_v // self.n_heads
        self.scale = self.head_dim_qk ** -0.5
        self.dropout = nn.Dropout(dropout)
        
        self.q_proj = nn.Linear(self.d_model, self.d_qk)
        self.k_proj = nn.Linear(self.d_model, self.d_qk)
        self.v_proj = nn.Linear(self.d_model, self.d_v)
        self.out_proj = nn.Linear(self.d_v, self.d_model)
        
    def forward(self, x, mask=None):
        batch_size, seq_length, embed_dim = x.size()
        V = self.v_proj(x).view(batch_size, seq_length, self.n_heads, self.head_dim_v)
        attn_scores = self._get_attention_pure_score(x, batch_size, seq_length)
        
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask[:, None, None, :] == 0, float("-inf"))
        
        attention = F.softmax(attn_scores, dim=-1)
        attention_output = torch.einsum("bhqk,bkhd->bqhd", attention, V).reshape(batch_size, seq_length, embed_dim)
        
        self.nan_logger.info(f"---------- attention(Innocent) ---------- ")
        self.nan_logger.info(f"attn_output hasn't nan: {not torch.isnan(attention_output).any()}")
        return self.out_proj(attention_output)
    
    def _get_attention_pure_score(self, x, batch_size, seq_length):
        Q = self.q_proj(x).view(batch_size, seq_length, self.n_heads, self.head_dim_qk)
        K = self.k_proj(x).view(batch_size, seq_length, self.n_heads, self.head_dim_qk)
        attention_scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) * self.scale
        return attention_scores

In [72]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model: int, 
                    # d_qk: int,
                    # d_v: int,
                    n_heads: int, 
                    dropout: float = 0.1,
                    nan_logger=None):
        super().__init__()
        self.d_model = d_model
        self.d_qk = self.d_model
        self.d_v = self.d_model
        self.n_heads = n_heads
        self.head_dim_qk = self.d_qk // self.n_heads
        self.head_dim_v = self.d_v // self.n_heads
        self.scale = self.head_dim_qk ** -0.5
        self.dropout = dropout

        self.q_proj = nn.Linear(self.d_model, self.d_qk)
        self.k_proj = nn.Linear(self.d_model, self.d_qk)
        self.v_proj = nn.Linear(self.d_model, self.d_v)
        self.out_proj = nn.Linear(self.d_v, self.d_model)
        self.nan_logger = nan_logger

    def forward(self, x, mask=None):
        batch_size, seq_length, embed_dim = x.size()
        print(f"x shape: {x.shape}")

        # Compute Q, K, V
        Q = self.q_proj(x).view(batch_size, seq_length, self.n_heads, self.head_dim_qk).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_length, self.n_heads, self.head_dim_qk).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_length, self.n_heads, self.head_dim_v).transpose(1, 2)
        
        # Q : [32, 8, 2421, 16] — batch_size, n_heads, seq_length, head_dim_qk
        # K : [32, 8, 2421, 16] — batch_size, n_heads, seq_length, head_dim_qk
        # V : [32, 8, 2421, 16] — batch_size, n_heads, seq_length, head_dim_v

        assert Q.shape == K.shape, "Q and K must have the same shape"
        
        if mask is not None:
            # [batch_size, 1, 1, seq_length]
            mask = mask[:, None, None, :]  # Broadcast along heads and query dimensions
        
        attn_output = F.scaled_dot_product_attention(
            query=Q, key=K, value=V,
            attn_mask=mask,
            scale=self.scale,
            dropout_p=self.dropout
        )# Q @ K^T

        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_length, embed_dim)

        self.nan_logger.info(f"---------- attention(scaled dot-product) ---------- ")
        self.nan_logger.info(f"attn_output hasn't nan: {not torch.isnan(attn_output).any()}")
        return self.out_proj(attn_output)



## Layer Normalisation $\boxed{\text{input layer}} \rightarrow \boxed{\text{normalised input layer}}$
$$
d_n \text{ is the dimension of the normalisation layer, here, } d_n = d_{\text{model}} = 
$$

$$
(\text{embed\_dim  is } d_n)
$$

$$
\mu^l = \frac{\sum_{i=1}^{d_n} x_i}{d_n}
$$

$$
\sigma^l = \sqrt{\frac{\sum_{i=1}^{d_n} (x_i - \mu^l)^2}{d_n}}
$$

$$
g_i^l, b_i^l \in \mathbb{R}^{d_n} : \text{learnable parameters}
$$

$$
\text{LayerNormalisation}_i(x) = g_i^l \cdot \frac{x_i - \mu^l}{\sigma^l} + b_i^l \text{  (of feature } i)
$$

In [73]:
class LayerNormalisation(nn.Module):
    def __init__(self, d_n, eps=1e-10, nan_logger = None):
        # d_n (int): dimension of the normalisation layer, 
        # here d_n = d_model = embed_dim
        # eps: epsilon, a small number to avoid division by zero
        super().__init__()
        self.d_n = d_n
        self.eps = eps

        self.g = nn.Parameter(torch.ones(d_n)) # gain
        self.b = nn.Parameter(torch.zeros(d_n)) # bias
        self.nan_logger = nan_logger

    def forward(self, x):
        # (batch_size, seq_length, 1)
        mu = x.mean(dim=-1, keepdim=True)

        # (batch_size, seq_length, 1)
        var = x.var(dim=-1, keepdim=True, unbiased=False)

        x_normalised = (x - mu) / torch.sqrt(var + self.eps)

        # (batch_size, seq_length, d_n)
        x_tilde = self.g * x_normalised + self.b
        
        self.nan_logger.info(f"---------Layer Normalisation-----------")
        self.nan_logger.info(f"x hasn't nan: {not torch.isnan(x).any()}")
        self.nan_logger.info(f"mu hasn't nan: {not torch.isnan(mu).any()}")
        self.nan_logger.info(f"var hasn't nan: {not torch.isnan(var).any()}")
        self.nan_logger.info(f"x_normalised hasn't nan: {not torch.isnan(x_normalised).any()}")
        self.nan_logger.info(f"x_tilde hasn't nan: {not torch.isnan(x_tilde).any()}")
        
        return x_tilde

## Feed Forward Network $\boxed{\text{input layer}} \rightarrow \boxed{\text{hidden layer}} \rightarrow \boxed{\text{output layer}}$ 

$$
x_i^l \text{ is the input of the } i^{th} token in l^{th} \text{ layer}
$$

$$
d_{\text{f}} \text{ is the dimension of the feed forward network}
$$

$$
W_{h,i}^l, b_i^l \in \mathbb{R}^{d_{\text{model}} \times d_{\text{f}}} \text{ are learnable parameters}
$$

$$
\text{summed input } s_i^l = x_i^l W_{h,i}^l + b_{h,i}^l
$$

$$
\text{hidden layer } h_i^l = \text{ReLU}(s_i^l)
$$

$$
\text{summed output } x_i^{l+1} = h_i^l W_{f,i}^l + b_{f,i}^l
$$

In [74]:
class FFN(nn.Module):
    def __init__(self, d_model, d_f, dropout=0.1, nan_logger=None):
        super().__init__()
        self.d_model = d_model
        self.d_f = d_f
        self.W_h = nn.Linear(self.d_model, self.d_f)
        self.W_f = nn.Linear(self.d_f, self.d_model)
        
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.nan_logger = nan_logger
        
    def forward(self, x):
        # (batch_size, seq_length, d_f)
        s_i = self.W_h(x)# summed input
        # (batch_size, seq_length, d_f)
        h_i = self.activation(s_i)
        h_i = self.dropout(h_i)
        
        # (batch_size, seq_length, d_model)
        x_next = self.W_f(h_i) # summed output
        
        self.nan_logger.info(f"---------FFN-----------")
        self.nan_logger.info(f"s_i hasn't nan : {not torch.isnan(s_i).any()}")
        self.nan_logger.info(f"h_i hasm't nan : {not torch.isnan(h_i).any()}")
        self.nan_logger.info(f"x_next hasn't nan : {not torch.isnan(x_next).any()}")
        
        return x_next

In [75]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_f, dropout=0.1, layer_idx=0, nan_logger=None):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_f = d_f
        self.layer_idx = layer_idx
        self.dropout = nn.Dropout(dropout)
        
        self.nan_logger = nan_logger
        
        # self.attention = InnocentAttention(self.d_model, self.n_heads, dropout)
        self.attention = ScaledDotProductAttention(self.d_model, self.n_heads, dropout, self.nan_logger)
        self.norm_attention = LayerNormalisation(self.d_model, self.nan_logger)
        
        self.ffn = FFN(self.d_model, self.d_f, dropout, self.nan_logger)
        self.norm_ffn = LayerNormalisation(self.d_model, self.nan_logger)
        
        
    def forward(self, x, mask=None):
        self.nan_logger.info(f"==============Entering Encoder Block {self.layer_idx}==============")
        
        # Attention
        attn_output = self.attention(x, mask)
        x = self.norm_attention(x + self.dropout(attn_output))
        
        # FFN
        ffn_output = self.ffn(x)
        x = self.norm_ffn(x + self.dropout(ffn_output))
        
        self.nan_logger.info(f"Encoder Block {self.layer_idx} output hasn't nan: {not torch.isnan(x).any()}")
        return x


In [76]:
class TransformerModelPMT_Classification(LightningModule):
    def __init__(self, 
                 d_model, 
                 n_heads, 
                 d_f, 
                 num_layers, 
                 d_input,
                 num_classes, 
                 dropout=0.1, 
                 learning_rate=1e-4):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_f = d_f
        self.num_layers = num_layers
        self.d_input = d_input
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.nan_logger = logging.getLogger("nan_logger")
        self.dropout = dropout

        # Input projection layer
        self.input_projection = nn.Linear(self.d_input, self.d_model)

        # Stacked encoder blocks
        self.encoder_blocks = nn.ModuleList(
            [EncoderBlock(self.d_model, self.n_heads, self.d_f, self.dropout, layer_idx=i, nan_logger=self.nan_logger) for i in range(self.num_layers)]
        )

        # Classification head
        self.classification_output_layer = nn.Linear(self.d_model, self.num_classes)

    def forward(self, x, mask=None):
        # Input projection
        x = self.input_projection(x)

        # Encoder blocks
        for encoder in self.encoder_blocks:
            x = encoder(x, mask)

        # Classification head: Mean pooling across sequence and output logits
        x = x.mean(dim=1)
        logits = self.classification_output_layer(x)
        nan_logger.info(f"---------Classification Head-----------")
        nan_logger.info(f"logits hasn't nan: {not torch.isnan(logits).any()}")
        return logits

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        x = batch["features"]
        y = batch["target"]  # For classification, this is the class index
        mask = batch["mask"]

        logits = self(x, mask)
        loss = F.cross_entropy(logits, y)

        train_logger.info(f"Epoch {self.current_epoch}: train_loss={loss.item():.4f}")
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch["features"]
        y = batch["target"]
        mask = batch["mask"]

        logits = self(x, mask)
        loss = F.cross_entropy(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()

        train_logger.info(f"Epoch {self.current_epoch}: val_loss={loss.item():.4f}, val_acc={acc.item() * 100:.2f}%")
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        return loss

    def test_step(self, batch, batch_idx):
        x = batch["features"]
        y = batch["target"]
        mask = batch["mask"]

        logits = self(x, mask)
        loss = F.cross_entropy(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()

        train_logger.info(f"Epoch {self.current_epoch}: test_loss={loss.item():.4f}, test_acc={acc.item() * 100:.2f}%")
        self.log("test_loss", loss)
        self.log("test_acc", acc)
        return loss
    
    def on_train_epoch_start(self):
        nan_logger.info(f"####################Training epoch {self.current_epoch}####################")
        train_logger.info(f"####################Training epoch {self.current_epoch}####################")


In [77]:
Sample_root = "/lustre/hpc/project/icecube/HE_Nu_Aske_Oct2024/PMTfied/Snowstorm/99999"
# NuMu_PeV_root

In [78]:
model_class = TransformerModelPMT_Classification(
    d_model=128,
    n_heads=8,
    d_f=256,
    num_layers=3,
    d_input=32, # the number of PMTfied features
    num_classes=3, # the number of flavours: NuE, NuMu, NuTau
    dropout=0.1,
    learning_rate=1e-5
)

In [79]:
current_date = datetime.now().strftime("%Y%m%d")
current_time = datetime.now().strftime("%H%M%S")

base_log_dir = os.path.join("logs", current_date)

base_checkpoint_dir = os.path.join("checkpoints", current_date)
os.makedirs(base_checkpoint_dir, exist_ok=True)

# Training log
train_log_filename = os.path.join(base_log_dir, f"{current_time}_training.log")
train_logger = logging.getLogger("training")
train_logger.setLevel(logging.INFO)
train_handler = logging.FileHandler(train_log_filename)
train_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
train_handler.setFormatter(train_formatter)
train_logger.addHandler(train_handler)

# NaN log
nan_log_filename = os.path.join(base_log_dir, f"{current_time}_nan_checks.log")
nan_logger = logging.getLogger("nan_checks")
nan_logger.setLevel(logging.INFO)
nan_handler = logging.FileHandler(nan_log_filename)
nan_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
nan_handler.setFormatter(nan_formatter)
nan_logger.addHandler(nan_handler)


FileNotFoundError: [Errno 2] No such file or directory: '/lustre/hpc/icecube/cyan/factory/IceCubeTransformer/logs/20250128/100922_training.log'

In [117]:
tb_logger = TensorBoardLogger(
    save_dir=base_log_dir,
    name=f"{current_time}",  # Add time to the logger name
)

# Set up the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=base_checkpoint_dir, 
    filename=f"{current_time}_transformer-epoch{{epoch:02d}}-val_loss{{val_loss:.2f}}",  # Add time to filename
    save_top_k=3,
    mode="min"
)

# Set up the early stopping callback
early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    patience=10,
    verbose=True,
    mode="min"
)

In [118]:
def runTraining(model_class: nn.Module):
    wandb.init(
        project=f"[{current_date}_{current_time}]Neutrino Flavour Classification",
        config={
            "d_model": model_class.d_model,
            "n_heads": model_class.n_heads,
            "d_f": model_class.d_f,
            "num_layers": model_class.num_layers,
            "d_input": model_class.d_input,
            "num_classes": model_class.num_classes,
            "dropout": model_class.dropout,
            "learning_rate": model_class.learning_rate,
            "epochs": 5,
            "attention": "Scaled Dot-Product",
        },
    )

    train_logger.info(
    "| Parameter       | Value               |\n"
    "|-----------------|---------------------|\n"
    f"| attention       | Scaled Dot-Product |\n"
    f"| d_model         | {model_class.d_model:<15}|\n"
    f"| n_heads         | {model_class.n_heads:<15}|\n"
    f"| d_f             | {model_class.d_f:<15}|\n"
    f"| num_layers      | {model_class.num_layers:<15}|\n"
    f"| d_input         | {model_class.d_input:<15}|\n"
    f"| num_classes     | {model_class.num_classes:<15}|\n"
    f"| dropout         | {model_class.dropout:<15}|\n"
    f"| learning_rate   | {model_class.learning_rate:<15}|\n\n"
    )
    train_logger.info("Starting training...")

    trainer = Trainer(
        max_epochs=5,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,  # Use 2 GPUs or CPUs
        gradient_clip_val=1.0,
        callbacks=[checkpoint_callback, early_stopping_callback],
        log_every_n_steps=1,  # Log metrics after every batch
        logger=tb_logger
    )
    trainer.fit(model_class, datamodule=datamodule_PeV_1_1)
    
    # Finalize wandb
    wandb.finish()
    return trainer

# trainer = runTraining(model_class)


I have build a transformer encoder for multi class classification. I checked the dataset and the dataloader manytimes so I am pretty bit sure that the dataset, dataloader and data module are excellent and almost perfect. And I just finished the model build which seems fine in terms of logical structure. However, after few epochs of training I got NaN. I found this out by adding printing lines after each layer of the transformer. So I would like to go back to the stage before the model building while keeping the current model, 

In [119]:
# def init_weights(m):
#     if isinstance(m, nn.Linear):  # Check if the layer is Linear
#         nn.init.xavier_uniform_(m.weight)  # Xavier initialisation for weights
#         if m.bias is not None:
#             nn.init.zeros_(m.bias)  # Zero initialisation for biases

In [120]:
# model.apply(init_weights)

In [121]:
# trainer.fit(model, datamodule=data_module_part)

In [122]:
# trainer.test(model, datamodule=data_module_part)