In [None]:
import torch 
import torch.nn as nn
import copy
import math

from pathlib import Path

import sys
import import_ipynb
dir = Path('notebooks')
sys.path.insert(0, str(dir.resolve()))
import globals

In [None]:
EMA = 0.999
NOISE_EMBEDDING_SIZE = 32

In [None]:
def offset_cosine_diffusion_schedule(diffusion_time):
    min_signal_rate = 0.02
    max_signal_rate = 0.95
    start_angle = torch.acos(torch.tensor(max_signal_rate))
    end_angle = torch.acos(torch.tensor(min_signal_rate))

    diffusion_angle = start_angle + diffusion_time * (end_angle - start_angle)

    signal_rate = torch.cos(diffusion_angle)
    noise_rate = torch.sin(diffusion_angle)

    return signal_rate, noise_rate


In [None]:
def sinusoidal_embedding(x):
    frequencies = torch.exp(input = torch.linspace(math.log(10), math.log(1000.0), NOISE_EMBEDDING_SIZE // 2))
    angular_speed = 2.0 * math.pi * frequencies
    embeddings = torch.concat(tensors = [torch.cos(angular_speed * x), torch.sin(angular_speed * x)], dim = 3)
    embeddings = torch.permute(input = embeddings, dims = (0, 3, 1, 2))
    return embeddings

def Swish(x):
        return x * torch.sigmoid(x)

def ResidualBlock(width):
    def apply(x):
        input_width = x.shape[1]
        if input_width == width:
            residual = x
        else:
            residual = nn.Conv2d(in_channels = x.shape[1], out_channels = width, kernel_size = 1)(x)
            
        x = nn.BatchNorm2d(num_features = x.shape[1])(x)
        x = nn.Conv2d(in_channels = x.shape[1], out_channels = width, kernel_size = 3, padding = 1)(x)
        x = Swish(x)
        x = nn.Conv2d(in_channels = x.shape[1], out_channels = width, kernel_size = 3, padding = 1)(x)
        x = torch.add(input=x, other = residual)
        return x
    return apply
    
def DownBlock(width, block_depth):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(width)(x)
            skips.append(x)
        x = nn.AvgPool2d(kernel_size = 2)(x)
        return x
    return apply

def UpBlock(width, block_depth):
    def apply(x):
        x, skips = x

        x = nn.UpsamplingBilinear2d(scale_factor = 2)(x)

        for _ in range(block_depth):
            x = torch.concatenate(tensors = [x, skips.pop()], dim = 1)
            x = ResidualBlock(width)(x)
        return x
    return apply



In [None]:
# U-Net
class U_Net(nn.Module):
    def __init__(self):
        super(U_Net, self).__init__()
        self.conv2d_1 = nn.Conv2d(in_channels = globals.CHANNELS, out_channels = 32, kernel_size = 1)
        self.downblock_1 = DownBlock(width = 32, block_depth = 2)
        self.downblock_2 = DownBlock(width = 64, block_depth = 2)
        self.downblock_3 = DownBlock(width = 96, block_depth = 2)
        self.residualblock_1 = ResidualBlock(width = 128)
        self.residualblock_2 = ResidualBlock(width = 128)
        self.upblock_1 = UpBlock(width = 96, block_depth = 2)
        self.upblock_2 = UpBlock(width = 64, block_depth = 2)
        self.upblock_3 = UpBlock(width = 32, block_depth = 2)
        self.conv2d_2 = nn.Conv2d(in_channels = 32, out_channels = 3, kernel_size = 1)



    def forward(self, noise_variance, noisy_image):

        embedded_noise = sinusoidal_embedding(noise_variance)
        upsampled_noise = nn.UpsamplingNearest2d(scale_factor = 64)(embedded_noise)

        conv_noisy_image = self.conv2d_1(noisy_image)

        x = torch.concat(tensors = [upsampled_noise, conv_noisy_image], dim = 1)

        skips = []
        x = self.downblock_1([x, skips])
        x = self.downblock_2([x, skips])
        x = self.downblock_3([x, skips])

        x = self.residualblock_1(x)
        x = self.residualblock_2(x)

        x = self.upblock_1([x, skips])
        x = self.upblock_2([x, skips])
        x = self.upblock_3([x, skips])

        pred_noise = self.conv2d_2(x)

        return pred_noise


In [None]:

class DiffusionModel(nn.Module):
    def __init__(self):
        super(DiffusionModel, self).__init__()
        self.network = U_Net()
        self.ema_network = copy.deepcopy(self.network)
        self.diffusion_schedule = offset_cosine_diffusion_schedule
        self.opt_network = torch.optim.Adam(params = self.network.parameters(), lr = 0.0001)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.1)

    def denoise(self, noisy_images, noise_rates, signal_rates, training):
        if training:
            pred_noises = self.network(
            noise_variance = noise_rates**2, noisy_image = noisy_images
        )
        else:
            pred_noises = self.ema_network(
            noise_variance = noise_rates**2, noisy_image = noisy_images
        )
        
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates

        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, diffusion_steps):
        num_images = initial_noise.shape[0]
        step_size = 1.0 / diffusion_steps
        current_images = initial_noise
        for step in range(diffusion_steps):
            diffusion_times = torch.ones((num_images, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(
                current_images, noise_rates, signal_rates, training=False
            )
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
                next_diffusion_times
            )
            current_images = (
                next_signal_rates * pred_images + next_noise_rates * pred_noises
            )
        return pred_images

    def generate(self, num_images, diffusion_steps, initial_noise=None):
        if initial_noise is None:
            initial_noise = torch.randn(
                size=(num_images, globals.CHANNELS, globals.IMAGE_SIZE, globals.IMAGE_SIZE)
            )
        generated_images = self.reverse_diffusion(
            initial_noise, diffusion_steps
        )
        generated_images = torch.mean(input = generated_images, dim = (0,1,2,3)) + generated_images * torch.var(input = generated_images, dim = (0,1,2,3))**0.5
        return generated_images

    def forward(self, images):
        self.opt_network.zero_grad()
        images = torch.tensor(nn.BatchNorm2d(images.shape[1])(images)).requires_grad_(True)

        noises = torch.randn(size=(globals.BATCH_SIZE, globals.CHANNELS, globals.IMAGE_SIZE, globals.IMAGE_SIZE))

        diffusion_times = torch.rand(
            size=(globals.BATCH_SIZE, 1, 1, 1)
        )
        signal_rates, noise_rates = self.diffusion_schedule(diffusion_times)
        
        noisy_images = signal_rates * images + noise_rates * noises

        pred_noises, _ = self.denoise(
            noisy_images, noise_rates, signal_rates, training=True
        )
        noise_loss = torch.mean(nn.functional.mse_loss(noises, pred_noises)) 

        noise_loss.backward()
        self.opt_network.step()

        networks_state_dict = self.network.state_dict()
        ema_networks_state_dict = self.ema_network.state_dict()


        for (ema_name, ema_param), (_, param) in zip(
            ema_networks_state_dict.items(), networks_state_dict.items()
        ):
            ema_networks_state_dict[ema_name] = EMA * ema_param + (1 - EMA) * param

        print('Loss is  {}'.format(noise_loss))

        return noise_loss
    
    def generate(self, images):
        images = torch.tensor(nn.BatchNorm2d(images.shape[1])(images))
        noises = torch.randn(size=(globals.BATCH_SIZE, globals.CHANNELS, globals.IMAGE_SIZE, globals.IMAGE_SIZE))
        diffusion_times = torch.rand(
            size=(globals.BATCH_SIZE, 1, 1, 1)
        )
        signal_rates, noise_rates = self.diffusion_schedule(diffusion_times)
        
        noisy_images = signal_rates * images + noise_rates * noises

        pred_noises, _ = self.denoise(
            noisy_images, noise_rates, signal_rates, training=False
        )

        noise_loss = torch.mean(nn.functional.mse_loss(noises, pred_noises))

        print('Loss is  {}'.format(noise_loss))

        return noise_loss

        
