In [None]:
import yaml
import os
from PIL import Image
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import Adam
import torchvision.transforms as transforms
import torchvision
from torchvision.utils import make_grid

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DDPM_CONFIG = "/kaggle/input/ddpm-configs/other/ddpm-configs/1/ddpm.yaml"
IMG_SAVE_ROOT = "outputs/DDPM"
CKPT_SAVE = "checkpoints/DDPM"

os.makedirs(IMG_SAVE_ROOT,exist_ok=True)
os.makedirs(CKPT_SAVE,exist_ok=True)

In [None]:
ROOT = "/kaggle/input/celebav-randomframes/data"
BATCH_SIZE = 2
IMG_SIZE = 64
SAMPLE_STEP = 2500
NUM_TIMESTEPS = 1000
CKPT_PATH = "/kaggle/input/hueshift-ddpm/pytorch/v4/1/ddpm.pth"
RESUME = False

In [None]:
# Read the config file #
with open(DDPM_CONFIG, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
ddpm_model_config = config['model_config']
ddpm_dataset_config = config['dataset_config']
ddpm_training_config = config['training_config']

In [None]:
class Dataset(Dataset):
    def __init__(self , root , transform=None):
        self.root = root
        self.files = os.listdir(root)
        self.len = len(self.files)
        if transform is not None:
            self.transforms = transforms.Compose(transform)
        else:
            self.transforms = None

        
    def __getitem__(self , i):
        file = self.files[i]
        im = cv2.imread(f'{self.root}/{file}')
        im = cv2.cvtColor(im, cv2.COLOR_BGR2LAB)
        if self.transforms is not None:
            im = self.transforms(im)
        return im
    
    def __len__(self):
        return self.len

In [None]:
class LinearNoiseSchedule:
    def __init__(self, T):
        super().__init__()
        beta_start = 1E-4
        beta_end = 0.02

        self.beta = torch.linspace(beta_start,beta_end,T,dtype=torch.float32)
        self.alpha = 1 - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

        # for forward process
        self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat) # for mean
        self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat) #for std_dev

        # for sampling process
        self.one_by_sqrt_alpha = 1. / torch.sqrt(self.alpha) # for mean
        self.one_by_sqrt_one_minus_alpha_hat = 1. / self.sqrt_one_minus_alpha_hat # for mean
        self.sqrt_beta = torch.sqrt(self.beta) # for std_dev
        
    def forward(self, x0, t):
        #separte L from A and B --> c0
        l,c0 = torch.split(x0,[1,2],dim=1)
        noise = torch.randn_like(c0).to(x0.device)

        sqrt_alpha_hat = self.sqrt_alpha_hat.to(x0.device)[t]
        sqrt_one_minus_alpha_hat = self.sqrt_one_minus_alpha_hat.to(x0.device)[t]

        # reshape to match no of dims (b,) -> (b,c,h,w)
        for _ in range(len(x0.shape) - 1):
            sqrt_alpha_hat = sqrt_alpha_hat.unsqueeze(-1)
            sqrt_one_minus_alpha_hat = sqrt_one_minus_alpha_hat.unsqueeze(-1)

        mean = sqrt_alpha_hat.to(x0.device) * c0
        std_dev = sqrt_one_minus_alpha_hat.to(x0.device)

        sample = mean + std_dev * noise 

        sample = torch.cat([l,sample],dim=1)
        return sample, noise # noise --> predicted by the model
    
    def backward(self,xt,noise_pred,t):
        l,ct = torch.split(xt,[1,2],dim=1)
        # Reshaping
        one_by_sqrt_alpha = self.one_by_sqrt_alpha.to(xt.device)[t]
        beta = self.beta.to(xt.device)[t]
        one_by_sqrt_one_minus_alpha_hat = self.one_by_sqrt_one_minus_alpha_hat.to(xt.device)[t]
        sqrt_beta = self.sqrt_beta.to(xt.device)[t]

        # reshape to match no of dims (b,) -> (b,c,h,w)
        for _ in range(len(xt.shape) - 1):
            one_by_sqrt_alpha = one_by_sqrt_alpha.unsqueeze(-1)
            one_by_sqrt_one_minus_alpha_hat = one_by_sqrt_one_minus_alpha_hat.unsqueeze(-1)
            beta = beta.unsqueeze(-1)
        
        mean = one_by_sqrt_alpha * (ct - beta * one_by_sqrt_one_minus_alpha_hat * noise_pred)
        std_dev = sqrt_beta


        if t==0:
            out = torch.cat([l, mean],dim=1)
        else:
            z = torch.randn_like(ct).to(xt.device)
            out = torch.cat([l, mean + std_dev * z],dim=1)
        return out

