In [None]:
from fastai.vision.data import *
import fastai
from fastai.vision import *
from PIL import Image
import torch
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from helpers import *

In [None]:
path_data = Path('training/')

path_lbl = path_data/'croppedLabels'
path_img = path_data/'croppedImages'

# get images and labels filenames
img_names = get_image_files(path_img)
lbl_names = get_image_files(path_lbl)

print(len(img_names), len(lbl_names))

In [None]:

def get_lbl_fn(img_fn: Path):  
  
    img_name = img_fn.name
    lbl_name = img_name

    return img_fn.parent.parent/('croppedLabels/' + lbl_name)

fname = Path('training/croppedImages/satImage_2_crop_2.png')

img = open_image(fname)
mask = open_mask(get_lbl_fn(fname))

fig, ax = plt.subplots(1,2, figsize=(10,6))

img.show(ax[0])
mask.show(ax[1])

In [None]:
# Classes for segmentation with 0,255 labels:
class SegLabelListCustom(SegmentationLabelList):
    def open(self, fn):
        return open_mask(fn, div=True)
class SegItemListCustom(SegmentationItemList):
    _label_cls = SegLabelListCustom

bs = 4
patch_shape = 16

print(f'Batch size:{bs}')
print(f'Patch shape:{patch_shape}')

src = (SegItemListCustom.from_folder(
    path_img).split_by_rand_pct().label_from_func(
        lambda x: path_lbl / x.relative_to(path_img), classes=['rest',
                                                               'road']))
data = (src.transform(get_transforms(flip_vert=True),
                      size=patch_shape,
                      tfm_y=True).databunch(bs=bs).normalize(imagenet_stats))
data

In [None]:
data.show_batch(2)

In [None]:
def acc_metric(input, target):
    target = target.squeeze(1)
    return (input.argmax(dim=1)==target).float().mean()

# weight decay
wd = 1e-2
#learning rate
lr=1e-3

# load the model, according to the data parameters (resolution, for example)
learn = unet_learner(data, models.resnet34, metrics=acc_metric, wd=wd)

#train the model with 3 epochs
learn.fit_one_cycle(3, lr)

In [None]:
#learn.save('road_Resnet34')

# select one image from the validation dataset
img = learn.data.valid_ds.x[22]
mask = learn.data.valid_ds.y[22]
pred = learn.predict(img)[0]

fig, ax = plt.subplots(1,3, figsize=(12,6))

img.show(ax[0])
mask.show(ax[1])
pred.show(ax[2])
img.shape

# Prediction part

In [None]:
def concatenate_mask(list_masks):
    """list_masks of size (625, 16, 16) where the image 
    is assembled column after column so first 25 elements 
    are the first elements in the first column starting from pos [0,0] to [25,0]
    masks are numpy of size (16,16)"""
    z = np.zeros((400, 400))
    for i in range(25):
        columns = np.concatenate(list_masks[0 + i * 25:25 + i * 25], axis=0)
        z[:, 0 + i * 16:16 + i * 16] = columns
    return z

In [None]:
def predictImage(img):
    pred = []
    cropped = img_crop(img,16,16)
    imgss = []
    for i in range(len(cropped)):
        Image.fromarray((cropped[i] * 255).astype(np.uint8),'RGB').save("training/croppedPredictions/satImage_"+str(i)+"_crop"+".png")   
    for i in range(625):
        im = open_image("training/croppedPredictions/satImage_"+str(i)+"_crop.png")
        pred.append(learn.predict(im)[0])
        predmask = np.array([np.array(i.data) for i in pred])
    predmask = predmask.reshape((625,16,16))
    img = concatenate_mask(predmask)
    out = Image.fromarray((img * 255).astype(np.uint8),'L').save("training/Prediction/satImage.png")   
    return 1
    

In [None]:
imagetotest = load_image('training/images/satImage_009.png')
predictedMask = predictImage(imagetotest)
realmask = open_image('training/labels/satImage_009.png')




In [None]:
imagetotest = open_image('training/images/satImage_009.png')
realmask = open_image('training/labels/satImage_009.png')
predictedMask = open_image('training/Prediction/satImage.png')

