In [None]:
import argparse
from train import train_sn
from test import test_ncsn, inpaint_ncsn, test_mix
from load_data import load_dataset
import torch
from models import *
import logging
import torch.distributions as TD
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from utils import plot_score_function, distribution2score
import os
import pandas as pd
import seaborn as sns
import json
import numpy as np
import matplotlib.pyplot as plt
from models import NoiseConditionalScoreNetwork
import torch
import os
from load_data import load_dataset
import gc
from tqdm import tqdm
from torchvision.utils import save_image, make_grid
from PIL import Image
from utils import distribution2score
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import numpy as np
from utils import batch_jacobian
from sklearn.decomposition import PCA

In [None]:
logging.basicConfig(level=logging.INFO)
logg = logging.getLogger(__name__)

Implement from the paper "Sliced Score Matching: A Scalable Approach to Density and Score Estimation"

In [None]:
def sliced_score_matching(model, x, labels,  M, distribution,v=None):
    if distribution == "normal":
        v = torch.randn(M, *x.shape).to(x.device)
    elif distribution == "rademacher":
        v = torch.randint(0, 2, (M, *x.shape)).to(x.device) * 2 - 1
    elif distribution == "pca" :
        v = torch.from_numpy(np.repeat(v[:, np.newaxis, :], x.shape[0], axis=1)) 
   
    v = v.to(x.device).float()
    N = x.shape[0]
    J = 0 
    sm = model(x, labels.to(x.device))
    grad_sm = batch_jacobian(input=x, output=sm)
    for i in range(N):
        for j in range(M): 
            J += 0.5 * torch.matmul(torch.matmul(v[j][i], grad_sm[i]), v[j][i]) + 0.5 * torch.matmul(v[j][i], sm[i])**2
    return J / (N * M)

In [None]:
def sliced_score_matching_vr(model, x, labels,  M, distribution):
    # print(labels, x.shape, labels.shape)
    # M directions
    if distribution == "normal":
        v = torch.randn(M, *x.shape).to(x.device)
    elif distribution == "rademacher":
        v = torch.randint(0, 2, (M, *x.shape)).to(x.device) * 2 - 1
    v = v.to(x.device).float()
    N = x.shape[0]
    J = 0 
    sm = model(x, labels.to(x.device))
    grad_sm = batch_jacobian(input=x, output=sm)
    for i in range(N):
        for j in range(M):  
            J += 0.5 * torch.matmul(torch.matmul(v[j][i], grad_sm[i]), v[j][i]) + 0.5 * torch.norm(sm[i], p=2)**2
    return J / (N * M)

In [None]:
def save_model(model: object, optimizer: object, path: str) -> None:
    torch.save([model.state_dict(), optimizer.state_dict()], path)

In [None]:
def train_sn(
    model: object,
    train_loader: object,
    n_epochs: int,
    lr: float,
    sigmas: torch.Tensor = torch.Tensor([0.1]),
    use_cuda: bool = False,
    conditional: bool = True,
    loss_type: str = "denoising_score_matching",
    n_vectors: int = 1,
    dist_type: str = "normal",
) -> dict:
    if use_cuda:
        critic = model.cuda()
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0, 0.9))
    batch_loss_history = {"loss": []}
    v = None
    if dist_type == "pca":
        # get the first principal component from the data
        pca = PCA(n_components=2) 
        pca.fit(train_loader.dataset.tensors[0])
        v = pca.components_
    for epoch_i in tqdm(range(n_epochs)):
        mean_loss = 0
        for batch_i, x in enumerate(train_loader):
            x = x[0]
            batch_size = x.shape[0]
            if use_cuda:
                x = x.cuda()
        
            if conditional:
                labels = torch.randint(len(sigmas), (batch_size,))
                sigma_batch = sigmas[labels].to(x.device)
                sigma_batch = sigma_batch.reshape(-1, 1)
            else:
                sigma_batch = (
                    sigmas[0] * torch.ones(batch_size, 1, device=x.device).float()
                )
            if loss_type == "denoising_score_matching":
                standart_noise = torch.randn_like(x)
                x_noisy = x + standart_noise * sigma_batch
                optimizer.zero_grad()
                if conditional:
                    pred_scores = model(x_noisy, labels.to(x.device))
                else:
                    pred_scores = model(x_noisy)  
                noisy_scores = -standart_noise / sigma_batch
                losses = torch.sum((pred_scores - noisy_scores) ** 2, axis=-1) / 2
                loss = torch.mean(losses * sigma_batch.flatten() ** 2)
            elif loss_type == "sliced_score_matching":
                x.requires_grad_(True)
                optimizer.zero_grad()
                loss = sliced_score_matching(model, x, labels, n_vectors, dist_type,v)
            elif loss_type == "sliced_score_matching_vr":
                x.requires_grad_(True)
                optimizer.zero_grad()
                loss = sliced_score_matching_vr(model, x, labels,  n_vectors, dist_type,v)
            loss.backward()
            optimizer.step()
            mean_loss += loss.data.cpu().numpy()
        batch_loss_history["loss"].append(mean_loss / len(train_loader))
    return model, batch_loss_history

