# Models numerical comparison - $FID$ computation

In [None]:
# Importing dependencies 
from utils import *
from ddpms import *

In [None]:
# Defining directories for generated images and model_disct_state names
# model's checkpoints directory
models_dir = './model_checkpoints'

# model checkpoints
models_dicts = ["/model_classic.pth",  # DDPM provided
                "/model_lds_simple.pth",  # DDPM + low-discrepancy sampling (VDM)
                "/model_lds_sobol.pth",  # DDPM + low-discrepancy sampling (Sobol)
                "/model_is.pt",  # DDPM + importance sampling
                "/model_x0.pt",  # DDPM predicting x_0
                "/model_mu.pt",  # DDPM predicting mu
                "/model_classic.pth",  # DDPM classifier guidance
                "/model_classifier_free.pt"  # DDPM classifier-free guidance
                ]

# generated images paths for each model and combination
gen_paths_non_guided = ["./gen_classic",  # DDPM provided (same for train and test)
                        "./gen_lsd_simple",  # DDPM + low-discrepancy sampling (VDM) (same for train and test)
                        "./gen_lsd_sobol",  # DDPM + low-discrepancy sampling (Sobol) (same for train and test)
                        "./gen_is",  # DDPM + importance sampling (same for train and test)
                        "./gen_x0",  # DDPM predicting x_0 (same for train and test)
                        "./gen_mu"  # DDPM predicting mu (same for train and test)
                        ]

gen_paths_class = [
    "./gen_class_10_train",  # DDPM classifier guided, w = 10, training set
    "./gen_class_10_test",  # DDPM classifier guided, w = 10, test set
]

gen_paths_class_free = [
    "./gen_class-free_10_train",  # DDPM classifier-free guidance, w = 10, training set
    "./gen_class-free_10_test"  # DDPM classifier-free guidance, w = 10, test set
]

# real data paths
eval_paths = ["./eval_img_train", "./eval_img_test"]

In [None]:
# Defining model parameter
# DDPM specific parameters
T = 1000
batch_size = 256

# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Backbone networks
mnist_unet = ScoreNet((lambda t: torch.ones(1).to(device)))
mnist_unet_mu = ScoreNet2()
mnist_unet_class_free = ScoreNet_class((lambda t: torch.ones(1).to(device)))

In [None]:
# Getting both test and train data 
torch.manual_seed(42)

