In [None]:
from typing import Dict, Optional, Tuple
from sympy import Ci
from tqdm import tqdm
import os

from PIL import Image

import matplotlib.pyplot as plt
import numpy as np


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision

from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.utils import save_image, make_grid

from mindiffusion.unet import NaiveUnet
from mindiffusion.ddpm import DDPM

In [None]:
class MapsDataset(Dataset):
    """ Highway maps dataset."""

    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.im_names = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.im_names)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.im_names[idx])
        image =torchvision.io.read_image(img_name)/255

        if self.transform:
            image = self.transform(image)

        return image

In [None]:
def train_maps(
   path_to_dir, n_epoch: int = 100, device: str = "cuda:0",
   load_pth: Optional[str] = None , Flip: bool = False,
   lr: float = 5e-5
) -> None:

    ddpm = DDPM(eps_model=NaiveUnet(3, 3, n_feat=128), betas=(1e-4, 0.02), n_T=1000)

    if load_pth is not None:
        ddpm.load_state_dict(torch.load(load_pth))

    ddpm.to(device)
    if Flip:
        tf = transforms.Compose(
            [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
             transforms.RandomVerticalFlip(),
             transforms.RandomHorizontalFlip()
            ])
    else:
         tf = transforms.Compose(
    [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])


    maps_dataset = MapsDataset(root_dir = path_to_dir,
                   transform = tf
                   )

    dataloader = DataLoader(maps_dataset, batch_size=128,
                            shuffle=True, num_workers=2)
    optim = torch.optim.Adam(ddpm.parameters(), lr=lr)

    for i in range(n_epoch):
        print(f"Epoch {i} : ")
        ddpm.train()

        pbar = tqdm(dataloader)
        loss_ema = None
        for x in pbar:
            optim.zero_grad()
            x = x.to(device)
            loss = ddpm(x)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.96 * loss_ema + 0.04 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()

        ddpm.eval()
        with torch.no_grad():
            if (i+1)%10 ==0:
                xh = ddpm.sample(8, (3, 64, 64), device)
                xset = torch.cat([xh, x[:8]], dim=0)
                grid = make_grid(xset, normalize=True, nrow=4, scale_each=True)
                save_image(grid, f"ddpm_sample_maps_10000_{i}.png")

            # save model
 
            torch.save(ddpm.state_dict(),"./data/ddpm_maps_10000.pth")

In [None]:
path_to_dir = "./data/my_dataset"
load_pth = "./data/ddpm_maps_10000.pth"
train_maps(path_to_dir,Flip = True,load_pth = load_pth,n_epoch = 100)

In [None]:
path_to_dir = "./data/my_dataset"
load_pth = "./data/ddpm_maps_10000.pth"
train_maps(path_to_dir,Flip = True,load_pth = load_pth,n_epoch = 10,lr = 1e-5)

In [None]:
device = "cuda:0"

ddpm = DDPM(eps_model=NaiveUnet(3, 3, n_feat=128), betas=(1e-4, 0.02), n_T=1000)

ddpm.load_state_dict(torch.load("./data/ddpm_maps_10000.pth"))

ddpm.to(device)

ddpm.eval()
with torch.no_grad():
    xh = ddpm.sample(25, (3, 64, 64), device)
    xset = torch.cat([xh], dim=0)
    grid = make_grid(xset, normalize=True, nrow=5,scale_each= True)
    save_image(grid, f"ddpm_sample_maps.png")