# Segmentation tasks

The following notebook performs the semantic segmentation task on 4 sets of images:
- raw images
- noisy augmented images
- denoised augmented images
- denoised raw images

It will make use of a list of models and compute metrics on them.

In [23]:
%load_ext autoreload
%autoreload 2

from models_loading import *
from data_loading import * 
from tasks_loading import *
from metrics_loading import *

from utils import *

# Used for tests purpose
import torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Models

### Models loading

In [2]:
load_model("", list_available=True)

# sinet = load_model("SINet").eval()
unet = load_model("UNet").eval()

Available models aliases:
  SINet
  UNet
  HRNet


In [3]:
# Fake data
x = torch.randn(1, 3, 1024, 2048)

In [4]:
# y = sinet(x)
# print(y.shape)
# y = unet(x)
# print(y.shape)

### Data Loading

In [5]:
# Set here the directories locations to the data for the evaluation task

cityscapes_dir = 'HRNet/data/cityscapes'
denoized_dir = '../candidate_model_predictions'
# cityscapes_distorted_dir = 'HRNet/data/cityscapes'

In [6]:
load_data("", list_available=True)

# cs = load_data("cityscapes", location=cityscapes_dir)

Available dataloaders aliases:
  cityscapes


### Evaluation Task

In [7]:
load_tasks("", list_available=True)

semantic_task = load_tasks("semantic")
semantic_unet = semantic_task(unet)

Available tasks aliases:
  semantic


In [8]:
# y = semantic_unet(x)
y, colored = semantic_unet(x, get_colored=True)
y, classified = semantic_unet(x, get_classification=True)
y.shape, classified.shape, colored.shape

Using 19 classes.
Using 19 classes.


(torch.Size([1, 19, 1024, 2048]), (1, 1024, 2048), (1, 3, 1024, 2048))

### Metrics

In [13]:
load_metrics("", list_available=True)

iou = load_metrics("iou")
dice = load_metrics("dice")

Available tasks aliases:
  iou
  dice


In [14]:
print(dice(y, torch.from_numpy(classified)))
print(iou(y, torch.from_numpy(classified)))

tensor(1.)
tensor(1.)


In [52]:
store_results(colored, 'output/')

(1, 1024, 2048, 3)
(1, 1024, 2048, 3)
output/result_0.png
