In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.optim.lr_scheduler as sched

import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder

from matplotlib import pyplot as plt

In [9]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

### Parameters

In [10]:
batch_size = 1000
epoch = 30


In [11]:
full_DS = ImageFolder(root="trafic_32", transform=transforms.ToTensor())
full_loader = DataLoader(dataset=full_DS, pin_memory=True, num_workers=0, shuffle=True, batch_size=batch_size)

len(full_DS)

39209

In [12]:
class Unet_Block(nn.Module):
    """
    Base U-Net block
    in_channels - number of input channels
    out_channels - number of out channels
    """

    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.activation = nn.LeakyReLU(negative_slope=0.02)

        self.layers = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=(1,1)),
            self.activation,
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=(1,1)),
            self.activation,
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=(1,1)),
            self.activation,
            nn.BatchNorm2d(out_channels)
        ])

    def froward(self, x):

        for layer in self.layers:
            x = layer(x)

        return x

In [13]:
class Rescaler(nn.Module):
    """
    Resacler either upscales given input or downscales (size x2 or /2 only)
    """

    def __init__(self, in_channels, out_channels, upscale:bool):
        super().__init__()

        if upscale:
            self.rescaler = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=(2,2))
        else:
            self.rescaler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=(2,2))

    def forward(self, x):
        return self.rescaler(x)

In [None]:
class U_Net(nn.Module):

    def __init__(self, in_channels):
        super().__init__()


        self.layers = nn.ModuleList([

            Unet_Block(3, 32),
            Rescaler(32, 32, upscale=False), # 16x16

            Unet_Block(32, 64),
            Rescaler(64, 64, upscale=False), #8x8

            Unet_Block(64, 128),
            Rescaler(128, 128, upscale=False), #4x4

            Unet_Block(128, 192),
            Rescaler(32, 32, upscale=False), # 2x2





        ])