# Data Setup

In [9]:
"""
Contains functionality for creating PyTorch DataLoaders for 
image classification data.
"""
import os

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
    train_dir: str, 
    test_dir: str, 
    transform: transforms.Compose, 
    batch_size: int, 
    num_workers: int=NUM_WORKERS
):
  """Creates training and testing DataLoaders.

  Takes in a training directory and testing directory path and turns
  them into PyTorch Datasets and then into PyTorch DataLoaders.

  Args:
    train_dir: Path to training directory.
    test_dir: Path to testing directory.
    transform: torchvision transforms to perform on training and testing data.
    batch_size: Number of samples per batch in each of the DataLoaders.
    num_workers: An integer for number of workers per DataLoader.

  Returns:
    A tuple of (train_dataloader, test_dataloader, class_names).
    Where class_names is a list of the target classes.
    Example usage:
      train_dataloader, test_dataloader, class_names = \
        = create_dataloaders(train_dir=path/to/train_dir,
                             test_dir=path/to/test_dir,
                             transform=some_transform,
                             batch_size=32,
                             num_workers=4)
  """
  # Use ImageFolder to create dataset(s)
  train_data = datasets.ImageFolder(train_dir, transform=transform)
  #test_data = datasets.ImageFolder(test_dir, transform=transform)

  # Get class names
  class_names = train_data.classes

  # Turn images into data loaders
  train_dataloader = DataLoader(
      train_data,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers,
      pin_memory=True,
  )
  # test_dataloader = DataLoader(
  #     test_data,
  #     batch_size=batch_size,
  #     shuffle=True,
  #     num_workers=num_workers,
  #     pin_memory=True,
  # )

  return train_dataloader

In [11]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataloader = create_dataloaders(
    "/Users/rachan/Desktop/CelebAx2_train", 
    "/Users/rachan/Desktop/CelebAx2_train", 
    transform,
    32
)

In [12]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x7fd7d89781f0>

In [16]:
img, label = next(iter(train_dataloader))

In [17]:
img.shape, label.shape

(torch.Size([32, 3, 128, 128]), torch.Size([32]))

# Model Unet

In [67]:
from torch import nn
import math


class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings


class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1 
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], 3, out_dim)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

In [68]:
model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model

Num params:  62438883


