# Tarea 6: Segmentacion semantica.


In [15]:
import os

In [16]:
if not os.path.exists('pytorch_unet'):
  # A continuacion se clona el repo
  !git clone https://github.com/milesial/Pytorch-UNet.git
  # Se instalan los requisitos del repo
  !pip install -r Pytorch-UNet/requirements.txt
  # se renombra para poder importar las funciones
  !mv Pytorch-UNet pytorch_unet

In [18]:
if not os.path.exists('data_semantics.zip'):
  !wget https://s3.eu-central-1.amazonaws.com/avg-kitti/data_semantics.zip
  !unzip data_semantics.zip

In [None]:
if not os.path.exists('kitti_inverse_map_1channel.py'):
  !wget https://raw.githubusercontent.com/Diego-II/Procesamiento-Avanzado-de-Imagenes/master/Tarea6/kitti_inverse_map_1channel.py

In [29]:
# Importamos las funciones del repositorio
from pytorch_unet.unet import UNet
from kitti_inverse_map_1channel import kitti_inverse_map_1channel

In [27]:
from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image


class BasicDataset(Dataset):
  def __init__(self, imgs_dir, masks_dir, read_mask, scale=1, mask_suffix=''):
  # def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''):
    self.imgs_dir = imgs_dir
    self.masks_dir = masks_dir
    self.read_mask = read_mask
    self.scale = scale
    self.mask_suffix = mask_suffix
    assert 0 < scale <= 1, 'Scale must be between 0 and 1'

    self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
                if not file.startswith('.')]
    logging.info(f'Creating dataset with {len(self.ids)} examples')

  def __len__(self):
    return len(self.ids)

  @classmethod
  def preprocess(cls, pil_img, scale):
    w, h = pil_img.size
    newW, newH = int(scale * w), int(scale * h)
    assert newW > 0 and newH > 0, 'Scale is too small'
    pil_img = pil_img.resize((newW, newH))

    img_nd = np.array(pil_img)

    if len(img_nd.shape) == 2:
      img_nd = np.expand_dims(img_nd, axis=2)

    # HWC to CHW
    img_trans = img_nd.transpose((2, 0, 1))
    if img_trans.max() > 1:
      img_trans = img_trans / 255

    return img_trans

  def __getitem__(self, i):
    idx = self.ids[i]
    mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
    img_file = glob(self.imgs_dir + idx + '.*')

    assert len(mask_file) == 1, \
      f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
    assert len(img_file) == 1, \
      f'Either no image or multiple images found for the ID {idx}: {img_file}'
    # mask = Image.open(mask_file[0])
    mask = kitty_inverse_map_1channel(np.array(mask, dtype=np.int32))
    img = Image.open(img_file[0])

    assert img.size == mask.size, \
      f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'

    # img = self.preprocess(img, self.scale)
    # mask = self.preprocess(mask, self.scale)

    return {
      'image': torch.from_numpy(img).type(torch.FloatTensor),
      'mask': torch.from_numpy(mask).type(torch.FloatTensor)
    }


class CarvanaDataset(BasicDataset):
  def __init__(self, imgs_dir, masks_dir, scale=1):
    super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')