In [None]:
#@title Mount Google Drive
from google.colab import drive
drive.mount('/content/drive/')
%cd "drive/MyDrive/Segmentation"

In [None]:
#@title PIP installations
!git clone https://github.com/milesial/Pytorch-UNet unet --quiet
!pip install -r ./unet/requirements.txt --quiet
!pip install torchmetrics --quiet

In [3]:
#@title Imports
import os
import gc
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, Subset, DataLoader
from torchmetrics.classification import BinaryJaccardIndex
from unet.unet import UNet
from unet.utils.dice_score import dice_loss

In [4]:
#@title Parameters
dataset_path = "dataset"
images_path = dataset_path+"/images"
masks_path = dataset_path+"/masks"
validation_path = "validation"

width = 640
height = 360

epochs = 15
batch_size = 8
last_epoch = 15
learning_rate = 1e-8

In [5]:
#@title Model class
class InvSegmentater(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.unet = UNet(n_channels=3, n_classes=1, bilinear=False)

    def forward(self, input):
        masks_pred = self.unet(input)
        masks_pred = masks_pred.reshape(masks_pred.shape[0], height, width)
        return masks_pred

In [6]:
#@title Dataset class
class SegmentationDataset(Dataset):
    def __init__(self):
        self.length = len([name for name in os.listdir(images_path) if os.path.isfile(os.path.join(images_path, name))])

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        image = Image.open(f"{images_path}/{str(idx+1).zfill(12)}.png")
        mask = Image.open(f"{masks_path}/{str(idx+1).zfill(12)}.png")
        image = image.resize((width, height), Image.Resampling.NEAREST)
        mask = mask.resize((width, height), Image.Resampling.NEAREST)
        image = np.array(image.convert("RGB"))
        mask = np.array(mask.convert("L"))
        image = np.moveaxis(image, -1, 0) # channels first
        mask = mask == 255
        image = image.astype("float32")
        mask = mask.astype("float32")
        return image, mask

In [None]:
#@title Instancing
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on", device)
np.set_printoptions(suppress=True)

model = InvSegmentater().to(device)
if last_epoch is not None and not last_epoch <= 0:
    model.load_state_dict(torch.load(f"checkpoints/checkpoint_epoch{str(last_epoch).zfill(3)}.ckpt", map_location=torch.device(device)))
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=0, verbose=True)
lossfn = torch.nn.BCEWithLogitsLoss()
# "This loss combines a Sigmoid layer and the BCELoss in one single class. This version is more numerically stable than using a plain Sigmoid followed by a BCELoss..."
# see https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
# Thats also the reason why the model doesn't include the sigmoid layer and why it has to be applied seperately
metric = BinaryJaccardIndex().to(device)

dataset = SegmentationDataset()
dataset_length = len(dataset)
validation_set = Subset(dataset, range(0, round(0.1*dataset_length)))
training_set = Subset(dataset, range(round(0.1*dataset_length), dataset_length))
validation_loader = DataLoader(validation_set, batch_size=batch_size, shuffle=False)
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
print("dataset size:", dataset_length)
print("training set size:", len(training_set))
print("validation set size:", len(validation_set))

In [8]:
#@title Validation loop
# modified from https://github.com/milesial/Pytorch-UNet/blob/master/train.py
n_batches_val = len(validation_loader)

def validate(save_images=False):
    os.makedirs(validation_path, exist_ok=True)
    model.train(False)
    count = 0
    avg_iou = 0
    avg_loss = 0
    for batch, (images, masks) in enumerate(validation_loader):
        images = images.to(device=device)
        masks = masks.to(device=device)

        with torch.no_grad():
            masks_pred = model(images)
            loss = lossfn(masks_pred, masks.float())
            loss += dice_loss(torch.sigmoid(masks_pred.squeeze(1)), masks.float(), multiclass=False)
            loss = loss.item()
            avg_loss += loss

            masks_pred = torch.round(torch.sigmoid(masks_pred.squeeze(1)))
            iou = metric(masks_pred, masks.float()).item()
            avg_iou += iou
            print(f"validation batch: {batch+1}/{n_batches_val} loss: {loss} IoU: {iou}")

        if not save_images:
            continue
        for img in torch.round(torch.sigmoid(masks_pred.squeeze(1))).cpu().numpy().astype(np.uint8):
            count += 1
            path = f"{validation_path}/{str(count).zfill(12)}.png"
            Image.fromarray(img.squeeze()*255, "L").save(path, "PNG")
    avg_iou /= count
    avg_loss /= count
    print("validation loss:", avg_loss, "validation IoU:", avg_iou)
    return avg_loss, avg_iou

In [None]:
#@title Trainings loop
# modified from https://github.com/milesial/Pytorch-UNet/blob/master/train.py

gc.collect()
loss_per_epoch = []
iou_per_epoch = []
n_batches_train = len(training_loader)
for epoch in range(last_epoch+1, epochs+1):
    model.train(True)
    for batch, (images, masks) in enumerate(training_loader):
        optimizer.zero_grad(set_to_none=True)

        images = images.to(device=device)
        masks = masks.to(device=device)

        masks_pred = model(images)

        loss = lossfn(masks_pred.squeeze(1), masks.float())
        loss += dice_loss(torch.sigmoid(masks_pred.squeeze(1)), masks.float(), multiclass=False)

        loss.backward()
        optimizer.step()

        print(f"epoch: {epoch} batch: {batch+1}/{n_batches_train} lr: {optimizer.param_groups[0]['lr']}")
        print(f"Loss: {loss.item()} IoU: {metric(torch.round(torch.sigmoid(masks_pred.squeeze(1))), masks).item()}")
        print("-----------------------------------------")
    validation_loss, validation_iou = validate()
    loss_per_epoch.append(validation_loss)
    iou_per_epoch.append(validation_iou)
    scheduler.step(validation_loss)
    torch.save(model.state_dict(), f"checkpoints/checkpoint_epoch{str(epoch).zfill(3)}.ckpt")

In [None]:
#@title Validate for all epochs
loss_per_epoch = []
iou_per_epoch = []
for epoch in range(1, epochs+1):
    print("Epoch:", epoch)
    model = InvSegmentater().to(device)
    model.load_state_dict(torch.load(f"checkpoints/checkpoint_epoch{str(epoch).zfill(3)}.ckpt", map_location=torch.device(device)))
    loss, iou = validate()
    loss_per_epoch.append(loss)
    iou_per_epoch.append(iou)

In [None]:
#@title Plot results and save validation masks
plt.plot(loss_per_epoch)
plt.title("Loss per Epoch")
plt.xlabel("Epochs")
plt.ylabel("Loss")
#plt.gca().set_ylim([0, 1])
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.savefig("loss_per_epoch.png")
plt.show()

plt.plot(iou_per_epoch)
plt.title("IoU per Epoch")
plt.xlabel("Epochs")
plt.ylabel("IoU")
#plt.gca().set_ylim([0, 1])
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.savefig("iou_per_epoch.png")
plt.show()


model = InvSegmentater().to(device)
model.load_state_dict(torch.load(f"checkpoints/checkpoint_epoch015.ckpt", map_location=torch.device(device)))
loss, iou = validate(save_images=True)
print("Saved images with average loss:", loss, "and average IoU:", iou)