# Calculate mean Intersection-Over-Union (mIOU) metric

A ready-to-use script to find mean Intersection-Over-Union metric of class pairs


**Input**:
- Existing Project (i.e. "london_roads")
- At least one pair of classes (i.e. ("car_gt", "car_lb"))

**Output**:
- intersection, union and IoU for each class pair


## Imports

In [1]:
import supervisely_lib as sly
import os
import collections
from prettytable import PrettyTable
from tqdm import tqdm
from matplotlib import pyplot as plt
import numpy as np

## Configuration

Edit the following settings for your own case

In [2]:
# Change this field to the name of your team, where target workspace exists.
team_name = "The AI Company" # Automatically inserted
# Change this field to the of your workspace, where target project exists.
workspace_name = "Journal" # Automatically inserted
# Change this field to the name of your target project.
project_name = "DL AC test" # Automatically inserted

# Configure the following dictionary  so that is will match pairs of ground truth and predicted classes
# between which IOU will be caluclated.
classes_mapping = {
    "Fabric": "Fabric_dl",
    "Gripper": "Gripper_dl",
    "Wrinkle": "Wrinkle_dl",
}

# If you are running this notebook on a Supervisely web instance, the connection
# details below will be filled in from environment variables automatically.
#
# If you are running this notebook locally on your own machine, edit to fill in the
# connection details manually. You can find your access token at
# "Your name on the top right" -> "Account settings" -> "API token".
address = os.environ['SERVER_ADDRESS']
token = os.environ['API_TOKEN']

## Script setup

Import nessesary packages and initialize Supervisely API to remotely manage your projects

In [3]:
# Initialize API object
api = sly.Api(address, token)

## Verify input values

Test that context (team / workspace / project) exists

In [4]:
team = api.team.get_info_by_name(team_name)
if team is None:
    raise RuntimeError("Team {!r} not found".format(team_name))

workspace = api.workspace.get_info_by_name(team.id, workspace_name)
if workspace is None:
    raise RuntimeError("Workspace {!r} not found".format(workspace_name))
    
project = api.project.get_info_by_name(workspace.id, project_name)
if project is None:
    raise RuntimeError("Project {!r} not found".format(project_name))
    
print("Team: id={}, name={}".format(team.id, team.name))
print("Workspace: id={}, name={}".format(workspace.id, workspace.name))
print("Project: id={}, name={}".format(project.id, project.name))

Team: id=18450, name=The AI Company
Workspace: id=26283, name=Journal
Project: id=59049, name=DL AC test


## Get Project Meta of Source Project

Project Meta contains information about classes and tags# Get source project meta

In [5]:
meta_json = api.project.get_meta(project.id)
meta = sly.ProjectMeta.from_json(meta_json)

# check that all classes exist
project_classes_names = list(classes_mapping.keys()) + list(classes_mapping.values())

for class_name in project_classes_names:
    if class_name not in meta.obj_classes.keys():
        raise RuntimeError("Class {!r} not found in source project {!r}".format(class_name, project.name))

## Iterate over all images, and calculate metric by annotations pairs

In [7]:
def safe_ratio(num, denom):
    return (num / denom) if denom != 0 else -1

def get_intersection(mask_1, mask_2):
    return (mask_1 & mask_2).sum()


def get_union(mask_1, mask_2):
    return (mask_1 | mask_2).sum()


def get_iou(mask_1, mask_2):
    return safe_ratio(get_intersection(mask_1, mask_2), get_union(mask_1, mask_2))

def _render_labels_for_class_name(labels, class_name, canvas):
    for label in labels:
        if label.obj_class.name == class_name:
            label.geometry.draw(canvas, True)

ious = {}
def add_pair(ann_gt, ann_pred):
    img_size = ann_gt.img_size
    for cls_gt, cls_pred in classes_mapping.items():
        mask_gt, mask_pred = np.full(img_size, False), np.full(img_size, False)
        _render_labels_for_class_name(ann_gt.labels, cls_gt, mask_gt)
        _render_labels_for_class_name(ann_pred.labels, cls_pred, mask_pred)
        iou = get_iou(mask_gt, mask_pred)
        if (cls_gt,cls_pred) not in ious:
            ious[(cls_gt,cls_pred)] = []
        if iou != -1:
            ious[(cls_gt,cls_pred)].append(iou)
        
for dataset in api.dataset.get_list(project.id):
    images = api.image.get_list(dataset.id)        
    for batch in sly.batched(images):
        image_ids = [image_info.id for image_info in batch]
        ann_infos = api.annotation.download_batch(dataset.id, image_ids)

        for ann_info in ann_infos:
            ann = sly.Annotation.from_json(ann_info.annotation, meta)
            add_pair(ann, ann)

print(ious)
        

{('Fabric', 'Fabric_dl'): [0.7522791650883671, 0.8952347965013777, 0.8899784932430264, 0.7758187619183939, 0.8267382717683922, 0.808839704732108, 0.9376385600313693, 0.8814706105614737, 0.8884434250948617, 0.8083127782030564, 0.7301220661392475, 0.9666791283913694, 0.8487230190084586, 0.9119832994086292, 0.8419801501716737, 0.8180559940672428, 0.8644226360853996, 0.8920362115711594, 0.8429583086960244, 0.8438082022772762, 0.9195160885993249, 0.8841620461925019, 0.8725760321168301, 0.8654411209340626, 0.929344268723437, 0.8864008887447604, 0.8295237316521651, 0.893406293821594, 0.8072006431837596, 0.9034449786017978, 0.7304002880458875], ('Gripper', 'Gripper_dl'): [0.9299777546469945, 0.9580294705898469, 0.9355698927511685, 0.9568031580952535, 0.9563680616882417, 0.9463951220289374, 0.9209468889954032, 0.8179614905166548, 0.8787132615142075, 0.9397630467637611, 0.9156924301294964, 0.9437814529864982, 0.9203895765210937, 0.9503679764562528, 0.9408228272097435, 0.9287592396731327, 0.96425

## Print results manually

In [8]:
table = PrettyTable(["classes pair", "metrics values"])

def build_values_text(values):
    return "iou: count: {} mean: {}, std: {}".format(len(values), np.mean(values), np.std(values))
    
for classes, values in ious.items():
    pair_text = "{} <-> {}".format(classes[0], classes[1])
    table.add_row([pair_text, build_values_text(values)])

print(table.get_string())

+------------------------+--------------------------------------------------------------------+
|      classes pair      |                           metrics values                           |
+------------------------+--------------------------------------------------------------------+
|  Fabric <-> Fabric_dl  | iou: count: 31 mean: 0.8563529020508075, std: 0.05741958630369401  |
| Gripper <-> Gripper_dl | iou: count: 31 mean: 0.9236704923370911, std: 0.03281281307350667  |
| Wrinkle <-> Wrinkle_dl | iou: count: 31 mean: 0.40374091956920627, std: 0.16848883174247614 |
+------------------------+--------------------------------------------------------------------+


# Done!