In [1]:
import torch 
from torch import nn 
from torchvision import transforms 
from torch.utils.data import DataLoader 
import torchvision
from tqdm.autonotebook import tqdm

device = ("mps"
         if torch.backends.mps.is_available()
         else "cuda"
         if torch.cuda.is_available()
         else "cpu")


In [2]:
class LinearScheduler:
    def __init__(self, beta_start, beta_end, num_train_timesteps):
        self.beta_start = beta_start 
        self.beta_end = beta_end 
        self.num_train_timesteps = num_train_timesteps
        self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim = 0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        
    def add_noise(self, original, noise, t):
        original_shape = original.shape 
        batch_size = original_shape[0]
        sqrt_alpha_cumprod = self.sqrt_alphas_cumprod[t].reshape(batch_size)
        sqrt_one_minus_alpha_cumprod = self.sqrt_one_minus_alphas_cumprod[t].reshape(batch_size)
        
        for _ in range(len(original_shape) -1):
            sqrt_alpha_cumprod = sqrt_alpha_cumprod.unsqueeze(-1)
            sqrt_one_minus_alpha_cumprod = sqrt_one_minus_alpha_cumprod.unsqueeze(-1)
            
        return sqrt_alpha_cumprod * original + sqrt_one_minus_alpha_cumprod * noise
    
    
    def sample_prev_timestep(self, xt, noise_pred, t):
        x0 = (xt - self.sqrt_one_minus_alphas_cumprod[t] * noise_pred)/self.sqrt_alphas_cumprod[t]
        x0 = torch.clamp(x0, -1, 1)
        
        mean = xt - (self.betas[t] * noise_pred)/self.sqrt_one_minus_alphas_cumprod[t]
        mean = xt/ self.sqrt_alphas_cumprod[t]
        
        if t==0:
            return mean, x0
        else:
            variance = self.betas[t] * (1. - self.alphas_cumprod[t-1])
            variance = variance / (1. - self.alphas_cumprod[t])
            sigma = variance ** 0.5 
            z = torch.randn_like(xt).to(xt.device)
            return mean + sigma * z, x0
        
    
            

In [3]:
def get_time_embedding(timesteps, t_emb_dim = 512):
    factor = 10000 ** (torch.arange(start = 0, 
                                    end = t_emb_dim // 2, 
                                    device = timesteps.device))
    t_emb = timesteps[:, None].repeat(1, t_emb_dim//2)/factor
    t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim = -1)
    return t_emb



In [4]:
class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, t_emb_dim, num_heads, downsample = True):
        super().__init__()
        self.downsample = downsample 
        self.conv_first = nn.Sequential(
                            nn.GroupNorm(8, in_channels), 
                            nn.SiLU(), 
                            nn.Conv2d(in_channels = in_channels, 
                                     out_channels = out_channels, 
                                     kernel_size =3, 
                                     stride = 1, 
                                     padding = 1))
        self.t_emb_proj = nn.Sequential(
                            nn.SiLU(), 
                            nn.Linear(t_emb_dim, out_channels))
        self.conv_second = nn.Sequential(
                            nn.GroupNorm(8, out_channels),
                            nn.SiLU(), 
                            nn.Conv2d(in_channels = out_channels, 
                                     out_channels = out_channels, 
                                     kernel_size = 3, 
                                     stride = 1, 
                                     padding = 1))
        self.attn_norm = nn.GroupNorm(8, out_channels)
        self.attn = nn.MultiheadAttention(out_channels, num_heads)
        self.downsample = nn.Conv2d(in_channels = out_channels, 
                                   out_channels = out_channels, 
                                   kernel_size = 3, 
                                   padding = 1, 
                                   stride = 2) if downsample else nn.Identity()
        
        self.residual_input_conv = nn.Conv2d(in_channels = in_channels, 
                                            out_channels = out_channels, 
                                            kernel_size = 1)
        
    def forward(self, x, t_emb):
        out = x 
        resnet_input = out 
        
        out = self.conv_first(out)
        t_emb = self.t_emb_proj(t_emb)
        out = out + t_emb[:, :, None, None]
        out = self.conv_second(out)
        out= out + self.residual_input_conv(resnet_input)
        
        batch_size, channels, h, w = out.shape
        in_attn = out.reshape(batch_size, channels, h * w)
        in_attn = self.attn_norm(in_attn)
        in_attn = in_attn.transpose(1, 2)
        out_attn, _ = self.attn(in_attn, in_attn, in_attn)
        out_attn = out_attn.reshape(batch_size, channels, h, w)
        out = out + out_attn
        
        
        out = self.downsample(out)
        return out 
    
        