In [None]:
def main():
    parser = argparse.ArgumentParser(
        description="Geneerate samples by estimating the gradient of the data distribution"
    )
    parser.add_argument("--batch_size", type=int, default=128, help="batch size")
    parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs")
    parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
    parser.add_argument("--lambda_max", type=float, default=0.01, help="lambda max")
    parser.add_argument("--lambda_min", type=float, default=1e-4, help="lambda min")
    parser.add_argument("--n_lambdas", type=int, default=10, help="number of lambdas")
    parser.add_argument("--use_cuda", type=bool, default=True, help="use cuda")
    parser.add_argument("--mode", type=str, default="train", help="mode")
    parser.add_argument("--n_samples", type=int, default=5, help="number of samples")
    parser.add_argument("--n_steps", type=int, default=100, help="number of steps")
    parser.add_argument("--save_freq", type=int, default=50, help="save frequency")
    parser.add_argument("--model_name", type=str, default="simple", help="model name")
    parser.add_argument("--eps", type=float, default=5e-5, help="eps")
    parser.add_argument("--dataset", type=str, default="cifar10")
    parser.add_argument("--directions", type=str, default="right")
    parser.add_argument("--loss_type", type=str, default="denoising")
    parser.add_argument("--n_vectors", type=int, default=1)
    parser.add_argument("--dist_type", type=str, default="normal")
    parser.add_argument("--save", type=bool, default=True)
    args = parser.parse_args()
    return args