In [None]:
imagetotest.show()
realmask.show()
predictedMask.show()

In [None]:
imagetotest = load_image('training/images/satImage_005.png')
imagetotest

In [None]:
%matplotlib inline
import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt
import os,sys
from PIL import Image

In [None]:
# Helper functions

def load_image(infilename):
    data = mpimg.imread(infilename)
    return data

def img_float_to_uint8(img):
    rimg = img - np.min(img)
    rimg = (rimg / np.max(rimg) * 255).round().astype(np.uint8)
    return rimg

# Concatenate an image and its groundtruth
def concatenate_images(img, gt_img):
    nChannels = len(gt_img.shape)
    w = gt_img.shape[0]
    h = gt_img.shape[1]
    if nChannels == 3:
        cimg = np.concatenate((img, gt_img), axis=1)
    else:
        gt_img_3c = np.zeros((w, h, 3), dtype=np.uint8)
        gt_img8 = img_float_to_uint8(gt_img)          
        gt_img_3c[:,:,0] = gt_img8
        gt_img_3c[:,:,1] = gt_img8
        gt_img_3c[:,:,2] = gt_img8
        img8 = img_float_to_uint8(img)
        cimg = np.concatenate((img8, gt_img_3c), axis=1)
    return cimg

def img_crop(im, w, h):
    list_patches = []
    imgwidth = im.shape[0]
    imgheight = im.shape[1]
    is_2d = len(im.shape) < 3
    for i in range(0,imgheight,h):
        for j in range(0,imgwidth,w):
            if is_2d:
                im_patch = im[j:j+w, i:i+h]
            else:
                im_patch = im[j:j+w, i:i+h, :]
            list_patches.append(im_patch)
    return list_patches

In [None]:

# select one image from the validation dataset
img = learn.data.valid_ds.x[21]
mask = learn.data.valid_ds.y[21]
pred = learn.predict(img)[0]

fig, ax = plt.subplots(1,3, figsize=(12,6))

img.show(ax[0])
mask.show(ax[1])
pred.show(ax[2])

In [None]:
learn.show_results()

In [None]:
testimg = open_image('training/images/satImage_001.png')
testimg.crop((0,0,200,200))


In [None]:
ff= fastai.vision.Image(testimg.px[0:16,0:16])
pred = learn.predict(ff)[0]

In [None]:
print(type(learn.data.valid_ds.x[7]))

In [None]:
def img_crop(im, w, h):
    im = np.array(im.data)
    list_patches = []
    print(im.shape)
    imgwidth = im.shape[1]
    imgheight = im.shape[2]
    is_2d = len(im.shape) < 3
    for i in range(0,imgheight,h):
        for j in range(0,imgwidth,w):
            if is_2d:
                im_patch = im[j:j+w, i:i+h]
            else:
                im_patch = im[:,j:j+w, i:i+h]
            list_patches.append(fastai.vision.Image(im_patch))
    return list_patches
def img_crop_2(im, w, h):
    list_patches = []
    imgwidth = im.shape[1]
    imgheight = im.shape[2]
    print(imgwidth,imgheight)
    is_2d = len(im.shape) < 3
    for i in range(0,imgheight,h):
        for j in range(0,imgwidth,w):
            if is_2d:
                im_patch = im.px[j:j+w, i:i+h]
                
            else:
                im_patch = im.px[j:j+w, i:i+h, :]
            list_patches.append(fastai.vision.Image(im_patch))
    return list_patches
def img_float_to_uint8(img):
    rimg = img - np.min(img)
    rimg = (rimg / np.max(rimg) * 255).round().astype(np.uint8)
    return rimg


In [None]:

def predictOneImage(img):
    pred = []
    croppedImg = img_crop(img,16,16 )
    for im in croppedImg:
        pred.append(learn.predict(im))
    return pred
testimg = open_image('training/images/satImage_001.png')   
patches = img_crop(testimg,16,16)


In [None]:
def reconstruct_img(patches,patch_size):
    j = []
    for f in patches:
        j.append(np.array(f.data)[0])
    return fastai.vision.Image(np.array([item for sublist in j for item in sublist]))