In [1]:
%matplotlib inline
import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt
import os,sys

In [2]:
import torch
import torch.nn as nn
from datasets import RoadsDatasetTrain, RoadsDatasetTest
import torch.utils.data as data
from model import UNet
from train import train
from torchvision import transforms
from PIL import Image
from predict import predict
from helpers import get_train_dataset

In [3]:
BATCH_SIZE = 1
EPOCHS = 2
LEARNING_RATE = 0.0001

PATCH_SIZE = 16
LARGE_PATCH_SIZE = 96

TRAIN_IMAGE_INITIAL_SIZE = 400
NUMBER_PATCH_PER_TRAIN_IMAGE = (TRAIN_IMAGE_INITIAL_SIZE // PATCH_SIZE) * (TRAIN_IMAGE_INITIAL_SIZE // PATCH_SIZE)


TEST_IMAGE_INITIAL_SIZE = 608
NUMBER_PATCH_PER_TEST_IMAGE = (TEST_IMAGE_INITIAL_SIZE // PATCH_SIZE) * (TEST_IMAGE_INITIAL_SIZE // PATCH_SIZE)
CRITERION = nn.BCELoss()

In [4]:
train_data_dir = "./Datasets/training"
train_dataset = RoadsDatasetTrain(patch_size=PATCH_SIZE, large_patch_size=LARGE_PATCH_SIZE, number_patch_per_image=NUMBER_PATCH_PER_TRAIN_IMAGE,image_initial_size= TRAIN_IMAGE_INITIAL_SIZE, root_dir=train_data_dir)
train_dataloader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_data_dir = "./Datasets/test_set_images"
test_dataset = RoadsDatasetTest(patch_size=PATCH_SIZE, large_patch_size=LARGE_PATCH_SIZE, number_patch_per_image=NUMBER_PATCH_PER_TEST_IMAGE,image_initial_size= TEST_IMAGE_INITIAL_SIZE,root_dir=test_data_dir)
test_dataloader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
unet = UNet()
train(model=unet, dataloader=train_dataloader, epochs=EPOCHS, criterion=CRITERION)

[Epoch 0, Batch 0/62500]:  [Loss: 0.66]
[Epoch 0, Batch 10/62500]:  [Loss: 0.66]
[Epoch 0, Batch 20/62500]:  [Loss: 0.61]
[Epoch 0, Batch 30/62500]:  [Loss: 0.58]
[Epoch 0, Batch 40/62500]:  [Loss: 0.57]
[Epoch 0, Batch 50/62500]:  [Loss: 0.65]
[Epoch 0, Batch 60/62500]:  [Loss: 0.55]
[Epoch 0, Batch 70/62500]:  [Loss: 0.56]
[Epoch 0, Batch 80/62500]:  [Loss: 0.53]
[Epoch 0, Batch 90/62500]:  [Loss: 0.59]
[Epoch 0, Batch 100/62500]:  [Loss: 0.51]
[Epoch 0, Batch 110/62500]:  [Loss: 0.57]
[Epoch 0, Batch 120/62500]:  [Loss: 0.81]
[Epoch 0, Batch 130/62500]:  [Loss: 0.50]
[Epoch 0, Batch 140/62500]:  [Loss: 0.53]
[Epoch 0, Batch 150/62500]:  [Loss: 0.54]
[Epoch 0, Batch 160/62500]:  [Loss: 0.53]
[Epoch 0, Batch 170/62500]:  [Loss: 0.49]
[Epoch 0, Batch 180/62500]:  [Loss: 0.53]
[Epoch 0, Batch 190/62500]:  [Loss: 0.61]
[Epoch 0, Batch 200/62500]:  [Loss: 0.60]
[Epoch 0, Batch 210/62500]:  [Loss: 0.46]
[Epoch 0, Batch 220/62500]:  [Loss: 0.57]
[Epoch 0, Batch 230/62500]:  [Loss: 0.46]
[Ep

[Epoch 0, Batch 1940/62500]:  [Loss: 0.28]
[Epoch 0, Batch 1950/62500]:  [Loss: 0.24]
[Epoch 0, Batch 1960/62500]:  [Loss: 0.27]
[Epoch 0, Batch 1970/62500]:  [Loss: 0.82]
[Epoch 0, Batch 1980/62500]:  [Loss: 0.40]
[Epoch 0, Batch 1990/62500]:  [Loss: 0.31]
[Epoch 0, Batch 2000/62500]:  [Loss: 0.32]
[Epoch 0, Batch 2010/62500]:  [Loss: 0.28]
[Epoch 0, Batch 2020/62500]:  [Loss: 0.34]
[Epoch 0, Batch 2030/62500]:  [Loss: 0.49]
[Epoch 0, Batch 2040/62500]:  [Loss: 0.77]
[Epoch 0, Batch 2050/62500]:  [Loss: 0.34]
[Epoch 0, Batch 2060/62500]:  [Loss: 0.85]
[Epoch 0, Batch 2070/62500]:  [Loss: 0.45]
[Epoch 0, Batch 2080/62500]:  [Loss: 0.48]
[Epoch 0, Batch 2090/62500]:  [Loss: 0.90]
[Epoch 0, Batch 2100/62500]:  [Loss: 0.33]
[Epoch 0, Batch 2110/62500]:  [Loss: 0.31]
[Epoch 0, Batch 2120/62500]:  [Loss: 0.29]
[Epoch 0, Batch 2130/62500]:  [Loss: 0.41]
[Epoch 0, Batch 2140/62500]:  [Loss: 0.76]
[Epoch 0, Batch 2150/62500]:  [Loss: 0.43]
[Epoch 0, Batch 2160/62500]:  [Loss: 0.28]
[Epoch 0, B

[Epoch 0, Batch 3850/62500]:  [Loss: 0.13]
[Epoch 0, Batch 3860/62500]:  [Loss: 0.11]
[Epoch 0, Batch 3870/62500]:  [Loss: 0.24]
[Epoch 0, Batch 3880/62500]:  [Loss: 0.27]
[Epoch 0, Batch 3890/62500]:  [Loss: 0.19]
[Epoch 0, Batch 3900/62500]:  [Loss: 0.40]
[Epoch 0, Batch 3910/62500]:  [Loss: 0.33]
[Epoch 0, Batch 3920/62500]:  [Loss: 0.32]
[Epoch 0, Batch 3930/62500]:  [Loss: 0.17]
[Epoch 0, Batch 3940/62500]:  [Loss: 0.33]
[Epoch 0, Batch 3950/62500]:  [Loss: 0.29]
[Epoch 0, Batch 3960/62500]:  [Loss: 0.30]
[Epoch 0, Batch 3970/62500]:  [Loss: 0.21]
[Epoch 0, Batch 3980/62500]:  [Loss: 0.16]
[Epoch 0, Batch 3990/62500]:  [Loss: 0.13]
[Epoch 0, Batch 4000/62500]:  [Loss: 0.29]
[Epoch 0, Batch 4010/62500]:  [Loss: 0.19]
[Epoch 0, Batch 4020/62500]:  [Loss: 0.46]
[Epoch 0, Batch 4030/62500]:  [Loss: 0.21]
[Epoch 0, Batch 4040/62500]:  [Loss: 0.13]
[Epoch 0, Batch 4050/62500]:  [Loss: 0.37]
[Epoch 0, Batch 4060/62500]:  [Loss: 0.23]
[Epoch 0, Batch 4070/62500]:  [Loss: 0.33]
[Epoch 0, B

[Epoch 0, Batch 5760/62500]:  [Loss: 0.08]
[Epoch 0, Batch 5770/62500]:  [Loss: 0.13]
[Epoch 0, Batch 5780/62500]:  [Loss: 0.23]


In [None]:
# (results, indices, masks) = predict(unet, test_dataloader)
predict(unet, test_dataloader)