### Скачать библиотеки и инициализировать device

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from torchvision.transforms import ToTensor
from torchvision.utils import save_image, make_grid
import os
from PIL import Image
import numpy as np

from typing import Dict, Tuple

In [2]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda', index=0)

### Датасет

In [3]:
!mkdir datasets
!gdown --id 139GsP9CqFCW1LA1Mf3e1gZpWz2uXmfHf -O datasets/tiny-floodnet-challenge.tar.gz
!tar -xzf datasets/tiny-floodnet-challenge.tar.gz -C datasets
!rm datasets/tiny-floodnet-challenge.tar.gz

Downloading...
From: https://drive.google.com/uc?id=139GsP9CqFCW1LA1Mf3e1gZpWz2uXmfHf
To: /content/datasets/tiny-floodnet-challenge.tar.gz
100% 50.4M/50.4M [00:01<00:00, 34.2MB/s]


In [4]:
class FloodNet(Dataset):

    def __init__(self, data_path: str, phase: str, augment: bool, img_size: int):
        self.num_classes = 8
        self.data_path = data_path
        self.phase = phase
        self.augment = augment
        self.img_size = img_size
        self.items = [filename.split('.')[0] for filename in os.listdir(f'{data_path}/{phase}/image')]
        if augment:
          self.transform = A.Compose([
                                      A.RandomScale(), 
                                      A.RandomCrop(width=self.img_size, height=self.img_size),
                                      A.Rotate(limit=30),
                                      A.Flip(p=0.5),
                                      A.HueSaturationValue()])
        else:
          self.transform = A.RandomCrop(width=self.img_size, height=self.img_size)
        self.to_tensor = ToTensor()

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):
        image = np.asarray(Image.open(f'{self.data_path}/{self.phase}/image/{self.items[index]}.jpg'))
        if self.phase == 'train':
          transformed = self.transform(image=image)
          image = transformed['image']
        image = self.to_tensor(image.copy())
        if self.phase == 'train':
          assert isinstance(image, torch.FloatTensor) and image.shape == (3, self.img_size, self.img_size)
        return image

In [5]:
dataset = FloodNet(data_path = 'datasets/tiny-floodnet-challenge',
                   phase = 'train',
                   augment = True,
                   img_size = 64)

loader = DataLoader(dataset, 
                    num_workers = 8,
                    batch_size = 32 , 
                    shuffle = True)

  cpuset_checked))


In [6]:
# check 
for data in loader:
  print(data.shape)
  break

  cpuset_checked))


torch.Size([32, 3, 64, 64])


### Model

