In [2]:
# importing dependencies 
from utils import *
from ddpms import *

In [3]:
# defining directory and model_disct_state names
models_dicts = ["model_classic.pth", "model_lds_simple.pth", "model_lds_sobol.pth"]
gen_paths = ["./gen_classic", "./gen_lsd_simple","./gen_lsd_sobol"]
eval_path = "./eval_img"

In [4]:
# Defining model parameter
T = 1000
learning_rate = 1e-3
epochs = 100
batch_size = 256
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))

In [5]:
# Define dataloader for MNIST test set
dataloader_gen = torch.utils.data.DataLoader(
    datasets.MNIST(eval_path,
                   download=True,
                   train=False, # selecting test data
                   transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

In [6]:
FIDs = []

# Comparison loop
for state_dict,folder in zip(models_dicts, gen_paths):
    # Loading the model 
    if state_dict == "model_lds_simple.pth":
        model = DDPM_low_discrepancy(mnist_unet, T=T, sampler = "simple").to(device)
        model.load_state_dict(torch.load(state_dict))
    if state_dict == "model_lds_sobol.pth":
        model = DDPM_low_discrepancy(mnist_unet, T=T, sampler = "sobol").to(device)
        model.load_state_dict(torch.load(state_dict)) 
    else: 
        model = DDPM_classic(mnist_unet, T=T).to(device)
        model.load_state_dict(torch.load(state_dict))

    # Sample generation 
    generate_save_samples(model,
                          dataloader_gen,
                          root_dir = folder)
    # Fid computation  
    print(f"Evaluating {state_dict.split('.')[0]}")
    FID = compute_fid(generated_images_dir = folder,
                      evaluation_images_dir = eval_path,
                      device = "cuda")
    FIDs.append(FID)

Sample have been already generated.

Evaluating model_classic
Evaluation on MNIST test set...


Loading real data into FID object:   0%|          | 0/40 [00:00<?, ?it/s]

Loading generated data into FID object:   0%|          | 0/40 [00:00<?, ?it/s]

Computing FID...
FID: 12.309158325195312.

Sample have been already generated.

Evaluating model_lds_simple
Evaluation on MNIST test set...


Loading real data into FID object:   0%|          | 0/40 [00:00<?, ?it/s]

Loading generated data into FID object:   0%|          | 0/40 [00:00<?, ?it/s]

Computing FID...
FID: 12.788650512695312.

Sample have been already generated.

Evaluating model_lds_sobol
Evaluation on MNIST test set...


Loading real data into FID object:   0%|          | 0/40 [00:00<?, ?it/s]

Loading generated data into FID object:   0%|          | 0/40 [00:00<?, ?it/s]

Computing FID...
FID: 10.245132446289062.

