In [1]:
pip install torchsummary


Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install torchvision 

Note: you may need to restart the kernel to use updated packages.


In [3]:
import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
#from einops import rearrange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
from torchvision import transforms


In [21]:
class Encoder(nn.Module):

	def __init__(self, encoded_space_dim):
		super().__init__()

		### Convolutional section
		self.encoder_cnn = nn.Sequential(
		nn.Conv2d(1, 8, 3, stride=2, padding=1),
		nn.ReLU(True),
		nn.Conv2d(8, 16, 3, stride=2, padding=1),
		nn.BatchNorm2d(16),
		nn.ReLU(True),
		nn.Conv2d(16, 32, 3, stride=2, padding=0),
		nn.ReLU(True)
		)

		### Flatten layer
		self.flatten = nn.Flatten(start_dim=1)
### Linear section
		self.encoder_lin = nn.Sequential(
			nn.Linear(3 * 3 * 32, 128),
			nn.ReLU(True),
			nn.Linear(128, encoded_space_dim)
		)

	def forward(self, x):
		x = self.encoder_cnn(x)
		x = self.flatten(x)
		x = self.encoder_lin(x)
		return x
	#output is of dimension (1, encoded_space_dim)
class Decoder(nn.Module):
	def __init__(self, encoded_space_dim):
		super().__init__()
		self.decoder_lin = nn.Sequential(
			nn.Linear(encoded_space_dim, 128),
			nn.ReLU(True),
			nn.Linear(128, 3 * 3 * 32),
			nn.ReLU(True)
		)

		self.unflatten = nn.Unflatten(dim=1,
		unflattened_size=(32, 3, 3))

		self.decoder_conv = nn.Sequential(
			nn.ConvTranspose2d(32, 16, 3,
			stride=2, output_padding=0),
			nn.BatchNorm2d(16),
			nn.ReLU(True),
			nn.ConvTranspose2d(16, 8, 3, stride=2,
			padding=1, output_padding=1),
			nn.BatchNorm2d(8),
			nn.ReLU(True),
			nn.ConvTranspose2d(8, 1, 3, stride=2,
			padding=1, output_padding=1)
		)

	def forward(self, x):
		x = self.decoder_lin(x)
		x = self.unflatten(x)
		x = self.decoder_conv(x)
		x = torch.sigmoid(x)
		return x

In [34]:
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):

        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), 
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), 
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2),
            act_fn(),
            nn.Flatten(), 
            nn.Linear(2*16*c_hid, latent_dim)
        )

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):

        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), 
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=0),
            act_fn(),
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2),
            nn.Tanh() 
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

In [36]:
enc = Encoder(1, 256,128)
x        = torch.randn(1000, 1, 28, 28)
enc(x).shape


torch.Size([1000, 128])

In [38]:
class ResNet(nn.Module):
    def __init__(self, in_ch, out_ch, num_blocks=4, num_layers=4, num_filters=64, kernel_size=3, stride=1, padding=1,
                 dilation=1, groups=1, bias=True, padding_mode='zeros', activation=nn.ReLU, norm=nn.BatchNorm2d,
                 dropout=nn.Dropout2d, residual=True, **kwargs):
        super().__init__()
        self.residual = residual
        self.activation = activation
        self.norm = norm
        self.dropout = dropout
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.num_blocks = num_blocks
        self.num_layers = num_layers
        self.num_filters = num_filters
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.padding_mode = padding_mode
        self.kwargs = kwargs

        self.blocks = nn.ModuleList([self._make_block() for _ in range(self.num_blocks)])
        self.head = nn.Conv2d(self.num_filters, self.out_ch, 1)

    def _make_block(self):
        layers = []
        for _ in range(self.num_layers):
            layers.append(nn.Conv2d(self.in_ch, self.num_filters, self.kernel_size, self.stride, self.padding,
                                    self.dilation, self.groups, self.bias, self.padding_mode))
            if self.norm is not None:
                layers.append(self.norm(self.num_filters))
            if self.activation is not None:
                layers.append(self.activation())
            if self.dropout is not None:
                layers.append(self.dropout())
            self.in_ch = self.num_filters
        return nn.Sequential(*layers)

    def forward(self, x, cond=None):
        for block in self.blocks:
            res = x
            x = block(x)
            if cond is not None:
                x += nn.Linear(cond.shape[1], x.shape[1], bias=False)(cond)
            if self.residual:
                x = x + res
        return self.head(x)