In [None]:
if __name__ == "__main__":
    torch.cuda.empty_cache()
    args = main()
    batch_size = args.batch_size

    # check cuda
    if args.use_cuda:
        assert torch.cuda.is_available(), "CUDA is not available"
    torch.cuda.empty_cache()
    # create sigmas
    sigmas = torch.tensor(
        np.exp(
            np.linspace(
                np.log(args.lambda_max), np.log(args.lambda_min), args.n_lambdas
            )
        ),
        dtype=torch.float32,
    )
    print(args.dataset)
    if args.dataset == "mnist":
        train_data, test_data = load_dataset(
            args.dataset, flatten=False, binarize=False
        )
        path = "./mnist.pth"
    elif args.dataset == "cifar10":
        train_data, test_data = load_dataset(
            args.dataset, flatten=False, binarize=False
        )
        path = "./pretrained_models/cifar10.pth"
    elif args.dataset == "celeba":
        train_data, test_data = load_dataset(
            args.dataset, flatten=False, binarize=False
        )
        path = "./pretrained_models/celeba.pth"
    elif args.dataset == "mixture":
        p = 0.2
        noise = 0.1
        mix = TD.Categorical(torch.tensor([p, 1 - p]).cuda())
        mv_normals = TD.MultivariateNormal(
            torch.tensor([[1.0, 1.0], [-1.0, -1.0]]).cuda(),
            noise * torch.eye(2).unsqueeze(0).cuda(),
        )
        mixture = TD.MixtureSameFamily(mix, mv_normals)
        # check if the dataset is already created
        if os.path.exists("datasets/train_data.json"):
            with open("datasets/train_data.json", "r") as f:
                train_data = torch.tensor(json.load(f))
            with open("datasets/test_data.json", "r") as f:
                test_data = torch.tensor(json.load(f))
            print("Dataset is loaded from json file")
        else:
            train_data, test_data = train_test_split(mixture.sample((10000,)))
            # save the dataset in a json file
            if not os.path.exists("datasets"):
                os.makedirs("datasets")
            with open("datasets/train_data.json", "w") as f:
                json.dump(train_data.tolist(), f)
            with open("datasets/test_data.json", "w") as f:
                json.dump(test_data.tolist(), f)
            print('Dataset is created and saved in "datasets" folder')

    # choose model
    if args.model_name == "ncsn":
        model = NoiseConditionalScoreNetwork()
    elif args.model_name == "simple_ncsn":
        model = SimpleNoiseConditionalScoreNetwork(
            hidden_dim=512, data_dim=2, num_sigmas=len(sigmas)
        )
    elif args.model_name == "condrefinenet":
        model = CondRefineNetDilated()
    elif args.model_name == "simple":
        model = SimpleScoreNetwork(hidden_dim=512)
    else:
        raise ValueError(
            'The argument model_name must have the values "ncsn", "condrefinenet" or "refinenet"'
        )

    # train or test
    if args.mode == "train":
        logg.info("Starting training")
        n_epochs = args.n_epochs
        lr = args.lr
        train_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(torch.tensor(train_data)),
            batch_size=batch_size,
            shuffle=True,
        )
        # save model in trained_models folder ceate it if not exist
        dataset, model_name, loss_type, epochs, samples , n_vectors, dist_type = (
            args.dataset,
            args.model_name,
            args.loss_type,
            args.n_epochs,
            args.n_samples,
            args.n_vectors,
            args.dist_type,
        )
        if not os.path.exists("trained_models"):
            os.makedirs("trained_models")
        model_path = f"trained_models/{dataset}_{model_name}_{loss_type}_{epochs}_{dist_type}_{n_vectors}.pth"
        # check if the model is already trained
        if os.path.exists(model_path):
            model = torch.load(model_path)
            print("Model is loaded from the saved file")
        else:
            model, batch_loss = train_sn(
                    model,
                    train_loader=train_loader,
                    n_epochs=n_epochs,
                    sigmas=sigmas,
                    lr=args.lr,
                    conditional=True,
                    use_cuda=args.use_cuda,
                    loss_type=args.loss_type,
                    n_vectors=args.n_vectors,
                    dist_type=args.dist_type,
                )
            
        if args.save == True:
            torch.save(
                model, model_path
            )
        
        if not os.path.exists("training_experiments"):
            os.makedirs("training_experiments")
        if loss_type == "sliced_score_matching" or loss_type == "sliced_score_matching_vr":
            exp_folder = (
                f"{dataset}_{model_name}_{loss_type}_epochs_{epochs}_samples_{samples}_n_vectors_{args.n_vectors}_dist_type_{args.dist_type}"
            )
        else :
            exp_folder = (
                f"{dataset}_{model_name}_{loss_type}_epochs_{epochs}_samples_{samples}"
            )
        full_path = os.path.join("training_experiments", exp_folder)
        if not os.path.exists(full_path):
            os.makedirs(full_path)

        # save a config file with parameters of the experiment
        
        with open(f"{full_path}/config.txt", "w") as f:
            for key, value in args.__dict__.items():
                f.write("%s:%s\n" % (key, value))
        # save loss values
        with open(f"{full_path}/loss.txt", "w") as f:
            for loss in batch_loss["loss"]:
                f.write("%s\n" % loss)
        #save plot of loss after each epoch
        plt.plot(batch_loss["loss"])
        plt.title("Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.savefig(f"training_experiments/" + exp_folder + "/loss.png")

        # save plot of distribution,score
        fig, ax = plt.subplots(1, 3, figsize=(12, 4))
        plot_score_function(
            distribution2score(mixture),
            sigmas,
            train_data,
            "TwoGaussMixture Score and Samples",
            ax=ax[0],
            npts=30,
        )
  
        samples = model.sample(
            n_samples=samples, n_steps=args.n_steps, eps=args.eps, sigmas=sigmas
        )
        plot_score_function(
            model,
            sigmas,
            samples,
            "Predicted scores",
            ax=ax[1],
            npts=30,
            plot_scatter=False,
        )
        
        plot_score_function(
            model, sigmas, samples, "Predicted scores and samples", ax=ax[2], npts=30
        )
        plt.savefig(f"training_experiments/" + exp_folder + "/cond_score.png")
        
        # crete an image with only the predicted samples
        
        fig, ax = plt.subplots(1, 1, figsize=(4, 4))
        plot_score_function(
            model, sigmas, samples, f"Predicted samples", ax=ax, npts=30
        )
        plt.savefig("training_experiments/" + exp_folder + f"/Predicted_samples {loss_type}_{dist_type}_{n_vectors}.png")
        
        # save an image with the actual samples
        fig, ax = plt.subplots(1, 1, figsize=(4, 4))
        plot_score_function( 
            distribution2score(mixture),
            sigmas,
            train_data,
            "TwoGaussMixture Score Function",
            ax=ax,
            npts=30,
        )
        plt.savefig(f"training_experiments/" + exp_folder + "/actual_samples.png")
        
    elif args.mode == "test":
        logg.info("Starting testing")
        
        if args.dataset == "mixture":
            predicted_losses = test_mix(
                mixture= mixture,
                test_data=train_data,
                sigmas=sigmas,                
            )
            print(predicted_losses)
            labels = os.listdir("trained_models")
            for i in range(len(labels)):
                labels[i] = "_".join(labels[i].split("_")[6:])        
            # Create a DataFrame for Seaborn
            predicted_losses = [loss.item() for loss in predicted_losses]
            df = pd.DataFrame({'Labels': labels, 'Losses': predicted_losses})
            sorted_data = sorted(zip(labels, predicted_losses), key=lambda x: x[1])
            sorted_labels, sorted_losses = zip(*sorted_data)
            # Plot box plot
            plt.figure(figsize=(10, 6))
            sns.barplot(x=list(sorted_labels), y=list(sorted_losses), palette='viridis')
            plt.xlabel('Experiment Labels')
            plt.ylabel('Predicted Losses')
            plt.title('Histogram of Predicted Losses by Experiment')
            plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels for better visibility
            plt.tight_layout() 
            plt.savefig(f"training_experiments/" + "/histogram.png")
                    
        else :
            test_ncsn(
                path=path,
                sigmas=sigmas,
                visualize=True,
                use_cuda=args.use_cuda,
                n_samples=args.n_samples,
                n_steps=args.n_steps,
                save_freq=args.save_freq,
                eps=args.eps,
                dataset=args.dataset,
            )
    elif args.mode == "inpaint":
        logg.info("Starting inpainting")
        inpaint_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(torch.tensor(test_data)),
            batch_size=batch_size,
            shuffle=True,
        )
        inpaint_ncsn(
            path=path,
            sigmas=sigmas,
            use_cuda=args.use_cuda,
            n_samples=args.n_samples,
            n_steps=args.n_steps,
            dataset=args.dataset,
            direction=args.directions,
        )
    else:
        raise ValueError('The argument mode must have the values "train" or "generate"')

