In [2]:
import os
import argparse

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from PIL import Image

In [None]:
parser = argparse.ArgumentParser()
args = parser.parse_args()

args.epochs = 500
args.batch_size = 12
args.image_size = 64
args.dataset_path = r"C:\Users\dome\datasets\landscape_img_folder"
args.device = "cuda"
args.lr = 3e-4
train(args)

In [4]:
def train(args):
  device = args.device
  dataloader = get_data(args)
  model = UNet().to(device)
  opt = torch.optim.AdamW(model.parameters(), lr=args.lr)
  loss_fn = nn.MSELoss()
  diffusion = Diffusion(img_size=args.image_size, device=device)
  l = len(dataloader)

  for epoch in range(args.epochs):
    for images, _ in dataloader:
      images = images.to(device)
      t = diffusion.sample_timesteps(images.shape[0]).to(device)
      x_t, noise = diffusion.noise_images(images, t)
      predicted_noise = model(x_t, t)
      loss = loss_fn(noise, predicted_noise)
      opt.zero_grad()
      loss.backward()
      opt.step()
    sampled_images = diffusion.sample(model, images.shape[0])
    save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))

In [3]:
def plot_images(images):
  plt.figure(figsize=(14, 14))
  plt.imshow(torch.cat([
      torch.cat([i for i in images.cpu()], dim=-1)
      ], dim=-2).permute(1, 2, 0).cpu())
  plt.show()


def save_images(images, path, **kwargs):
  grid = torchvision.utils.make_grid(images, **kwargs)
  ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
  im = Image.fromarray(ndarr)
  im.save(path)


def get_data(args):
  transforms = torchvision.transforms.Compose([
      torchvision.transforms.Resize(80),
      torchvision.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])
  dataset = torchvision.datasets.ImageFolder(args.dataset_path, transform=transforms)
  dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
  return dataloader

In [None]:
class SelfAttention(nn.Module):
  def __init__(self,
               channels: int=3,
               size: int=256):
    super().__init__()
    self.channels = channels
    self.size = size
    self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
    self.ln = nn.LayerNorm([channels])
    self.ff = nn.Sequential(
        nn.LayerNorm([channels]),
        nn.Linear(channels, channels),
        nn.GeLU(),
        nn.Linear(channels, channels)
    )

  def forward(self, x: torch.tensor):
    x = x.view(x.shape[0], self.channels, -1).swapaxes(1, 2)
    out = self.ln(x)
    attention_val, _ = self.mha(out, out, out)
    attention_val = attention_val + x
    attention_val = self.ff(attention_val) + attention_val
    return attenion_val.swapaxes(2, 1).view(-1, self.chanels, self.size, self.size)


class ConvBlock(nn.Module):
  def __init__(self, in_channels: int, mid_channels: int, out_channels: int, residual=False):
    super().__init__()
    self.residual = residual
    if not mid_channels:
      mid_channels = out_channels
    self.layers = nn.Sequential(
        nn.Conv2d(in_channels, mid_channels, 3, 1, bias=False),
        nn.GroupNorm(1, mid_channels),
        nn.GeLU(),
        nn.Conv2d(mid_channels, out_channels, 3, 1, bias=False),
        nn.GroupNorm(1, out_channels)
    )

  def forward(self, x: torch.tensor):
    if self.residual:
      return F.gelu(x + self.layers(x))
    else:
      return self.layers(x)


class EncodingBlock(nn.Module):
  def __init__(self, in_channels, out_channels, emb_dim: int=256):
    super().__init__()
    self.block = nn.Sequential(
        nn.MaxPool(2),
        ConvBlock(in_channels, in_channels, residual=True),
        ConvBlock(in_channels, out_channels)
    )
    self.embedding_block = nn.Sequential(
        nn.SiLU(),
        nn.Linear(emb_dim, out_channels)
    )

  def forward(self, x: torch.tensor, t):
    x = self.block(x)
    emb = self.embedding_block(t)[:, :, None, None].view(1, 1, x.shape[-2], x.shape[-1])
    return x + emb


