# set path

In [100]:
import os
try:
    from google.colab import drive
    drive.mount('/content/drive/', force_remount=False)

    #unpack zipped file (reading files from drive is slow)
    os.chdir('/content')
    import shutil
    print("Unzipping data...")
    shutil.unpack_archive("/content/drive/My Drive/ai-side-projects/self-supervised-halos/data/freya_postprocess.zip", "./")
    print("Unzipping done")
    rootpath = '/content/freya_postprocess/'

except:
    %matplotlib inline
    rootpath = '/Users/sdbykov/work/self-supervised-halos/data/freya_postprocess/'


# Imports

In [101]:
import torch

In [102]:
import torchvision

In [103]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from glob import glob
from tqdm import tqdm
import time

import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

from torch.utils.data import TensorDataset, DataLoader



try:
    subhalos_df = pd.read_pickle('/content/drive/My Drive/ai-side-projects/self-supervised-halos/data/subhalos_df.pkl')
except:
    subhalos_df = pd.read_pickle('/Users/sdbykov/work/self-supervised-halos/data/subhalos_df.pkl')
subhalos_df['logSubhaloMass'] = np.log10(subhalos_df['SubhaloMass']*1e10/0.6774)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
if device=='cuda':
    print(torch.cuda.get_device_properties(0).name)



cpu


# Data loaders and transformers

In [104]:
mass_bins = np.linspace(11, 14.7, 11) ## number of classes = len(mass_bins) - 1 = 10
mass_bins_nums = np.histogram(subhalos_df['logSubhaloMass'], bins=mass_bins)[0]
#mass_bins_nums = np.log10(mass_bins_nums+1) #logarithm to make the difference between bins less pronounced
mass_bins_weights = np.max(mass_bins_nums)/mass_bins_nums
mass_bins_weights = mass_bins_weights / np.sum(mass_bins_weights)

print(f"bin weights: {mass_bins_weights}, bin counts: {mass_bins_nums}, bins: {mass_bins}")
#i think log is better weight, with linear weighting, low mass halos almost do not contribute to the loss due to small weight relative to normal halos (large mass bin is 0.1-0.5, and smallest is 4e-4).
mass_bins_weights = torch.tensor(mass_bins_weights, dtype=torch.float32)