SimpleUnet(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): ReLU()
  )
  (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (downs): ModuleList(
    (0): Block(
      (time_mlp): Linear(in_features=32, out_features=128, bias=True)
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): Block(
      (time_mlp): Linear(in_features=32, out_features=256, bias=True)
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transfor

# Forward Diff

In [69]:
import torch.nn.functional as F

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """ 
    Takes an image and a timestep as input and 
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


# Define beta schedule
T = 10
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)


In [71]:
!python data_setup.py

<class 'torch.utils.data.dataloader.DataLoader'>


In [73]:
!python model_builder.py

Num params:  62438883
SimpleUnet(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): ReLU()
  )
  (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (downs): ModuleList(
    (0): Block(
      (time_mlp): Linear(in_features=32, out_features=128, bias=True)
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): Block(
      (time_mlp): Linear(in_features=32, out_features=256, bias=True)
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(

# Check loss

In [74]:
!python loss_function.py

MSELoss()


# Check Forward Diffusion

In [163]:
class Diffusion():
    def __init__(self, timesteps: int, device:str):
        self.T = timesteps
        self.betas = torch.linspace(0.0001, 0.02, self.T)
        self.alphas = 1. - self.betas
        self.initialize()

    def initialize(self):
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.alphas_cumprod_prev = F.pad(
            self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(
            1. - self.alphas_cumprod)
        self.posterior_variance = self.betas * \
            (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)

    def get_index_from_list(self, vals, t, x_shape):
        """ 
        Returns a specific index t of a passed list of values vals
        while considering the batch dimension.
        """
        batch_size = t.shape[0]
        out = vals.gather(-1, t.cpu())
        return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

    def forward_diffusion_sample(self, x_0, t, device='cpu'):
        """ 
        Takes an image and a timestep as input and 
        returns the noisy version of it
        """
        print(device)
        noise = torch.randn_like(x_0)
        sqrt_alphas_cumprod_t = get_index_from_list(
            df.sqrt_alphas_cumprod, t, x_0.shape)
        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
                df.sqrt_one_minus_alphas_cumprod, t, x_0.shape
            )
            # mean + variance
        return sqrt_alphas_cumprod_t * x_0.to(device) \
            + sqrt_one_minus_alphas_cumprod_t.to(device) * \
            noise.to(device), noise.to(device)


In [170]:
df = Diffusion(timesteps=100, device='cpu')

print(df)

<__main__.Diffusion object at 0x7fd7df270460>


In [165]:
x_0 = torch.randn([3, 128, 128])
t = torch.randint(0, T, (1,), device='cpu').long()

In [166]:
df.forward_diffusion_sample(x_0, t, 'cpu')

cpu


(tensor([[[ 0.6509,  1.5814, -1.2953,  ...,  1.1174, -1.8700, -0.8699],
          [-0.1046,  0.8800, -2.1652,  ..., -0.6017,  1.4224, -1.5678],
          [-1.2751, -0.9727, -0.9866,  ..., -1.0991, -0.6591,  0.9050],
          ...,
          [-1.8942, -2.2021,  0.4821,  ...,  0.7815,  1.2127, -0.0481],
          [ 0.3326,  1.5892,  1.2392,  ..., -0.1221, -0.4928,  0.9036],
          [-0.0869,  0.6065, -0.8448,  ..., -0.6293, -0.7503,  0.7203]],
 
         [[-0.9900,  0.0377, -0.1942,  ..., -0.1260,  1.5745, -1.4457],
          [ 0.3763,  0.1199,  0.6197,  ..., -0.6600, -0.9085,  1.2543],
          [-0.5720, -0.6608,  0.6289,  ...,  1.0766, -1.1316, -0.3396],
          ...,
          [-1.2903,  0.3768,  0.5196,  ..., -0.5940, -2.2943, -0.8634],
          [-0.9143,  0.1486,  0.2058,  ..., -0.6821,  0.3741, -1.0009],
          [-1.0774,  0.8922,  0.6868,  ..., -0.0298, -0.2245,  0.2329]],
 
         [[ 1.7853, -0.3524, -1.5612,  ...,  0.0469, -1.2702,  1.3568],
          [ 1.5059, -2.3284,

In [168]:
!python forward.py

tensor([0.9999, 0.9997, 0.9995, 0.9993, 0.9991, 0.9989, 0.9987, 0.9985, 0.9983,
        0.9981, 0.9979, 0.9977, 0.9975, 0.9973, 0.9971, 0.9969, 0.9967, 0.9965,
        0.9963, 0.9961, 0.9959, 0.9957, 0.9955, 0.9953, 0.9951, 0.9949, 0.9947,
        0.9945, 0.9943, 0.9941, 0.9939, 0.9937, 0.9935, 0.9933, 0.9931, 0.9929,
        0.9927, 0.9925, 0.9923, 0.9921, 0.9919, 0.9917, 0.9915, 0.9913, 0.9911,
        0.9909, 0.9907, 0.9905, 0.9903, 0.9901, 0.9898, 0.9896, 0.9894, 0.9892,
        0.9890, 0.9888, 0.9886, 0.9884, 0.9882, 0.9880, 0.9878, 0.9876, 0.9874,
        0.9872, 0.9870, 0.9868, 0.9866, 0.9864, 0.9862, 0.9860, 0.9858, 0.9856,
        0.9854, 0.9852, 0.9850, 0.9848, 0.9846, 0.9844, 0.9842, 0.9840, 0.9838,
        0.9836, 0.9834, 0.9832, 0.9830, 0.9828, 0.9826, 0.9824, 0.9822, 0.9820,
        0.9818, 0.9816, 0.9814, 0.9812, 0.9810, 0.9808, 0.9806, 0.9804, 0.9802,
        0.9800])
(tensor([[[ 0.2565,  0.1360, -0.5729,  ...,  0.7214, -0.5790, -0.5875],
         [ 0.4977,  0.9400,  0.