In [None]:
import os
from torchvision import transforms
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, Dataset
from glob import glob
from matplotlib import pyplot as plt
import numpy as np
from torchvision import transforms
import cv2
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)
denormalize = transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
    std=[1/0.229, 1/0.224, 1/0.255]
)

def preprocess_image(img):
    img = torch.tensor(img).permute(2,0,1)[None].float()
    img = normalize(img)
    return img.to(device)

def stems(split):
    items = list(glob(f'D:\pycharm\DL-Pytorch\Dataset1/images_prepped_{split}/*.png'))
    items_new = [item.split('/')[-1] for item in items]
    items = [item.split('.')[0] for item in items_new]
    return items

def get_segmentation_arr(img, n_classes):
    seg_labels = np.zeros((224, 224, n_classes))
    for c in range(n_classes):
        seg_labels[:, :, c] = (img == c).astype(int)
    return seg_labels

In [None]:
class SegData(Dataset):
    def __init__(self, split):
        self.items = stems(split)
        # print(self.items)
        self.split = split
    def __len__(self):
        return len(self.items)
    def __getitem__(self, ix):
        image = cv2.imread(f'D:\pycharm\DL-Pytorch\Dataset1/images_prepped_{self.split}/{self.items[ix]}.png', 1)
        image = cv2.resize(image, (224,224))
        mask = cv2.imread(f'D:\pycharm\DL-Pytorch\Dataset1/annotations_prepped_{self.split}/{self.items[ix]}.png', 0)
        mask = cv2.resize(mask, (224,224))
        return image, mask
    def choose(self):
        return self[random.randint(len(self))]
    def collate_fn(self, batch):
        ims, ce_masks = [], []
        for item in batch:
            img, mask = item
            img = preprocess_image(img)
            ims.append(img)
            ce_masks.append(torch.tensor(mask)[None].long().to(device))
        images = torch.cat(ims).to(device)
        ce_masks = torch.cat(ce_masks).to(device)
        return images, ce_masks

In [None]:
trn_ds = SegData('train')
val_ds = SegData('test')
trn_dl = DataLoader(trn_ds, batch_size=4, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, batch_size=1, shuffle=True, collate_fn=val_ds.collate_fn)

plt.imshow(cv2.cvtColor(trn_ds[11][0], cv2.COLOR_BGR2RGB))
plt.show()

In [None]:
def conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

def up_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
        nn.ReLU(inplace=True)
    )

In [None]:
from torchvision.models import vgg16_bn
class UNet(nn.Module):
    def __init__(self, pretrained=True, out_channels=12):
        super().__init__()

        self.encoder = vgg16_bn(pretrained=pretrained).features
        self.block1 = nn.Sequential(*self.encoder[:6])
        self.block2 = nn.Sequential(*self.encoder[6:13])
        self.block3 = nn.Sequential(*self.encoder[13:20])
        self.block4 = nn.Sequential(*self.encoder[20:27])
        self.block5 = nn.Sequential(*self.encoder[27:34])

        self.bottleneck = nn.Sequential(*self.encoder[34:])
        self.conv_bottleneck = conv(512, 1024)

        self.up_conv6 = up_conv(1024, 512)
        self.conv6 = conv(512 + 512, 512)
        self.up_conv7 = up_conv(512, 256)
        self.conv7 = conv(256 + 512, 256)
        self.up_conv8 = up_conv(256, 128)
        self.conv8 = conv(128 + 256, 128)
        self.up_conv9 = up_conv(128, 64)
        self.conv9 = conv(64 + 128, 64)
        self.up_conv10 = up_conv(64, 32)
        self.conv10 = conv(32 + 64, 32)
        self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1)
    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)

        bottleneck = self.bottleneck(block5)
        x = self.conv_bottleneck(bottleneck)

        x = self.up_conv6(x)
        x = torch.cat([x, block5], dim=1)
        x = self.conv6(x)

        x = self.up_conv7(x)
        x = torch.cat([x, block4], dim=1)
        x = self.conv7(x)

        x = self.up_conv8(x)
        x = torch.cat([x, block3], dim=1)
        x = self.conv8(x)

        x = self.up_conv9(x)
        x = torch.cat([x, block2], dim=1)
        x = self.conv9(x)

        x = self.up_conv10(x)
        x = torch.cat([x, block1], dim=1)
        x = self.conv10(x)

        x = self.conv11(x)

        return x

In [None]:
ce = nn.CrossEntropyLoss()
def UnetLoss(preds, targets):
    ce_loss = ce(preds, targets)
    acc = (torch.max(preds, 1)[1] == targets).float().mean()
    return ce_loss, acc

def train_batch(model, data, optimizer, criterion):
    model.train()
    ims, ce_masks = data
    _masks = model(ims)
    optimizer.zero_grad()
    loss, acc = criterion(_masks, ce_masks)
    loss.backward()
    optimizer.step()
    return loss.item(), acc.item()

@torch.no_grad()
def validate_batch(model, data, criterion):
    model.eval()
    ims, masks = data
    _masks = model(ims)
    loss, acc = criterion(_masks, masks)
    return loss.item(), acc.item()

In [None]:
model = UNet().to(device)
criterion = UnetLoss
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 30

In [None]:
train_loss_epochs = []
val_loss_epochs = []
# log = Report(n_epochs)
for epoch in range(n_epochs):
    N = len(trn_dl)
    trn_loss = []
    val_loss = []
    for ix, data in enumerate(trn_dl):
        loss, acc = train_batch(model, data, optimizer, criterion)
        pos = (epoch + (ix+1)/N)
        trn_loss.append(loss)
    train_loss_epochs.append(np.average(trn_loss))

    N = len(val_dl)
    for bx, data in enumerate(val_dl):
        loss, acc = validate_batch(model, data, criterion)
        pos = (epoch + (ix+1)/N)
        val_loss.append(loss)
    val_loss_epochs.append(np.average(val_loss))

In [None]:
epochs = np.arange(n_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()

In [None]:
im, mask = next(iter(val_dl))
_mask = model(im)

In [None]:
_, _mask = torch.max(_mask, dim=1)

In [None]:
plt.subplot(131)
plt.imshow(im[0].permute(1,2,0).detach().cpu()[:,:,0], cmap='gray')
plt.title('Original image')
plt.subplot(132)
plt.imshow(mask.permute(1,2,0).detach().cpu()[:,:,0], cmap='gray')
plt.title('Original mask')
plt.subplot(133)
plt.imshow(_mask.permute(1,2,0).detach().cpu()[:,:,0], cmap='gray')
plt.title('Predicted mask')
plt.show()