# dataloader for train set
dataloader_mnist_train = torch.utils.data.DataLoader(
    datasets.MNIST(eval_paths[0],
                   download=True,
                   train=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

# dataloader for test set
dataloader_mnist_test = torch.utils.data.DataLoader(
    datasets.MNIST(eval_paths[1],
                   download=True,
                   train=False,
                   transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=False)

In [None]:
# pre-allocating results' dictionaries
FIDs = {}

## Non-guided DDPMs

In [None]:
# Comparison loop for non-guided models
for state_dict, folder in zip(models_dicts, gen_paths_non_guided):

    # defining model path 
    model_path = models_dir + state_dict

    # Loading the DDPM model 
    if state_dict == "/model_lds_simple.pth":
        model = DDPM_low_discrepancy(mnist_unet, T=T, sampler="simple").to(device)
        model.load_state_dict(torch.load(model_path))
    elif state_dict == "/model_lds_sobol.pth":
        model = DDPM_low_discrepancy(mnist_unet, T=T, sampler="sobol").to(device)
        model.load_state_dict(torch.load(model_path))
    elif state_dict == "/model_is.pt":
        model = DDPM_importance(mnist_unet, T=T).to(device)
        model.load_state_dict(torch.load(model_path))
    elif state_dict == "/model_x0.pt":
        model = DDPM_x0(mnist_unet, T=T).to(device)
        model.load_state_dict(torch.load(model_path))
    elif state_dict == "/model_mu.pt":
        model = DDPM_mu(mnist_unet_mu, T=T).to(device)
        model.load_state_dict(torch.load(model_path))

    # guided diffusion models
    elif state_dict in ["/model_classic.pt", "/model_classifier_free.pt"]:
        continue

    # DDPM provided
    else:
        model = DDPM_classic(mnist_unet, T=T).to(device)
        model.load_state_dict(torch.load(model_path))

    # Generating data 
    generate_save_samples(model,
                          dataloader_mnist_test,  # here is indifferent what we pass
                          root_dir=folder)

    # Fid computation  
    model_name = state_dict.split('.')[0].strip('/')
    print(f"Evaluating {model_name}")
    for train, path in zip([True, False], eval_paths):
        fid_key = f"{model_name}_{'train' if train else 'test'}"
        if train:
            eval_batches = 40
        else:
            eval_batches = None
        FID = compute_fid(generated_images_dir=folder,
                          evaluation_images_dir=path,
                          train_mnist=train,
                          download_mnist=False,
                          eval_batches=eval_batches,
                          device="cuda")
        FIDs[fid_key] = FID

In [None]:
print(FIDs)

## Classifier guided diffusion

In [None]:
# Classifier specific parameters
beta_1 = 1e-4
beta_T = 2e-2
# Load classifier - Classifier guidance DDPM only
model_classifier = RobustMNISTClassifier().to(device)
model_classifier.load_state_dict(torch.load("./model_checkpoints/classifier.pt"))
model_classifier.eval()
classifier = ClassifierWrapper(model_classifier, T=T, beta_1=beta_1, beta_T=beta_T).to(device)

# loading the model
model_path = models_dir + models_dicts[-2]
model = DDPM_class(mnist_unet, T=T).to(device)
model.load_state_dict(torch.load(model_path))

for path in gen_paths_class:
    params = path.split('_')
    w = int(params[2])
    fid_key = f"model_classifier_w{w}_{params[-1]}"

    if params[-1] == "train":
        # generating samples
        generate_save_samples(
            model,
            dataloader_mnist_train,
            root_dir=path,
            guided=True,
            w=w,
            classifier=classifier
        )
        #computing FID
        FID = compute_fid(generated_images_dir=path,
                          evaluation_images_dir=eval_paths[0],
                          train_mnist=True,
                          shuffle=True,
                          device="cuda",
                          eval_batches=40,
                          seed=42)

        FIDs[fid_key] = FID

    else:  # test
        # generating samples
        generate_save_samples(
            model,
            dataloader_mnist_test,
            root_dir=path,
            guided=True,
            w=w,
            classifier=classifier
        )
        # computing FID
        FID = compute_fid(generated_images_dir=path,
                          evaluation_images_dir=eval_paths[1],
                          train_mnist=False,
                          shuffle=False,
                          device="cuda",
                          eval_batches=None,
                          seed=42)

        FIDs[fid_key] = FID

In [None]:
print(FIDs)

## Classifier-free guidance

In [None]:
# loading the model
model_path = models_dir + models_dicts[-1]
model = DDPM_class_free(mnist_unet_class_free, T=T).to(device)
model.load_state_dict(torch.load(model_path))

for path in gen_paths_class_free:
    params = path.split('_')
    w = int(params[2])
    fid_key = f"model_class-free_w{w}_{params[-1]}"

    if params[-1] == "train":
        # generating samples
        generate_save_samples(
            model,
            dataloader_mnist_train,
            root_dir=path,
            guided=True,
            w=w
        )
        #computing FID
        FID = compute_fid(generated_images_dir=path,
                          evaluation_images_dir=eval_paths[0],
                          train_mnist=True,
                          shuffle=True,
                          device="cpu",
                          eval_batches=40,
                          seed=42)

        FIDs[fid_key] = FID

    else:  # test
        # generating samples
        generate_save_samples(
            model,
            dataloader_mnist_test,
            root_dir=path,
            guided=True,
            w=w
        )
        # computing FID
        FID = compute_fid(generated_images_dir=path,
                          evaluation_images_dir=eval_paths[1],
                          train_mnist=False,
                          shuffle=False,
                          device="cpu",
                          eval_batches=None,
                          seed=42)

        FIDs[fid_key] = FID

In [None]:
print(FIDs)