In [2]:
import sys
import torch
import pickle
from glob import glob
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import healpy as hp

In [3]:
import os
import sys
import torch
import pickle
from glob import glob
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import healpy as hp

In [6]:
sys.path.append('/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/scripts')
sys.path.append('/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/')
from diffusion.diffusionclass import Diffusion
from diffusion.schedules import TimestepSampler, linear_beta_schedule
from diffusion.ResUnet_timeembed import Unet
from maploader.maploader import get_data, get_minmaxnormalized_data, get_loaders, MapDataset
from utils.filter_boost import filter_boost, calculate_k, batch_filter_boost
from utils.run_utils import initialize_config, setup_trainer

In [7]:
config_file = "/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/config/config_diffusion.yaml"
config_dict = initialize_config(config_file)

pl.seed_everything(1234)

### get training data
lrmaps_dir = "/gpfs02/work/akira.tokiwa/gpgpu/FastPM/healpix/nc128/"
hrmaps_dir = "/gpfs02/work/akira.tokiwa/gpgpu/FastPM/healpix/nc256/"
n_maps = len(glob(lrmaps_dir + "*.fits"))
nside = 512
order = 4

CONDITIONAL = True
BATCH_SIZE = 24
TRAIN_SPLIT = 0.8

config_dict['train']['batch_size'] = BATCH_SIZE
config_dict["data"]["conditional"] = CONDITIONAL

[rank: 0] Global seed set to 1234


In [8]:
lr = get_data(lrmaps_dir, n_maps, nside, order, issplit=True)
hr = get_data(hrmaps_dir, n_maps, nside, order, issplit=True)

lr, transforms_lr, inverse_transforms_lr, range_min_lr, range_max_lr = get_minmaxnormalized_data(lr)
print("LR data loaded. min: {}, max: {}".format(range_min_lr, range_max_lr))

hr, transforms_lr, inverse_transforms_hr, range_min_hr, range_max_hr = get_minmaxnormalized_data(hr)
print("HR data loaded. min: {}, max: {}".format(range_min_hr, range_max_hr))

data_input, data_condition = hr-lr, lr
train_loader, val_loader = get_loaders(data_input, data_condition, TRAIN_SPLIT, BATCH_SIZE)

LR data loaded. min: 0.0, max: 2.0128371715545654
HR data loaded. min: 0.0, max: 2.903632402420044


In [6]:
#get sampler type
sampler = TimestepSampler(timesteps=int(config_dict['diffusion']['timesteps']), **config_dict['diffusion']['sampler_args'])

Sampler type uniform


In [7]:
sys.path.append('/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/run')
from run_diffusion import Unet_pl

In [8]:
#get model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Unet_pl(Unet, config_dict, sampler = sampler).to(device)

In [9]:
ckpt_path = "/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/ckpt_logs/diffusion/HR_LR_normalized/version_1/checkpoints/Run_10-29_19-42epoch=12-val_loss=0.03.ckpt"

