In [11]:
import torch
import torchvision
from torch.utils import data
from torch import nn
import torch.nn.functional as F

In [12]:
image_to_tensor = torchvision.transforms.ToTensor()
tensor_to_image = torchvision.transforms.ToPILImage()

mnist_train = torchvision.datasets.MNIST('../data/', train= True, transform=image_to_tensor, download=True)
mnist_test = torchvision.datasets.MNIST('../data/', train= False, transform=image_to_tensor, download=True)

In [13]:
assert len(mnist_train) == 60000
assert len(mnist_test) == 10000
figure, label = mnist_train[0]
D = figure.numel()
assert D == 784

In [14]:
trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(mnist_test, batch_size=32, shuffle=True)

In [15]:
class FeedForward(nn.Module):
    def __init__(self, D, nlayers =8):
        super(FeedForward, self).__init__()
        assert type(D) == int
        assert type(nlayers) == int
        self.D = D
        self.nlayers = nlayers
        self.first = nn.Linear(D+1, D)
        self.linears = nn.ModuleList([nn.Linear(D, D) for i in range(nlayers-1)])
        

    def forward(self, x, t):
        x = torch.cat((x, t.reshape(-1, 1)), axis=1)
        x = self.first(x)
        for lin in self.linears:
            x = torch.clamp(lin(F.relu(x)), min=0., max=1.)
        return x
    
    

class Diffusion():
    def __init__(self, T, D, betas = 0.01):
        assert type(T) == int
        self.T = T
        self.betas = betas = torch.tensor(0.5).repeat(T) if type(betas) == float else betas
        self.alphas = 1-self.betas
        self.alphas_hat = torch.cumprod(self.alphas, axis=0)
        self.mvn = torch.distributions.MultivariateNormal(torch.zeros(D), torch.eye(D))
        

    def sample_t(self, B, all_equal = True):
        if all_equal:
            t = torch.randint(0, self.T, torch.Size([1])).repeat(B)
        else:
            t = torch.randint(0, self.T, torch.Size([B]))
        return t



In [16]:
epochs = 500
loss_hist = []

diff = Diffusion(T = 200, D = D, betas = torch.linspace(1e-04, 0.02, 200))
eps_theta = FeedForward(D)

optimizer = torch.optim.Adam(eps_theta.parameters(), lr = 1e-03)
L2_loss = torch.nn.MSELoss(reduction='sum')

In [17]:
for epoch in range(epochs):
    epoch_loss = 0

    for images, _ in trainloader:

        optimizer.zero_grad()

        B = images.shape[0]
        eps = diff.mvn.sample(torch.Size([B]))

        x0 = torch.flatten(images, start_dim = 1)
        t = diff.sample_t(B, all_equal=False)

        alphas_hat_array = torch.broadcast_to(diff.alphas_hat[t].reshape(-1, 1), (B, D))
        x = x0*torch.sqrt(alphas_hat_array) + eps*torch.sqrt(1-alphas_hat_array)

        eps_pred  = eps_theta(x, t/diff.T)
        loss = L2_loss(eps, eps_pred)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'epoch: {epoch} \t loss: {epoch_loss}')
    loss_hist.append(epoch_loss)


epoch: 0 	 loss: 47010550.69921875
epoch: 1 	 loss: 46955217.61328125
epoch: 2 	 loss: 46929584.890625
epoch: 3 	 loss: 46903828.900390625
epoch: 4 	 loss: 46885236.080078125
epoch: 5 	 loss: 46887525.328125
epoch: 6 	 loss: 46873618.0859375
epoch: 7 	 loss: 46850020.08203125
epoch: 8 	 loss: 46838223.017578125
epoch: 9 	 loss: 46840237.521484375
epoch: 10 	 loss: 46832952.642578125
epoch: 11 	 loss: 46831529.24609375
epoch: 12 	 loss: 46831905.294921875
epoch: 13 	 loss: 46813189.1875
epoch: 14 	 loss: 46800065.53125
epoch: 15 	 loss: 46806977.55078125
epoch: 16 	 loss: 46804328.1640625
epoch: 17 	 loss: 46797081.08984375
epoch: 18 	 loss: 46801228.59375
epoch: 19 	 loss: 46775368.7421875
epoch: 20 	 loss: 46756717.56640625
epoch: 21 	 loss: 46749003.36328125
epoch: 22 	 loss: 46756823.021484375
epoch: 23 	 loss: 46753701.275390625
epoch: 24 	 loss: 46762081.12109375
epoch: 25 	 loss: 46745911.037109375
epoch: 26 	 loss: 46750144.5546875
epoch: 27 	 loss: 46749096.53125
epoch: 28 	 lo

KeyboardInterrupt: 

In [84]:
img = images[0]
img.shape

torch.Size([1, 28, 28])

In [86]:
conv_11 = nn.Conv2d(in_channels=1, out_channels= 16, kernel_size=3, padding=1)
conv_12 = nn.Conv2d(in_channels=16, out_channels= 16, kernel_size=3, padding=1)

mp = nn.MaxPool2d(kernel_size = (2,2), stride= 2)