In [5]:
 class MidBlock(nn.Module):
        def __init__(self, in_channels, out_channels, num_heads, t_emb_dim):
            super().__init__()
            self.resnet_conv_first = nn.ModuleList([
                nn.Sequential(nn.GroupNorm(8, in_channels), 
                             nn.SiLU(), 
                             nn.Conv2d(in_channels, out_channels, 
                                      kernel_size = 3, stride = 1, 
                                      padding = 1)), 
                nn.Sequential(nn.GroupNorm(8, out_channels), 
                             nn.SiLU(), 
                             nn.Conv2d(out_channels, out_channels, 
                                      kernel_size = 3, 
                                      stride = 1, padding = 1))
            ])
            
            self.time_projection = nn.ModuleList([
                nn.Sequential(nn.SiLU(), 
                             nn.Linear(t_emb_dim, out_channels)), 
                nn.Sequential(nn.SiLU(), 
                             nn.Linear(t_emb_dim, out_channels))
            ])
            
            self.attn_norm = nn.GroupNorm(8, out_channels)
            self.attn = nn.MultiheadAttention(out_channels, num_heads)
            self.residual_input_conv = nn.ModuleList([
                nn.Conv2d(in_channels, out_channels, kernel_size = 1), 
                nn.Conv2d(out_channels, out_channels, kernel_size = 1)
            ])
            
            self.resnet_conv_second = nn.ModuleList([
                nn.Sequential(nn.GroupNorm(8, out_channels), 
                             nn.SiLU(), 
                             nn.Conv2d(out_channels, out_channels, kernel_size = 3, 
                                      stride =1, padding = 1)), 
                nn.Sequential(nn.GroupNorm(8, out_channels), 
                             nn.SiLU(), 
                             nn.Conv2d(out_channels, out_channels, kernel_size = 3, 
                                      stride = 1, padding = 1))
            ])
            
        def forward(self, x, t_emb):
            out = x 
            residual_input = x
            
            out = self.resnet_conv_first[0](x)
            out = out + self.time_projection[0](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[0](out)
            out = out + self.residual_input_conv[0](residual_input)
            
            batch_size, channels, h, w = out.shape
            attn_in = out.reshape(batch_size, channels, h * w)
            attn_in = self.attn_norm(attn_in)
            attn_in  = attn_in.transpose(1, 2)
            out_attn, _ = self.attn(attn_in, attn_in, attn_in)
            out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
            out = out + out_attn
            
            residual_input = out 
            out = self.resnet_conv_first[1](out)
            out = out + self.time_projection[1](t_emb)[:, :, None, None]
            out = self.resnet_conv_second[1](out)
            out = out + self.residual_input_conv[1](residual_input)
            
            return out 
            

In [6]:
class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads, t_emb_dim, upsample = True):
        super().__init__()
        
        self.resnet_conv_first = nn.Sequential(nn.GroupNorm(8, in_channels), 
                                              nn.SiLU(), 
                                              nn.Conv2d(in_channels, out_channels,
                                                       kernel_size =3, stride = 1, 
                                                       padding = 1))
        self.time_projection = nn.Sequential(nn.SiLU(), 
                                            nn.Linear(t_emb_dim, out_channels))
        self.resnet_conv_second = nn.Sequential(nn.GroupNorm(8, out_channels), 
                                               nn.SiLU(), 
                                               nn.Conv2d(out_channels, out_channels, 
                                                        kernel_size =3, stride = 1, 
                                                        padding = 1))
        self.attn_norm = nn.GroupNorm(8, out_channels)
        self.attn = nn.MultiheadAttention(out_channels, num_heads)
        self.residual_input_conv = nn.Conv2d(in_channels, out_channels, 
                                            kernel_size = 1)
        self.upsample = nn.ConvTranspose2d(in_channels, 
                                          in_channels, 
                                          kernel_size = 4, 
                                          stride = 2, 
                                          padding = 1) if upsample else nn.Identity()
        
    def forward(self, x, out_down, t_emb):
        
        out = torch.cat([x, out_down], dim = 1)
        out = self.upsample(out)
        
        residual_input = out
        out = self.resnet_conv_first(out)
        out = out + self.time_projection(t_emb)[:, :, None, None]
        out = self.resnet_conv_second(out)
        out = out + self.residual_input_conv(residual_input)
        
        batch_size, channels, h, w = out.shape
        attn_in = self.attn_norm(out).reshape(batch_size, channels, h * w)
        attn_in = attn_in.transpose(1, 2)
        attn_out, _ = self.attn(attn_in, attn_in, attn_in)
        attn_out = attn_out.transpose(1, 2).reshape(batch_size, channels, h, w)
        out =attn_out + out 

        
        return out
    