In [None]:
def get_time_embedding(time_steps, temb_dim):
    r"""
    Convert time steps tensor into an embedding using the
    sinusoidal time embedding formula
    :param time_steps: 1D tensor of length batch size
    :param temb_dim: Dimension of the embedding
    :return: BxD embedding representation of B time steps
    """
    assert temb_dim % 2 == 0, "time embedding dimension must be divisible by 2"
    
    # factor = 10000^(2i/d_model)
    factor = 10000 ** ((torch.arange(
        start=0, end=temb_dim // 2, dtype=torch.float32, device=time_steps.device) / (temb_dim // 2))
    )
    
    # pos / factor
    # timesteps B -> B, 1 -> B, temb_dim
    t_emb = time_steps[:, None].repeat(1, temb_dim // 2) / factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
    return t_emb

In [None]:
class DownBlock(nn.Module):
    r"""
    Down conv block with attention.
    Sequence of following block
    1. Resnet block with time embedding
    2. Attention block
    3. Downsample
    """
    
    def __init__(self, in_channels, out_channels, t_emb_dim,
                 down_sample, num_heads, num_layers, attn, norm_channels, cross_attn=False, context_dim=None):
        super().__init__()
        self.num_layers = num_layers
        self.down_sample = down_sample
        self.attn = attn
        self.context_dim = context_dim
        self.cross_attn = cross_attn
        self.t_emb_dim = t_emb_dim
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for i in range(num_layers)
            ]
        )
        if self.t_emb_dim is not None:
            self.t_emb_layers = nn.ModuleList([
                nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(self.t_emb_dim, out_channels)
                )
                for _ in range(num_layers)
            ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        
        if self.attn:
            self.attention_norms = nn.ModuleList(
                [nn.GroupNorm(norm_channels, out_channels)
                 for _ in range(num_layers)]
            )
            
            self.attentions = nn.ModuleList(
                [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                 for _ in range(num_layers)]
            )

        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )

        #DownSampling
        self.down_sample_conv = nn.Conv2d(out_channels, out_channels,
                                          4, 2, 1) if self.down_sample else nn.Identity()
    
    def forward(self, x, t_emb=None, context=None):
        out = x
        for i in range(self.num_layers):
            # Resnet block of Unet
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            if self.t_emb_dim is not None:
                out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input) # residual connection
            
            if self.attn:
                # Attention block of Unet
                batch_size, channels, h, w = out.shape
                in_attn = out.reshape(batch_size, channels, h * w)
                in_attn = self.attention_norms[i](in_attn)
                in_attn = in_attn.transpose(1, 2)
                out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
                out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
                out = out + out_attn
            
        # Downsample
        out = self.down_sample_conv(out)
        return out


class MidBlock(nn.Module):
    r"""
    Mid conv block with attention.
    Sequence of following blocks
    1. Resnet block with time embedding
    2. Attention block
    3. Resnet block with time embedding
    """
    
    def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, num_layers, norm_channels, cross_attn=None, context_dim=None):
        super().__init__()
        self.num_layers = num_layers
        self.t_emb_dim = t_emb_dim
        self.context_dim = context_dim
        self.cross_attn = cross_attn
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers + 1)
            ]
        )
        
        if self.t_emb_dim is not None:
            self.t_emb_layers = nn.ModuleList([
                nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(t_emb_dim, out_channels)
                )
                for _ in range(num_layers + 1)
            ])
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers + 1)
            ]
        )
        
        self.attention_norms = nn.ModuleList(
            [nn.GroupNorm(norm_channels, out_channels)
             for _ in range(num_layers)]
        )
        
        self.attentions = nn.ModuleList(
            [nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
             for _ in range(num_layers)]
        )
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers + 1)
            ]
        )
    
    def forward(self, x, t_emb=None, context=None):
        out = x
        
        # First resnet block
        resnet_input = out
        out = self.resnet_conv_first[0](out)
        if self.t_emb_dim is not None:
            out = out + self.t_emb_layers[0](t_emb)[:, :, None, None]
        out = self.resnet_conv_second[0](out)
        out = out + self.residual_input_conv[0](resnet_input)
        
        for i in range(self.num_layers):
            # Attention Block
            batch_size, channels, h, w = out.shape
            in_attn = out.reshape(batch_size, channels, h * w)
            in_attn = self.attention_norms[i](in_attn)
            in_attn = in_attn.transpose(1, 2)
            out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
                
            
            # Resnet Block
            resnet_input = out
            out = self.resnet_conv_first[i + 1](out)
            if self.t_emb_dim is not None:
                out = out + self.t_emb_layers[i + 1](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i + 1](out)
            out = out + self.residual_input_conv[i + 1](resnet_input)
        
        return out


