In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as sched

import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder

from matplotlib import pyplot as plt

In [9]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

### Parameters

In [10]:
batch_size = 1000
epoch = 30


In [11]:
full_DS = ImageFolder(root="trafic_32", transform=transforms.ToTensor())
full_loader = DataLoader(dataset=full_DS, pin_memory=True, num_workers=0, shuffle=True, batch_size=batch_size)

len(full_DS)

39209

In [20]:
class Unet_Block(nn.Module):
    """
    Base U-Net block
    in_channels - number of input channels
    out_channels - number of out channels
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.activation = nn.SiLU()

        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=(1,1)),
            self.activation,
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=(1,1)),
            self.activation,
            nn.BatchNorm2d(out_channels) # to moze na GN zmienic
        ])

    def froward(self, x):

        for layer in self.layers:
            x = layer(x)

        return x

In [21]:
class Rescaler(nn.Module):
    """
    Resacler either upscales given input or downscales it (size x2 or /2 only)
    """

    def __init__(self, in_channels, out_channels, upscale:bool):
        super().__init__()

        if upscale:
            self.rescaler = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=(2,2))
        else:
            self.rescaler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=(2,2))

    def forward(self, x):
        return self.rescaler(x)

In [31]:
class U_Net(nn.Module):

    def __init__(self, in_channels):
        super().__init__()


        self.encoder = nn.ModuleList([

            Unet_Block(in_channels, 32),
            Rescaler(32, 32, upscale=False), # 16x16

            Unet_Block(32, 64),
            Rescaler(64, 64, upscale=False), #8x8

            Unet_Block(64, 128),
            Rescaler(128, 128, upscale=False), #4x4

            Unet_Block(128, 172),
            Rescaler(32, 32, upscale=False), # 2x2
        ])

        self.latent = nn.Sequential(
            nn.Conv2d(in_channels=172, out_channels=226, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            nn.Conv2d(in_channels=226, out_channels=226, kernel_size=3, padding=1, stride=1),
            nn.SiLU(),
            nn.BatchNorm2d(226)
        )

        self.decoder = nn.ModuleList([

            Rescaler(226, 226, upscale=True), # 4x4
            Unet_Block(226, 172),

            Rescaler(64, 64, upscale=True), #8x8
            Unet_Block(192, 128),

            Rescaler(128, 128, upscale=True), #16x16
            Unet_Block(128, 64),

            Rescaler(32, 32, upscale=True), # 32x32
            Unet_Block(64, 32),

            nn.Conv2d(in_channels=32, out_channels=in_channels, kernel_size=1, stride=1, padding=0)

        ])

    def forward(self, x):

        residual = []
        # encoder
        for idx, layer in enumerate(self.encoder):
            x = layer(x)
        if(idx%2 == 0):
            residual.append(x)

        # latent
        x = self.latent(x)

        # decoder

        for idx, layer in enumerate(self.decoder):
            x = layer(x)
        if(idx%2 == 0):
            x = torch.cat([x, residual[int((idx/2)-1)]], dim=1) # (B, CH, H, W) - dim=1

        return x




In [32]:
model = U_Net(3)

TypeError: list is not a Module subclass

In [28]:
int(1.1)

1

In [23]:
params_sum = 0
for params in model.parameters():
    params_sum+=params.view(-1).size(0)
params_sum

3103109