In [None]:
import glob
import copy

In [None]:
import sys
sys.path.append("../")
sys.path.append("../imagen/")

from helpers import *
from imagen_pytorch import Unet, Imagen, ImagenTrainer, NullUnet

In [None]:
RUN_NAME = "64_FC_nio_rot904_3e-4"
BASE_DIR = f"{BASE_HOME}/models/{RUN_NAME}/models/64_FC/"

In [None]:
seed_value = 42
torch.manual_seed(seed_value)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed_value)

In [None]:
ckpt_files = sorted(glob.glob(BASE_DIR + "ckpt_1_*"))
ckpt_trainer_files = sorted(glob.glob(BASE_DIR + "ckpt_trainer_1_*"))

In [None]:
unet1 = Unet(
    dim = 32,
    cond_dim = 1024,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
)  

unets = [unet1]

In [None]:
class DDPMArgs:
    def __init__(self):
        pass
    
args = DDPMArgs()
args.batch_size = 16
args.image_size = 64 ; args.o_size = 64 ; args.n_size = 128 ;
args.continuous_embed_dim = 64*64*4
args.dataset_path = f"/rds/general/user/zr523/home/researchProject/satellite/dataloader/{args.o_size}_FC"
args.datalimit = False
args.lr = 3e-4
args.mode = "fc"
args.region = region_to_abbv["North Indian Ocean"]

train_dataloader, test_dataloader = get_satellite_data(args)
_ = len(train_dataloader) ; _ = len(test_dataloader)

In [None]:
del args.region
args.exclude_region = region_to_abbv["North Indian Ocean"]
_, oreg_test_dataloader = get_satellite_data(args)
_ = len(oreg_test_dataloader)

del args.exclude_region
args.region = region_to_abbv["North Indian Ocean"]

In [None]:
if '1k' in RUN_NAME:
    timesteps = 1000
else:
    timesteps = 250

imagen = Imagen(
    unets = unets,
    image_sizes = (64),
    timesteps = 250,
    cond_drop_prob = 0.1,
    condition_on_continuous = True,
    continuous_embed_dim = args.continuous_embed_dim,
)

In [None]:
random_idx = [5]

metric_dict = {
    "kl_div": [],
    "rmse": [],
    "mae":  [],
    "psnr": [],
    "ssim": [],
    "fid": []
}

train_test_metric_dict = {
    "train": copy.deepcopy(metric_dict), 
    "test": copy.deepcopy(metric_dict),
    "oreg_test": copy.deepcopy(metric_dict)
}

for idx in range(len(ckpt_trainer_files[:2])):
    ckpt_trainer_path = ckpt_trainer_files[idx]

    for mode in ["train", "test", "oreg_test"]:
        if mode == "train" : dataloader = train_dataloader
        elif mode == "test": dataloader = test_dataloader
        elif mode == "oreg_test": dataloader = oreg_test_dataloader
        
        trainer = ImagenTrainer(imagen, lr=args.lr, verbose=False).cuda()
        trainer.load(ckpt_trainer_path)  
        
        batch_idx = dataloader.random_idx[random_idx[0]]
        img_64, _, era5 = dataloader.get_batch(batch_idx)
        cond_embeds = era5.reshape(era5.shape[0], -1).float().cuda()
        ema_sampled_images = imagen.sample(
                batch_size = img_64.shape[0],          
                cond_scale = 3.,
                continuous_embeds = cond_embeds,
            )
        
        y_true = img_64.cpu()
        y_pred = ema_sampled_images.cpu()
        metric_dict = calculate_metrics(y_pred, y_true)
        for key in metric_dict.keys():
            train_test_metric_dict[mode][key].append(metric_dict[key])

In [None]:
with open(f"{BASE_HOME}/models/{RUN_NAME}/metrics_v2.pkl", "wb") as file:
    pickle.dump(train_test_metric_dict, file)