In [1]:
import sys
import torch
import pickle
from glob import glob
import pytorch_lightning as pl

In [2]:
sys.path.append('/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/scripts')
from diffusion.diffusionclass import Diffusion
from diffusion.schedules import TimestepSampler, linear_beta_schedule
from diffusion.ResUnet_timeembed import Unet
from maploader.maploader import get_data, get_minmaxnormalized_data, get_loaders, MapDataset
from utils.filter_boost import filter_boost, calculate_k, batch_filter_boost
from utils.run_utils import initialize_config, setup_trainer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
config_file = "/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/config/config_diffusion.yaml"
config_dict = initialize_config(config_file)

pl.seed_everything(1234)

### get training data
lrmaps_dir = "/gpfs02/work/akira.tokiwa/gpgpu/FastPM/healpix/nc128/"
hrmaps_dir = "/gpfs02/work/akira.tokiwa/gpgpu/FastPM/healpix/nc256/"
n_maps = len(glob(lrmaps_dir + "*.fits"))
nside = 512
order = 4

CONDITIONAL = True
BATCH_SIZE = 24
TRAIN_SPLIT = 0.8

config_dict['train']['batch_size'] = BATCH_SIZE
config_dict["data"]["conditional"] = CONDITIONAL

[rank: 0] Global seed set to 1234


In [4]:
dataset_lr = MapDataset(lrmaps_dir, n_maps, nside, order)
dataset_hr = MapDataset(hrmaps_dir, n_maps, nside, order)

In [4]:
lr = get_data(lrmaps_dir, n_maps, nside, order, issplit=False)
hr = get_data(hrmaps_dir, n_maps, nside, order, issplit=False)

lr, inverse_transforms_lr, range_min_lr, range_max_lr = get_minmaxnormalized_data(lr)
print("LR data loaded. min: {}, max: {}".format(range_min_lr, range_max_lr))

hr, inverse_transforms_hr, range_min_hr, range_max_hr = get_minmaxnormalized_data(hr)
print("HR data loaded. min: {}, max: {}".format(range_min_hr, range_max_hr))

data_input, data_condition = hr-lr, lr
train_loader, val_loader = get_loaders(data_input, data_condition, TRAIN_SPLIT, BATCH_SIZE)

LR data loaded. min: 0.0, max: 2.0128371715545654
HR data loaded. min: 0.0, max: 2.903632402420044