In [None]:
def test_ncsn(
    path: str,
    sigmas: torch.Tensor,
    visualize: bool = True,
    use_cuda: bool = False,
    n_samples: int = 5,
    n_steps: int = 100,
    save_freq: int = 50,
    eps: float = 5e-5,
    dataset: str = "mnist",
):
    if dataset == "mnist":
        refine_net = NoiseConditionalScoreNetwork(use_cuda=use_cuda)
    elif dataset == "cifar10":
        print("dataset is cifar10")
        refine_net = NoiseConditionalScoreNetwork(
            use_cuda=use_cuda, n_channels=3, image_size=32, num_classes=10, ngf=128
        )
    elif dataset == "celeba":
        refine_net = NoiseConditionalScoreNetwork(
            use_cuda=use_cuda, n_channels=3, image_size=32, num_classes=10, ngf=128
        )
    print("dataset: ", dataset)
    print("path: ", path)
    states = torch.load(path)
    pretrained = False
    if len(states) == 2:  # optimizer state was also saved in the checkpoint
        refine_net.load_state_dict(states[0])
        pretrained = True
    else:
        refine_net.load_state_dict(torch.load(path))
    print("Model is pretrained: ", pretrained)
    refine_net.cuda()
    refine_net.eval()
    samples, history = refine_net.sample(
        n_samples=n_samples, n_steps=n_steps, sigmas=sigmas, eps=eps, save_history=True
    )
    if visualize:
        visualize_history(
            samples,
            history,
            sigmas,
            save_freq,
            pretrained,
            dataset=dataset,
            save_folder=f"{n_samples}_samples_{n_steps}_steps_sigma_{sigmas[0]:.4f}_{sigmas[-1]:.4f}_eps_{eps:.5f}_dataset_{dataset}",
        )

