In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, TensorDataset, DataLoader, random_split

In [10]:
x = torch.tensor([[1], [2], [3]])
print(x.shape)
x = x.expand(3, 5)
print(x.shape)

torch.Size([3, 1])
torch.Size([3, 5])


In [23]:
def meshgrid(height, width):
  """
  Returns a uniform grid ranging from [0, H-1] and [0, W-1].
  """
  xt = torch.ones((height, 1)) @ torch.linspace(0.0, width - 1.0, width).unsqueeze(-1).transpose(1, 0)
  yt = torch.linspace(0.0, height - 1.0, height).unsqueeze(1) @ torch.ones((1, width))

  return xt, yt

In [25]:
def repeat(x, num_repeats):
  repeats = torch.ones(num_repeats).unsqueeze(-1).transpose(1, 0)
  x = torch.reshape(x, (-1, 1)) @ repeats
  return torch.reshape(x, (-1, 1)).squeeze()

In [27]:
def interpolate(im, x, y):
  im = torch.nn.functional.pad(im, (0, 0, 1, 1, 1, 1, 0, 0))
  batch_size, height, width, channels = im.shape
  batch_size, out_height, out_width = x.shape
  x, y = x.reshape(-1, 1) + 1.0, y.reshape(-1, 1) + 1.0

  x0 = torch.floor(x)
  x1 = x0 + 1
  y0 = torch.floor(y)
  y1 = y0 + 1

  x0 = torch.clamp(x0, 0, width - 1)
  x1 = torch.clamp(x1, 0, width - 1)
  y0 = torch.clamp(y0, 0, height - 1)
  y1 = torch.clamp(y1, 0, height - 1)

  dim2 = width
  dim1 = width*height
  base = repeat(torch.arange(0, batch_size)*dim1, out_height*out_width)
  base_y0 = base + y0*dim2
  base_y1 = base + y1*dim2

  idx_a = base_y0 + x0
  idx_b = base_y1 + x0
  idx_c = base_y0 + x1
  idx_d = base_y1 + x1

  im_flat = torch.reshape(im, [-1, channels])
  im_flat = im_flat.float()
  dim, _ = idx_a.transpose(1,0).shape
  Ia = torch.gather(im_flat, 0, idx_a.transpose(1,0).expand(dim, channels))
  Ib = torch.gather(im_flat, 0, idx_b.transpose(1,0).expand(dim, channels))
  Ic = torch.gather(im_flat, 0, idx_c.transpose(1,0).expand(dim, channels))
  Id = torch.gather(im_flat, 0, idx_d.transpose(1,0).expand(dim, channels))

  x1_f = x1.float()
  y1_f = y1.float()
  dx = x1_f - x
  dy = y1_f - y

  wa = (dx * dy).transpose(1,0)
  wb = (dx * (1-dy)).transpose(1,0)
  wc = ((1-dx) * dy).transpose(1,0)
  wd = ((1-dx) * (1-dy)).transpose(1,0)

  output = torch.sum(torch.squeeze(torch.stack([wa*Ia, wb*Ib, wc*Ic, wd*Id], dim=1)), 1)
  output = torch.reshape(output, [-1, out_height, out_width, channels])
  return output

In [None]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

  def forward(self, x):
    return self.conv(x)

In [None]:
class UNet(nn.Module):
  def __init__(self, in_channels: int=1, out_channels: int=1, features: list=[64, 128, 256, 512]):
    super().__init__()
    self.ups = nn.ModuleList()
    self.downs = nn.ModuleList()
    self.pool = nn.MaxPool2d(2, 2)

    for feature in features:
      self.downs.append(DoubleConv(in_channels, feature))
      in_channels = feature

    self.bottleneck = DoubleConv(feature, 2*features)

    for feature in reversed(features):
      self.ups.append(nn.ConvTranspose2d(2*feature, feature, kernel_size=2, stride=2))
      self.ups.append(DoubleConv(2*feature, feature))

    self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

  def forward(self, x):
    skip_connections = []

    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)

    skip_connections = skip_connections[:,:,-1]
    for idx in range(0, len(self.ups), 2):
      x = self.ups[idx]
      x = torch.cat([skip_connections[idx//2], x], dim=1)
      x = self.ups[idx+1]

    return self.final_conv(x)