conv_21 = nn.Conv2d(in_channels=16, out_channels= 32, kernel_size=3, padding=1)
conv_22 = nn.Conv2d(in_channels=32, out_channels= 32, kernel_size=3, padding=1)

conv_31 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
conv_32 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

upconv_1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2) 

conv_41 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
conv_42 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)

upconv_2 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2) 

conv_51 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1)
conv_52 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1)

conv_out = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)

mp = nn.MaxPool2d(kernel_size=2, stride= 2)

x_11 = img
x_12 = conv_11(x_11)
x_13 = conv_12(x_12)

print(x_11.shape, x_12.shape, x_13.shape)

x_21 = mp(x_13)
x_22 = conv_21(x_21)
x_23 = conv_22(x_22)

print(x_21.shape, x_22.shape, x_23.shape)

x_31 = mp(x_23)
x_32 = conv_31(x_31)
x_33 = conv_32(x_32)

print(x_31.shape, x_32.shape, x_33.shape)

x_41 = upconv_1(x_33)
x_42 = conv_41(x_41)
x_43 = conv_42(x_42)

print(x_41.shape, x_42.shape, x_43.shape)

x_51 = upconv_2(x_43)
x_52 = conv_51(x_51)
x_53 = conv_52(x_52)

print(x_51.shape, x_52.shape, x_53.shape)

x_out = conv_out(x_53)

print(x_out.shape)


torch.Size([1, 28, 28]) torch.Size([16, 28, 28]) torch.Size([16, 28, 28])
torch.Size([16, 14, 14]) torch.Size([32, 14, 14]) torch.Size([32, 14, 14])
torch.Size([32, 7, 7]) torch.Size([64, 7, 7]) torch.Size([64, 7, 7])
torch.Size([64, 14, 14]) torch.Size([32, 14, 14]) torch.Size([32, 14, 14])
torch.Size([32, 28, 28]) torch.Size([16, 28, 28]) torch.Size([16, 28, 28])
torch.Size([1, 28, 28])


In [88]:
class TimeEmbedding(nn.Module):
    def __init__(self, out_channels, out_fig_side):
        super(TimeEmbedding, self).__init__()
        assert type(out_fig_side) == int and type(out_channels) == int
        self.out_square_side = out_fig_side
        self.out_channels = out_channels
        self.linear = nn.Linear(in_features=1, out_features=self.out_channels*(out_fig_side**2))
        self.dropout == nn.Dropout()

    def forward(self, x):
        x = self.dropout(self.linear(x)).view(-1, self.out_channels, self.out_square_side, self.out_square_side)

        


class Conv2Block(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(Conv2Block, self).__init__()
        assert type(in_channels) == int and type(out_channels) == int
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv2d_1 = nn.Conv2d(in_channels=self.in_channels, out_channels= self.out_channels, kernel_size=3, padding=1)
        self.conv2d_2 = nn.Conv2d(in_channels=self.out_channels, out_channels= self.out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv2d_1(x))
        x = F.relu(self.conv2d_2(x))
        return x

        





class UNet(nn.Module):


    def __init__(self):
        super(Unet, self).__init__()
        self.conv_11 = nn.Conv2d(in_channels=1, out_channels= 16, kernel_size=3, padding=1)
        self.conv_12 = nn.Conv2d(in_channels=16, out_channels= 16, kernel_size=3, padding=1)

        self.mp = nn.MaxPool2d(kernel_size=2, stride= 2)

        self.conv_21 = nn.Conv2d(in_channels=16, out_channels= 32, kernel_size=3, padding=1)
        self.conv_22 = nn.Conv2d(in_channels=32, out_channels= 32, kernel_size=3, padding=1)

        self.conv_31 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv_32 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.upconv_1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2) 

        self.conv_41 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
        self.conv_42 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)

        self.upconv_2 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2) 

        self.conv_51 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1)
        self.conv_52 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1)

        self.conv_out = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)


        


In [None]:
timemb = TimeEmbedding(16, 28)

In [None]:
conv_11 = nn.Conv2d(in_channels=1, out_channels= 16, kernel_size=3)
conv_12 = nn.Conv2d(in_channels=16, out_channels= 16, kernel_size=3)

conv_21 = nn.Conv2d(in_channels=16, out_channels= 32, kernel_size=3)
conv_22 = nn.Conv2d(in_channels=32, out_channels= 32, kernel_size=3)

upconv_1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2) 

conv_31 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3)
conv_32 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3)

mp = nn.MaxPool2d(kernel_size = (2,2), stride= 2)

x_11 = img
x_12 = conv_11(x_11)
x_13 = conv_12(x_12)

print(x_11.shape, x_12.shape, x_13.shape)

x_21 = mp(x_13)
x_22 = conv_21(x_21)
x_23 = conv_22(x_22)

print(x_21.shape, x_22.shape, x_23.shape)

x_31 = upconv_1(x_23)
x_32 = conv_31(x_31)
x_33 = conv_32(x_32)

print(x_31.shape, x_32.shape, x_33.shape)