In [7]:
class UNet(nn.Module):
    def __init__(self, im_channels = 3):
        super().__init__()
        self.down_channels = [32, 64, 128, 256]
        self.mid_channels = [256, 256, 256]
        self.conv_in = nn.Conv2d(in_channels = im_channels, 
                                out_channels = self.down_channels[0], kernel_size = 3, 
                                stride = 1, padding = 1)
        self.t_emb_dim = 128
        self.t_proj = nn.Sequential(
                    nn.Linear(self.t_emb_dim, self.t_emb_dim), 
                    nn.SiLU(), 
                    nn.Linear(self.t_emb_dim, self.t_emb_dim))
        self.downblocks = nn.ModuleList()
        for i in range(len(self.down_channels) - 1):
            self.downblocks.append(DownBlock(self.down_channels[i], self.down_channels[i+1], 
                                            t_emb_dim = self.t_emb_dim, num_heads = 4, 
                                            downsample = True))
        self.midblocks = nn.ModuleList()
        for i in range(len(self.mid_channels) - 1):
            self.midblocks.append(MidBlock(in_channels = self.mid_channels[i], 
                                          out_channels = self.mid_channels[i+1], 
                                          num_heads = 4, 
                                          t_emb_dim = self.t_emb_dim))
            
        self.upblocks = nn.ModuleList()
        for i in reversed(range(len(self.down_channels))):
            if i==0:
                break
            self.upblocks.append(UpBlock(in_channels = self.down_channels[i] * 2, 
                                        out_channels = self.down_channels[i - 1], 
                                        num_heads = 4, 
                                        t_emb_dim = self.t_emb_dim, 
                                        upsample = True))
        self.norm_out = nn.GroupNorm(8, self.down_channels[0])
        self.conv_out = nn.Conv2d(in_channels = self.down_channels[0], 
                                 out_channels = im_channels, 
                                 kernel_size = 1)
        
    def forward(self, x, t):
        out = self.conv_in(x)
        t_emb = get_time_embedding(t, t_emb_dim = self.t_emb_dim)
        t_emb = self.t_proj(t_emb)
        
        down_outs = []
        for downblock in self.downblocks:
            out = downblock(out, t_emb = t_emb)
#             print(out.shape)
            down_outs.append(out)
        
        for midblock in self.midblocks:
            out = midblock(out, t_emb = t_emb)
#             print(out.shape)
            
        for upblock in self.upblocks:
            d_out = down_outs.pop()
            out = upblock(out, out_down = d_out, t_emb = t_emb)
#             print(out.shape)
            
        out = self.norm_out(out)
        out = nn.SiLU()(out)
        out = self.conv_out(out)
        return out
        

In [8]:


BETA_START = 0.0001
BETA_END = 0.02
NUM_TRAIN_TIMESTEPS = 1000
TIME_EMB_DIM= 128
NUM_EPOCHS = 5

image_size = 32
batch_size = 32


scheduler = LinearScheduler(beta_start = BETA_START, 
                           beta_end = BETA_END, 
                           num_train_timesteps = NUM_TRAIN_TIMESTEPS)




In [9]:
import os
from skimage.io import imread
from tqdm.autonotebook import tqdm
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from PIL import Image

low_res_4x_train_base_path = "/kaggle/input/div2k-dataset-for-super-resolution/Dataset/DIV2K_train_LR_bicubic_X4/X4/"

low_res_4x_train_image_files = sorted(os.listdir(low_res_4x_train_base_path))


low_res_size = 112



lowres_transform = transforms.Compose(
    [
        transforms.Resize((low_res_size, low_res_size), interpolation = Image.NEAREST),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        
    ]
)


def preprocess(image_filename, base_path):
    image_file_path = os.path.join(base_path, image_filename)
    im = Image.open(image_file_path)
    
    # Resize the image to a common shape, for example, (256, 256)
    im = lowres_transform(im)
    return im

# Create separate lists for high-res and low-res images

low_res_4x_images = [preprocess(image_file, low_res_4x_train_base_path) 
                     for image_file in tqdm(low_res_4x_train_image_files, total=len(low_res_4x_train_image_files))]

# Using TensorDataset with separate tensors
train_dataset = TensorDataset(torch.stack(low_res_4x_images))

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

  0%|          | 0/800 [00:00<?, ?it/s]

In [19]:
NUM_EPOCHS = 10
criterion = nn.MSELoss()


unet = UNet(im_channels = 3).to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr = 1e-3)

losses = []

for epoch in tqdm(range(NUM_EPOCHS)):
    for idx, (batch) in tqdm(enumerate(train_dataloader), total = len(train_dataloader)):
        optimizer.zero_grad()
        batch = batch[0].to(device)
        t = torch.randint(0, 1000, (batch_size, ))
        noise = torch.randn_like(batch).to(device)
        noised_images = scheduler.add_noise(batch, noise, t).to(device)
        noise_pred = unet(noised_images, t)
        loss = criterion(noise_pred, noise)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        

[nan, nan, nan]