In [None]:
from fastai.vision.data import *
import fastai
from fastai.vision import *
from PIL import Image
import torch
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from helpers import *

In [None]:
path_data = Path('training/')

path_lbl = path_data/'croppedLabels'
path_img = path_data/'croppedImages'

# get images and labels filenames
img_names = get_image_files(path_img)
lbl_names = get_image_files(path_lbl)

print(len(img_names), len(lbl_names))

In [None]:

def get_lbl_fn(img_fn: Path):  
  
    img_name = img_fn.name
    lbl_name = img_name

    return img_fn.parent.parent/('croppedLabels/' + lbl_name)

fname = Path('training/croppedImages/satImage_2_crop_2.png')

img = open_image(fname)
mask = open_mask(get_lbl_fn(fname))

fig, ax = plt.subplots(1,2, figsize=(10,6))

img.show(ax[0])
mask.show(ax[1])

In [None]:
# Classes for segmentation with 0,255 labels:
class SegLabelListCustom(SegmentationLabelList):
    def open(self, fn):
        return open_mask(fn, div=True)
class SegItemListCustom(SegmentationItemList):
    _label_cls = SegLabelListCustom

bs = 4
patch_shape = 16

print(f'Batch size:{bs}')
print(f'Patch shape:{patch_shape}')

src = (SegItemListCustom.from_folder(
    path_img).split_by_rand_pct().label_from_func(
        lambda x: path_lbl / x.relative_to(path_img), classes=['rest',
                                                               'road']))
data = (src.transform(get_transforms(flip_vert=True),
                      size=patch_shape,
                      tfm_y=True).databunch(bs=bs).normalize(imagenet_stats))
data

In [None]:
data.show_batch(2)

In [None]:
def acc_metric(input, target):
    target = target.squeeze(1)
    return (input.argmax(dim=1)==target).float().mean()

# weight decay
wd = 1e-2
#learning rate
lr=1e-3

# load the model, according to the data parameters (resolution, for example)
learn = unet_learner(data, models.resnet34, metrics=acc_metric, wd=wd)

#train the model with 3 epochs
#learn.fit_one_cycle(3, lr)

In [None]:
learn.load('road_Resnet34')


In [None]:

# select one image from the validation dataset
img = learn.data.valid_ds.x[22]
mask = learn.data.valid_ds.y[22]
pred = learn.predict(img)[0]

fig, ax = plt.subplots(1,3, figsize=(12,6))

img.show(ax[0])
mask.show(ax[1])
pred.show(ax[2])
img.shape

# Prediction part

In [None]:
def predictImage(img_path: Path, out_folder: Path):
    '''
    Predict the mask of an image and save the result in the wanted folder
    :param img: path to the image to predict
    :param out_folder: Path in which the image will be saved
    '''
    
    img = load_image(img_path)
    size = img.shape[1]
    pred = []
    cropped = img_crop(img,16,16)
    imgss = []
    numberOfPatches = (size//16)**2
    for i in range(len(cropped)):
        Image.fromarray((cropped[i] * 255).astype(np.uint8),'RGB').save("training/croppedPredictions/satImage_"+str(i)+"_crop"+".png")   
    for i in range(numberOfPatches):
        im = open_image("training/croppedPredictions/satImage_"+str(i)+"_crop.png")
        pred.append(learn.predict(im)[0])
        predmask = np.array([np.array(i.data) for i in pred])
    predmask = predmask.reshape((numberOfPatches,16,16))
    img = concatenate_mask(predmask, size)
    if not (out_folder / img_path.name.replace(".png","_prediction.png")).exists():
        try:
            (out_folder).mkdir()
        except:
            print("file exist")
        print("wenther")
    out = Image.fromarray((img * 255).astype(np.uint8),'L').save(out_folder / img_path.name.replace(".png","_prediction.png"))   
    return out
    

In [None]:
def concatenate_mask(cropped_masks,size):
    '''
    reassemble a nparray of mask to an Image
    :param cropped_masks: nparray of shape(625,16,16) containing all cropped 16x16 masks
    :return out: nparray of shape(400,400)
    '''
    w = size // 16
    h = size // w
    out = np.zeros((size, size))
    for i in range(w):
        columns = np.concatenate(cropped_masks[0 + i * w:w + i * w], axis=0)
        out[:, 0 + i * h:h + i * h] = columns
    return out

In [None]:
imgToTest_path = Path('training/images/satImage_009.png')
out = Path('training/Prediction/')
predictedMask = predictImage(imgToTest_path, out )
realmask = open_image('training/labels/satImage_009.png')




In [None]:
imagetotest = open_image('training/images/satImage_009.png')
realmask = open_image('training/labels/satImage_009.png')
predictedMask = open_image('training/Prediction/satImage_009_prediction.png')

In [None]:
imagetotest.show()
#realmask.show()
predictedMask.show()

In [None]:
imgToTest_path = Path('test_set_images/test_set_images/test_4/test_4.png')
out = Path('training/Prediction/')
predictedMask = predictImage(imgToTest_path, out)
#realmask = open_image('training/labels/satImage_009.png')




In [None]:
imgToTest_path = Path('test_set_images/test_set_images/test_4/test_4.png')
#realmask = open_image('training/labels/satImage_002.png')
predictedMask = open_image('training/Prediction/test_4_prediction.png')
imgtoTest = open_image(imgToTest_path)
predictedMask.show()
imgtoTest.show()

In [None]:
imgtoTest.shape

# create submission

In [None]:
from mask_to_submission import *
from tqdm import tqdm
testPath = Path('test_set_images/test_set_images/')
out = Path('test_set_images/predictions/')
list_img = []
def createSub(out_folder: Path, submission_file_name):
    
    for i in tqdm(range(50)):
        
        mask = predictImage(testPath / f"test_{i+1}/test_{i+1}.png", out/f"test_{i+1}")
        list_img.append(Image.open(out / f"test_{i+1}/test_{i+1}_prediction.png"))
    mask_to_submission(submission_file_name, list_img)


In [None]:
createSub(out, "testsub")