# Import config

In [1]:
%load_ext yamlmagic
# %cd /Users/moustholmes/Projects/lightning-hydra-template

import os
import numpy as np
from omegaconf import DictConfig, OmegaConf
import hydra
from hydra import initialize, compose
from hydra.utils import instantiate


from pprint import pprint

# Set the PROJECT_ROOT environment variable
os.environ['PROJECT_ROOT'] = '/Users/moustholmes/Projects/METAL-AI'

# Change the current working directory to the project root
os.chdir('/Users/moustholmes/Projects/METAL-AI')

# Initialize Hydra with the config path relative to the project root
initialize(version_base=None, config_path='configs', job_name='notebook')

cfg_train = compose(config_name='train',) #overrides=['experiment=effect_gaussian_nll'])

print(OmegaConf.to_yaml(cfg_train))

task_name: train
tags:
- dev
train: true
test: true
ckpt_path: null
seed: null
data:
  _target_: src.data.dict_datamodule.DictDataModule
  data_dir: ${paths.data_dir}
  batch_size: 16
  train_val_splitter:
    _target_: src.data.components.data_utils.TripleTrainValSplitter
    validation_percentage: 0.05
    ASF_size_percentage: 0.03
    include_ion:
    - - 21
      - 21
    - - 22
      - 21
    - - 23
      - 21
    - - 24
      - 21
    - - 22
      - 22
    - - 24
      - 22
    - - 25
      - 22
    - - 23
      - 23
    - - 24
      - 23
    - - 25
      - 23
    - - 26
      - 23
    - - 24
      - 24
    - - 25
      - 24
    - - 26
      - 24
    - - 27
      - 24
    unseen_ion:
    - - 23
      - 22
    remove_nan_effect: false
  num_workers: 0
  shuffle: true
  pin_memory: true
  persistent_workers: false
model:
  scheduler:
    _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
    _partial_: true
    T_0: 5
    T_mult: 2
    eta_min: 0.0
  _target_: src.model

In [2]:
cfg_train = compose(config_name='train', overrides=['experiment=effect_gaussian_nll'])
print(OmegaConf.to_yaml(cfg_train))

task_name: train
tags:
- effect
- transformer_encoder
train: true
test: true
ckpt_path: null
seed: 12345
data:
  _target_: src.data.dict_datamodule.DictDataModule
  data_dir: ${paths.data_dir}
  batch_size: 128
  train_val_splitter:
    _target_: src.data.components.data_utils.RandomTrainValSplitter
    validation_percentage: 0.15
  num_workers: 0
  shuffle: true
  remove_nan_effect: true
  pin_memory: true
  persistent_workers: false
model:
  scheduler:
    _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
    _partial_: true
    T_0: 30
    T_mult: 1
    eta_min: 0.0
  _target_: src.models.metalAI_module.MetalAILitModule
  target_name: effect
  model:
    _target_: src.models.components.Transformer_encoder_model.simple_transformer_encoder_model
    csf_encoder:
      _target_: src.models.components.CSF_encoders.simple_CSF_encoder
      output_size: 4
    d_model: 32
    nhead: 2
    dim_forward: 16
    num_layers: 4
    output_size: 2
    dropout: 0.0
    output_activati

# Load data_dict


In [2]:
import pickle
with open(cfg_train.data.data_dir, 'rb') as file:
    data_dict = pickle.load(file)

## Inspect raw data

In [4]:
for ion_key in data_dict.keys():
    for asf_key in data_dict[ion_key].keys():
        print(ion_key, asf_key)
        pprint(data_dict[ion_key][asf_key])
        break
    break