class UpBlock(nn.Module):
    r"""
    Up conv block with attention.
    Sequence of following blocks
    1. Upsample
    1. Concatenate Down block output
    2. Resnet block with time embedding
    3. Attention Block
    """
    
    def __init__(self, task, in_channels, out_channels, t_emb_dim,
                 up_sample, num_heads, num_layers, attn, norm_channels):
        super().__init__()
        self.task = task
        self.num_layers = num_layers
        self.up_sample = up_sample
        self.t_emb_dim = t_emb_dim
        self.attn = attn
        self.resnet_conv_first = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, in_channels if i == 0 else out_channels),
                    nn.SiLU(),
                    nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=3, stride=1,
                              padding=1),
                )
                for i in range(num_layers)
            ]
        )
        
        if self.t_emb_dim is not None:
            self.t_emb_layers = nn.ModuleList([
                nn.Sequential(
                    nn.SiLU(),
                    nn.Linear(t_emb_dim, out_channels)
                )
                for _ in range(num_layers)
            ])
        
        self.resnet_conv_second = nn.ModuleList(
            [
                nn.Sequential(
                    nn.GroupNorm(norm_channels, out_channels),
                    nn.SiLU(),
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
                )
                for _ in range(num_layers)
            ]
        )
        if self.attn:
            self.attention_norms = nn.ModuleList(
                [
                    nn.GroupNorm(norm_channels, out_channels)
                    for _ in range(num_layers)
                ]
            )
            
            self.attentions = nn.ModuleList(
                [
                    nn.MultiheadAttention(out_channels, num_heads, batch_first=True)
                    for _ in range(num_layers)
                ]
            )
            
        self.residual_input_conv = nn.ModuleList(
            [
                nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, kernel_size=1)
                for i in range(num_layers)
            ]
        )
        if self.task == "unet":
            self.up_sample_conv = nn.ConvTranspose2d(in_channels//2, in_channels//2, 4, 2, 1) if self.up_sample else nn.Identity()
        else: 
            self.up_sample_conv = nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1) if self.up_sample else nn.Identity()
    
    def forward(self, x, out_down=None, t_emb=None):
        # Upsample
        x = self.up_sample_conv(x)
        
        # Concat with Downblock output
        # used in diffusion but not in AE(since the output of encoder should be included in input of decoder)..maintain independance    
        if out_down is not None:
            x = torch.cat([x, out_down], dim=1)
        
        out = x
        for i in range(self.num_layers):
            # Resnet Block
            resnet_input = out
            out = self.resnet_conv_first[i](out)
            if self.t_emb_dim is not None:
                out = out + self.t_emb_layers[i](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[i](out)
            out = out + self.residual_input_conv[i](resnet_input)
            
            # Self Attention
            if self.attn:
                batch_size, channels, h, w = out.shape
                in_attn = out.reshape(batch_size, channels, h * w)
                in_attn = self.attention_norms[i](in_attn)
                in_attn = in_attn.transpose(1, 2)
                out_attn, _ = self.attentions[i](in_attn, in_attn, in_attn)
                out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
                out = out + out_attn
        return out

In [None]:
class Unet(nn.Module):
    r"""
    Unet model comprising
    Down blocks, Midblocks and Uplocks
    """
    # im_channels will be the no of input channels (latent channels)
    def __init__(self, in_channels, out_channels, model_config, condition=False):
        super().__init__()
        self.condition = condition
        self.down_channels = model_config['DOWN_CHANNELS'] # [256, 384, 512, 768]
        self.mid_channels = model_config['MID_CHANNELS'] # [768, 512]
        self.t_emb_dim = model_config['TIME_EMB_DIM'] # 512
        self.down_sample = model_config['DOWN_SAMPLE'] # [True, True, True]
        self.num_down_layers = model_config['NUM_DOWN_LAYERS'] # 2
        self.num_mid_layers = model_config['NUM_MID_LAYERS'] # 2
        self.num_up_layers = model_config['NUM_UP_LAYERS'] # 2
        self.attns = model_config['ATTN'] # [True, True, True]
        self.norm_channels = model_config['NORM_CHANNELS'] # 32
        self.num_heads = model_config['NUM_HEADS'] # 16
        self.conv_out_channels = model_config['CONV_OUT_CHANNELS'] # 128
        
        assert self.mid_channels[0] == self.down_channels[-1]
        assert self.mid_channels[-1] == self.down_channels[-2]
        assert len(self.down_sample) == len(self.down_channels) - 1
        assert len(self.attns) == len(self.down_channels) - 1


        # Spatial Conditioning
        if self.condition:
            self.cond_channels = model_config['CONDITION']['COND_CHANNELS']
            self.conv_in_concat = nn.Conv2d(in_channels + self.cond_channels,
                                            self.down_channels[0], kernel_size=3, padding=1)
        else:
            self.conv_in = nn.Conv2d(in_channels, self.down_channels[0], kernel_size=3, padding=1)
        
        # Initial projection from sinusoidal time embedding
        self.t_proj = nn.Sequential(
            nn.Linear(self.t_emb_dim, self.t_emb_dim),
            nn.SiLU(),
            nn.Linear(self.t_emb_dim, self.t_emb_dim)
        )
        
        # only change is to add time embeddings
        self.downs = nn.ModuleList([])
        for i in range(len(self.down_channels) - 1):
            self.downs.append(DownBlock(self.down_channels[i], self.down_channels[i + 1], 
                                        t_emb_dim=self.t_emb_dim,down_sample=self.down_sample[i],
                                        num_heads=self.num_heads,
                                        num_layers=self.num_down_layers,
                                        attn=self.attns[i],
                                        norm_channels=self.norm_channels))

        self.mids = nn.ModuleList([])
        for i in range(len(self.mid_channels) - 1):
            self.mids.append(MidBlock(self.mid_channels[i], self.mid_channels[i + 1], self.t_emb_dim,
                                      num_heads=self.num_heads,
                                      num_layers=self.num_mid_layers,
                                      norm_channels=self.norm_channels))
        
        self.ups = nn.ModuleList([])
        for i in reversed(range(len(self.down_channels) - 1)):
            self.ups.append(UpBlock("unet",self.down_channels[i] * 2, self.down_channels[i - 1] if i != 0 else self.conv_out_channels,
                                    self.t_emb_dim, up_sample=self.down_sample[i],
                                        num_heads=self.num_heads,
                                        num_layers=self.num_up_layers,
                                        attn=self.attns[i],
                                        norm_channels=self.norm_channels))
        
        self.norm_out = nn.GroupNorm(self.norm_channels, self.conv_out_channels)
        self.conv_out = nn.Conv2d(self.conv_out_channels, out_channels, kernel_size=3, padding=1)
    
    def forward(self, x, t, cond_in = None):
        # Shapes assuming downblocks are [C1, C2, C3, C4]
        # Shapes assuming midblocks are [C4, C4, C3]
        # Shapes assuming downsamples are [True, True, False]
        # B x C x H x W
        if cond_in is not None:
            # cond_in = self.cond_conv_in(cond_in)
            cond_in = nn.functional.interpolate(size = x.shape[-2:])
            x = torch.concat([x, cond_in],dim=1)
            out = self.conv_in_concat(x)

        else:
            out = self.conv_in(x)

        # B x C1 x H x W
        
        # t_emb -> B x t_emb_dim
        t_emb = get_time_embedding(torch.as_tensor(t).long(), self.t_emb_dim)
        t_emb = self.t_proj(t_emb)
        
        down_outs = []
        for down in self.downs:
            down_outs.append(out)
            out = down(out, t_emb)
        # down_outs  [B x C1 x H x W, B x C2 x H/2 x W/2, B x C3 x H/4 x W/4]
        # out B x C4 x H/4 x W/4
        
        for mid in self.mids:
            out = mid(out, t_emb)
        # out B x C3 x H/4 x W/4
        
        for up in self.ups:
            down_out = down_outs.pop()
            out = up(out, down_out, t_emb)
            # out [B x C2 x H/4 x W/4, B x C1 x H/2 x W/2, B x 16 x H x W]
        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)
        # out B x C x H x W
        return out

