# Goals and methods
In this notebook, we will be using 'Segmentation_models.pytorch' to setup the different neural network architectures and experiment with different parameters. We will be using 'Albumentations' for data augmentation, given we have a pretty small dataset to work with. The whole implementation is based on pytorch, so that will be the 'glue' in our project, to put it all together

# 1. Imports

In [9]:
# Data tools
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import numpy as np
import json

# Data visualization
import cv2
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw

# Data loading
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

# 2. Loading dataset

### 2.1. Directories

In [4]:
# Global data directory
DATA_DIR = './data/'

# Train dataset directory
train_file = os.path.join(DATA_DIR, 'train.json')
#x_train_dir = os.path.join(DATA_DIR, 'train.json')
#y_train_dir = os.path.join(DATA_DIR, 'trainannot')

# Validation dataset directory
val_file = os.path.join(DATA_DIR, 'val.json')
#x_valid_dir = os.path.join(DATA_DIR, 'val')
#y_valid_dir = os.path.join(DATA_DIR, 'valannot')

# Testing dataset directory
test_file = os.path.join(DATA_DIR, 'test.json')
#x_test_dir = os.path.join(DATA_DIR, 'test')
#y_test_dir = os.path.join(DATA_DIR, 'testannot')

### 2.2. From polygon points to mask
Our dataset is built different from the one in the tutorial which uses images and masks in image shape for the annotations. Ours is a list of points of a polygon for the annotations, so we fit ours to the method we're using through this function.

In [3]:
def pol_to_mask(img_width, img_height, pol_array):
    img = Image.new('L', (img_width, img_height), 0)
    ImageDraw.Draw(img).polygon(pol_array, outline=1, fill=1)
    mask = np.array(img)
    
    return mask

### 2.3. Visualization

In [5]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

### 2.3. Dataloader
Helper class for data extraction, transformation and preprocessing.

In [8]:
class Dataset(BaseDataset):
    """Solar Panel dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    
    """
    
    CLASSES = ['solar panel']
    
    def __init__(
            self, 
            data_dir,
            #images_dir, 
            masks_dir, 
            classes=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        
        # Opening Json file with data and storing it in array
        with open(data_dir) as json_file:
            data_file = json.load(json_file)
        
        # Recovering image ids
        ids = []
        imgs_dir = []
        for image_data in data_file['images']:
            ids.append(image_data['id'])
            imgs_dir.append(image_data['file_name'])
        
        self.ids = ids
        
        
        #self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
        
        # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        
        # read data
        image = cv2.imread(self.images_fps[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[i], 0)
        
        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype('float')
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
            
        return image, mask
        
    def __len__(self):
        return len(self.ids)