(16, 13) ((2, 2, 6, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0), (2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0))
{'converged': array([nan, nan, nan, nan, nan, nan,  1., nan, nan, nan, nan, nan, nan,
       nan]),
 'done_effect': True,
 'effect': [np.float64(2.4898640200604847),
            np.float64(11.61487054542535),
            np.float64(11.1684417578352

# dataset

In [5]:
# %load /Users/moustholmes/Projects/METAL-AI/src/data/components/datasets.py
import torch
import h5py
import numpy as np
import scipy.spatial

from torch.utils.data.dataset import Dataset

class SimpleDictDataset(Dataset):
    def __init__(self, data_dict, remove_nan_effect=False):
        """
        data: The nested dictionary containing all your data
        """
        self.data = data_dict
        self.index_mapping = self._create_index_mapping(remove_nan_effect)

    def _create_index_mapping(self, remove_nan_effect):
        mapping = []
        for ion_key in self.data.keys():
            for asf_key in self.data[ion_key].keys():

                data_point = self.data[ion_key][asf_key]

                data_point["excitations"] = torch.tensor(data_point["excitations"])
                data_point["converged_mask"] = torch.from_numpy(
                    ~np.isnan(data_point["converged"])
                )
                data_point["converged"] = torch.tensor(
                    data_point["converged"], dtype=torch.bool
                )

                if "effect" not in data_point:
                    data_point["effect"] = torch.full(
                        (len(data_point["excitations"]),), float("nan")
                    )
                else:
                    data_point["effect"] = torch.tensor(data_point["effect"])

                data_point["n_protons"] = torch.tensor(ion_key[0], dtype=torch.long)
                data_point["n_electrons"] = torch.tensor(ion_key[1], dtype=torch.long)

                if remove_nan_effect:
                    if "effect" not in self.data[ion_key][asf_key]:
                        continue
                    if np.isnan(self.data[ion_key][asf_key]["effect"]).all():
                        continue
                mapping.append((ion_key, asf_key))

        return mapping

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, idx):
        # Retrieve the actual keys from the index
        ion_key, asf_key = self.index_mapping[idx]

        data_point = self.data[ion_key][asf_key]

        return data_point
class GroupedDictDataset(Dataset):
    def __init__(self, data_dict, remove_nan_effect=False, ion_include=None, ion_exclude=None):
        self.data = data_dict
        self.ion_include = ion_include
        self.ion_exclude = ion_exclude
        self.index_mapping = self._create_index_mapping(remove_nan_effect)
        self.grouped_indices = self._group_by_excitations_length()

    def _create_index_mapping(self, remove_nan_effect):
        mapping = []
        for ion_key in self.data.keys():
            # Apply inclusion and exclusion filters
            if self.ion_include is not None and ion_key not in self.ion_include:
                continue
            if self.ion_exclude is not None and ion_key in self.ion_exclude:
                continue

            for asf_key in self.data[ion_key].keys():
                data_point = self.data[ion_key][asf_key]

                if len(data_point["excitations"]) == 1:
                    continue

                data_point["filling_numbers"] = torch.tensor(asf_key)

                if not isinstance(data_point["excitations"], torch.Tensor):
                    data_point["excitations"] = torch.tensor(data_point["excitations"])
                
                if not isinstance(data_point["converged"], torch.Tensor):
                    data_point["converged_mask"] = torch.from_numpy(~np.isnan(data_point["converged"]))
                    data_point["converged"] = torch.tensor(data_point["converged"], dtype=torch.bool)

                if not isinstance(data_point["effect"], torch.Tensor):
                    data_point["effect"] = torch.tensor(data_point["effect"], dtype=torch.float32)

                data_point["n_protons"] = torch.tensor(ion_key[0], dtype=torch.long)
                data_point["n_electrons"] = torch.tensor(ion_key[1], dtype=torch.long)

                if remove_nan_effect:
                    if "effect" not in data_point:
                        continue
                    if np.isnan(data_point["effect"]).all():
                        continue
                mapping.append((ion_key, asf_key))

        return mapping

    def _group_by_excitations_length(self):
        groups = {}
        for idx, (ion_key, asf_key) in enumerate(self.index_mapping):
            excitations_length = len(self.data[ion_key][asf_key]["excitations"])
            if excitations_length == 1:
                continue
            if excitations_length not in groups:
                groups[excitations_length] = []
            groups[excitations_length].append(idx)
        return groups

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, idx):
        ion_key, asf_key = self.index_mapping[idx]
        data_point = self.data[ion_key][asf_key]

        selected_data = {
            'excitations': data_point['excitations'],
            'filling_numbers': data_point['filling_numbers'],
            'converged': data_point['converged'],
            'converged_mask': data_point['converged_mask'],
            'effect': data_point['effect'],
            'n_protons': data_point['n_protons'],
            'n_electrons': data_point['n_electrons'],
        }
        return selected_data

        # return data_point
        # return {
        #     "ion_key": ion_key,
        #     "asf_key": asf_key,
        #     **data_point,  # Unpacking the data point into the returned dictionary
        # }

In [96]:
dataset = GroupedDictDataset(data_dict,ion_include=[(4,4)],)
print(len(dataset))
for data in dataset:
    pprint(data)
    break

933
{'converged': tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
        False,  True,  True,  True]),
 'converged_mask': tensor([False, False, False, False, False, False, False, False, False, False,
         True, False, False, False]),
 'effect': tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]),
 'excitations': tensor([[5, 5],
        [4, 5],
        [4, 4],
        [3, 5],
        [3, 4],
        [3, 3],
        [2, 5],
        [2, 4],
        [2, 3],
        [2, 2],
        [1, 2],
        [0, 5],
        [0, 4],
        [0, 0]]),
 'filling_numbers': tensor([[2, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 

In [58]:
dataset = GroupedDictDataset(data_dict)
# used_ex = [[0,0],[1, 0], [2, 0], [2, 1], [3, 0], [3, 1], [3, 2], [4, 0], [4, 1], [4, 2], [4, 3], [5, 0], [5, 1], [5, 2], [5, 3], [6, 0], [6, 1], [6, 2]]
used_ex =  [(0, 0),
 (1, 0),
 (0, 1),
 (0, 2),
 (1, 1),
 (1, 2),
 (2, 1),
 (0, 3),
 (2, 2),
 (1, 3),
 (0, 4),
 (0, 5),
 (3, 2),
 (1, 4),
 (2, 3),
 (3, 3),
 (2, 4),
 (1, 5),
 (3, 4),
 (4, 3),
 (2, 5),
 (3, 5),
 (4, 4),
 (5, 4),
 (4, 5),
 (5, 5)]

used_ex = [list(ex) for ex in used_ex]


used_ex += [[0,0],[3, 4],[4,4],[3,3],[5,5],[5,4]]
print(len(dataset))
for data in dataset:
    for i in range(data['excitations'].shape[0]):
        if data['excitations'][i].tolist() not in used_ex:
            print(data['excitations'][i].tolist())
            # pprint(data)
            break

[[0, 0], [1, 0], [0, 1], [0, 2], [1, 1], [1, 2], [2, 1], [0, 3], [2, 2], [1, 3], [0, 4], [0, 5], [3, 2], [1, 4], [2, 3], [3, 3], [2, 4], [1, 5], [3, 4], [4, 3], [2, 5], [3, 5], [4, 4], [5, 4], [4, 5], [5, 5]]
True
125175


## split dataset

In [3]:
# %load /Users/moustholmes/Projects/METAL-AI/src/data/components/data_utils.py
import pickle
import random
import numpy as np

def load_data_dict(data_dir):
    # Load the data dictionary from the data directory
    with open(data_dir, "rb") as f:
        data_dict = pickle.load(f)

    return data_dict

def save_data_dict(data_dict, data_dir):
    # Save the data dictionary to the data directory
    with open(data_dir, "wb") as f:
        pickle.dump(data_dict, f)

def top_asf_size_train_val_split(data_dict, validation_percentage=0.05):
    # Flatten the data_dict to a list of tuples: (ion_key, asf_key, excitations_length)
    flat_data = [
        (ion_key, asf_key, len(data_dict[ion_key][asf_key]["excitations"]))
        for ion_key in data_dict.keys()
        for asf_key in data_dict[ion_key].keys()
    ]

    # Sort by the length of excitations in descending order
    flat_data.sort(key=lambda x: x[2], reverse=True)

    # Determine the number of validation data points
    validation_size = int(len(flat_data) * validation_percentage)

    # Create the validation and training datasets
    validation_data = {}
    training_data = {}

    # Add the top 5% to the validation set
    for ion_key, asf_key, _ in flat_data[:validation_size]:
        if ion_key not in validation_data:
            validation_data[ion_key] = {}
        validation_data[ion_key][asf_key] = data_dict[ion_key][asf_key]

    # Add the remaining 95% to the training set
    for ion_key, asf_key, _ in flat_data[validation_size:]:
        if ion_key not in training_data:
            training_data[ion_key] = {}
        training_data[ion_key][asf_key] = data_dict[ion_key][asf_key]

    return training_data, validation_data

class TopASFSizeTrainValSplitter:
    def __init__(self, validation_percentage: float = 0.05):
        self.validation_percentage = validation_percentage

    def __call__(self, data_dict):
        return top_asf_size_train_val_split(data_dict, self.validation_percentage)

def random_train_val_split(data_dict, validation_percentage=0.05):
    # Flatten the data_dict to a list of tuples: (ion_key, asf_key)
    flat_data = [
        (ion_key, asf_key)
        for ion_key in data_dict.keys()
        for asf_key in data_dict[ion_key].keys()
    ]
    


    
    # Shuffle the data randomly
    random.shuffle(flat_data)
    
    # Determine the number of validation data points
    validation_size = int(len(flat_data) * validation_percentage)
    
    # Create the validation and training datasets
    validation_data = {}
    training_data = {}
    
    # Add the first part to the validation set
    for ion_key, asf_key in flat_data[:validation_size]:
        if ion_key not in validation_data:
            validation_data[ion_key] = {}
        validation_data[ion_key][asf_key] = data_dict[ion_key][asf_key]
    
    # Add the remaining part to the training set
    for ion_key, asf_key in flat_data[validation_size:]:
        if ion_key not in training_data:
            training_data[ion_key] = {}
        training_data[ion_key][asf_key] = data_dict[ion_key][asf_key]
    
    return training_data, validation_data

class RandomTrainValSplitter:
    def __init__(self, validation_percentage: float = 0.05):
        self.validation_percentage = validation_percentage

    def __call__(self, data_dict):
        return random_train_val_split(data_dict, self.validation_percentage)

def triple_train_val_split(
    data_dict, 
    validation_percentage=0.05,
    ASF_size_percentage=0.05,
    include_ion=None,
    unseen_ion=None,
    remove_nan_effect=False,
):
    """
    Split data into:
    - `training_data`: Remaining training data after validation splits.
    - `validation_data`: Random subset for validation from training ions.
    - `unseen_data`: Data belonging to unseen ions (not part of include_ion).
    - `large_data`: ASFs larger than the calculated thresholds.
    """

    include_ion = [tuple(item) for item in include_ion] if include_ion else None
    unseen_ion = [tuple(item) for item in unseen_ion] if unseen_ion else None

    # Step 1: Calculate thresholds for large ASFs for each ion
    ion_threshold = {}
    for ion_key in data_dict.keys():
        # Get all ASF sizes for this ion, applying `remove_nan_effect` if True
        asf_sizes = [
            len(asf_data["excitations"])
            for asf_key, asf_data in data_dict[ion_key].items()
            if not (remove_nan_effect and ("effect" not in asf_data or np.isnan(asf_data["effect"]).all()))
        ]
        if not asf_sizes:  # If no valid ASFs remain, skip this ion
            ion_threshold[ion_key] = float('inf')  # No ASFs qualify
            continue
        # Sort ASF sizes in descending order
        asf_sizes.sort(reverse=True)
        # Calculate the threshold index
        threshold_index = max(0, int(len(asf_sizes) * ASF_size_percentage) - 1)
        # Determine the threshold size
        ion_threshold[ion_key] = asf_sizes[threshold_index]

    # Step 2: Initialize splits
    training_data = {}
    validation_data = {}
    unseen_data = {}
    large_data = {}

    # Step 3: Assign data points to the correct sets
    for ion_key, asfs in data_dict.items():
        for asf_key, asf_data in asfs.items():
            if remove_nan_effect:
                    if "effect" not in asf_data:
                        continue
                    if np.isnan(asf_data["effect"]).all():
                        continue
            asf_size = len(asf_data["excitations"])
            above_threshold = asf_size >= ion_threshold[ion_key]

            # Check if ion belongs to unseen_ion
            if unseen_ion is not None and ion_key in unseen_ion:
                if not above_threshold:
                    if ion_key not in unseen_data:
                        unseen_data[ion_key] = {}
                    unseen_data[ion_key][asf_key] = asf_data
            # Include ions in training or validation
            elif include_ion is None or ion_key in include_ion:
                if above_threshold:
                    if ion_key not in large_data:
                        large_data[ion_key] = {}
                    large_data[ion_key][asf_key] = asf_data
                else:
                    if ion_key not in training_data:
                        training_data[ion_key] = {}
                    training_data[ion_key][asf_key] = asf_data

    # Step 4: Randomly split training_data into training and validation subsets
    flat_training_data = [
        (ion_key, asf_key)
        for ion_key, asfs in training_data.items()
        for asf_key in asfs.keys()
    ]
    random.shuffle(flat_training_data)
    validation_size = int(len(flat_training_data) * validation_percentage)
    validation_subset = flat_training_data[:validation_size]

    final_training_data = {}
    for ion_key, asf_key in flat_training_data[validation_size:]:
        if ion_key not in final_training_data:
            final_training_data[ion_key] = {}
        final_training_data[ion_key][asf_key] = training_data[ion_key][asf_key]

    for ion_key, asf_key in validation_subset:
        if ion_key not in validation_data:
            validation_data[ion_key] = {}
        validation_data[ion_key][asf_key] = training_data[ion_key][asf_key]

    return final_training_data, validation_data, unseen_data, large_data

class TripleTrainValSplitter:
    def __init__(
        self,
        validation_percentage: float = 0.05,
        ASF_size_percentage: float = 0.05,
        include_ion=None,
        unseen_ion=None,
        remove_nan_effect=False,
    ):
        self.validation_percentage = validation_percentage
        self.ASF_size_percentage = ASF_size_percentage
        self.include_ion = include_ion
        self.unseen_ion = unseen_ion
        self.remove_nan_effect = remove_nan_effect

    def __call__(self, data_dict):
        return triple_train_val_split(
            data_dict,
            self.validation_percentage,
            self.ASF_size_percentage,
            self.include_ion,
            self.unseen_ion,
            self.remove_nan_effect,
        )



In [99]:
# train_val_splitter = TopASFSizeTrainValSplitter()
train_val_splitter = RandomTrainValSplitter()

data_dict_train, data_dict_val = train_val_splitter(data_dict)
dataset = GroupedDictDataset( data_dict, remove_nan_effect=True, ion_include=[(4,4)])
dataset_train = GroupedDictDataset( data_dict_train, remove_nan_effect=True, ion_include=[(4,4)])
dataset_val = GroupedDictDataset( data_dict_val, remove_nan_effect=True, ion_include=[(4,4)])

print(len(dataset))
print(len(dataset_train))
print(len(dataset_val))

  if np.isnan(data_point["effect"]).all():


392
372
20


In [6]:
# [(22,22),
# (24,22),
# (25,22),
# (23,23),
# (24,23),
# (25,23),
# (26,23),]


triple_train_val_splitter = TripleTrainValSplitter( 
    validation_percentage=0.2,
    ASF_size_percentage=0.05, 
    include_ion = [
        (22,22),
        (24,22),
        (25,22),
        (23,23),
        (24,23),
        (25,23),
        (26,23),
        ], 
    unseen_ion =[
        (23,22)
        ],
    remove_nan_effect=True
    
    )
data_dict_train, data_dict_val, data_dict_unseen, data_dict_large = triple_train_val_splitter(data_dict)
dataset_train = GroupedDictDataset(data_dict_train, remove_nan_effect=False)
dataset_val = GroupedDictDataset(data_dict_val, remove_nan_effect=False)
dataset_unseen = GroupedDictDataset(data_dict_unseen, remove_nan_effect=False)
dataset_large = GroupedDictDataset(data_dict_large, remove_nan_effect=False)

print(len(dataset_train))
print(len(dataset_val))
print(len(dataset_unseen))
print(len(dataset_large))
# print(len(dataset_val)/(len(dataset_val) +len(dataset_train)))

1478
368
100
150


## dataloader


### collate function

In [11]:
# %load /Users/moustholmes/Projects/METAL-AI/src/data/components/collate_fns.py
import torch

def dict_collate_fn(batch):
    # Find the maximum number of 'CSFs' across all samples in the batch.
    max_csf_count = max([len(item["excitations"]) for item in batch])

    # Initialize lists to hold padded 'CSFs' and the original lengths.
    original_lengths = []

    padded_csfs_list = []
    padded_converged_list = []
    padded_converged_mask_list = []
    padded_effect_list = []

    for item in batch:
        original_length = item["excitations"].size(0)
        original_lengths.append(original_length)
        padded_csfs = torch.cat(
            [
                item["excitations"],
                torch.zeros(max_csf_count - original_length, 2),  # change the 2 to the number of allowed excitations
            ]
        )
        # print(item['converged'])
        padded_converged = torch.cat(
            [
                item["converged"],
                torch.zeros(
                    max_csf_count - original_length,
                ),  # Pad with zeros. dtype=torch.bool
            ]
        )
        padded_converged_mask = torch.cat(
            [
                item["converged_mask"],
                torch.zeros(
                    max_csf_count - original_length,
                ),  # Pad with zeros. dtype=torch.bool
            ]
        )

        padded_effect = torch.cat(
            [
                item["effect"],
                torch.zeros(max_csf_count - original_length),  # Pad with zeros.
            ]
        )
        padded_csfs_list.append(padded_csfs)
        padded_converged_list.append(padded_converged)
        padded_converged_mask_list.append(padded_converged_mask)
        padded_effect_list.append(padded_effect)

    # Stack the padded 'CSFs'.
    padded_csfs = torch.stack(padded_csfs_list)
    padded_converged = torch.stack(padded_converged_list)
    padded_converged_mask = torch.stack(padded_converged_mask_list)
    padded_effect = torch.stack(padded_effect_list)

    # Create a mask based on the original lengths.
    mask = torch.ones_like(padded_csfs[:, :, 0], dtype=torch.bool)
    for i, length in enumerate(original_lengths):
        mask[i, :length] = False

    # Convert other attributes to tensors.
    n_electrons = torch.stack([item["n_electrons"] for item in batch])
    n_protons = torch.stack([item["n_protons"] for item in batch])

    # Return a dictionary with the batched data and the mask.
    return {
        "excitations": padded_csfs,
        "pad_mask": mask,
        "n_electrons": n_electrons,
        "n_protons": n_protons,
        "converged": padded_converged,
        "converged_mask": padded_converged_mask,
        "effect": padded_effect,
    }

### samplers

In [12]:
# %%writefile /Users/moustholmes/Projects/METAL-AI/src/data/components/samplers.py
from torch.utils.data import Sampler
import random

# class GroupedBatchSampler(Sampler):
#     def __init__(self, dataset, batch_size, shuffle=True):
#         self.dataset = dataset
#         self.batch_size = batch_size
#         self.grouped_indices = dataset.grouped_indices
#         if shuffle:
#             self.group_lengths = sorted(self.grouped_indices.keys())
#         else:
#             self.group_lengths = list(self.grouped_indices.keys())

#     def __iter__(self):
#         for length in self.group_lengths:
#             indices = self.grouped_indices[length]
#             random.shuffle(indices)
#             # Yield full batches only
#             for i in range(0, len(indices), self.batch_size):
#                 if i + self.batch_size <= len(indices):
#                     yield indices[i:i + self.batch_size]

#     def __len__(self):
#         # This is an approximation, as we drop the last incomplete batch in each group
#         return sum(len(indices) // self.batch_size for indices in self.grouped_indices.values())

class GroupedBatchSampler(Sampler):
    def __init__(self, dataset, batch_size, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.grouped_indices = dataset.grouped_indices
        self.shuffle = shuffle

    def __iter__(self):
        all_batches = []

        # Sort the group lengths if shuffle is False
        group_lengths = sorted(self.grouped_indices.keys()) if not self.shuffle else list(self.grouped_indices.keys())

        for length in group_lengths:
            indices = self.grouped_indices[length]
            # Shuffle within each group if shuffle is True
            if self.shuffle:
                random.shuffle(indices)
            
            # Form batches from the indices of this group
            for i in range(0, len(indices), self.batch_size):
                if i + self.batch_size <= len(indices):
                    all_batches.append(indices[i:i + self.batch_size])

        # Shuffle all batches if shuffle is True
        if self.shuffle:
            random.shuffle(all_batches)

        # Yield each batch one by one
        for batch in all_batches:
            yield batch

    def __len__(self):
        # This is an approximation, as we drop the last incomplete batch in each group
        return sum(len(indices) // self.batch_size for indices in self.grouped_indices.values())

In [13]:
from torch.utils.data import DataLoader
dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
# sampler = GroupedBatchSampler(dataset, batch_size=cfg_train.data.batch_size)
dataloader = DataLoader(dataset, batch_size=8, collate_fn=dict_collate_fn)

for batch in dataloader:
    print(batch['excitations'].shape[1])
    
    # break

  if np.isnan(data_point["effect"]).all():


24
24
23
24
24
9
22
24
24
23
22
20
22
22
17
25
23
22
24
24
23
25
20
23
20
19
22
21
23
24
20
19
23
24
23
23
25
23
23
21
19
24
22
24
24
24
23
22
25
23
22
23
23
21
22
24
23
23
20
17
21
22
23
22
24
24
24
21
21
23
23
24
21
23
23
18
25
19
24
22
19
20
8
21
23
23
23
18
24
23
23
24
23
20
24
22
22
17
19
19
19
21
19
18
19
17
19
17
14
18
18
18
20
17
20
20
19
19
18
17
8
18
17
18
18
19
18
19
18
17
18
19
17
18
18
19
19
19
18
17
20
18
19
17
18
19
19
18
16
17
20
18
16
16
19
17
18
18
15
19
17
19
19
17
18
18
17
19
19
20
17
18
17
17
17
19
18
19
19
19
17
18
19
18
18
20
19
17
19
17
17
14
16
19
18
19
19
17
18
15
18
15
19
18
17
18
19
16
17
19
16
17
18
17
18
18
18
17
18
6
19
20
19
19
19
19
18
18
20
17
16
17
19
19
16
18
9
19
19
18
17
18
17
19
18
19
19
18
20
18
18
18
18
20
9
19
18
18
18
19
17
19
14
15
19
18
19
20
18
11
19
18
19
8
5
18
19
18
24
23
21
23
22
23
23
23
23
24
23
19
25
24
23
24
23
23
20
23
23
23
21
22
22
23
23
24
20
24
23
22
24
22
23
22
24
24
23
22
23
23
23
22
17
23
20
24
25
23
5
9
25
23
23
22
24
21
24

In [14]:
from torch.utils.data import DataLoader
dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
# sampler = GroupedBatchSampler(dataset, batch_size=cfg_train.data.batch_size)
dataloader = DataLoader(dataset, batch_size=8, collate_fn=dict_collate_fn)

for batch in dataloader:
    print(batch)
    
    break

  if np.isnan(data_point["effect"]).all():


{'excitations': tensor([[[5., 5.],
         [5., 4.],
         [4., 5.],
         [4., 4.],
         [3., 5.],
         [3., 4.],
         [3., 3.],
         [3., 2.],
         [2., 5.],
         [2., 4.],
         [2., 1.],
         [1., 5.],
         [0., 5.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.]],

        [[5., 5.],
         [5., 4.],
         [4., 5.],
         [4., 4.],
         [4., 3.],
         [3., 5.],
         [3., 4.],
         [3., 3.],
         [3., 2.],
         [2., 5.],
         [2., 4.],
         [2., 3.],
         [2., 2.],
         [2., 1.],
         [1., 5.],
         [1., 3.],
         [1., 2.],
         [0., 5.],
         [0., 4.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.]],

        [[5., 5.],
         [5., 4.],
         [4., 5.],
         [4

In [15]:
from torch.utils.data import DataLoader
dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
sampler = GroupedBatchSampler(dataset, batch_size=cfg_train.data.batch_size)
dataloader = DataLoader(dataset, batch_sampler=sampler)

for batch in dataloader:
    print(batch['excitations'].shape)
    break

  if np.isnan(data_point["effect"]).all():


torch.Size([128, 17, 2])


In [16]:
from torch.utils.data import DataLoader
dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)

sampler = GroupedBatchSampler(dataset, batch_size=cfg_train.data.batch_size)
dataloader = DataLoader(dataset, batch_sampler=sampler, collate_fn=dict_collate_fn)

for batch in dataloader:
    print(batch['excitations'])
    

  if np.isnan(data_point["effect"]).all():


tensor([[[5., 5.],
         [5., 4.],
         [4., 5.],
         ...,
         [0., 3.],
         [0., 1.],
         [0., 0.]],

        [[5., 5.],
         [5., 4.],
         [4., 5.],
         ...,
         [0., 3.],
         [0., 2.],
         [0., 0.]],

        [[5., 5.],
         [5., 4.],
         [4., 5.],
         ...,
         [0., 3.],
         [0., 2.],
         [0., 0.]],

        ...,

        [[5., 5.],
         [4., 5.],
         [4., 4.],
         ...,
         [0., 2.],
         [0., 1.],
         [0., 0.]],

        [[5., 5.],
         [4., 5.],
         [4., 4.],
         ...,
         [0., 3.],
         [0., 2.],
         [0., 0.]],

        [[5., 5.],
         [5., 4.],
         [4., 5.],
         ...,
         [0., 4.],
         [0., 1.],
         [0., 0.]]])
tensor([[[5., 5.],
         [4., 5.],
         [4., 4.],
         ...,
         [0., 3.],
         [0., 1.],
         [0., 0.]],

        [[5., 5.],
         [4., 5.],
         [4., 4.],
         ...,
     

In [17]:
from torch.utils.data import DataLoader

splitter = RandomTrainValSplitter(0.5)

train_data_dict, val_data_dict = splitter(data_dict)

train_dataset = GroupedDictDataset( train_data_dict, remove_nan_effect=True)
val_dataset = GroupedDictDataset( val_data_dict, remove_nan_effect=True)

train_sampler = GroupedBatchSampler(train_dataset, batch_size=cfg_train.data.batch_size, shuffle = True)
val_sampler = GroupedBatchSampler(val_dataset, batch_size=cfg_train.data.batch_size,shuffle = True)
train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler)
val_dataloader = DataLoader(val_dataset, batch_sampler=val_sampler)

for batch in train_dataloader:
    print(batch['excitations'].shape[1])


print()
print('val!!!')
print()

for batch in val_dataloader:
    print(batch['excitations'].shape[1])

  if np.isnan(data_point["effect"]).all():


2
12
4
15
3
18
10
4
21
7
8
3
11
9
3
12
6
17
9
6
6
3
9
17
12
2
9
14
10
17
7
19
14
22
16
8
11
5
8
13
2
5
9
8
16
9
18
6
15
14
17
16
7
7
17
17
14
8
7
9
13
16
5
18
14
5
16
14
5
16
8
11
8
10
9
18
17
6
18
6
20
4
15
19
16
17
24
17
4
4
4
2
8
11
3
16
19
16
5
11
10
2
5
13
17
18
17
15
10
18
5
11
3
3
3
5
17
18
11
18
23
15
8
23
16
20
6
10
17
6
14
13
17
15
15
5
17
7
22
15
21
16
7
7
10
14
17
16
12
19
7
10
18
10
12
15
13
15
19
12
18
17
18
18
13
15
12
4
13
4
19
22
17
6
7
13
4
3
8
14
6
9
16
17
16
3
17
18
12
4
5
11
18
16
11
18
6
7
12
16
4
3
17
16
20
13
4
12
19
3

val!!!

19
16
6
22
19
5
16
7
5
20
5
16
21
19
10
20
15
3
6
18
10
20
7
11
17
8
6
13
16
9
17
12
8
3
8
10
19
13
13
18
3
2
7
14
19
10
8
4
17
17
5
8
9
12
17
17
17
4
4
12
18
11
4
12
11
5
5
7
4
12
18
15
10
4
6
23
14
18
21
15
3
4
9
9
7
13
3
6
6
7
12
12
7
9
19
14
14
14
10
18
3
8
3
4
13
14
17
13
11
6
17
7
8
18
17
16
12
2
17
3
7
19
10
6
9
11
18
16
17
17
2
17
6
15
4
23
15
17
2
11
2
10
8
16
10
3
18
11
11
4
3
14
16
13
22
18
17
12
15
17
18
6
3
16
16
17
15
6
7
15

# model

## encoders

In [93]:
# %%writefile /Users/moustholmes/Projects/METAL-AI/src/models/components/CSF_encoders.py
import torch
from torch import nn

class simple_CSF_encoder(nn.Module):
    def __init__(self, output_size=32):
        super(simple_CSF_encoder, self).__init__()
        # Assuming the input size is 6 because we append n_electrons (1) and n_protons (1) to each 4-dimensional CSF.
        self.output_size = output_size
        self.network = nn.Sequential(
        nn.Linear(4, 64), # number of allowed excitations + 2
        nn.ReLU(),
        nn.Linear(64, output_size)
        )

    def forward(self, input_dict):
        excitations = input_dict["excitations"]
        n_electrons = input_dict["n_electrons"]
        n_protons = input_dict["n_protons"]

        # Append n_electrons and n_protons to each excitations
        n_electrons = n_electrons.float().unsqueeze(-1).unsqueeze(-1).expand(-1, excitations.size(1), 1)
        n_protons = n_protons.float().unsqueeze(-1).unsqueeze(-1).expand(-1, excitations.size(1), 1)
        extended_excitations = torch.cat([excitations, n_electrons, n_protons], dim=-1)
        return self.network(extended_excitations)
    
class no_CSF_encoder(nn.Module):
    def __init__(self,):
        super(simple_CSF_encoder, self).__init__()
        # Assuming the input size is 6 because we append n_electrons (1) and n_protons (1) to each 4-dimensional CSF.


    def forward(self, input_dict):

        excitations = input_dict["excitations"]
        n_electrons = input_dict["n_electrons"]
        n_protons = input_dict["n_protons"]

        # Append n_electrons and n_protons to each CSF
        
        n_electrons = n_electrons.float().unsqueeze(-1).unsqueeze(-1).expand(-1, excitations.size(1), 1)
        n_protons = n_protons.float().unsqueeze(-1).unsqueeze(-1).expand(-1, excitations.size(1), 1)
        extended_excitations = torch.cat([excitations, n_electrons, n_protons], dim=-1)
        return extended_excitations

class RotaryEmbedding2Angle4D(nn.Module):
    def __init__(self, dim):
        """
        Initialize the RepeatedRotaryEmbedding module.

        Args:
            dim (int): Dimension of the embeddings, must be divisible by 4.
        """
        super().__init__()
        assert dim % 4 == 0, "Embedding dimension must be divisible by 4"
        self.dim = dim

    def forward(self, embeddings, theta, phi):
        """
        Applies the repeated rotation on the embeddings using angles theta and phi.

        Args:
            embeddings (torch.Tensor): Input embeddings of shape [batch_size, n, dim].
            theta (torch.Tensor): Rotation angles for xy planes of shape [batch_size].
            phi (torch.Tensor): Rotation angles for zw planes of shape [batch_size].

        Returns:
            torch.Tensor: Rotated embeddings of the same shape as input.
        """
        batch_size, n, dim = embeddings.shape
        assert dim == self.dim, f"Input embedding dimension {dim} does not match initialized dimension {self.dim}"

        num_blocks = dim // 4  # Number of 4x4 blocks per embedding

        # Compute rotation components
        cos_theta = torch.cos(theta).unsqueeze(1).repeat(1, num_blocks)  # Shape: [batch_size, num_blocks]
        sin_theta = torch.sin(theta).unsqueeze(1).repeat(1, num_blocks)  # Shape: [batch_size, num_blocks]
        cos_phi = torch.cos(phi).unsqueeze(1).repeat(1, num_blocks)      # Shape: [batch_size, num_blocks]
        sin_phi = torch.sin(phi).unsqueeze(1).repeat(1, num_blocks)      # Shape: [batch_size, num_blocks]

        # Create block-diagonal rotation matrix for the entire batch
        R = torch.zeros(batch_size, dim, dim, device=embeddings.device)

        for i in range(num_blocks):
            # Define indices for the i-th block
            start = i * 4
            end = start + 4

            # Fill the block-diagonal matrix for all batches
            R[:, start:start+2, start:start+2] = torch.stack([
                torch.stack([cos_theta[:, i], -sin_theta[:, i]], dim=-1),
                torch.stack([sin_theta[:, i], cos_theta[:, i]], dim=-1)
            ], dim=1)

            R[:, start+2:end, start+2:end] = torch.stack([
                torch.stack([cos_phi[:, i], -sin_phi[:, i]], dim=-1),
                torch.stack([sin_phi[:, i], cos_phi[:, i]], dim=-1)
            ], dim=1)

        # Apply the rotation to embeddings
        rotated_embeddings = torch.einsum('bnd,bdm->bnm', embeddings, R)

        return rotated_embeddings

class PairEmbedding(nn.Module):
    def __init__(self, embedding_dim, pair_list=None):
        """
        Args:
            embedding_dim (int): Dimension of the embedding.
            pair_list (list of tuples): List of unique (x, y) pairs to index.
        """
        super(PairEmbedding, self).__init__()
        if pair_list is None:
            pair_list = [
                (0, 0), (1, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 1),
                (0, 3), (2, 2), (1, 3), (0, 4), (0, 5), (3, 2), (1, 4),
                (2, 3), (3, 3), (2, 4), (1, 5), (3, 4), (4, 3), (2, 5),
                (3, 5), (4, 4), (5, 4), (4, 5), (5, 5)
            ]
        
        # Initialize the embedding layer
        self.embedding = nn.Embedding(len(pair_list), embedding_dim)
        
        # Create the lookup table
        self.max_val = max(max(pair) for pair in pair_list)
        lookup_table = torch.full((self.max_val + 1, self.max_val + 1), -1, dtype=torch.long)
        for idx, pair in enumerate(pair_list):
            lookup_table[pair[0], pair[1]] = idx
        
        # Register the lookup table as a buffer so it moves with the module
        self.register_buffer("lookup_table", lookup_table, persistent=True)

    def forward(self, pair_tensor):
        """
        Args:
            pair_tensor (torch.Tensor): Tensor of shape [batch_size, n, 2] containing pairs.

        Returns:
            torch.Tensor: Embedding tensor of shape [batch_size, n, embedding_dim].
        """
        # Use the lookup table to find indices for all pairs in the batch
        indices = self.lookup_table[pair_tensor[..., 0], pair_tensor[..., 1]]
        
        # Pass the indices to the embedding layer
        embedded = self.embedding(indices)  # This will be of shape [batch_size, n, embedding_dim]
        return embedded



class ExcitaionEmbeddingIonRoPE(nn.Module):
    def __init__(self, output_size, angle_scale=0.05):
        super(ExcitaionEmbeddingIonRoPE, self).__init__()
        self.output_size = output_size
        self.angle_scale = angle_scale

        self.excitation_embedding = PairEmbedding(embedding_dim=output_size)
        self.ion_rope = RotaryEmbedding2Angle4D(dim=output_size)
s
    def forward(self, input_dict):
        excitations = input_dict["excitations"]
        n_electrons = input_dict["n_electrons"]
        n_protons = input_dict["n_protons"]


        # Embed the excitations
        embedded_excitations = self.excitation_embedding(excitations)

        # rotary positional encoding for the ions
        embedding = self.ion_rope(embedded_excitations, n_electrons * self.angle_scale, n_protons * self.angle_scale)

        return embedding

## Transformer_encoder_model

In [89]:
# %%writefile /Users/moustholmes/Projects/METAL-AI/src/models/components/Transformer_encoder_model.py
import torch
from torch import nn
import torch.nn.functional as F
from typing import Callable, Optional


class simple_transformer_encoder_model(nn.Module):
    def __init__(
        self,
        csf_encoder,
        input_size: int = 4, # number of allowed excitations + 2
        d_model: int = 64,
        nhead: int = 8,
        dim_forward: int = 64,
        num_layers: int = 6,
        output_size: int =1,
        dropout: float = 0.5,
        output_activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    ):
        super(simple_transformer_encoder_model, self).__init__()
        self.csf_encoder = csf_encoder
        encoder_layers = nn.TransformerEncoderLayer(
            self.csf_encoder.output_size, nhead, dim_forward, dropout, batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.decoder = nn.Linear(self.csf_encoder.output_size, output_size)
        self.output_activation = output_activation

    def forward(self, input_dict):
        encoded_csfs = self.csf_encoder(input_dict)
        output = self.transformer_encoder(encoded_csfs)#,src_key_padding_mask = input_dict['pad_mask']  src_key_padding_mask = ~input_dict['mask']
        # output = output[:, 0, :]
        output = self.decoder(output)

        if self.output_activation:
            output = self.output_activation(output)

        return output.squeeze(-1)


        # csfs = input_dict["excitations"]
        # encoded_csfs = self.csf_encoder(
        #     csfs, input_dict["n_electrons"], input_dict["n_protons"]
        # )
        # output = self.transformer_encoder(
        #     encoded_csfs, src_key_padding_mask=input_dict["pad_mask"]
        # )  # src_key_padding_mask = ~input_dict['mask']
        # # output = output[:, 0, :]
        # output = self.decoder(output)
        # return F.sigmoid(output).squeeze(-1)


Overwriting /Users/moustholmes/Projects/METAL-AI/src/models/components/Transformer_encoder_model.py


### Loss Functions

In [86]:
# %%writefile /Users/moustholmes/Projects/METAL-AI/src/models/components/loss_function_wrappers.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class LossFuncMaskWrapper(nn.Module):
    def __init__(self, loss_fn: nn.Module):
        """
        Args:
            loss_fn (nn.Module): A PyTorch loss function like CrossEntropyLoss, MSELoss, etc.
        """
        super(LossFuncMaskWrapper, self).__init__()
        self.loss_fn = loss_fn

    def forward(self, input, target, mask= None):
        """
        Args:
            input (Tensor): Predicted values of size (N, *) where * means any number of additional dimensions.
            mask (Tensor): Mask of size (N) to filter out the padded values.
            target (Tensor): True values of size (N, *).

        Returns:
            Tensor: Computed loss after applying the mask.
        """
        # Apply the mask to input and target
        if mask is None:
            return self.loss_fn(input, target)
        else:
            return self.loss_fn(input[mask], target[mask])
        
        # masked_input = input[mask].float()
        # masked_target = target[mask].float()

        # # Compute the loss using the provided loss function
        # loss = self.loss_fn(masked_input, masked_target)
        
        # return loss


class GaussianNLLLossWrapper(nn.Module):
    def __init__(self, loss_fn: nn.Module):
        """
        Args:
            loss_fn (nn.Module): A PyTorch loss function like GaussianNLLLoss which takes two inputs and returns the loss.
        """
        super(GaussianNLLLossWrapper, self).__init__()

        self.loss_fn = loss_fn

    def forward(self, input, target,  mask = None):
        """
        Args:
            input (Tensor): Predicted means of size (N, 2) first column is the mean and the second column is the variance.
            mask (Tensor): Mask of size (N) to filter out the padded values.
            target (Tensor): True values of size (N)

        Returns:
            Tensor: Computed Gaussian negative log likelihood loss.
        """
        # Ensure variances are non-negative by adding eps (if not handled elsewhere)
        mean = input[:,:,0]#[mask]
        var = input[:,:,1]#[mask]#.clamp(min=self.eps)
        target = target.float() #[mask]
        
        # Compute the loss using GaussianNLLLoss
        loss = self.loss_fn(mean, target, var)
        
        return loss

class DiscretizedNLLLoss(nn.Module):
    def __init__(self, loss_fn: nn.Module, num_bins: int, min_value: float, max_value: float):
        """
        :param num_bins: Number of bins to discretize the continuous range
        :param min_value: Minimum value of the range to discretize
        :param max_value: Maximum value of the range to discretize
        """
        super(DiscretizedNLLLoss, self).__init__()
        self.loss_fn = loss_fn
        self.num_bins = num_bins
        self.min_value = min_value
        self.max_value = max_value
        
        # Calculate the width of each bin
        self.bin_width = (max_value - min_value) / num_bins
        
        # Bins are represented by their center values
        self.bin_centers = torch.linspace(min_value + self.bin_width / 2,
                                          max_value - self.bin_width / 2, num_bins)

    def forward(self, logits, targets):
        """
        :param predictions: The continuous predictions from the model [batch_size, 1]
        :param targets: The continuous target values [batch_size, 1]
        """

        # Convert continuous target values into bin indices
        target_bin_indices = ((targets - self.min_value) / self.bin_width).long().clamp(0, self.num_bins - 1)
        
        # Apply CrossEntropyLoss (which includes log-softmax)
        loss = self.loss_fn(logits, target_bin_indices.squeeze(-1))

        return loss




Overwriting /Users/moustholmes/Projects/METAL-AI/src/models/components/loss_function_wrappers.py


### Inspect model


### simple encoder

In [29]:
import torch
from torch.utils.data import DataLoader

from src.models.components.CSF_encoders import simple_CSF_encoder
from src.models.components.Transformer_encoder_model import simple_transformer_encoder_model
from src.models.components.loss_function_wrappers import LossFuncMaskWrapper
# from src.data.components.datasets import GroupedDictDataset
from src.data.components.samplers import GroupedBatchSampler

import pickle
with open(cfg_train.data.data_dir, 'rb') as file:
    data_dict = pickle.load(file)

batch_size = 8 #cfg_train.data.batch_size

dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
sampler = GroupedBatchSampler(dataset, batch_size=batch_size, shuffle=False)
dataloader = DataLoader(dataset, batch_sampler=sampler)


d_model = 8
csf_encoder= simple_CSF_encoder(d_model)
model = simple_transformer_encoder_model(csf_encoder, d_model=d_model,output_activation=torch.nn.Sigmoid())

loss_fn=LossFuncMaskWrapper(torch.nn.BCELoss(reduction='sum'))  #CrossEntropyLoss

i = 0
for batch in dataloader:
    preds= model(batch)
    targets = batch['converged']
    # print()

    # print(preds.shape)
    # print(batch['converged'].shape)

    mask_converged = batch['converged_mask']
    # mask_effect = batch['effect_mask']
    print('filling numbers')
    print(batch['filling_numbers'])
    print('targets')
    print(batch['converged'])
    print(batch['effect'])
    print('masks')
    print(mask_converged)
    # print(mask_effect)
    print('masked targets')
    print(batch['converged'][mask_converged])
    # print(batch['effect'][mask_effect])
    
    loss = loss_fn(preds, targets, mask_converged)
    print('loss')
    print(loss)

    print()
    
    i+=1
    if i == 10:
        break

  if np.isnan(data_point["effect"]).all():


filling numbers
tensor([[[2, 2, 6, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 

In [69]:
import torch
from torch.utils.data import DataLoader

from src.models.components.CSF_encoders import simple_CSF_encoder
# from src.models.components.Transformer_encoder_model import simple_transformer_encoder_model
from src.models.components.loss_function_wrappers import LossFuncMaskWrapper
# from src.data.components.datasets import GroupedDictDataset
from src.data.components.samplers import GroupedBatchSampler

import pickle
with open(cfg_train.data.data_dir, 'rb') as file:
    data_dict = pickle.load(file)

batch_size = 8 #cfg_train.data.batch_size

dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
sampler = GroupedBatchSampler(dataset, batch_size=batch_size, shuffle=False)
dataloader = DataLoader(dataset, batch_sampler=sampler)


d_model = 8
csf_encoder= ExcitaionEmbeddingIonRoPE(d_model)
model = simple_transformer_encoder_model(csf_encoder, d_model=d_model,output_activation=torch.nn.Sigmoid())

loss_fn=LossFuncMaskWrapper(torch.nn.BCELoss(reduction='sum'))  #CrossEntropyLoss

i = 0
for batch in dataloader:
    preds= model(batch)
    targets = batch['converged']
    # print()

    # print(preds.shape)
    # print(batch['converged'].shape)

    mask_converged = batch['converged_mask']
    # mask_effect = batch['effect_mask']
    print('filling numbers')
    print(batch['filling_numbers'])
    print('targets')
    print(batch['converged'])
    print(batch['effect'])
    print('masks')
    print(mask_converged)
    # print(mask_effect)
    print('masked targets')
    print(batch['converged'][mask_converged])
    # print(batch['effect'][mask_effect])
    
    loss = loss_fn(preds, targets, mask_converged)
    print('loss')
    print(loss)

    print()
    
    i+=1
    if i == 10:
        break

  if np.isnan(data_point["effect"]).all():


filling numbers
tensor([[[2, 2, 6, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 

### rotary embedding

'/Users/moustholmes/Projects/METAL-AI/data/data_dict_Li-Rh_new_effect.pkl'

In [101]:
import torch
from torch.utils.data import DataLoader

from src.models.components.CSF_encoders import simple_CSF_encoder
from src.models.components.Transformer_encoder_model import simple_transformer_encoder_model
from src.models.components.loss_function_wrappers import LossFuncMaskWrapper
# from src.data.components.datasets import GroupedDictDataset
from src.data.components.samplers import GroupedBatchSampler

import pickle
with open(cfg_train.data.data_dir, 'rb') as file:
    data_dict = pickle.load(file)

batch_size = 8 #cfg_train.data.batch_size

dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
sampler = GroupedBatchSampler(dataset, batch_size=batch_size, shuffle=False)
dataloader = DataLoader(dataset, batch_sampler=sampler)


d_model = 8
csf_encoder= simple_CSF_encoder(d_model)
model = simple_transformer_encoder_model(csf_encoder, d_model=d_model,output_activation=torch.nn.Sigmoid())

loss_fn=LossFuncMaskWrapper(torch.nn.BCELoss(reduction='sum'))  #CrossEntropyLoss

i = 0
for batch in dataloader:
    preds= model(batch)
    targets = batch['converged']
    # print()

    # print(preds.shape)
    # print(batch['converged'].shape)

    mask_converged = batch['converged_mask']
    # mask_effect = batch['effect_mask']
    print('filling numbers')
    print(batch['filling_numbers'])
    print('targets')
    print(batch['converged'])
    print(batch['effect'])
    print('masks')
    print(mask_converged)
    # print(mask_effect)
    print('masked targets')
    print(batch['converged'][mask_converged])
    print(batch['effect'])
    
    loss = loss_fn(preds, targets, mask_converged)
    print('loss')
    print(loss)

    print()
    
    i+=1
    if i == 10:
        break

  if np.isnan(data_point["effect"]).all():


filling numbers
tensor([[[2, 2, 6, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[2, 2, 6, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 2, 6, 2, 1, 0, 0, 0, 0, 0, 

In [62]:
import torch
from torch import nn
import matplotlib.pyplot as plt

class RotaryEmbedding2Angle4D(nn.Module):
    def __init__(self, dim):
        """
        Initialize the RepeatedRotaryEmbedding module.

        Args:
            dim (int): Dimension of the embeddings, must be divisible by 4.
        """
        super().__init__()
        assert dim % 4 == 0, "Embedding dimension must be divisible by 4"
        self.dim = dim

    def forward(self, embeddings, theta, phi):
        """
        Applies the repeated rotation on the embeddings using angles theta and phi.

        Args:
            embeddings (torch.Tensor): Input embeddings of shape [batch_size, n, dim].
            theta (torch.Tensor): Rotation angles for xy planes of shape [batch_size].
            phi (torch.Tensor): Rotation angles for zw planes of shape [batch_size].

        Returns:
            torch.Tensor: Rotated embeddings of the same shape as input.
        """
        batch_size, n, dim = embeddings.shape
        assert dim == self.dim, f"Input embedding dimension {dim} does not match initialized dimension {self.dim}"

        num_blocks = dim // 4  # Number of 4x4 blocks per embedding

        # Compute rotation components
        cos_theta = torch.cos(theta).unsqueeze(1).repeat(1, num_blocks)  # Shape: [batch_size, num_blocks]
        sin_theta = torch.sin(theta).unsqueeze(1).repeat(1, num_blocks)  # Shape: [batch_size, num_blocks]
        cos_phi = torch.cos(phi).unsqueeze(1).repeat(1, num_blocks)      # Shape: [batch_size, num_blocks]
        sin_phi = torch.sin(phi).unsqueeze(1).repeat(1, num_blocks)      # Shape: [batch_size, num_blocks]

        # Create block-diagonal rotation matrix for the entire batch
        R = torch.zeros(batch_size, dim, dim, device=embeddings.device)

        for i in range(num_blocks):
            # Define indices for the i-th block
            start = i * 4
            end = start + 4

            # Fill the block-diagonal matrix for all batches
            R[:, start:start+2, start:start+2] = torch.stack([
                torch.stack([cos_theta[:, i], -sin_theta[:, i]], dim=-1),
                torch.stack([sin_theta[:, i], cos_theta[:, i]], dim=-1)
            ], dim=1)

            R[:, start+2:end, start+2:end] = torch.stack([
                torch.stack([cos_phi[:, i], -sin_phi[:, i]], dim=-1),
                torch.stack([sin_phi[:, i], cos_phi[:, i]], dim=-1)
            ], dim=1)

        # Apply the rotation to embeddings
        rotated_embeddings = torch.einsum('bnd,bdm->bnm', embeddings, R)

        return rotated_embeddings



class PairEmbedding(nn.Module):
    def __init__(self, embedding_dim, pair_list=[(0, 0), (1, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 1), (0, 3), (2, 2), (1, 3), (0, 4), (0, 5), (3, 2), (1, 4), (2, 3), (3, 3), (2, 4), (1, 5), (3, 4), (4, 3), (2, 5), (3, 5), (4, 4), (5, 4), (4, 5), (5, 5)]):
        super(PairEmbedding, self).__init__()
        
        # Initialize the embedding layer
        self.embedding = nn.Embedding(len(pair_list), embedding_dim)
        
        # Create a lookup table
        self.max_val = max(max(pair) for pair in pair_list)  # Find the maximum integer in the pairs
        self.lookup_table = torch.full((self.max_val + 1, self.max_val + 1), -1, dtype=torch.long)
        
        # Fill the lookup table with indices
        for idx, pair in enumerate(pair_list):
            self.lookup_table[pair[0], pair[1]] = idx

    def forward(self, pair_tensor):
        # pair_tensor is of shape [batch_size, n, 2]
        
        # Use the lookup table to find indices for all pairs in the batch
        indices = self.lookup_table[pair_tensor[..., 0], pair_tensor[..., 1]]
        # indices will have shape [batch_size, n]
        # print(indices)
        
        # Pass the indices to the embedding layer
        embedded = self.embedding(indices)  # This will be of shape [batch_size, n, embedding_dim]
        return embedded

# # Define the pairs and instantiate the module
# pairs = [
#     (0, 0), (1, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 1),
#     (0, 3), (2, 2), (1, 3), (0, 4), (0, 5), (3, 2), (1, 4),
#     (2, 3), (3, 3), (2, 4), (1, 5), (3, 4), (4, 3), (2, 5),
#     (3, 5), (4, 4), (5, 4), (4, 5), (5, 5)
# ]
# embedding_dim = 8  # Example embedding dimension
# pair_embedding = PairEmbedding(embedding_dim, pairs)
# rotate = RotaryEmbedding2Angle4D(embedding_dim)

# # Example usage with a batch of pairs
# batch_pair_tensor = torch.tensor([[[1, 3], [0, 1], [0, 1], [5,5]], [[3, 4], [2, 5], [0, 1],[5,5]], [[3, 4], [2, 5], [0, 1],[5,5]]])  # Shape [2, 3, 2]
# print(batch_pair_tensor.shape)
# embedded = pair_embedding(batch_pair_tensor)
# print(embedded)  # Should output embeddings with shape [batch_size, n, embedding_dim]
# rotated = rotate(embedded, torch.tensor(np.pi), torch.tensor(np.pi))
# print(rotated)  # Should output embeddings with shape [batch_size, n, embedding_dim]



In [63]:
embedding_dim = 8  # Example embedding dimension
pairs = [
    (0, 0), (1, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 1),
    (0, 3), (2, 2), (1, 3), (0, 4), (0, 5), (3, 2), (1, 4),
    (2, 3), (3, 3), (2, 4), (1, 5), (3, 4), (4, 3), (2, 5),
    (3, 5), (4, 4), (5, 4), (4, 5), (5, 5)
]

pair_embedding = PairEmbedding(embedding_dim, pairs)
rotate = RotaryEmbedding2Angle4D(embedding_dim)

# Example batch of embeddings and angles
batch_pair_tensor = torch.tensor([
    [[1, 3], [0, 1], [0, 1], [5, 5]],
    [[3, 4], [2, 5], [0, 1], [5, 5]],
    [[3, 4], [2, 5], [0, 1], [5, 5]]
])  # Shape [batch_size, n, 2]
embedded = pair_embedding(batch_pair_tensor)

# Angles for each batch element
theta = torch.tensor([np.pi*2, np.pi*2, np.pi*2])  # Shape [batch_size]
phi = torch.tensor([np.pi*2, np.pi*2, np.pi*2])    # Shape [batch_size]

rotated = rotate(embedded, theta, phi)
print(rotated.shape)  # Should output embeddings with shape [batch_size, n, embedding_dim]

print(embedded.shape)

torch.Size([3, 4, 8])
torch.Size([3, 4, 8])


In [64]:
theta_2pi = torch.full((batch_size,), 2 * torch.pi)
phi_2pi = torch.full((batch_size,), 2 * torch.pi)
theta_pi = torch.full((batch_size,), torch.pi)
phi_pi = torch.full((batch_size,), torch.pi)

rotated_2pi = rotate(embeddings, theta_2pi, phi_2pi)
rotated_pi = rotate(embeddings, theta_pi, phi_pi)

identity_test = torch.allclose(embeddings, rotated_2pi, atol=1e-5)
negation_test = torch.allclose(embeddings * -1, rotated_pi, atol=1e-5)

print("Identity Test (2π):", identity_test)
print("Negation Test (π):", negation_test)


Identity Test (2π): True
Negation Test (π): True


In [65]:
# Simplified test case
embedding_dim = 8  # Must be divisible by 4
batch_size, n = 2, 4  # Example batch size and sequence length

# Example embeddings
embeddings = torch.randn(batch_size, n, embedding_dim)

# Angles for testing
theta_2pi = torch.full((batch_size,), 2 * torch.pi)
phi_2pi = torch.full((batch_size,), 2 * torch.pi)
theta_pi = torch.full((batch_size,), torch.pi)
phi_pi = torch.full((batch_size,), torch.pi)

# Initialize the rotary embedding layer
rotate = RotaryEmbedding2Angle4D(embedding_dim)

# Test rotations
rotated_2pi = rotate(embeddings, theta_2pi, phi_2pi)  # Rotate with 2*pi
rotated_pi = rotate(embeddings, theta_pi, phi_pi)     # Rotate with pi

# Compare results
identity_test = torch.allclose(embeddings, rotated_2pi, atol=1e-5)  # Check if rotation by 2*pi is identity
negation_test = torch.allclose(embeddings * -1, rotated_pi, atol=1e-5)  # Check if rotation by pi negates

identity_test, negation_test


(True, True)

In [104]:
batch_size = 8 #cfg_train.data.batch_size

dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
sampler = GroupedBatchSampler(dataset, batch_size=batch_size, shuffle=False)
dataloader = DataLoader(dataset, batch_sampler=sampler)


d_model = 8
csf_encoder= simple_CSF_encoder(d_model)
# csf_encoder= ExcitaionEmbeddingIonRoPE(d_model)
model = simple_transformer_encoder_model(
    csf_encoder, 
    d_model=d_model,
    output_activation=torch.nn.ReLU(),
    output_size=2
    )

loss_fn=GaussianNLLLossWrapper(nn.GaussianNLLLoss(reduction='sum'))

i = 0
for batch in dataloader:
    preds= model(batch)
    # print(preds.shape)
    target = batch['converged']
    # print()

    # print(preds.shape)
    # print(batch['converged'].shape)

    mask_converged = batch['converged_mask']
    # mask_effect = batch['effect_mask']
    print('targets')
    print(batch['converged'].shape)
    print(batch['effect'].shape)
    print('preds')
    print(preds.shape)
    print('masks')
    print(mask_converged.shape)
    # print(mask_effect.shape)
    # print('masked targets')
    # print(batch['converged'][mask_converged])
    # print(batch['effect'][mask_effect])
    loss = loss_fn(preds, target, mask_converged)
    print('loss')
    print(loss)

    print()
    
    # i+=1
    # if i == 10:
    #     break

  if np.isnan(data_point["effect"]).all():


targets
torch.Size([8, 2])
torch.Size([8, 2])
preds
torch.Size([8, 2, 2])
masks
torch.Size([8, 2])
loss
tensor(1999982.7500, grad_fn=<SumBackward0>)

targets
torch.Size([8, 2])
torch.Size([8, 2])
preds
torch.Size([8, 2, 2])
masks
torch.Size([8, 2])
loss
tensor(2609322.5000, grad_fn=<SumBackward0>)

targets
torch.Size([8, 2])
torch.Size([8, 2])
preds
torch.Size([8, 2, 2])
masks
torch.Size([8, 2])
loss
tensor(3328031.2500, grad_fn=<SumBackward0>)

targets
torch.Size([8, 2])
torch.Size([8, 2])
preds
torch.Size([8, 2, 2])
masks
torch.Size([8, 2])
loss
tensor(1596908.5000, grad_fn=<SumBackward0>)

targets
torch.Size([8, 2])
torch.Size([8, 2])
preds
torch.Size([8, 2, 2])
masks
torch.Size([8, 2])
loss
tensor(1499991.1250, grad_fn=<SumBackward0>)

targets
torch.Size([8, 2])
torch.Size([8, 2])
preds
torch.Size([8, 2, 2])
masks
torch.Size([8, 2])
loss
tensor(1113145.7500, grad_fn=<SumBackward0>)

targets
torch.Size([8, 2])
torch.Size([8, 2])
preds
torch.Size([8, 2, 2])
masks
torch.Size([8, 2])
l

### model and data to GPU

In [82]:
print(f'is mps available ? {torch.backends.mps.is_available()}')
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

batch_size = 8 #cfg_train.data.batch_size

dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
sampler = GroupedBatchSampler(dataset, batch_size=batch_size, shuffle=False)
dataloader = DataLoader(dataset, batch_sampler=sampler)


d_model = 8
csf_encoder= simple_CSF_encoder(d_model)
# csf_encoder= ExcitaionEmbeddingIonRoPE(d_model)
model = simple_transformer_encoder_model(csf_encoder, d_model=d_model,output_activation=F.sigmoid)


loss_fn=LossFuncMaskWrapper(torch.nn.CrossEntropyLoss(reduction='sum')).to(device)
model.to(device)

i = 0
for batch in dataloader:
    batch = {key: value.to(device) for key, value in batch.items()}

    preds= model(batch)
    targets = batch['converged']
    # print()

    mask_converged = batch['converged_mask']
    # mask_effect = batch['effect_mask']
    print('targets')
    print(batch['converged'])
    print(batch['effect'])
    print('masks')
    print(mask_converged)
    # print(mask_effect)
    print('masked targets')
    print(batch['converged'][mask_converged])
    # print(batch['effect'][mask_effect])
    loss = loss_fn(preds, targets, mask_converged)
    print('loss')
    print(loss)

    print()
    
    i+=1
    if i == 10:
        break

is mps available ? True


  if np.isnan(data_point["effect"]).all():


targets
tensor([[True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True]], device='mps:0')
tensor([[1.8431e+01, -0.0000e+00],
        [1.3922e+00, 3.5701e-04],
        [9.4553e+00, 1.5222e-10],
        [2.4841e+00, 2.3364e-06],
        [1.1950e+01, 4.8708e-13],
        [4.4626e+00, 1.4969e-05],
        [2.1566e+00, 1.0559e-05],
        [8.2161e-01, 4.9957e-03]], device='mps:0')
masks
tensor([[ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False],
        [ True, False]], device='mps:0')
masked targets
tensor([True, True, True, True, True, True, True, True], device='mps:0')
loss
tensor(16.7312, device='mps:0', grad_fn=<NegBackward0>)

targets
tensor([[True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
    

## pure pytorch train model 

### BCE converged

### L1 effect


In [None]:
from torch.optim import Adam

print(f'is mps available? {torch.backends.mps.is_available()}')
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

batch_size = cfg_train.data.batch_size

dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
sampler = GroupedBatchSampler(dataset, batch_size=batch_size, shuffle=False)
dataloader = DataLoader(dataset, batch_sampler=sampler)


d_model = cfg_train.model.model.d_model
csf_encoder= simple_CSF_encoder(d_model)
# csf_encoder= ExcitaionEmbeddingIonRoPE(d_model)
# model = simple_transformer_encoder_model(csf_encoder, d_model=d_model, output_activation=F.sigmoid,output_size=1)
model = simple_transformer_encoder_model(csf_encoder, d_model=d_model, output_activation=F.relu,output_size=2)

# loss_fn=LossFuncMaskWrapper(torch.nn.BCELoss(reduction='sum')).to(device)  #CrossEntropyLoss
loss_fn=GaussianNLLLossWrapper(nn.GaussianNLLLoss(reduction='sum'))
model.to(device)


optimizer = Adam(model.parameters(), lr=0.001)  # Initialize the optimizer

num_epochs = 10  # Number of epochs to train for

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    for batch in dataloader:
        # Assuming your DataLoader outputs a batch as a dictionary with 'CSFs', 'n_electrons', 'n_protons', and 'effect'
        batch = {key: value.to(device) for key, value in batch.items()}

        optimizer.zero_grad()  # Zero the parameter gradients

        # Forward pass
        preds = model(batch)
        print(preds.shape)
        # targets = batch['converged']
        targets = batch['effect']
        # print(targets.shape)

        mask_converged = batch['converged_mask']
        # mask_effect = batch['effect_mask']
        # print(mask_effect.shape)

        # Compute loss
        loss = loss_fn(preds, targets )#mask_converged
        # loss = loss_fn(preds, targets, mask_converged )#mask_converged
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * preds.size(0)  # Multiply by batch size for total loss

    epoch_loss = running_loss / len(dataloader.dataset)  # Average loss per sample
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

print('Finished Training')

In [None]:
from torch.optim import Adam

print(f'is mps available? {torch.backends.mps.is_available()}')
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

batch_size = cfg_train.data.batch_size

dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
sampler = GroupedBatchSampler(dataset, batch_size=batch_size, shuffle=False)
dataloader = DataLoader(dataset, batch_sampler=sampler)


d_model = cfg_train.model.model.d_model
csf_encoder= simple_CSF_encoder(d_model)
# csf_encoder= ExcitaionEmbeddingIonRoPE(d_model)
# model = simple_transformer_encoder_model(csf_encoder, d_model=d_model, output_activation=F.sigmoid,output_size=1)
model = simple_transformer_encoder_model(csf_encoder, d_model=d_model, output_activation=F.relu,output_size=2)

# loss_fn=LossFuncMaskWrapper(torch.nn.BCELoss(reduction='sum')).to(device)  #CrossEntropyLoss
loss_fn=GaussianNLLLossWrapper(nn.GaussianNLLLoss(reduction='sum'))
model.to(device)


optimizer = Adam(model.parameters(), lr=0.001)  # Initialize the optimizer

num_epochs = 10  # Number of epochs to train for

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    for batch in dataloader:
        # Assuming your DataLoader outputs a batch as a dictionary with 'CSFs', 'n_electrons', 'n_protons', and 'effect'
        batch = {key: value.to(device) for key, value in batch.items()}

        optimizer.zero_grad()  # Zero the parameter gradients

        # Forward pass
        preds = model(batch)
        # targets = batch['converged']
        targets = batch['effect']
        # print(targets.shape)

        mask_converged = batch['converged_mask']
        # mask_effect = batch['effect_mask']
        # print(mask_effect.shape)

        # Compute loss
        loss = loss_fn(preds, targets )#mask_converged
        # loss = loss_fn(preds, targets, mask_converged )#mask_converged
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * preds.size(0)  # Multiply by batch size for total loss

    epoch_loss = running_loss / len(dataloader.dataset)  # Average loss per sample
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

print('Finished Training')

### Gaussian NLL effect

In [105]:
from torch.optim import Adam

print(f'is mps available? {torch.backends.mps.is_available()}')
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

batch_size = cfg_train.data.batch_size

dataset = GroupedDictDataset( data_dict, remove_nan_effect=True)
sampler = GroupedBatchSampler(dataset, batch_size=batch_size, shuffle=False)
dataloader = DataLoader(dataset, batch_sampler=sampler)


d_model = cfg_train.model.model.d_model
csf_encoder= simple_CSF_encoder(d_model)
# csf_encoder= ExcitaionEmbeddingIonRoPE(d_model)
# model = simple_transformer_encoder_model(csf_encoder, d_model=d_model, output_activation=F.sigmoid,output_size=1)
model = simple_transformer_encoder_model(csf_encoder, d_model=d_model, output_activation=F.relu,output_size=2)

# loss_fn=LossFuncMaskWrapper(torch.nn.BCELoss(reduction='sum')).to(device)  #CrossEntropyLoss
loss_fn=GaussianNLLLossWrapper(nn.GaussianNLLLoss(reduction='sum'))
model.to(device)


optimizer = Adam(model.parameters(), lr=0.001)  # Initialize the optimizer

num_epochs = 10  # Number of epochs to train for

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    for batch in dataloader:
        # Assuming your DataLoader outputs a batch as a dictionary with 'CSFs', 'n_electrons', 'n_protons', and 'effect'
        batch = {key: value.to(device) for key, value in batch.items()}

        optimizer.zero_grad()  # Zero the parameter gradients

        # Forward pass
        preds = model(batch)
        print(preds.shape)
        # targets = batch['converged']
        targets = batch['effect']
        # print(targets.shape)

        mask_converged = batch['converged_mask']
        # mask_effect = batch['effect_mask']
        # print(mask_effect.shape)

        # Compute loss
        loss = loss_fn(preds, targets )#mask_converged
        # loss = loss_fn(preds, targets, mask_converged )#mask_converged
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * preds.size(0)  # Multiply by batch size for total loss

    epoch_loss = running_loss / len(dataloader.dataset)  # Average loss per sample
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

print('Finished Training')

is mps available? True


  if np.isnan(data_point["effect"]).all():


torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 2, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 3, 2])
torch.Size([128, 4, 2])
torch.Size([128, 4, 2])
torch.Size([128, 4, 2])
torch.Size([128, 4, 2])
torch.Size([128, 4, 2])
torch.Size([128, 4, 2])
torch.Size([128,

# pytorch lightning 

## datamodule

In [20]:
# %%writefile /Users/moustholmes/Projects/METAL-AI/src/data/dict_datamodule.py
from typing import Any, Dict, Optional, Tuple, Callable

import torch
from lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms

import pickle

from src.data.components.datasets import GroupedDictDataset
from src.data.components.samplers import GroupedBatchSampler
from src.data.components.collate_fns import dict_collate_fn
from src.data.components.data_utils import load_data_dict

class DictDataModule(LightningDataModule):
    """`LightningDataModule` for the MNIST dataset.

    write about structure of the data, download, split, transform, etc...

    A `LightningDataModule` implements 7 key methods:

    ```python
        def prepare_data(self):
        # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
        # Download data, pre-process, split, save to disk, etc...

        def setup(self, stage):
        # Things to do on every process in DDP.
        # Load data, set variables, etc...

        def train_dataloader(self):
        # return train dataloader

        def val_dataloader(self):
        # return validation dataloader

        def test_dataloader(self):
        # return test dataloader

        def predict_dataloader(self):
        # return predict dataloader

        def teardown(self, stage):
        # Called on every process in DDP.
        # Clean up after fit or test.
    ```

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://lightning.ai/docs/pytorch/latest/data/datamodule.html
    """

    def __init__(
        self,
        data_dir: str = "data/",
        batch_size: int = 64,
        num_workers: int = 0,
        train_val_splitter: Optional[Callable] = None,
        shuffle: bool = True,
        pin_memory: bool = False,
        remove_nan_effect: bool = False,
        ion_include = None,
        persistent_workers: bool = False,
    ) -> None:
        """Initialize a `HDF5DataModule`.

        :param data_dir: The data directory. Defaults to `"data/"`.
        :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
        :param batch_size: The batch size. Defaults to `64`.
        :param num_workers: The number of workers. Defaults to `0`.
        :param pin_memory: Whether to pin memory. Defaults to `False`.
        """
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.data_dict = None

        self.train_val_splitter = self.hparams.train_val_splitter

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None
        self.data_unseen: Optional[Dataset] = None
        self.data_large: Optional[Dataset] = None

        self.batch_size_per_device = batch_size

    def prepare_data(self) -> None:
        """Download data if needed. Lightning ensures that `self.prepare_data()` is called only
        within a single process on CPU, so you can safely add your downloading logic within. In
        case of multi-node training, the execution of this hook depends upon
        `self.prepare_data_per_node()`.

        Do not use it to assign state (self.x = y).
        """
        # Load the data dictionary from the data directory
        self.data_dict = load_data_dict( self.hparams.data_dir)
        self.data_dict_train, self.data_dict_val, self.data_dict_unseen, self.data_dict_large = self.train_val_splitter( self.data_dict)

    def setup(self, stage: Optional[str] = None) -> None:
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.

        This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
        `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
        `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
        `self.setup()` once the data is prepared and available for use.

        :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
        """
        # Divide batch size by the number of devices.
        if self.trainer is not None:
            if self.hparams.batch_size % self.trainer.world_size != 0:
                raise RuntimeError(
                    f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
                )
            self.batch_size_per_device = (
                self.hparams.batch_size // self.trainer.world_size
            )

        
        # ion_include = [tuple(item) for item in self.hparams.ion_include] if self.hparams.ion_include else None
        
        # load and split datasets only if not loaded already
        if not self.data_train:
            self.data_train = GroupedDictDataset(
                data_dict=self.data_dict_train, 
                remove_nan_effect=self.hparams.remove_nan_effect,
                # ion_include=ion_include,
                )
        if not self.data_val:
            self.data_val = GroupedDictDataset(
                data_dict=self.data_dict_val, 
                remove_nan_effect=self.hparams.remove_nan_effect,
                # ion_include=ion_include,
                )
        if not self.data_test:
            self.data_test = GroupedDictDataset(
                data_dict=self.data_dict, 
                remove_nan_effect=self.hparams.remove_nan_effect,
                # ion_include=ion_include,
                )
        if not self.data_unseen:
            self.data_unseen = GroupedDictDataset(
                data_dict=self.data_dict_unseen, 
                remove_nan_effect=self.hparams.remove_nan_effect,
                # ion_include=ion_include,
                )
        if not self.data_large:
            self.data_large = GroupedDictDataset(
                data_dict=self.data_dict_large, 
                remove_nan_effect=self.hparams.remove_nan_effect,
                # ion_include=ion_include,
                )

        print(len(self.data_train),len(self.data_val),len(self.data_test),len(self.data_unseen),len(self.data_large))

    def train_dataloader(self) -> DataLoader[Any]:
        """Create and return the train dataloader.

        :return: The train dataloader.
        """
        return DataLoader(
            dataset=self.data_train,
            batch_sampler=GroupedBatchSampler(self.data_train, batch_size =self.batch_size_per_device, shuffle=self.hparams.shuffle),
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            persistent_workers=self.hparams.persistent_workers,
        )

    def val_dataloader(self) -> DataLoader[Any]:
        """Create and return the validation dataloader.

        :return: The validation dataloader.
        """
        return DataLoader(
            dataset=self.data_val,
            batch_sampler=GroupedBatchSampler(self.data_val, batch_size =self.batch_size_per_device, shuffle=self.hparams.shuffle),
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            persistent_workers=self.hparams.persistent_workers,
        )

    def test_dataloader(self) -> DataLoader[Any]:
        """Create and return the test dataloader.

        :return: The test dataloader.
        """
        return [
            DataLoader(
            dataset=self.data_val,
            batch_sampler=GroupedBatchSampler(self.data_val, batch_size =self.batch_size_per_device, shuffle=self.hparams.shuffle), #data_test,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            persistent_workers=self.hparams.persistent_workers,
            ),
            DataLoader(
            dataset=self.data_unseen,
            batch_sampler=GroupedBatchSampler(self.data_unseen, batch_size =self.batch_size_per_device, shuffle=self.hparams.shuffle),
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            persistent_workers=self.hparams.persistent_workers,
            ),
            DataLoader(
            dataset=self.data_large,
            batch_sampler=GroupedBatchSampler(self.data_large, batch_size = 4, shuffle=self.hparams.shuffle), #self.batch_size_per_device
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            persistent_workers=self.hparams.persistent_workers,
            ),
        ]

    def teardown(self, stage: Optional[str] = None) -> None:
        """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
        `trainer.test()`, and `trainer.predict()`.

        :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
            Defaults to ``None``.
        """
        pass

    def state_dict(self) -> Dict[Any, Any]:
        """Called when saving a checkpoint. Implement to generate and save the datamodule state.

        :return: A dictionary containing the datamodule state that you want to save.
        """
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        """Called when loading a checkpoint. Implement to reload datamodule state given datamodule
        `state_dict()`.

        :param state_dict: The datamodule state returned by `self.state_dict()`.
        """
        pass


if __name__ == "__main__":
    _ = DictDataModule()


In [3]:
%load /Users/moustholmes/Projects/METAL-AI/configs/data/dict_dataset.yaml
_target_: src.data.dict_datamodule.DictDataModule
data_dir: ${paths.data_dir}
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
train_val_splitter:
    _target_: src.data.components.data_utils.TopASFSizeTrainValSplitter
    validation_percentage : 0.15
num_workers: 0
shuffle: False
remove_nan_effect: False
pin_memory: True
persistent_workers: False

SyntaxError: invalid syntax (235822373.py, line 3)

In [21]:

from hydra.utils import instantiate

cfg_train = compose(config_name='train')
datamodule = instantiate(cfg_train.data)
print(datamodule)


<src.data.dict_datamodule.DictDataModule object at 0x1169eed90>


## LightningModule

In [3]:
%%writefile /Users/moustholmes/Projects/METAL-AI/src/models/metalAI_module.py
from typing import Any, Dict, Tuple
import pickle
import copy
import torch
from lightning import LightningModule
from torchmetrics import MaxMetric, MeanMetric, MinMetric
from torchmetrics.classification.accuracy import Accuracy
from torchmetrics.classification import BinaryF1Score, BinaryAccuracy


class MetalAILitModule(LightningModule):
    """Example of a `LightningModule` for MNIST classification.

    A `LightningModule` implements 8 key methods:

    ```python
    def __init__(self):
    # Define initialization code here.

    def setup(self, stage):
    # Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
    # This hook is called on every process when using DDP.

    def training_step(self, batch, batch_idx):
    # The complete training step.

    def validation_step(self, batch, batch_idx):
    # The complete validation step.

    def test_step(self, batch, batch_idx):
    # The complete test step.

    def predict_step(self, batch, batch_idx):
    # The complete predict step.

    def configure_optimizers(self):
    # Define and configure optimizers and LR schedulers.
    ```

    Docs:
        https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
    """

    def __init__(
        self,
        model: torch.nn.Module,
        loss_fn: torch.nn.Module,
        target_name: str,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler,
        compile: bool,
    ) -> None:
        """Initialize a `MNISTLitModule`.

        :param net: The model to train.
        :param optimizer: The optimizer to use for training.
        :param scheduler: The learning rate scheduler to use for training.
        """
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.model = model

        # target name for specifying convergence or effect prediction
        self.target_name = target_name

        # loss function
        self.loss_fn = loss_fn

        # for averaging loss across batches
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()

        #
        self.test_val_loss = MeanMetric()
        self.test_unseen_loss = MeanMetric()
        self.test_large_loss = MeanMetric()

        self.val_loss_best = MinMetric()


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Perform a forward pass through the model `self.net`.

        :param x: A tensor of images.
        :return: A tensor of logits.
        """
        return self.model(x)

    def on_train_start(self) -> None:
        """Lightning hook that is called when training begins."""
        # by default lightning executes validation step sanity checks before training starts,
        # so it's worth to make sure validation metrics don't store results from these checks
        self.val_loss.reset()
        self.val_loss_best.reset()

    def model_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform a single model step on a batch of data.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.

        :return: A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
        """

        preds = self.forward(batch)
        targets = batch[self.target_name]
        
        if self.target_name == 'converged':
            mask = batch[self.target_name + "_mask"]
            loss = self.loss_fn(preds, targets, mask)
        else:
            loss = self.loss_fn(preds, targets)
            mask = None
        return loss, preds, targets, mask # preds[mask], targets[mask],

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Perform a single training step on a batch of data from the training set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        :return: A tensor of losses between model predictions and targets.
        """
        loss, preds, targets, mask = self.model_step(batch)

        # update and log loss
        self.train_loss(loss)
        self.log(
            "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True
        )
        # return loss or backpropagation will fail
        return {"loss": loss, "preds": preds, "targets": targets, 'mask': mask}
    
    def on_train_epoch_end(self) -> None:
        "Lightning hook that is called when a training epoch ends."
        pass

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, dataloader_idx: int = 0
    ) -> None:
        """Perform a single validation step on a batch of data from the validation set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        loss, preds, targets, mask = self.model_step(batch)

        # update and log loss
        self.val_loss(loss)
        self.val_loss_best(loss)

        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "preds": preds, "targets": targets, 'mask': mask}

    def on_validation_epoch_end(self) -> None:
        "Lightning hook that is called when a validation epoch ends."
        # acc = self.val_acc.compute()  # get current val acc
        # self.val_acc_best(acc)  # update best so far val acc
        # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        self.log("val/loss_best", self.val_loss_best.compute(), sync_dist=True, prog_bar=True)
        # pass

    def test_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, dataloader_idx: int = 0
    ) -> None:
        """Perform a single test step on a batch of data from the test set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        loss, preds, targets, mask = self.model_step(batch)

        # update and log loss
        if dataloader_idx == 0:
            self.test_val_loss(loss)

            self.log("test/val_loss", self.test_val_loss, on_step=True, on_epoch=False, prog_bar=False)
        elif dataloader_idx == 1:
            self.test_unseen_loss(loss)

            self.log("test/unseen_loss", self.test_unseen_loss, on_step=True, on_epoch=False, prog_bar=False)
        elif dataloader_idx == 2:
            self.test_large_loss(loss)

            self.log("test/large_loss", self.test_large_loss, on_step=True, on_epoch=False, prog_bar=False)
   
        return {"loss": loss, "preds": preds, "targets": targets, 'mask': mask}

    def setup(self, stage: str) -> None:
        """Lightning hook that is called at the beginning of fit (train + validate), validate,
        test, or predict.

        This is a good hook when you need to build models dynamically or adjust something about
        them. This hook is called on every process when using DDP.

        :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
        """
        if self.hparams.compile and stage == "fit":
            self.net = torch.compile(self.net)

    def configure_optimizers(self) -> Dict[str, Any]:
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
        """
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}


if __name__ == "__main__":
    _ = MetalAILitModule(None, None, None, None, None, None)


Overwriting /Users/moustholmes/Projects/METAL-AI/src/models/metalAI_module.py


In [82]:
%%writefile /Users/moustholmes/Projects/METAL-AI/configs/model/transformer_encoder_model.yaml
_target_: src.models.metalAI_module.MetalAILitModule

target_name: 'converged'

model:
  _target_: src.models.components.Transformer_encoder_model.simple_transformer_encoder_model
  csf_encoder: #${CSF_encoders.simple_csf_encoder}  # Reference the entire encoder config
    _target_: src.models.components.CSF_encoders.simple_CSF_encoder #/home/projects/ku_00258/people/mouhol/METAL-AI/src/models/components/CSF_encoder.py
    output_size: 4
  d_model: 32 #${model.CSF_encoders.simple_csf_encoder}  # Directly use the encoder's output_size
  nhead: 2
  dim_forward: 16
  num_layers: 4
  output_size: 1
  dropout: 0.00
  output_activation:
    _target_: torch.nn.Sigmoid
    

loss_fn:
  _target_: src.loss_functions.loss_function_wrappers.LossFuncMaskWrapper # torch.nn.MSELoss
  loss_fn: 
    _target_: torch.nn.BCELoss
    reduction: sum

optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  lr: 0.001
  weight_decay: 0.1

scheduler:
  _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
  _partial_: true
  T_0: 5
  T_mult: 2
  eta_min: 0.

# scheduler:
#   _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
#   _partial_: true
#   mode: min
#   factor: 0.1
#   patience: 3


# compile model for faster training with pytorch 2.0
compile: false



Overwriting /Users/moustholmes/Projects/METAL-AI/configs/model/transformer_encoder_model.yaml


In [24]:
lightningmodule = instantiate(cfg_train.model)
# pprint(cfg_train.model.model)
print(OmegaConf.to_yaml(cfg_train.model))
print()
print(lightningmodule)


scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  _partial_: true
  mode: min
  factor: 0.01
  patience: 30
_target_: src.models.metalAI_module.MetalAILitModule
target_name: converged
model:
  _target_: src.models.components.Transformer_encoder_model.simple_transformer_encoder_model
  csf_encoder:
    _target_: src.models.components.CSF_encoders.simple_CSF_encoder
    output_size: 4
  d_model: 32
  nhead: 2
  dim_forward: 16
  num_layers: 4
  output_size: 1
  dropout: 0.0
  output_activation:
    _target_: torch.nn.Sigmoid
loss_fn:
  _target_: src.models.components.loss_function_wrappers.LossFuncMaskWrapper
  loss_fn:
    _target_: torch.nn.BCELoss
    reduction: sum
optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  lr: 0.001
  weight_decay: 0.01
compile: false


MetalAILitModule(
  (model): simple_transformer_encoder_model(
    (csf_encoder): simple_CSF_encoder(
      (network): Sequential(
        (0): Linear(in_features=4, out_features=64, bias=True)
    

/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'loss_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss_fn'])`.


## Trainer

In [2]:
from lightning import Trainer
import pickle

trainer = Trainer(max_epochs=5)
datamodule = instantiate(cfg_train.data)
lightningmodule = instantiate(cfg_train.model)

trainer.fit(model= lightningmodule, datamodule=datamodule)
trainer.test(model= lightningmodule, datamodule=datamodule)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/Users/moustholmes/miniconda3/envs/m

3199 799 133385 762 1022


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


3199 799 133385 762 1022


/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{}, {}, {}]

In [3]:
test_dataloaders = datamodule.test_dataloader()
test_dataloaders

[<torch.utils.data.dataloader.DataLoader at 0x11288b7c0>,
 <torch.utils.data.dataloader.DataLoader at 0x112897b80>,
 <torch.utils.data.dataloader.DataLoader at 0x32d190c40>]

## calbacks

### Metric logger

In [35]:
# %%writefile /Users/moustholmes/Projects/METAL-AI/src/callbacks/metric_loggers.py
import torchmetrics
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from torchmetrics.wrappers import MetricTracker
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError, MeanAbsolutePercentageError


class CSFMetricsLogger(Callback):
    """Callback for logging classification metrics using torchmetrics.
    
    This callback logs metrics at the end of each training, validation, and test batch.
    The metrics are logged to the PyTorch Lightning module's logger.
    
    Args:
        metrics (torchmetrics.MetricCollection): A collection of metrics to log."""
    

    def __init__(self,):
        classification_metrics = torchmetrics.MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall(), BinaryF1Score()])
        self.train_metrics = classification_metrics.clone(prefix="train/")
        self.val_metrics = classification_metrics.clone(prefix="val/")
        self.test_metrics = classification_metrics.clone(prefix="test/")
        prev_batch_size = 0

    # def log_metrics(self, pl_module, metrics, preds, targets):
    #     """Log the given metrics to the PyTorch Lightning module's logger.

    #     Args:
    #         pl_module (LightningModule): The Lightning module being trained.
    #         metrics (torchmetrics.MetricCollection): The metrics to log.
    #         preds (torch.Tensor): The predicted outputs.
    #         targets (torch.Tensor): The ground truth targets.
    #     """

    #     pl_module.log_dict(
    #         metrics(preds, targets),
    #         on_step=False,
    #         on_epoch=True,
    #         prog_bar=True,
    #     )

    def on_train_start(self, trainer, pl_module):
        self.val_metrics.reset()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        print(outputs["batch_shape"],outputs["batch_shape"][1])
        print(outputs["preds"], outputs["targets"])
        self.train_metrics.update( outputs["preds"], outputs["targets"])
        if outputs["batch_shape"][1] != self.prev_batch_size:
            pl_module.log_dict(self.train_metrics.compute())
            self.train_metrics.reset()

        self.prev_batch_size = outputs["batch_shape"][1]
            
        # self.log_metrics(pl_module, self.train_metrics, outputs["preds"], outputs["targets"])

    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
    ):
        self.val_metrics.update( outputs["preds"], outputs["targets"])
        if outputs["batch_shape"][1] != self.prev_batch_size:
            pl_module.log_dict(self.val_metrics.compute())
            self.val_metrics.reset()
        self.prev_batch_size = outputs["batch_shape"][1]

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        self.test_metrics.update( outputs["preds"], outputs["targets"])
        if outputs["batch_shape"][1] != self.prev_batch_size:
            pl_module.log_dict(self.test_metrics.compute())
            self.test_metrics.reset()
        self.prev_batch_size = outputs["batch_shape"][1]

class RegressionMetricsLogger(Callback):
    """Callback for logging training, validation, and test metrics using torchmetrics.

    This callback logs metrics at the end of each training, validation, and test batch.
    The metrics are logged to the PyTorch Lightning module's logger.

    """

    def __init__(self):# metrics: torchmetrics.MetricCollection

        regression_metrics = torchmetrics.MetricCollection([MeanAbsoluteError(),])# MeanAbsolutePercentageError()
        self.train_metrics = regression_metrics.clone(prefix="train/")
        self.val_metrics = regression_metrics.clone(prefix="val/")
        self.test_metrics = regression_metrics.clone(prefix="test/")

    def log_metrics(self, pl_module, metrics, preds, targets):
        """Log the given metrics to the PyTorch Lightning module's logger.

        Args:
            pl_module (LightningModule): The Lightning module being trained.
            metrics (torchmetrics.MetricCollection): The metrics to log.
            preds (torch.Tensor): The predicted outputs.
            targets (torch.Tensor): The ground truth targets.
        """

        pl_module.log_dict(
            metrics(preds, targets),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

    def setup(self, trainer, pl_module, stage):
        self.train_metrics.to(pl_module.device)
        self.val_metrics.to(pl_module.device)
        self.test_metrics.to(pl_module.device)

    def on_train_start(self, trainer, pl_module):
        self.val_metrics.reset()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        self.log_metrics(pl_module, self.train_metrics, outputs["preds"], outputs["targets"])

    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
    ):
        self.log_metrics(pl_module, self.val_metrics, outputs["preds"], outputs["targets"])

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        self.log_metrics(pl_module, self.test_metrics, outputs["preds"], outputs["targets"])

class GaussianNLLMetricsLogger(Callback):
    """Callback for logging training, validation, and test metrics using torchmetrics.

    This callback logs metrics at the end of each training, validation, and test batch.
    The metrics are logged to the PyTorch Lightning module's logger.

    """

    def __init__(self):# metrics: torchmetrics.MetricCollection

        regression_metrics = torchmetrics.MetricCollection([MeanAbsoluteError(),])# MeanAbsolutePercentageError()
        # variance_metrics = torchmetrics.MetricCollection([MeanAbsoluteError(),])# MeanAbsolutePercentageError()
        self.train_metrics = regression_metrics.clone(prefix="train/")
        self.train_variance = torchmetrics.MeanMetric()
        self.val_metrics = regression_metrics.clone(prefix="val/")
        self.val_variance = torchmetrics.MeanMetric()
        self.test_metrics = regression_metrics.clone(prefix="test/")
        self.test_variance = torchmetrics.MeanMetric()

    def log_metrics(self, pl_module, metrics, preds, targets):
        """Log the given metrics to the PyTorch Lightning module's logger.

        Args:
            pl_module (LightningModule): The Lightning module being trained.
            metrics (torchmetrics.MetricCollection): The metrics to log.
            preds (torch.Tensor): The predicted outputs.
            targets (torch.Tensor): The ground truth targets.
        """

        pl_module.log_dict(
            metrics(preds, targets),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

    def setup(self, trainer, pl_module, stage):
        self.train_metrics.to(pl_module.device)
        self.train_variance.to(pl_module.device)
        self.val_metrics.to(pl_module.device)
        self.val_variance.to(pl_module.device)
        self.test_metrics.to(pl_module.device)
        self.test_variance.to(pl_module.device)

    def on_train_start(self, trainer, pl_module):
        self.val_metrics.reset()
        self.val_variance.reset()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        mask = outputs["mask"]
        mean = outputs["preds"][:,:,0][mask]
        variance = outputs["preds"][:,:,1][mask]
        targets = outputs["targets"][mask]


        pl_module.log("train/mean_variace", self.test_variance(variance), on_step=False, on_epoch=True, prog_bar=True)
        self.log_metrics(pl_module, self.train_metrics, mean, targets)

    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
    ):  
        mask = outputs["mask"]
        mean = outputs["preds"][:,:,0][mask]
        variance = outputs["preds"][:,:,1][mask]
        targets = outputs["targets"][mask]

        pl_module.log("val/mean_variace", self.val_variance(variance), on_step=False, on_epoch=True, prog_bar=True)
        self.log_metrics(pl_module, self.val_metrics, mean, targets)

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        mask = outputs["mask"]
        mean = outputs["preds"][:,:,0][mask]
        variance = outputs["preds"][:,:,1][mask]
        targets = outputs["targets"][mask]

        pl_module.log("test/mean_variace", self.test_variance(variance), on_step=False, on_epoch=True, prog_bar=True)
        self.log_metrics(pl_module, self.test_metrics, mean, targets)

class ClassificationMetricsLogger(Callback):
    """Callback for logging training, validation, and test metrics using torchmetrics.

    This callback logs metrics at the end of each training, validation, and test batch.
    The metrics are logged to the PyTorch Lightning module's logger.

    """

    def __init__(self):# metrics: torchmetrics.MetricCollection

        classification_metrics = torchmetrics.MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall(), BinaryF1Score()])
        self.train_metrics = classification_metrics.clone(prefix="train/")
        self.val_metrics = classification_metrics.clone(prefix="val/")
        self.test_metrics = classification_metrics.clone(prefix="test/")

    def log_metrics(self, pl_module, metrics, preds, targets):
        """Log the given metrics to the PyTorch Lightning module's logger.

        Args:
            pl_module (LightningModule): The Lightning module being trained.
            metrics (torchmetrics.MetricCollection): The metrics to log.
            preds (torch.Tensor): The predicted outputs.
            targets (torch.Tensor): The ground truth targets.
        """

        pl_module.log_dict(
            metrics(preds, targets),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

    def setup(self, trainer, pl_module, stage):
        self.train_metrics.to(pl_module.device)
        self.val_metrics.to(pl_module.device)
        self.test_metrics.to(pl_module.device)

    def on_train_start(self, trainer, pl_module):
        self.val_metrics.reset()

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        self.log_metrics(pl_module, self.train_metrics, outputs["preds"], outputs["targets"])

    def on_validation_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
    ):
        self.log_metrics(pl_module, self.val_metrics, outputs["preds"], outputs["targets"])

    def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        self.log_metrics(pl_module, self.test_metrics, outputs["preds"], outputs["targets"])


class BestMetricsLogger(ClassificationMetricsLogger):
    """Callback for logging and tracking the best validation metrics using torchmetrics
    MetricTracker.

    This callback extends MetricLogger to track the best validation metrics over epochs.
    It logs the best metrics observed so far at the end of each validation epoch.

    Args:
        metrics (torchmetrics.MetricCollection): A collection of metrics to log and track.

    Attributes:
        val_metrics (torchmetrics.MetricTracker): Metrics for tracking the best validation performance.
    """

    def __init__(self, metrics: torchmetrics.MetricCollection):
        super().__init__(metrics)
        self.val_metrics = MetricTracker(
            metrics.clone(prefix="val/"),
            maximize=[metric.higher_is_better for _, metric in metrics.items()],
        )

    def on_validation_epoch_start(self, trainer, pl_module):
        self.val_metrics.increment()

    def on_validation_epoch_end(self, trainer, pl_module):
        pl_module.log_dict(
            {f"{k}_best": v for k, v in self.val_metrics.best_metric().items()},
            prog_bar=True,
        )

Overwriting /Users/moustholmes/Projects/METAL-AI/src/callbacks/metric_loggers.py


In [16]:
cfg_train = compose(config_name='train',)
datamodule = instantiate(cfg_train.data)
lightningmodule = instantiate(cfg_train.model)

trainer = Trainer(max_epochs=2, callbacks=[ClassificationMetricsLogger()], accelerator="gpu")
trainer.fit(model= lightningmodule, datamodule=datamodule, )

/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'loss_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss_fn'])`.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name       | Type                             | Params | Mode 
------------------------------------------------------------------------
0 | model      | simple_transformer_encoder_model | 1.6 K  | train
1 | loss_fn    | LossFuncMaskWrapper              | 0      | train
2 | train_loss | Mea

106376 18799 125175


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
/Users/moustholmes/miniconda3/envs/metal-ai/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [34]:
cfg_train = compose(config_name='train', overrides=['experiment=effect_gaussian_nll'])
datamodule = instantiate(cfg_train.data)
lightningmodule = instantiate(cfg_train.model)

trainer = Trainer(
    max_epochs=2,
    limit_train_batches=200,
    limit_val_batches=200,
    limit_test_batches=200,
    callbacks=[GaussianNLLMetricsLogger()], 
    accelerator="gpu"
    )
trainer.fit(model= lightningmodule, datamodule=datamodule, )

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name       | Type                             | Params | Mode 
------------------------------------------------------------------------
0 | model      | simple_transformer_encoder_model | 1.6 K  | train
1 | loss_fn    | GaussianNLLLossWrapper           | 0      | train
2 | train_loss | MeanMetric                       | 0      | train
3 | val_loss   | MeanMetric                       | 0      | train
4 | test_loss  | MeanMetric                       | 0      | train
------------------------------------------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.006     Total estimated model params size (MB)


43832 7695 51527


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [45]:
# %%writefile /Users/moustholmes/Projects/METAL-AI/src/callbacks/save_test_inference_to_dict.py

import torch
from lightning.pytorch.callbacks import Callback
import pickle
from typing import Tuple


class SaveTestInferenceToDict(Callback):
    """Callback for saving test inference to a dictionary.

    Args:
        save_dir (str): path to save the dictionary.

    Attributes:
        save_dir (str):  path to save the dictionary.
        pred_dict (Dict): Dict used to save predictions.
    """

    def __init__(self, save_dir: str, filename: str = 'results'):
        self.save_dir = save_dir
        self.filename = filename

    def on_test_start(self, trainer, pl_module):
        self.pred_dict = {}

    def on_test_batch_end(
            self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
        ):
        """Perform a single test step on a batch of data from the test set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        mask = outputs['mask']
        preds = outputs['preds'][mask]
        targets = outputs['targets'][mask]
        
        
        # print(mask.shape)
        # print(csf.shape)
        # print(csf.unsqueeze(1).shape)
        
        # print(csf)
        excitations = batch["excitations"][mask]
        # print(excitations)
        # print(excitations[0])
        # print( mask.sum().item() )
        
        n = batch["converged"].shape[-1]
        
        # print(csf.repeat(n,1,1,1).shape)
        # print(csf.repeat(n,1,1,1).view(-1, n, 2)[mask.view(-1)].shape)
        # print(csf.repeat(n,1,1,1).view(-1, n, 2)[mask.view(-1)])
        # print(csf.repeat(n,1,1,1)[mask.unsqueeze(0).repeat(n, 1, 1)])
        n_protons = batch["n_protons"].unsqueeze(1).repeat(1,n)[mask]
        # print(csf.unsqueeze(1).repeat(1,1,n)[mask.unsqueeze(-1)])
        n_electrons = batch["n_electrons"].unsqueeze(1).repeat(1,n)[mask]
        csf = batch["excitations"].repeat(n,1,1,1).view(-1, n, 2)[mask.view(-1)]

        for i in range(len(targets)):
            ion_key = (n_protons[i].item(), n_electrons[i].item())
            # print(ion_key)
            csf_key = tuple(map(tuple, csf[i].cpu().numpy()))
            # print(csf_key)

            if ion_key not in self.pred_dict:
                self.pred_dict[ion_key] = {}
            
            if csf_key not in self.pred_dict[ion_key]:
                self.pred_dict[ion_key][csf_key] = {}

            self.pred_dict[ion_key][csf_key]['preds'] = preds[i].cpu().numpy()
            self.pred_dict[ion_key][csf_key]['targets'] = targets[i].cpu().numpy()
            self.pred_dict[ion_key][csf_key]['excitation'] = excitations[i].cpu().numpy()

    def on_test_end(self, trainer, pl_module):
        with open(self.save_dir+f'/{self.filename}.pkl', 'wb') as f:
            pickle.dump(self.pred_dict, f)


class GaussianNLLSaveTestInferenceToDict(Callback):
    """Callback for saving test inference to a dictionary.

    Args:
        save_dir (str): path to save the dictionary.

    Attributes:
        save_dir (str):  path to save the dictionary.
        pred_dict (Dict): Dict used to save predictions.
    """

    def __init__(self, save_dir: str, filename: str = 'results'):
        self.save_dir = save_dir
        self.filename = filename

    def on_test_start(self, trainer, pl_module):
        self.pred_dict = {}

    def on_test_batch_end(
            self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
        ):
        """Perform a single test step on a batch of data from the test set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        mask = outputs['mask']
        mean = outputs['preds'][:,:,0][mask]
        variance = outputs['preds'][:,:,1][mask]
        # preds = outputs['preds'][mask]
        targets = outputs['targets'][mask]
        
        
        # print(mask.shape)
        # print(csf.shape)
        # print(csf.unsqueeze(1).shape)
        
        # print(csf)
        excitations = batch["excitations"][mask]
        # print(excitations)
        # print(excitations[0])
        # print( mask.sum().item() )
        
        n = batch["converged"].shape[-1]
        
        # print(csf.repeat(n,1,1,1).shape)
        # print(csf.repeat(n,1,1,1).view(-1, n, 2)[mask.view(-1)].shape)
        # print(csf.repeat(n,1,1,1).view(-1, n, 2)[mask.view(-1)])
        # print(csf.repeat(n,1,1,1)[mask.unsqueeze(0).repeat(n, 1, 1)])
        n_protons = batch["n_protons"].unsqueeze(1).repeat(1,n)[mask]
        # print(csf.unsqueeze(1).repeat(1,1,n)[mask.unsqueeze(-1)])
        n_electrons = batch["n_electrons"].unsqueeze(1).repeat(1,n)[mask]
        csf = batch["excitations"].repeat(n,1,1,1).view(-1, n, 2)[mask.view(-1)]

        for i in range(len(mean)):
            ion_key = (n_protons[i].item(), n_electrons[i].item())
            # print(ion_key)
            csf_key = tuple(map(tuple, csf[i].cpu().numpy()))
            # print(csf_key)

            if ion_key not in self.pred_dict:
                self.pred_dict[ion_key] = {}
            
            if csf_key not in self.pred_dict[ion_key]:
                self.pred_dict[ion_key][csf_key] = {}

            self.pred_dict[ion_key][csf_key]['mean'] = mean[i].cpu().numpy()
            self.pred_dict[ion_key][csf_key]['variance'] = variance[i].cpu().numpy()
            self.pred_dict[ion_key][csf_key]['targets'] = targets[i].cpu().numpy()
            self.pred_dict[ion_key][csf_key]['excitation'] = excitations[i].cpu().numpy()

    def on_test_end(self, trainer, pl_module):
        with open(self.save_dir+f'/{self.filename}.pkl', 'wb') as f:
            pickle.dump(self.pred_dict, f)


Overwriting /Users/moustholmes/Projects/METAL-AI/src/callbacks/save_test_inference_to_dict.py


In [5]:
from lightning import Trainer

datamodule = instantiate(cfg_train.data)
lightningmodule = instantiate(cfg_train.model)


trainer = Trainer(
    max_epochs=3,         
    limit_train_batches=1,
    limit_val_batches=1,
    limit_test_batches=200,
    callbacks=[GaussianNLLSaveTestInferenceToDict('/Users/moustholmes/Projects/METAL-AI', name='gaussian_nll_results_test')], 
    # accelerator="gpu"
    )
trainer.fit(model= lightningmodule, datamodule=datamodule, )
trainer.test(model= lightningmodule, datamodule=datamodule, )

InstantiationException: Error in call to target 'src.data.components.data_utils.TripleTrainValSplitter':
TypeError("__init__() got an unexpected keyword argument 'unique_ion'")
full_key: data.train_val_splitter

In [43]:
from lightning import Trainer

cfg_train = compose(config_name='train', overrides=['experiment=effect_gaussian_nll'])
datamodule = instantiate(cfg_train.data)
lightningmodule = instantiate(cfg_train.model)



trainer = Trainer(
    max_epochs=3,         
    limit_train_batches=1,
    limit_val_batches=1,
    # limit_test_batches=6,  
    callbacks=[GaussianNLLSaveTestInferenceToDict('/Users/moustholmes/Projects/METAL-AI', name='gaussian_nll_results_test')], 
    # accelerator="gpu"
    )
trainer.fit(model= lightningmodule, datamodule=datamodule, )
trainer.test(model= lightningmodule, datamodule=datamodule, )

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.

  | Name       | Type                             | Params | Mode 
------------------------------------------------------------------------
0 | model      | simple_transformer_encoder_model | 1.6 K  | train
1 | loss_fn    | GaussianNLLLossWrapper           | 0      | train
2 | train_loss | MeanMetric                       | 0      | train
3 | val_loss   | MeanMetric                       | 0      | train
4 | test_loss  | MeanMetric                       | 0      | train
------------------------------------------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.006     Total estimated model params size (MB)


43815 7712 51527


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


43815 7712 51527


Testing: |          | 0/? [00:00<?, ?it/s]

[{}]