class DecodingBlock(nn.Module):
  def __init__(self, in_channels, out_channels, emb_dim: int=256):
    super().__init__()
    self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    self.block = nn.Sequential(
        ConvBlock(in_channels, in_channels, residual=True),
        ConvBlock(in_channels, out_channels, in_channels // 2)
    )
    self.embedding_block = nn.Sequential(
        nn.SiLU(),
        nn.Linear(emb_dim, out_channels)
    )

  def forward(self, x, skip_x, t):
    x = self.up(x)
    x = torch.cat([skip_x, x], dim=1)
    x = self.block(x)
    emb = self.embedding_block(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
    return x + emb


class UNet(nn.Module):
  def __init__(self, in_channels: int=3, out_channels: int=3, emb_dim: int=256, device=torch.device('cuda')):
    super().__init__()
    self.device = device
    self.emb_dim = emb_dim
    self.inc = ConvBlock(in_channels, 64)
    self.down1 = EncodingBlock(64, 128)
    self.sa1 = SelfAttention(128, 32)
    self.down2 = EncodingBlock(128, 256)
    self.sa2 = SelfAttention(256, 16)
    self.down3 = EncodingBlock(256, 256)
    self.sa3 = SelfAttention(256, 8)

    self.bot1 = ConvBlock(256, 512)
    self.bot2 = ConvBlock(512, 512)
    self.bot3 = ConvBlock(512, 256)

    self.up1 = DecodingBlock(512, 128)
    self.sa4 = SelfAttention(128, 16)
    self.up2 = DecodingBlock(256, 64)
    self.sa5 = SelfAttention(64, 32)
    self.up3 = DecodingBlock(128, 64)
    self.sa6 = SelfAttention(64, 64)
    self.outc = nn.Conv2d(64, out_channels, kernel_size=1)

def pos_encoding(self, t, channels):
  inv_freq = 1.0 / (10000
            ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
        )
  pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
  pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
  pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
  return pos_enc

def forward(self, x, t):
  t = t.unsqueeze(-1).type(torch.float)
  t = self.pos_encoding(t, self.emb_dim)
  x1 = self.inc(x)
  x2 = self.down1(x1, t)
  x2 = self.sa1(x2)
  x3 = self.down2(x2, t)
  x3 = self.sa2(x3)
  x4 = self.down3(x3, t)
  x4 = self.sa3(x4)

  x4 = self.bot1(x4)
  x4 = self.bot2(x4)
  x4 = self.bot3(x4)

  x = self.up1(x4, x3, t)
  x = self.sa4(x)
  x = self.up2(x, x2, t)
  x = self.sa5(x)
  x = self.up3(x, x1, t)
  x = self.sa6(x)
  return self.outc(x)



In [None]:
class Diffusion:
  def __init__(self,
               noise_steps: int=1000,
               beta_start: float =1e-4,
               beta_end: float=0.02,
               img_size: int=256,
               device: torch.device=torch.device('cuda')
               ):
    self.noise_steps = noise_steps
    self.beta_start = beta_start
    self.beta_end = beta_end
    self.img_size = img_size
    self.device = device

    self.beta = self.prepare_noise_schedule().to(device)
    self.alpha = 1.0 - self.beta
    self.alpha_bar = torch.cumprod(self.alpha, dim=0)


  def prepare_noise_schedule(self):
    return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

  def noise_images(self, x: torch.tensor, t: int):
    eps = torch.randn_like(x, device=self.device)
    return torch.sqrt(self.alpha_bar[t][:, None, None, None]) * x + torch.sqrt(1.0 - self.alpha_bar[t][:, None, None, None]) * eps, eps

  def sample_timesteps(self, n: int):
    return torch.randint(low=0, high=self.noise_steps, size=(n,))

  def sample(self, model: nn.Module, n: int):
    model.eval()
    with torch.no_grad():
      x = torch.randn((n, 3, self.img_size, self.img_size), device=self.device)
      for i in range(self.noise_steps - 1, 0, -1):
        t = (torch.ones(n) * i).long().to(self.device)
        predicted_noise = model(x, t)
        alpha = self.alpha[t][:, None, None, None]
        alpha_bar = self.alpha_bar[t][:, None, None, None]
        beta = self.beta[t][:, None, None, None]
        if t > 1:
          noise = torch.randn_like(x, device=self.device)
        else:
          noise = torch.zeros_like(x, device=self.device)
        x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_bar))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1.0) / 2
        x = (x * 255).type(torch.uint8)
        return x





