In [1]:
import torch
import torch.utils
import torch.utils.data
from tqdm.auto import tqdm
from torch import nn
import argparse
import torch.nn.functional as F
import utils
import dataset
import os
import matplotlib.pyplot as plt
import numpy as np
import math
from typing import Optional, Union, List, Tuple
from helperClasses import TimeEmbedding, UNetModel

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class DDPM(nn.Module):
    def __init__(self, n_dim=3, n_steps=200, num_channels=128):
        """
        Noise prediction network for the DDPM

        Args:
            n_dim: int, the dimensionality of the data
            n_steps: int, the number of steps in the diffusion process
        We have separate learnable modules for `time_embed` and `model`. `time_embed` can be learned or a fixed function as well

        """
        super().__init__()
        self.time_embed = TimeEmbedding(n_steps)
        self.model = UNetModel(imageChannels=n_dim, numChannels=num_channels, numDownsampling=5, numBlocks=2, numIntermediateChannels=64, numResidualChannels=64, numFeatureChannels=64, numTopChannels=64, numOutputChannels=n_dim)

    def forward(self, x, t):
        """
        Args:
            x: torch.Tensor, the input data tensor [batch_size, n_dim]
            t: torch.Tensor, the timestep tensor [batch_size]

        Returns:
            torch.Tensor, the predicted noise tensor [batch_size, n_dim]
        """
        t = self.time_embed(t)
        return self.model(x, t)