In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from torch import nn
from torchvision import datasets, transforms
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from torch.optim import lr_scheduler
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
rng = np.random.default_rng()

In [None]:
num_decoders = 8
decoder_depth = 4
decoder_channels = 32
img_dim = 32
img_channels = 3
batch_size = 256
PATH = 'model_checkpoint.pth'

In [None]:
class normal_distribute_block(nn.Module):
    def __init__(self, img_dim, sd):
        super().__init__()
        self.img_dim = img_dim
        self.sd = sd
        rng = np.random.default_rng()

    def up_sd(self):
      self.sd = self.sd + 0.05

    def forward(self, x):
        initial_deviations = self.sd * torch.randn((self.img_dim, self.img_dim)).to(device) + torch.ones((self.img_dim, self.img_dim)).to(device)
        initial_mean = torch.zeros((self.img_dim, self.img_dim)).to(device)#, requires_grad=True)
        initial_x = x
        x = initial_deviations * x
        x = x + initial_mean
        return x

In [None]:
class diffusion_forward(nn.Module):
    def __init__(self, img_dim, num_layers):
        super().__init__()

        self.transformation = nn.ModuleList([
            normal_distribute_block(img_dim, 0.1) for i in range(num_layers)
        ])

        self.num_layers = num_layers
        self.img_dim = img_dim
        self.rm_rec = 0
        self.rm_sq = 0

    def inc_diff(self):
      for block in self.transformation:
        block.up_sd()

    def up_removed_box(self):
      self.rm_sq = self.rm_sq + 1
      self.rm_rec = self.rm_rec + 1

    def forward(self, x):
        x = x.unsqueeze(0).repeat(self.num_layers+1, 1, 1, 1, 1)
        for i, blur in enumerate(self.transformation):
            x[i+1] = blur(x[i])

        return x

In [None]:
class Unet(nn.Module):
    def __init__(self, img_dim, num_layers, in_channels, initial_channels):
        super().__init__()

        self.convolution_list = []
        self.upscale_list = []
        self.num_layers = num_layers
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(2, 2)
        self.output_conv = nn.Conv2d(initial_channels, 3, 1)
        self.convolution_list = nn.ModuleList([])

        for i in range(num_layers):
            self.convolution_list.append(nn.ModuleList([nn.Conv2d(in_channels, initial_channels, 3, padding=1), nn.Conv2d(initial_channels, initial_channels, 3, padding=1), nn.Conv2d(initial_channels * 2, initial_channels, 3, padding=1)]))
            in_channels = initial_channels
            initial_channels = initial_channels * 2

        self.intermediate_conv = nn.Conv2d(in_channels, initial_channels, 3, padding = 1)
        self.middle_conv = nn.Conv2d(initial_channels, initial_channels, 3, padding = 1)

    def upscale(self, x):
        x = x.repeat_interleave(2, dim=2)
        x = x.repeat_interleave(2, dim=3)
        return x

    def forward(self, x):
        xs_list = []
        for layer in range(self.num_layers):
            x = self.convolution_list[layer][0](x)
            for i in range(1, 3):
                x = self.relu(self.convolution_list[layer][1](x))
            xs_list.append(x)
            x = self.max_pool(x)

        x = self.intermediate_conv(x)
        x = self.middle_conv(x)

        for layer in range(self.num_layers):
            backward_layer = self.num_layers - layer - 1
            x = self.upscale(x)
            x = self.convolution_list[backward_layer][2](x)
            x = x + xs_list[backward_layer]
            for i in range(1, 3):
                x = self.convolution_list[backward_layer][1](x)

        x = self.output_conv(x)
        return x


In [None]:
class diffusion_backward(nn.Module):
  def __init__(self, img_dim, num_layers, in_channels, initial_channels, num_decoders):
      super().__init__()

      self.num_decoders = num_decoders
      self.unets = nn.ModuleList([
          Unet(img_dim, num_layers, in_channels, initial_channels) for i in range(num_decoders)
      ])

  def disabe_unet(self, unet_no):
      self.unets[unet_no] = self.unets[unet_no].detach()

  def forward(self, x):
      x = x[num_decoders]
      x_record = x.clone().unsqueeze(0).repeat(self.num_decoders+1, 1, 1, 1, 1)
      for i in range(len(self.unets)):
          x = self.unets[i](x.clone().detach())
          x_record[num_decoders - i - 1] = x

      return x_record

In [None]:
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = datasets.CIFAR10(root='.', train=True, download=True, transform=data_transform)
testset = datasets.CIFAR10(root='.', train=False, download=True, transform=data_transform)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader  = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

dataiter = iter(train_loader)
images, labels = next(dataiter)

In [None]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

transform = T.ToPILImage()

imshow(images.cpu().detach()[0])
model_forward = diffusion_forward(img_dim, num_decoders).to(device)
model_backward = diffusion_backward(img_dim, decoder_depth, img_channels, decoder_channels, num_decoders).cuda()


loss = nn.MSELoss()
model2.train()

images = images.to(device)
for layer in range(num_decoders):
  imshow(model1(images)[layer][0].cpu().detach())


In [None]:
load_model = False

if load_model:
    model_backward.load_state_dict(torch.load(PATH))

In [None]:
model_backward.train()

optimizers = []
for i in range(len(model2.unets)):
    optimizers.append(torch.optim.Adam(model2.unets[num_decoders - i - 1].parameters(), lr=1, eps=1))

torch.autograd.set_detect_anomaly(True)

schedulers = []
for i in range(len(model2.unets)):
    schedulers.append(lr_scheduler.StepLR(optimizers[i], step_size=10, gamma=0.5))

model_forward.inc_diff()

for i in range(1, 180):
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        images = data.to(device)
        corrupted = model_forward(images)
        restored = model_backward(corrupted)
        total_loss = loss(restored, corrupted)
        epoch_loss += total_loss.item()

        for j in range(num_decoders):
            optimizers[j].zero_grad()

        total_loss.backward()
        for j in range(num_decoders):
            optimizers[j].step()

    for j in range(num_decoders):
        schedulers[j].step()

    if i % 30 == 0:
        model_forward.inc_diff()
        for optimizer in optimizers:
            optimizer.param_groups[0]['lr'] = 1

    if i % 1 == 0:
      print(f'Epoch no: {i} , Loss: {epoch_loss / 60000}')

In [None]:
torch.save(model2.state_dict(), PATH)

In [None]:
model2.eval()
image_num = 2

imshow(images.cpu().detach()[image_num])
max_blurred = model1(images)
restored = model_backward(max_blurred)[:, image_num]
imshow(max_blurred[8, image_num].cpu().detach())
for i in range(num_decoders):
  imshow(restored[i].cpu().detach())