In [7]:
class DDPM(nn.Module):
  def __init__(self, eps_model: nn.Module, betas: Tuple[float, float], n_T: int, criterion: nn.Module = nn.MSELoss()) -> None:
    super(DDPM, self).__init__()
    self.eps_model = eps_model

    # register_buffer allows us to freely access these tensors by name. It helps device placement.
    for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
        self.register_buffer(k, v)

    self.n_T = n_T
    self.criterion = criterion

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Makes forward diffusion x_t, and tries to guess epsilon value from x_t using eps_model.
    This implements Algorithm 1 in the paper.
    """

    _ts = torch.randint(1, self.n_T + 1, (x.shape[0],)).to(x.device)
    # t ~ Uniform(0, n_T)
    eps = torch.randn_like(x)  # eps ~ N(0, 1)

    x_t = (
        self.sqrtab[_ts, None, None, None] * x
        + self.sqrtmab[_ts, None, None, None] * eps
    )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
    # We should predict the "error term" from this x_t. Loss is what we return.

    return self.criterion(eps, self.eps_model(x_t, _ts / self.n_T))

  def sample(self, n_sample: int, size, device) -> torch.Tensor:

    x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1)

    # This samples accordingly to Algorithm 2. It is exactly the same logic.
    for i in range(self.n_T, 0, -1):
      z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
      eps = self.eps_model(x_i, torch.tensor(i / self.n_T).to(device).repeat(n_sample, 1))
      x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z)

    return x_i


def ddpm_schedules(beta1: float, beta2: float, T: int) -> Dict[str, torch.Tensor]:
  """
  Returns pre-computed schedules for DDPM sampling, training process.
  """
  assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

  beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
  sqrt_beta_t = torch.sqrt(beta_t)
  alpha_t = 1 - beta_t
  log_alpha_t = torch.log(alpha_t)
  alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

  sqrtab = torch.sqrt(alphabar_t)
  oneover_sqrta = 1 / torch.sqrt(alpha_t)

  sqrtmab = torch.sqrt(1 - alphabar_t)
  mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

  return {
    "alpha_t": alpha_t,  # \alpha_t
    "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
    "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
    "alphabar_t": alphabar_t,  # \bar{\alpha_t}
    "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
    "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
    "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
  }

In [8]:
class Conv3(nn.Module):
  def __init__(
    self, in_channels: int, out_channels: int, is_res: bool = False) -> None:
    super().__init__()
    self.main = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, 3, 1, 1),
      nn.GroupNorm(8, out_channels),
      nn.ReLU(),
    )
    self.conv = nn.Sequential(
      nn.Conv2d(out_channels, out_channels, 3, 1, 1),
      nn.GroupNorm(8, out_channels),
      nn.ReLU(),
      nn.Conv2d(out_channels, out_channels, 3, 1, 1),
      nn.GroupNorm(8, out_channels),
      nn.ReLU(),
    )

    self.is_res = is_res

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.main(x)
    if self.is_res:
      x = x + self.conv(x)
      return x / 1.414
    else:
      return self.conv(x)


class UnetDown(nn.Module):
  def __init__(self, in_channels: int, out_channels: int) -> None:
    super(UnetDown, self).__init__()
    layers = [Conv3(in_channels, out_channels), nn.MaxPool2d(2)]
    self.model = nn.Sequential(*layers)

  def forward(self, x: torch.Tensor) -> torch.Tensor:

    return self.model(x)


class UnetUp(nn.Module):
  def __init__(self, in_channels: int, out_channels: int) -> None:
    super(UnetUp, self).__init__()
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
        Conv3(out_channels, out_channels),
        Conv3(out_channels, out_channels),
    ]
    self.model = nn.Sequential(*layers)

  def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
    x = torch.cat((x, skip), 1)
    x = self.model(x)

    return x

class TimeSiren(nn.Module):
  def __init__(self, emb_dim: int) -> None:
    super(TimeSiren, self).__init__()

    self.lin1 = nn.Linear(1, emb_dim, bias=False)
    self.lin2 = nn.Linear(emb_dim, emb_dim)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = x.view(-1, 1)
    x = torch.sin(self.lin1(x))
    x = self.lin2(x)
    return x


In [9]:
class NaiveUnet(nn.Module):
  def __init__(self, in_channels: int, out_channels: int, n_feat: int = 256) -> None:
    super(NaiveUnet, self).__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels

    self.n_feat = n_feat

    self.init_conv = Conv3(in_channels, n_feat, is_res=True)

    self.down1 = UnetDown(n_feat, n_feat)
    self.down2 = UnetDown(n_feat, 2 * n_feat)
    self.down3 = UnetDown(2 * n_feat, 2 * n_feat)

    self.to_vec = nn.Sequential(nn.AvgPool2d(4), nn.ReLU())

    self.timeembed = TimeSiren(2 * n_feat)

    self.up0 = nn.Sequential(
        nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 4, 4),
        nn.GroupNorm(8, 2 * n_feat),
        nn.ReLU(),
    )

    self.up1 = UnetUp(4 * n_feat, 2 * n_feat)
    self.up2 = UnetUp(4 * n_feat, n_feat)
    self.up3 = UnetUp(2 * n_feat, n_feat)
    self.out = nn.Conv2d(2 * n_feat, self.out_channels, 3, 1, 1)

  def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:

    x = self.init_conv(x)

    down1 = self.down1(x)
    down2 = self.down2(down1)
    down3 = self.down3(down2)

    thro = self.to_vec(down3)
    temb = self.timeembed(t).view(-1, self.n_feat * 2, 1, 1)

    thro = self.up0(thro + temb)

    up1 = self.up1(thro, down3) + temb
    up2 = self.up2(up1, down2)
    up3 = self.up3(up2, down1)

    out = self.out(torch.cat((up3, x), 1))

    return out

In [10]:
model = DDPM(eps_model=NaiveUnet(3, 3, n_feat=128), betas=(1e-4, 0.02), n_T=1000).to(device)

### Train


In [11]:
lr = 3e-4
num_epochs = 10

optim = torch.optim.Adam(model.parameters(), lr=lr)

In [12]:
for i in range(num_epochs):
  model.train()
  for idx, data in enumerate(loader):
    data = data.to(device)
    optim.zero_grad()
    loss = model(data)
    loss.backward()
    optim.step()
  model.eval()
  with torch.no_grad():
    xh = model.sample(8, (3, 128, 128), device)
    xset = torch.cat([xh, data[:8]], dim=0)
    grid = make_grid(xset, normalize=True, value_range=(-1, 1), nrow=4)
    save_image(grid, f"./contents/model_sample{i:03d}.png")

    # save model
    torch.save(model.state_dict(), f"./model_floodnet.pth")

  cpuset_checked))


KeyboardInterrupt: ignored