In [None]:
def lab_to_rgb(lab_image):
    # OpenCV expects the range [0, 255] for color images
    lab = (lab_image * 255).astype(np.uint8)  # Scale to [0, 255]
    rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)  # Convert LAB to BGR
    return rgb

In [None]:
def sample(L, gt, model, scheduler, sample_no):
    num_samples = L.shape[0]
    out_shape = (num_samples,2,IMG_SIZE,IMG_SIZE)
    xt = torch.cat([L, torch.randn(out_shape).to(DEVICE)],dim=1).to(DEVICE)

    for t in reversed(range(NUM_TIMESTEPS)):
        # Get prediction of noise
        timestep = torch.ones(num_samples, dtype=torch.long, device=DEVICE) * t

        noise_pred = model(xt, timestep)
        
        # Use scheduler to get x0 and xt-1
        xt = scheduler.backward(xt, noise_pred, t)

        # if t == 0:
        #     # Decode ONLY the final iamge to save time
        #     ims = vae.decode(xt)
        # else:
        #     ims = xt
        ims = xt

    # Add Ground Truths
    save_output = ims.cpu()
    save_output = save_output.permute(0,2,3,1).numpy() #open cv compatible
    save_output = np.array([lab_to_rgb(img) for img in save_output]) #convert lab->rgb
    save_output = torch.tensor(save_output).permute(0, 3, 1, 2).float() / 255.0 #convert to torch tensor image

    save_input = gt.cpu()
    save_input = save_input.permute(0,2,3,1).numpy()
    save_input = np.array([lab_to_rgb(img) for img in save_input])
    save_input = torch.tensor(save_input).permute(0, 3, 1, 2).float() / 255.0
    
    grid = make_grid(torch.cat([save_input, save_output], dim=0), nrow=num_samples)
    img = torchvision.transforms.ToPILImage()(grid)

    img.save(f"{IMG_SAVE_ROOT}/{str(sample_no).zfill(10)}.jpg")
    plt.imshow(img)
    plt.show()