# Score neural network for the diffusion process. Approximates what you should do at each timestep
class ScoreNet(nn.Module):

    def __init__(self, latent_dim, embedding_dim, n_blocks=32):
        super().__init__()
        self.latent_dim = latent_dim
        self.embedding_dim = embedding_dim
        self.resnet = ResNet(self.latent_dim, self.latent_dim, num_blocks=n_blocks, num_layers=4,
                             num_filters=64, kernel_size=1, stride=1, padding=1, dilation=1, groups=1, bias=True,
                             padding_mode='zeros', activation=nn.ReLU, norm=nn.BatchNorm2d, dropout=nn.Dropout2d,
                             residual=True)

    def forward(self, x, t, conditioning):
        
        timestep = get_timestep_embedding(t, self.embedding_dim)
        print("Timstep dim {}, Cond dim {}",timestep.shape,  conditioning.shape)
        cond = torch.cat([timestep, conditioning], dim=1)
        #cond=nn.Flatten(0)(cond)
        print("Input to liearn",cond.shape)
        
        cond = nn.SiLU()(nn.Linear(288, self.embedding_dim * 4)(cond))
        cond = nn.SiLU()(nn.Linear(self.embedding_dim * 4, self.embedding_dim * 4)(cond))
        cond = nn.Linear(self.embedding_dim * 4, self.embedding_dim)(cond)
        print(x.shape)
        #x=nn.Flatten(0)(x)
        print(x.shape)
        h = nn.Linear(self.latent_dim, self.embedding_dim)(x)
        h = self.resnet(h, cond)
        return x + h

cond=nn.rand()
cond=nn.Flatten()(cond)

In [41]:
def get_timestep_embedding(timesteps, embedding_dim):
    assert len(timesteps.shape) == 1
    timesteps *= 1000
    half_dim = embedding_dim // 2
    emb = np.log(10000) / (half_dim - 1)
    emb = np.exp(np.arange(half_dim) * -emb)
    emb = np.outer(timesteps, emb)
    emb = np.concatenate([np.sin(emb), np.cos(emb)], axis=1)
    print(timesteps.shape, embedding_dim)
    assert emb.shape == (timesteps.shape[0], embedding_dim)
    return torch.from_numpy(emb).float()

In [42]:
# Testing whole scorenet
Time=4
def gamma(ts, gamma_min=-6, gamma_max=6):
    return gamma_max + (gamma_min - gamma_max) * ts
g_t = gamma(Time)
embed=256
conditioning = torch.arange(128) % (10 + 26 + 26 + 1)
conditioning=torch.nn.Embedding( num_embeddings=128,embedding_dim=embed)(conditioning)
scorenet = ScoreNet(latent_dim=128,embedding_dim=32)
x        = torch.randn(1,128)
#setting the t to be a vector
t= g_t * np.ones(x.shape[1])

scorenet(x,t,conditioning)

(128,) 32
Timstep dim {}, Cond dim {} torch.Size([128, 32]) torch.Size([128, 256])
Input to liearn torch.Size([128, 288])
torch.Size([1, 128])
torch.Size([1, 128])


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 32]

# Testing dimension match

In [74]:
# Stesting get_timestep_embedding
Time=4
def gamma(ts, gamma_min=-6, gamma_max=6):
    return gamma_max + (gamma_min - gamma_max) * ts
g_t = gamma(Time)


#Dimneison of conditioning must be same as th second dimenison of the x in this case
conditioning = torch.arange(128) % (10 + 26 + 26 + 1)
conditioning=torch.nn.Embedding( num_embeddings=10 + 26 + 26 + 1,embedding_dim=32)(conditioning)
scorenet = ScoreNet(latent_dim=128,embedding_dim=128)
x        = torch.randn(1,128)#encoder output

#Setting the t to be a vector
t= g_t * np.ones(x.shape[1])
timestep=get_timestep_embedding(t, 32)
print(timestep.shape,  conditioning.shape)

(128,) 32
torch.Size([128, 32]) torch.Size([128, 32])


In [73]:
conditioning = torch.arange(256) % (10 + 26 + 26 + 1)
conditioning=torch.nn.Embedding( num_embeddings=10 + 26 + 26 + 1,embedding_dim=128)(conditioning)
conditioning.shape

torch.Size([256, 128])

In [52]:
cond       = torch.randn(10,128)
cond=nn.Flatten(0)(cond)

In [76]:
# Testing whole scorenet
Time=4
def gamma(ts, gamma_min=-6, gamma_max=6):
    return gamma_max + (gamma_min - gamma_max) * ts
g_t = gamma(Time)
embed=256
latent=49152
conditioning = torch.arange(128) % (10 + 26 + 26 + 1)
conditioning=torch.nn.Embedding( num_embeddings=128,embedding_dim=embed)(conditioning)
scorenet = ScoreNet(latent_dim=latent,embedding_dim=128)
x        = torch.randn(1,128)
#setting the t to be a vector
t= g_t * np.ones(x.shape[1])

scorenet(x,t,conditioning)


(128,) 128
Timstep dim {}, Cond dim {} torch.Size([128, 128]) torch.Size([128, 256])
Input to liearn torch.Size([49152])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x128 and 49152x128)

In [40]:
conditioning = torch.arange(128) % (10 + 26 + 26 + 1)
conditioning.shape

torch.Size([128])