In [1]:
import sys

sys.path.insert(0, 'utils')

In [None]:
import os
from PIL import Image
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import torchvision.transforms.functional as TF
from tqdm import tqdm
import torch.optim as optim

from Network import *


%load_ext autoreload
%autoreload 2

In [50]:
"""
Definition of the dataset
"""
class RoadDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.images[index])
        image = np.array(Image.open(img_path).convert("RGB"))
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32) # grayscale
        mask = np.where(mask > 4.0, 1.0, 0.0)
       
        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        return image, mask

In [51]:
"""
Returns the loader for the dataset
"""
def get_loaders(
    train_dir,
    train_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True,
):
    train_dataset = RoadDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    val_dataset = RoadDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=val_transform,
    )
    
    train_ds, _ = random_split(
        train_dataset,
        [80, 20], 
        generator=torch.Generator().manual_seed(42))

    _, val_ds = random_split(
        val_dataset,
        [80, 20], 
        generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False,
    )

    return train_loader, val_loader

In [52]:
"""
Returns the set of transformations that will be performed on the dataset
depending on the parameters given
"""
def get_transform(
    image_height,
    image_width,
    max_rotation,
    p_hflip,
    p_vflip,
    normalize,
    crop_width,
    crop_height
):
    cwidth = int(crop_width)
    cheight = int(crop_height)
    transformations = []
    if(image_height != 0):
        transformations.append(A.Resize(height=image_height, width=image_width))
    if(max_rotation > 0):
        transformations.append(A.Rotate(limit=max_rotation, p=1.0))
    if(p_hflip > 0):
        transformations.append(A.HorizontalFlip(p=p_hflip))
    if(p_vflip > 0):
        transformations.append(A.VerticalFlip(p=p_vflip))
    if(normalize):
        transformations.append(A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0, # dividing by 237, get a value between 0 and 1
        ))
    if(cwidth > 0):
        transformations.append(A.RandomCrop(width=cwidth, height=cheight))
    transformations.append(ToTensorV2())
    return A.Compose(transformations)

In [53]:
image_height = 400  #  400 pixels originally
image_width = 400  #  400 pixels originally
train_dir = "data/train_images/"
train_maskdir = "data/train_masks/"
test_dir = 'data/test_images/'
batch_size = 4
num_workers = 0
pin_memory = True

train_transform = get_transform(image_height, image_width, 35, 0.5, 0.1, True, image_width/2, image_height/2)

val_transform = get_transform(image_height, image_width, 0, 0, 0, True, image_width/2, image_height/2)

train_loader, val_loader = get_loaders(
    train_dir,
    train_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers,
    pin_memory,
)

In [108]:
# Hyperparameters etc.
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 10
NUM_EPOCHS = 25
NUM_WORKERS = 0
IMAGE_HEIGHT = 400  # 400 originally
IMAGE_WIDTH = 400  # 400 originally
PIN_MEMORY = True
LOAD_MODEL = False

def train_fn(loader, model, optimizer, loss_fn, scaler,scheduler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE) # add a channel dimension
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
            
        scheduler.step(loss)
        
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

In [100]:
import os
import shutil

determination = 'data/test_images/'
if not os.path.exists(determination):
    os.makedirs(determination)

path = 'data/test/'
folders = os.listdir(path)
for folder in folders:
    dir = path + '/' + str(folder)
    files = os.listdir(dir)
    for file in files:
        source = dir + '/' + str(file)
        deter = determination + '/' + str(file)
        shutil.copyfile(source, deter)

In [101]:
path_list = os.listdir(test_dir)
path_list.sort(key=lambda x: int(x.split(".")[0].split("_")[1]))
path_list

"""
Defines the test set of the dataset
"""
class RoadData_test_set(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = path_list # list all the files that are in that folder

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        image = np.array(Image.open(img_path).convert("RGB"))

        if self.transform is not None:
            augmentations = self.transform(image=image)
            image = augmentations["image"]

        return image

test_transform = A.Compose(
    [
        A.Resize(height=image_height, width=image_width),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

test_dataset = RoadData_test_set(
    image_dir=test_dir,
    transform=test_transform,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    num_workers=0,
    # num_workers =  NUM_WORKERS,
    pin_memory=pin_memory,
    shuffle=False
)

In [115]:
dict_double_conv = {"BatchNorm": True,
        "activation": nn.ReLU(inplace=True),
        "p_dropout": 0.2,
        "use_dropout": False,
        "bias": False}

dict_ups = {"BatchNorm": False,
        "p_dropout": 0.2,
        "use_dropout": False,
        "bias": False}

In [126]:
model = UNET(dict_double_conv, dict_ups,in_channels=3, out_channels=1,init=True).to(DEVICE)

#model = UNET_no_skip_connection(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)#LEARNING_RATE)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.05, patience = 20,verbose = True)

In [121]:
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1) # the grayscale does not have channels, add
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)

    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
    )
    model.train()