In [None]:
scheduler = LinearNoiseSchedule(T=NUM_TIMESTEPS)

In [None]:
transform = [
    transforms.ToTensor(),
    transforms.Resize((IMG_SIZE, IMG_SIZE), Image.BICUBIC),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]


data_loader = DataLoader(
    Dataset(ROOT,transform),
    batch_size= BATCH_SIZE,
    shuffle = True,
    num_workers = 2
)

In [None]:
model = Unet(in_channels = 3, out_channels = 2, model_config = ddpm_model_config).to(DEVICE)

if RESUME: 
    print("loading state dict")
    model.load_state_dict(torch.load(CKPT_PATH))
model = torch.nn.DataParallel(model, device_ids = [0,1])

In [None]:
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 2**20
print('model size: {:.3f}MB'.format(size_all_mb))

In [None]:
optimizer = Adam(model.parameters(),lr=1E-5)
criterion = torch.nn.MSELoss()

In [None]:
def train(
        num_epochs,
        data_loader,
        optimizer,
        T,
        scheduler,
        model,
        criterion,
        sample_step,
):
    step_count = 0
    sample_no = 0
    for epoch in range(num_epochs):
        losses = []
        for im in tqdm(data_loader):
            step_count+=1
            optimizer.zero_grad()
            im = im.float().to(DEVICE)
            L, _ = torch.split(im,[1,2],dim=1)
            L = L.float().to(DEVICE)

            # # moving from pixel space to latent space
            # with torch.no_grad():
            #     latent_im, _, _, _ = vae.encode(im)

            t = torch.randint(0,T,(im.shape[0],)).to(DEVICE)

            noisy_im, noise = scheduler.forward(im, t)
            noise_pred = model(noisy_im, t)

            loss = criterion(noise_pred, noise)
            losses.append(loss.item())
            loss.backward()
            optimizer.step()
            if step_count % sample_step == 0 or step_count==1:
                with torch.no_grad():
                    sample(L, im, model, scheduler, sample_no)
                sample_no+=1
        print(f"{epoch} Loss {np.mean(losses)}")
        torch.save(model.module.state_dict(), f"{CKPT_SAVE}/ddpm.pth")

In [None]:
train(
    num_epochs = 1,
    data_loader = data_loader,
    optimizer = optimizer,
    T = NUM_TIMESTEPS,
    scheduler = scheduler,
    model = model,
    criterion = criterion,
    sample_step = SAMPLE_STEP
)