In [7]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as TF

import os
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

In [8]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

cuda


## Hyperparameters

In [9]:
EPOCHS = 100
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_WORKERS = 4

LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_FILE = "checkpoint-test.pth"

IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
IMAGE_CHANNELS = 3




## Model Architecture

In [10]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


In [11]:
class UNET(nn.Module):
    def __init__(self, in_channels=IMAGE_CHANNELS, out_channels=IMAGE_CHANNELS):
        super(UNET, self).__init__()
        self.conv1 = DoubleConv(in_channels, 64)
        self.conv2 = DoubleConv(64, 128)
        self.conv3 = DoubleConv(128, 256)
        self.conv4 = DoubleConv(256, 512)
        self.conv5 = DoubleConv(512, 1024)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv6 = DoubleConv(1024, 512)
        self.conv7 = DoubleConv(512, 256)
        self.conv8 = DoubleConv(256, 128)
        self.conv9 = DoubleConv(128, 64)

        self.tconv1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.tconv2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.tconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.tconv4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)

        self.bottleneck = DoubleConv(1024, 1024)
        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        c1 = self.conv1(x)
        skip_connections.append(c1)
        c2 = self.conv2(self.maxpool(c1))
        skip_connections.append(c2)
        c3 = self.conv3(self.maxpool(c2))
        skip_connections.append(c3)
        c4 = self.conv4(self.maxpool(c3))
        skip_connections.append(c4)
        out = self.conv5(self.maxpool(c4))
        out = self.bottleneck(out)
        skip_connections = skip_connections[::-1]
        out = self.tconv1(out)
        if out.shape != skip_connections[0].shape:
            out = TF.resize(out, size=skip_connections[0].shape[2:])
        out = torch.cat([out, skip_connections[0]], dim=1)
        out = self.conv6(out)

        out = self.tconv2(out)
        if out.shape != skip_connections[1].shape:
            out = TF.resize(out, size=skip_connections[1].shape[2:])
        out = torch.cat((out, skip_connections[1]), dim=1)
        out = self.conv7(out)

        out = self.tconv3(out)
        if out.shape != skip_connections[2].shape:
            out = TF.resize(out, size=skip_connections[2].shape[2:])
        out = torch.cat((out, skip_connections[2]), dim=1)
        out = self.conv8(out)

        out = self.tconv4(out)
        if out.shape != skip_connections[3].shape:
            out = TF.resize(out, size=skip_connections[3].shape[2:])
        out = torch.cat((out, skip_connections[3]), dim=1)
        out = self.conv9(out)

        out = self.out(out)
        return out
    