class HaloDataset(torch.utils.data.Dataset):
    mass_bins = np.linspace(11, 14.7, 11)  # Define mass_bins globally or pass as argument

    def __init__(self, root_dir, subhalos_df,
                 load_2d=True, load_3d=False, load_mass=False,
                 DEBUG_LIMIT_FILES=None):
        self.root_dir = root_dir
        self.subhalos_df = subhalos_df
        self.files_3d = sorted(glob(root_dir +  '3d/*.npz'))
        self.files_2d = sorted(glob(root_dir + '2d/*.npz'))
        self.files_mass = sorted(glob(root_dir + 'mass/*.npz'))

        if DEBUG_LIMIT_FILES:
            self.files_3d = self.files_3d[:DEBUG_LIMIT_FILES]
            self.files_2d = self.files_2d[:DEBUG_LIMIT_FILES]
            self.files_mass = self.files_mass[:DEBUG_LIMIT_FILES]

        self.load_2d = load_2d
        self.load_3d = load_3d
        self.load_mass = load_mass


        self.halos_ids = [int(file.split('_')[-2].split('.')[0]) for file in self.files_2d]
        self.loaded_data = self.preload_data()


    def preload_data(self):
        #lesson learned: loading all data at once is faster than loading it on the fly. Before that all files were loaded for each index separately and with the inference time of 0.1 sec the data loading was 30 sec
        load_2d = self.load_2d
        load_3d = self.load_3d
        load_mass = self.load_mass

        data_dict = {}


        data_dict_3d = {}
        data_dict_mass = {}

        if load_2d:
            data_dict_2d = {}
            for file in tqdm(self.files_2d, desc='Preparing 2D data'):
                halo_id = int(file.split('_')[-2].split('.')[0])
                data = np.load(file)
                data_dict_2d[halo_id] = {
                    'map_2d_xy': data['map_2d_xy'],
                    'map_2d_xz': data['map_2d_xz'],
                    'map_2d_yz': data['map_2d_yz'],
                }
            data_dict['2d'] = data_dict_2d
        
        if load_3d:
            for file in tqdm(self.files_3d, desc='Preparing 3D data'):
                halo_id = int(file.split('_')[-2].split('.')[0])
                data = np.load(file)
                data_dict_3d[halo_id] = {
                    'map_3d': data['map_3d'],
                }
            data_dict['3d'] = data_dict_3d
        
        if load_mass:
            for file in tqdm(self.files_mass, desc='Preparing mass data'):
                halo_id = int(file.split('_')[-2].split('.')[0])
                data = np.load(file)
                data_dict_mass[halo_id] = {
                    'mass_hist': data['mass_hist'],
                    'snap': data['snap'],
                }
            data_dict['mass'] = data_dict_mass

        return data_dict

    def __len__(self):
        return len(self.halos_ids)

    def select_random_projection(self):
        return np.random.choice(['xy', 'xz', 'yz'])


    def __getitem_2d__(self, idx):
        if not self.load_2d:
            return np.zeros(1)

        halo_id = self.halos_ids[idx]
        data_2d = self.loaded_data['2d'][halo_id]

        # Select a random projection
        selected_projection = self.select_random_projection()

        if selected_projection == 'xy':
            selected_data = np.expand_dims(data_2d['map_2d_xy'], axis=0)
        elif selected_projection == 'xz':
            selected_data = np.expand_dims(data_2d['map_2d_xz'], axis=0)
        elif selected_projection == 'yz':
            selected_data = np.expand_dims(data_2d['map_2d_yz'], axis=0)

        return selected_data
    
    def __getitem_3d__(self, idx):
        if not self.load_3d:
            return np.zeros(1)
        halo_id = self.halos_ids[idx]
        data_3d = self.loaded_data['3d'][halo_id]
        selected_data = np.expand_dims(data_3d['map_3d'], axis=0)
        return selected_data
    
    def __getitem_mass__(self, idx):
        if not self.load_mass:
            return (np.zeros(1), np.zeros(1))

        halo_id = self.halos_ids[idx]
        data_mass = self.loaded_data['mass'][halo_id]
        snap = data_mass['snap']
        mass_hist = data_mass['mass_hist']
        selected_data = (snap, mass_hist)
        selected_data = np.expand_dims(selected_data, axis=0)

        return selected_data
    
    def __getitem_label__(self, idx):
        halo_id = self.halos_ids[idx]
        label_mass = self.subhalos_df.loc[halo_id]['logSubhaloMass']
        label_class = np.digitize(label_mass, self.mass_bins) - 1
        label = (label_mass, label_class, halo_id)
        return label


    def __getitem__(self, idx):

        data_2d = self.__getitem_2d__(idx)
        
        data_3d = self.__getitem_3d__(idx)
        
        data_mass = self.__getitem_mass__(idx)
        
        label = self.__getitem_label__(idx)

        result_tuple = (data_2d, data_3d, data_mass)

        return result_tuple, label


##full dataset
#dataset = HaloDataset(root_dir = rootpath, subhalos_df=subhalos_df, load_2d=True, load_3d=True, load_mass=True)
#dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

bin weights: [4.08623003e-04 8.36488906e-04 1.83216523e-03 3.89963599e-03
 8.45849615e-03 1.80333421e-02 4.49692201e-02 9.10914971e-02
 3.22960762e-01 5.07509769e-01], bin counts: [8694 4247 1939  911  420  197   79   39   11    7], bins: [11.   11.37 11.74 12.11 12.48 12.85 13.22 13.59 13.96 14.33 14.7 ]


In [105]:
dataset = HaloDataset(root_dir = rootpath, subhalos_df=subhalos_df, load_2d=False, load_3d=True, load_mass=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

Preparing 3D data: 100%|██████████| 16544/16544 [01:20<00:00, 204.37it/s]
Preparing mass data: 100%|██████████| 16544/16544 [00:27<00:00, 605.38it/s]


In [106]:
for data, label in tqdm(dataloader):
    pass

100%|██████████| 130/130 [00:27<00:00,  4.74it/s]