In [10]:
state_dict = torch.load(ckpt_path)['state_dict']
model.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=['model.init_conv.laplacian', 'model.init_conv_lr.laplacian', 'model.down_blocks.0.0.block1.conv.laplacian', 'model.down_blocks.0.0.block2.conv.laplacian', 'model.down_blocks.0.1.block1.conv.laplacian', 'model.down_blocks.0.1.block2.conv.laplacian', 'model.down_blocks.1.0.block1.conv.laplacian', 'model.down_blocks.1.0.block2.conv.laplacian', 'model.down_blocks.1.0.res_conv.laplacian', 'model.down_blocks.1.1.block1.conv.laplacian', 'model.down_blocks.1.1.block2.conv.laplacian', 'model.down_blocks.2.0.block1.conv.laplacian', 'model.down_blocks.2.0.block2.conv.laplacian', 'model.down_blocks.2.0.res_conv.laplacian', 'model.down_blocks.2.1.block1.conv.laplacian', 'model.down_blocks.2.1.block2.conv.laplacian', 'model.down_blocks.3.0.block1.conv.laplacian', 'model.down_blocks.3.0.block2.conv.laplacian', 'model.down_blocks.3.0.res_conv.laplacian', 'model.down_blocks.3.1.block1.conv.laplacian', 'model.down_blocks.3.1.block2.conv.laplacian', 'model.mid_block1.block

In [11]:
beta_func = linear_beta_schedule
beta_args = config_dict['diffusion']['schedule_args']
betas = beta_func(timesteps=int(config_dict['diffusion']['timesteps']), **beta_args)
tmp_diffusion = Diffusion(betas)

In [12]:
map_dir = "/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/results/imgs/diffusion/HR_LR_normalized/"
if not os.path.exists(map_dir):
    os.makedirs(map_dir)

In [13]:
BATCH_SIZE = 8
PATCH_SIZE = 128

In [14]:
timesteps = int(config_dict['diffusion']['timesteps'])

In [15]:
i = 0
model.eval()
tmp_sample =data_input[BATCH_SIZE*i:BATCH_SIZE*(i+1)].to(device)
tmp_lr = data_condition[BATCH_SIZE*i:BATCH_SIZE*(i+1)].to(device)
q_sample = tmp_diffusion.q_sample(tmp_sample, torch.full((BATCH_SIZE,), timesteps-1, device=device))
img = torch.randn(tmp_sample.shape, device=device)
batch_img = []
for j in reversed(range(0, timesteps)):
    t = torch.full((BATCH_SIZE,), j, device=device, dtype=torch.long)
    #loss = model.diffusion.p_losses(model.model, tmp_sample, t, tmp_lr)
    img = tmp_diffusion.p_sample(model.model, img, t, tmp_lr, j)
    if (j % 10 == 0):
        batch_img.append(np.hstack(img.detach().cpu().numpy()[:BATCH_SIZE, : , 0]))
    #print('Step {}, Loss {}'.format(j, loss), flush=True)

In [17]:
batch_img[0].shape

(131072,)

In [18]:
print("Start Diffusion")
model.eval()
all_img = []
for i in range(int(PATCH_SIZE/BATCH_SIZE)):
    tmp_sample =data_input[BATCH_SIZE*i:BATCH_SIZE*(i+1)].to(device)
    tmp_lr = data_condition[BATCH_SIZE*i:BATCH_SIZE*(i+1)].to(device)
    q_sample = tmp_diffusion.q_sample(tmp_sample, torch.full((BATCH_SIZE,), timesteps-1, device=device))
    img = torch.randn(tmp_sample.shape, device=device)
    batch_img = []
    for j in reversed(range(0, timesteps)):
        t = torch.full((BATCH_SIZE,), j, device=device, dtype=torch.long)
        loss = model.diffusion.p_losses(model.model, tmp_sample, t, tmp_lr)
        img = tmp_diffusion.p_sample(model.model, img, t, tmp_lr, j)
        print('Step {}, Loss {}'.format(j, loss), flush=True)
        if (j % 10 == 0):
            batch_img.append(np.hstack(img.detach().cpu().numpy()[:BATCH_SIZE, : , 0]))
    all_img.append(np.vstack(batch_img))

Start Diffusion
Step 999, Loss 0.06477835774421692
Step 998, Loss 0.06598452478647232
Step 997, Loss 0.0682447999715805
Step 996, Loss 0.06764979660511017
Step 995, Loss 0.06644610315561295
Step 994, Loss 0.0660264641046524
Step 993, Loss 0.06633004546165466
Step 992, Loss 0.06774276494979858
Step 991, Loss 0.06653936207294464
Step 990, Loss 0.06412060558795929
Step 989, Loss 0.06319484114646912
Step 988, Loss 0.06302626430988312
Step 987, Loss 0.06381027400493622
Step 986, Loss 0.06598766148090363
Step 985, Loss 0.06837636232376099
Step 984, Loss 0.06833910942077637
Step 983, Loss 0.06839028000831604
Step 982, Loss 0.0655847042798996
Step 981, Loss 0.06436821818351746
Step 980, Loss 0.06390494108200073
Step 979, Loss 0.06535329669713974
Step 978, Loss 0.06618858873844147
Step 977, Loss 0.0670400857925415
Step 976, Loss 0.06536988914012909
Step 975, Loss 0.06421730667352676
Step 974, Loss 0.06456606835126877
Step 973, Loss 0.06423711031675339
Step 972, Loss 0.0649961531162262
Step 971,

In [18]:
def read_maps(map_dir, diffsteps=100, batch_size=16):
    maps = sorted(glob(map_dir + "/*.fits"), key=lambda x: (int(x.split("/")[-1].split("_")[2]), int(x.split("/")[-1].split(".")[0].split("_")[-1])))
    map_diffused = []
    for i in range(diffsteps):
        map_steps = []
        for j in range(batch_size):
            map_steps.append(hp.read_map(maps[i*batch_size+j]))
        map_steps = np.array(map_steps)
        map_steps = np.hstack(map_steps)
        map_diffused.append(map_steps)
    map_diffused = np.array(map_diffused)
    return map_diffused

In [19]:
map_diffused = read_maps(map_dir, diffsteps=100, batch_size=16)

In [20]:
map_diffused.shape

(100, 3145728)

In [9]:
map_dir ="/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR"
png_files = sorted(glob(map_dir + "/png/*.png"), key=lambda x: int(x.split("/")[-1].split(".")[0].split("_")[-1]))[::-1]
ps_files = sorted(glob(map_dir + "/ps/*.png"), key=lambda x: int(x.split("/")[-1].split(".")[0].split("_")[-2]))[::-1]

In [10]:
png_files

['/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR/png/step_990.png',
 '/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR/png/step_980.png',
 '/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR/png/step_970.png',
 '/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR/png/step_960.png',
 '/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR/png/step_950.png',
 '/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR/png/step_940.png',
 '/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR/png/step_930.png',
 '/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR/png/step_920.png',
 '/gpfs02/work/akira.tokiwa/gpgpu/Github/SR-SPHERE/LEGACY/srsphere/diffusion/diffused_map_HR/png/step_91