def check_F1_score(loader, model, device="cuda"):
    TP = 0
    FP = 0
    TN = 0
    FN = 0
    num_correct = 0
    num_pixels = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).unsqueeze(1) # the grayscale does not have channels, add
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            TP += ((preds == 1)*(y==1)).sum()
            FP += ((preds == 1)*(y==0)).sum()
            FN += ((preds == 0)*(y==1)).sum()
            num_pixels += torch.numel(preds)
            num_correct += (preds == y).sum()
    recall = TP/(TP+FN)
    precision = TP/(TP+FP)
    
    print(
        f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f} and F1-score {2*recall*precision/(recall+precision):.2f}"
    )
    model.train()


def save_predictions_as_imgs(
    test_loader, model, folder="saved_images", device="cuda"
):
    model.eval()
    for idx, x in enumerate(test_loader):
        x = x.to(device=device)
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx+1}.png"
        )

    model.train()

In [122]:
check_F1_score(val_loader, model, device=DEVICE)

Got 196351/800000 with acc 24.54 and F1-score 0.34


In [127]:
scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler,scheduler)

    # check accuracy
    check_F1_score(val_loader, model, device=DEVICE)

# # print some examples to a folder
# save_predictions_as_imgs(
#     test_loader, model, folder="saved_images", device=DEVICE
# )

100%|███████████████████████████████| 20/20 [01:53<00:00,  5.65s/it, loss=0.407]


Got 661067/800000 with acc 82.63 and F1-score 0.44


100%|████████████████████████████████| 20/20 [01:42<00:00,  5.14s/it, loss=0.61]


Got 646130/800000 with acc 80.77 and F1-score 0.18


 60%|██████████████████▌            | 12/20 [00:58<00:38,  4.85s/it, loss=0.433]

Epoch 00053: reducing learning rate of group 0 to 5.0000e-05.


100%|████████████████████████████████| 20/20 [01:39<00:00,  4.97s/it, loss=0.48]


Got 652343/800000 with acc 81.54 and F1-score 0.48


 65%|████████████████████▏          | 13/20 [01:08<00:36,  5.20s/it, loss=0.348]

Epoch 00074: reducing learning rate of group 0 to 2.5000e-06.


100%|███████████████████████████████| 20/20 [01:41<00:00,  5.09s/it, loss=0.482]


Got 674724/800000 with acc 84.34 and F1-score 0.57


100%|███████████████████████████████| 20/20 [01:53<00:00,  5.65s/it, loss=0.448]


Got 685512/800000 with acc 85.69 and F1-score 0.54


 70%|█████████████████████▋         | 14/20 [01:15<00:43,  7.19s/it, loss=0.467]

Epoch 00115: reducing learning rate of group 0 to 1.2500e-07.


100%|███████████████████████████████| 20/20 [01:51<00:00,  5.58s/it, loss=0.373]


Got 652123/800000 with acc 81.52 and F1-score 0.51


 75%|███████████████████████▎       | 15/20 [01:19<00:24,  4.82s/it, loss=0.395]

Epoch 00136: reducing learning rate of group 0 to 6.2500e-09.


100%|███████████████████████████████| 20/20 [01:44<00:00,  5.22s/it, loss=0.437]


Got 685090/800000 with acc 85.64 and F1-score 0.52


100%|███████████████████████████████| 20/20 [01:55<00:00,  5.78s/it, loss=0.348]


Got 676845/800000 with acc 84.61 and F1-score 0.57


100%|███████████████████████████████| 20/20 [01:44<00:00,  5.23s/it, loss=0.484]


Got 681930/800000 with acc 85.24 and F1-score 0.57


100%|███████████████████████████████| 20/20 [01:43<00:00,  5.20s/it, loss=0.379]


Got 672626/800000 with acc 84.08 and F1-score 0.52


100%|███████████████████████████████| 20/20 [01:43<00:00,  5.16s/it, loss=0.474]


Got 672302/800000 with acc 84.04 and F1-score 0.54


100%|███████████████████████████████| 20/20 [02:11<00:00,  6.56s/it, loss=0.482]


Got 680239/800000 with acc 85.03 and F1-score 0.57


100%|████████████████████████████████| 20/20 [02:21<00:00,  7.08s/it, loss=0.35]


Got 669259/800000 with acc 83.66 and F1-score 0.55


 35%|███████████▏                    | 7/20 [00:45<01:23,  6.45s/it, loss=0.367]


KeyboardInterrupt: 