In [None]:
def test_mix(mixture, test_data: torch.Tensor, sigmas: torch.Tensor):
    true_scores = distribution2score(mixture)(test_data.cuda(), None)
    labels = torch.arange(len(sigmas)).cuda()
    labels = labels.repeat_interleave(test_data.size(0) // len(labels))
    predicted_losses = []
    for model_name in os.listdir("trained_models"):
        model = torch.load(f"trained_models/{model_name}")
        score = model(test_data.cuda(), labels)
        loss = 0.5*(torch.norm(score - true_scores, p=2, dim=-1)**2).mean()
        predicted_losses.append(loss.detach().cpu().numpy())
    return predicted_losses

In [None]:
def visualize_history(
    samples, history, sigmas, save_freq, pretrained, dataset, save_folder="samples"
):
    print("Visualizing history")
    grid_samples = make_grid(samples, nrow=5)
    grid_img = grid_samples.permute(1, 2, 0).clip(0, 1)
    print("Saving images")
    # creae save folder
    if not os.path.exists(save_folder):
        if pretrained:
            save_folder = save_folder + "_pretrained"
            os.makedirs(save_folder)
        else:
            os.makedirs(save_folder)
    steps_per_sigma = int(len(history) / len(sigmas))
    for step in range(len(history)):
        sigma_step = step % steps_per_sigma
        sigma_idx = step // steps_per_sigma
        grid_samples = make_grid(history[step], nrow=5)
        grid_img = grid_samples.permute(1, 2, 0).clip(0, 1)
        # save images in the save folder after converting them to numpy arrays
        grid_img = grid_img.cpu().numpy()
        step_size = sigma_step * save_freq
        print("grid img min max: ", grid_img.min(), grid_img.max())
        plt.imsave(
            f"{save_folder}/sigma_{sigmas[sigma_idx]:.4f}_step_{step_size}.png",
            grid_img,
        )
    gc.collect()

In [None]:
def anneal_Langevin_dynamics_inpainting(
    x_mod,
    image,
    scorenet,
    sigmas,
    img_size,
    n_channels,
    direction="left",
    n_steps_each=100,
    step_lr=0.000008,
):
    images = []
    original_image = image.unsqueeze(1).expand(-1, x_mod.shape[1], -1, -1, -1)
    original_image = original_image.contiguous().view(
        -1, n_channels, img_size, img_size
    )
    x_mod = x_mod.view(-1, n_channels, img_size, img_size)
    mask = torch.zeros_like(image)
    if direction == "left":
        half_original_image = original_image[:, :, :, : img_size // 2]
        mask[:, :, :, : img_size // 2] = 1.0 
    elif direction == "right":
        half_original_image = original_image[:, :, :, img_size // 2 :]
        mask[:, :, :, img_size // 2 :] = 1.0
    elif direction == "top":
        half_original_image = original_image[:, :, : img_size // 2, :]
        mask[:, :, : img_size // 2, :] = 1.0
    elif direction == "bottom":
        half_original_image = original_image[:, :, img_size // 2 :, :]
        mask[:, :, img_size // 2 :, :] = 1.0
    occluded_img = image * mask
    # save half original image
    # save_image(half_original_image, 'inpainting/half_original_image_
    with torch.no_grad():
        for c, sigma in tqdm(
            enumerate(sigmas),
            total=len(sigmas),
            desc="annealed Langevin dynamics sampling",
        ):
            labels = torch.ones(x_mod.shape[0], device=x_mod.device) * c
            labels = labels.long()
            step_size = step_lr * (sigma / sigmas[-1]) ** 2
            corrupted_half_image = (
                half_original_image + torch.randn_like(half_original_image) * sigma
            )
            # save corrupted half image
            if direction == "left":
                x_mod[:, :, :, : img_size // 2] = corrupted_half_image
            elif direction == "right":
                x_mod[:, :, :, img_size // 2 :] = corrupted_half_image
            elif direction == "top":
                x_mod[:, :, : img_size // 2, :] = corrupted_half_image
            elif direction == "bottom":
                x_mod[:, :, img_size // 2 :, :] = corrupted_half_image
            for s in range(n_steps_each):
                images.append(torch.clamp(x_mod, 0.0, 1.0).to("cpu"))
                noise = torch.randn_like(x_mod) * np.sqrt(step_size * 2)
                grad = scorenet(x_mod, labels)
                x_mod = x_mod + step_size * grad + noise
                if direction == "left":
                    x_mod[:, :, :, : img_size // 2] = corrupted_half_image
                elif direction == "right":
                    x_mod[:, :, :, img_size // 2 :] = corrupted_half_image
                elif direction == "top":
                    x_mod[:, :, : img_size // 2, :] = corrupted_half_image
                elif direction == "bottom":
                    x_mod[:, :, img_size // 2 :, :] = corrupted_half_image
        #
        return images, occluded_img

In [None]:
def inpaint_ncsn(path, sigmas, use_cuda, n_samples, n_steps, dataset, direction):
    if dataset == "mnist":
        refine_net = NoiseConditionalScoreNetwork(
            use_cuda=use_cuda, n_channels=1, image_size=28, num_classes=10
        )
    elif dataset == "cifar10":
        refine_net = NoiseConditionalScoreNetwork(
            use_cuda=use_cuda, n_channels=3, image_size=32, num_classes=10, ngf=128
        )
    states = torch.load(path)
    if len(states) == 2:  # optimizer state was also saved in the checkpoint
        refine_net.load_state_dict(states[0])
        pretrained = True
    else:
        refine_net.load_state_dict(torch.load(path))
    refine_net.cuda()
    refine_net.eval()
    # download test samples of MNIST
    if dataset == "mnist":
        train_data, test_data = load_dataset("mnist", flatten=False, binarize=False)
    if dataset == "cifar10":
        train_data, test_data = load_dataset("cifar10", flatten=False, binarize=False)
    data_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=n_samples,
        shuffle=False,
        num_workers=0,
        drop_last=True,
    )
    test_data = next(iter(data_loader))
    samples = torch.rand(
        n_samples,
        n_samples,
        refine_net.n_channels,
        refine_net.image_size,
        refine_net.image_size,
    ).cuda()
    # save this test
    save_image(
        test_data,
        "inpainting/original_{0}_{1}_{2}_{3}.png".format(
            dataset, direction, n_steps, n_samples
        ),
        nrow=5,
    )
    images, occluded_img = anneal_Langevin_dynamics_inpainting(
        samples,
        test_data,
        refine_net,
        sigmas,
        n_steps_each=n_steps,
        step_lr=0.00002,
        img_size=refine_net.image_size,
        n_channels=refine_net.n_channels,
        direction=direction,
    )
    imgs = []
    print("occluded image shape: ", occluded_img.shape)
    print("test data shape: ", test_data.shape)
    # Convert the occluded image to a grid
    occluded_grid = make_grid(
        occluded_img,
        nrow=1,
        normalize=True,
        scale_each=True,
    )
    test_grid = make_grid(
        test_data,
        nrow=1,
        normalize=True,
        scale_each=True,
    )
    print(occluded_grid.shape, test_grid.shape)
    for i, sample in tqdm(enumerate(images)):
        sample = sample.view(
            n_samples**2,
            refine_net.n_channels,
            refine_net.image_size,
            refine_net.image_size,
        )
        image_grid = make_grid(
            sample,
            nrow=n_samples,
            normalize=True,
            scale_each=True,
        )
        image_grid = torch.cat([occluded_grid, image_grid, test_grid], dim=2)
        if i % 10 == 0:
            im = Image.fromarray(
                image_grid.mul_(255)
                .add_(0.5)
                .clamp_(0, 255)
                .permute(1, 2, 0)
                .to("cpu", torch.uint8)
                .numpy()
            )
            imgs.append(im)
        # save last image
        # save_image(image_grid, 'inpainting/latest_inpainting_{0}_{1}_{2}_{3}_sigma_{4}_{5}_{6}.png'.format(dataset, direction, n_steps, n_samples, sigmas[0], sigmas[-1], i), nrow=n_samples)
    imgs[-1].save(
        "inpainting/latest_inpainting_{0}_{1}_{2}_{3}_sigma_{4}_{5}.png".format(
            dataset, direction, n_steps, n_samples, sigmas[0], sigmas[-1]
        )
    )
    imgs[0].save(
        "inpainting/inpainting_{0}_{1}_{2}_{3}_sigma_{4}_{5}.gif".format(
            dataset, direction, n_steps, n_samples, sigmas[0], sigmas[-1]
        ),
        save_all=True,
        append_images=imgs[1:],
        optimize=False,
        duration=40,
        loop=0,
    )
    # show the gif