In [1]:
from fastai.callback.hook import *
from fastbook import *
from fastai.vision.all import *
import fastbook
import torch.nn as nn
fastbook.setup_book()

<h3><b>Diffusion Models (DDPM)</b></h3>

<h5><b>1) U-Net</b></h5>

<img src='../../docs/diffusion/unet.png' />

<h5><b>1.1) Double Convolution</b></h5>

In [2]:
class DoubleConv(nn.Module):
    def __init__(self, ics, ocs, mcs=None, ks=3):
        super().__init__()
        if not mcs:
            mcs = ocs
        self.ks = ks
        self.dbc = nn.Sequential(
            nn.Conv2d(ics, mcs, kernel_size=ks, padding=1, bias=False),
            nn.BatchNorm2d(mcs),            
            nn.ReLU(inplace=True),
            nn.Conv2d(mcs, ocs, kernel_size=ks, padding=1, bias=False),
            nn.BatchNorm2d(ocs),
            nn.ReLU(inplace=True)
        )

    def forward(self, x): return self.dbc(x)



<h5><b>1.2) Down Scale Convolution</b></h6>

In [3]:
class DownScale(nn.Module):
    """Downscaling with max pooling and doble convolution"""
    def __init__(self, ics, ocs):
        super().__init__()
        self.maxconv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(ics, ocs)
        )
    
    def forward(self, x): return self.maxconv(x)


<h5><b>1.3) Up Scale Convolution</b></h6>

In [14]:
#torch based upsampling
class UpScale(nn.Module):
    def __init__(self, ics, ocs, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(ics, ocs, ics//2)
        else:
            self.up = nn.ConvTranspose2d(ics, ics//2, kernel_size=2, stride=2)
            self.conv = DoubleConv(ics, ocs)
    
    def forward(self,x1,x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


<h5><b>1.4) Out Convolution</b></h6>

In [5]:
class OutConv(nn.Module):
    def __init__(self, ics, ocs) -> None:
        super(OutConv, self).__init__()
        # Last 1x1 convolution
        self.conv = nn.Conv2d(ics, ocs, kernel_size=1)
    def forward(self, x): return self.conv(x)

<h4><b>U-Net Construction</b></h4>

In [15]:
class UNet(nn.Module):
    def __init__(self, nc, ncls, bilinear=False):
        super(UNet, self).__init__()
        self.nc = nc # number of channels
        self.ncls = ncls # number of classes
        # Main models
        self.inc = (DoubleConv(nc, 64))
        self.down1 = (DownScale(64,128)) 
        self.down2 = (DownScale(128,256)) 
        self.down3 = (DownScale(256,512))
        factor = 2 if bilinear else 1
        self.down4 = (DownScale(512, 1024 // factor))
        self.up1 = (UpScale(1024, 512 // factor, bilinear))
        self.up2 = (UpScale(512, 256 // factor, bilinear))
        self.up3 = (UpScale(256, 128 // factor, bilinear))
        self.up4 = (UpScale(128, 64, bilinear))
        self.outc = (OutConv(64, ncls))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits



In [7]:
path = untar_data(URLs.PETS)/'images'

In [8]:
p = get_image_files_sorted(path)
img = tensor(Image.open(p[0]))
H, W, C = img.shape
img = img.permute(2,0,1).unsqueeze(0).float()

In [9]:
img.shape

torch.Size([1, 3, 400, 600])

In [16]:
unet = UNet(C, 1)
x = unet(img)
x

tensor([[[[ 0.3275, -0.3614,  0.0464,  ...,  0.1696,  0.4264, -0.0396],
          [ 0.3945,  0.6337,  0.0862,  ...,  0.1472,  0.2186, -0.5170],
          [ 0.6860,  0.3366,  0.4820,  ...,  0.2100,  0.5643, -0.0712],
          ...,
          [ 0.4541,  0.0038,  0.4837,  ..., -0.2348,  0.5824, -0.5123],
          [ 0.3414, -0.1465,  0.3276,  ..., -0.0782,  0.3636, -0.3238],
          [ 0.2112,  0.1453,  0.3182,  ..., -0.0858,  0.3465, -0.0974]]]], grad_fn=<ConvolutionBackward0>)

In [18]:
x.shape

torch.Size([1, 1, 400, 600])