In [None]:
import glob
import copy

In [None]:
import sys
sys.path.append("../")
sys.path.append("../imagen/")

from helpers import *
from imagen_pytorch import Unet, Imagen, ImagenTrainer, NullUnet

In [None]:
RUN_NAME = "64_FC_rot904_3e-4"
BASE_DIR = f"{BASE_HOME}/models/{RUN_NAME}/models/64_FC"

In [None]:
seed_value = 42
torch.manual_seed(seed_value)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed_value)

In [None]:
best_epoch_dict = {
    "64_FC_rot904_sep_3e-4": 180,
    "64_FC_rot904_3e-4": 240,
    "64_FC_3e-4": 235
}

In [None]:
unet1 = Unet(
    dim = 32,
    cond_dim = 1024,
    dim_mults = (1, 2, 4, 8),
    num_resnet_blocks = 3,
    layer_attns = (False, True, True, True),
)  

unets = [unet1]

In [None]:
class DDPMArgs:
    def __init__(self):
        pass
    
args = DDPMArgs()
args.batch_size = 16
args.image_size = 64 ; args.o_size = 64 ; args.n_size = 128 ;
args.continuous_embed_dim = 64*64*4
args.dataset_path = f"/rds/general/user/zr523/home/researchProject/satellite/dataloader/{args.o_size}_FC"
args.datalimit = False
args.lr = 3e-4
args.mode = "fc"

train_dataloader, test_dataloader = get_satellite_data(args)
_ = len(train_dataloader) ; _ = len(test_dataloader)

if '1k' in RUN_NAME:
    timesteps = 1000
else:
    timesteps = 250

imagen = Imagen(
    unets = unets,
    image_sizes = (64),
    timesteps = 250,
    cond_drop_prob = 0.1,
    condition_on_continuous = True,
    continuous_embed_dim = args.continuous_embed_dim,
)

In [None]:
metric_dict = {
    "kl_div": [],
    "rmse": [],
    "mae":  [],
    "psnr": [],
    "ssim": [],
    "fid": []
}

test_metric_dict = copy.deepcopy(metric_dict)
best_epoch = best_epoch_dict[RUN_NAME]
ckpt_trainer_path = f"{BASE_DIR}/ckpt_trainer_1_{best_epoch:03}.pt"
trainer = ImagenTrainer(imagen, lr=args.lr, verbose=False).cuda()
trainer.load(ckpt_trainer_path) 

for idx in tqdm(range(5)):
    batch_idx = test_dataloader.random_idx[idx]
    img_64, _, era5 = test_dataloader.get_batch(batch_idx)
    cond_embeds = era5.reshape(era5.shape[0], -1).float().cuda()
    ema_sampled_images = imagen.sample(
            batch_size = img_64.shape[0],          
            cond_scale = 3.,
            continuous_embeds = cond_embeds,
            use_tqdm = False
        )
    
    y_true = img_64.cpu()
    y_pred = ema_sampled_images.cpu()
    metric_dict = calculate_metrics(y_pred, y_true)
    for key in metric_dict.keys():
        test_metric_dict[key].append(metric_dict[key])

In [None]:
with open(f"{BASE_HOME}/models/{RUN_NAME}/metrics_test.pkl", "wb") as file:
    pickle.dump(test_metric_dict, file)