In [116]:
#!unzip data_evensplit.zip

In [117]:
import os
import time
from glob import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import random
import numpy as np
import cv2
from operator import add
from tqdm import tqdm
import imageio
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score, roc_auc_score


In [118]:
def set_seed(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def train_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [119]:
class DriveDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images_path = images_path
        self.masks_path = masks_path
        self.n_samples = len(images_path)

    def __getitem__(self, index):
        image = cv2.imread(self.images_path[index], cv2.IMREAD_GRAYSCALE)
        image = image/255.0
        image = np.expand_dims(image, axis=0)
        image = image.astype(np.float32)
        image = torch.from_numpy(image)

        mask = cv2.imread(self.masks_path[index], cv2.IMREAD_GRAYSCALE)
        mask = mask/255.0
        mask = np.expand_dims(mask, axis=0)
        mask = mask.astype(np.float32)
        mask = torch.from_numpy(mask)

        return image, mask

    def __len__(self):
        return self.n_samples

In [120]:
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.drop = nn.Dropout(0.2)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.drop(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

class Downs(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv = Conv(in_channels, out_channels)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)

        return x, p

class Ups(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.conv = Conv(out_channels+out_channels, out_channels)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x

class UNET(nn.Module):
    def __init__(self):
        super().__init__()
        self.down1 = Downs(1, 64)
        self.down2 = Downs(64, 128)
        self.down3 = Downs(128, 256)
        self.down4 = Downs(256, 512)

        self.bottleneck = Conv(512, 1024)

        self.up1 = Ups(1024, 512)
        self.up2 = Ups(512, 256)
        self.up3 = Ups(256, 128)
        self.up4 = Ups(128, 64)

        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        skip1, down1 = self.down1(inputs)
        skip2, down2 = self.down2(down1)
        skip3, down3 = self.down3(down2)
        skip4, down4 = self.down4(down3)

        b = self.bottleneck(down4)

        up1 = self.up1(b, skip4)
        up2 = self.up2(up1, skip3)
        up3 = self.up3(up2, skip2)
        up4 = self.up4(up3, skip1)

        outputs = self.outputs(up4)

        return outputs

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        inputs = torch.sigmoid(inputs)

        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersect = (inputs * targets).sum()
        dice = (2.*intersect + smooth)/(inputs.sum() + targets.sum() + smooth)

        return 1 - dice

class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        inputs = torch.sigmoid(inputs)

        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersect = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersect + smooth)/(inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE

In [121]:
def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0
    model.train()
    for x,y in loader:
        x = x.to(device, dtype = torch.float32)
        y = y.to(device, dtype = torch.float32)

        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss = epoch_loss/len(loader)
    return epoch_loss

def evaluate(model, loader, loss_fn, device):
    epoch_loss = 0.0

    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device, dtype=torch.float32)
            y = y.to(device, dtype=torch.float32)

            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            epoch_loss += loss.item()

        epoch_loss = epoch_loss/len(loader)
    return epoch_loss

In [None]:
set_seed(42)
create_dir("files")

X_train = sorted(glob(r"/content/data/train/image/*"))
y_train = sorted(glob(r"/content/data/train/mask/*"))

X_val = sorted(glob(r"/content/data/test/image/*"))
y_val = sorted(glob(r"/content/data/test/mask/*"))

height = 512
width = 512
img_size = (height, width)
batch_size = 2
epochs = 100
lr = 1e-4
chkpt_path = "files/checkpoint.pth"

train_data = DriveDataset(X_train, y_train)
val_data = DriveDataset(X_val, y_val)

train_loader = DataLoader(
    dataset = train_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    dataset = val_data,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = UNET()
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
loss_fn = DiceBCELoss()

best_loss = float("inf")
loss_dict = {'train':[],
             'val':[]}

for epoch in range(epochs):

    start = time.time()
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    valid_loss = evaluate(model, val_loader, loss_fn, device)
    loss_dict['train'].append(train_loss)
    loss_dict['val'].append(valid_loss)

    if valid_loss < best_loss:
        hist = "Validation loss has improved from {:2.4f} to {:2.4f}".format(best_loss, valid_loss)
        print(hist)
        best_loss = valid_loss
        torch.save(model.state_dict(), chkpt_path)

    end = time.time()
    epoch_mins, epoch_secs = train_time(start, end)
    hist = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s\n'
    hist += f'\tTrain Loss: {train_loss:.3f}\n'
    hist += f'\t Val. Loss: {valid_loss:.3f}\n'
    print(hist)



Validation loss has improved from inf to 1.3024
Epoch: 01 | Epoch Time: 0m 25s
	Train Loss: 1.348
	 Val. Loss: 1.302

Validation loss has improved from 1.3024 to 1.0501
Epoch: 02 | Epoch Time: 0m 26s
	Train Loss: 1.120
	 Val. Loss: 1.050

Validation loss has improved from 1.0501 to 1.0163
Epoch: 03 | Epoch Time: 0m 26s
	Train Loss: 1.057
	 Val. Loss: 1.016

Validation loss has improved from 1.0163 to 0.9788
Epoch: 04 | Epoch Time: 0m 26s
	Train Loss: 1.014
	 Val. Loss: 0.979

Validation loss has improved from 0.9788 to 0.9475
Epoch: 05 | Epoch Time: 0m 26s
	Train Loss: 0.976
	 Val. Loss: 0.948

Validation loss has improved from 0.9475 to 0.9259
Epoch: 06 | Epoch Time: 0m 26s
	Train Loss: 0.939
	 Val. Loss: 0.926

Validation loss has improved from 0.9259 to 0.9026
Epoch: 07 | Epoch Time: 0m 26s
	Train Loss: 0.907
	 Val. Loss: 0.903

Validation loss has improved from 0.9026 to 0.8773
Epoch: 08 | Epoch Time: 0m 27s
	Train Loss: 0.876
	 Val. Loss: 0.877

Validation loss has improved from 0

In [None]:
def calculate_metrics(y_true, y_pred):
    """ Ground truth """
    y_true = y_true.cpu().numpy()
    y_true = y_true > 0.5
    y_true = y_true.astype(np.uint8)
    y_true = y_true.reshape(-1)

    """ Prediction """
    y_pred = y_pred.cpu().numpy()
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.uint8)
    y_pred = y_pred.reshape(-1)

    score_jaccard = jaccard_score(y_true, y_pred)
    score_f1 = f1_score(y_true, y_pred)
    score_recall = recall_score(y_true, y_pred)
    score_precision = precision_score(y_true, y_pred)
    score_acc = accuracy_score(y_true, y_pred)
    roc_auc = roc_auc_score(y_true, y_pred)

    return [score_jaccard, score_f1, score_recall, score_precision, score_acc, roc_auc]

def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)
    mask = np.concatenate([mask, mask, mask], axis=-1)
    return mask


set_seed(42)
create_dir("results")

test_x = sorted(glob("/content/data/test/image/*"))
test_y = sorted(glob("/content/data/test/mask/*"))

height = 512
width = 512
size = (height, width)
checkpoint_path = "files/checkpoint.pth"


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNET()
model = model.to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

metrics_score = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
time_taken = []

for i, (img, msk) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
    name = img.split("/")[-1].split(".")[0]
    image = cv2.imread(img, cv2.IMREAD_GRAYSCALE) ## (512, 512, 3)
    image = cv2.resize(image, size)
    x = np.expand_dims(image, axis=0)      ## (3, 512, 512)
    x = x/255.0
    x = np.expand_dims(x, axis=0)           ## (1, 3, 512, 512)
    x = x.astype(np.float32)
    x = torch.from_numpy(x)
    x = x.to(device)

    mask = cv2.imread(msk, cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, size)
    y = np.expand_dims(mask, axis=0)
    y = y/255.0
    y = np.expand_dims(y, axis=0)
    y = y.astype(np.float32)
    y = torch.from_numpy(y)
    y = y.to(device)

    with torch.no_grad():
        start_time = time.time()
        pred_y = model(x)
        pred_y = torch.sigmoid(pred_y)
        total_time = time.time() - start_time
        time_taken.append(total_time)


        score = calculate_metrics(y, pred_y)
        metrics_score = list(map(add, metrics_score, score))
        pred_y = pred_y[0].cpu().numpy()        ## (1, 512, 512)
        pred_y = np.squeeze(pred_y, axis=0)     ## (512, 512)
        pred_y = pred_y > 0.5
        pred_y = np.array(pred_y, dtype=np.uint8)

    ori_mask = mask_parse(mask)
    pred_y = mask_parse(pred_y)
    line = np.ones((size[1], 10, 3)) * 128

    cat_images = np.concatenate(
        [cv2.resize(cv2.imread(img, cv2.IMREAD_COLOR), size), line, ori_mask, line, pred_y * 255], axis=1
    )
    cv2.imwrite(f"results/{name}.png", cat_images)

jaccard = metrics_score[0]/len(test_x)
f1 = metrics_score[1]/len(test_x)
recall = metrics_score[2]/len(test_x)
precision = metrics_score[3]/len(test_x)
acc = metrics_score[4]/len(test_x)
roc_auc = metrics_score[5]/len(test_x)
print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f} - ROC-AUC: {roc_auc:1.4f}")

fps = 1/np.mean(time_taken)
print("FPS: ", fps)

In [None]:
#!zip -r results.zip results/

In [None]:
cv2.imread(img, cv2.IMREAD_COLOR).shape

In [None]:
import matplotlib.pyplot as plt
plt.plot(range(epochs), loss_dict['train'])
plt.plot(range(epochs), loss_dict['val'])