
-----
**General Deep Learning and CV Resources**

* [Pytorch docs](https://pytorch.org/docs/stable/index.html)
* [Torchvision docs](https://pytorch.org/vision/stable/index.html)
* [Intro to Pytorch in Google Colab](https://medium.com/dair-ai/pytorch-1-2-quickstart-with-google-colab-6690a30c38d)
* [Intro to Pytorch YouTube Series](https://www.youtube.com/playlist?list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz)
* [Pytorch Common Mistakes](https://www.youtube.com/watch?v=O2wJ3tkc-TU)
* [Intro to Pytorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction.html)
* [Debugging with Pytorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/common/debugging.html)
* [Deep Learning for Computer Vision](https://www.youtube.com/watch?v=dJYGatp4SvA&list=PLzfU85Pe_4tM17CM5NQ1yMdCdXeD9vBNH&index=6)
* [Definition of Classification vs Detection vs Segmentation](https://www.clarifai.com/blog/classification-vs-detection-vs-segmentation-models-the-differences-between-them-and-how-each-impact-your-results)

-----
**UNET Resources**
* [Original UNET Paper](https://arxiv.org/abs/1505.04597)
* [UNET FROM SCRATCH](https://www.youtube.com/watch?v=IHq1t7NxS8k)
* [UNET Implimentation Source](https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/image_segmentation/semantic_segmentation_unet/model.py)
* [Why does the number of features double for each layer of UNET?](https://www.quora.com/Why-does-the-UNET-double-the-number-of-feature-channels-after-each-Maxpooling-layer)
* [Why do we need skip connections in the UNET architecture?](https://www.analyticsvidhya.com/blog/2021/08/all-you-need-to-know-about-skip-connections/#:~:text=Skip%20Connections%20(or%20Shortcut%20Connections,different%20problems%20in%20different%20architectures.)




## 1.&nbsp;Imports

In [1]:
# Deep learning imports.
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset

import torchvision
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks, make_grid
from torchvision.ops import masks_to_boxes
import torchvision.transforms.functional as TF
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

import torchmetrics

# Standard imports.
from typing import List, Union
import gc
import matplotlib.pyplot as plt
import numpy as np
import cv2

# Necessary for creating our images.
from skimage.draw import line_aa

# Interactive widgets for data viz.
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

  "class": algorithms.Blowfish,


## 2.&nbsp;Pytorch Dataset

**base classes (identical in each tutorial)**

In [2]:
class Draw:
    """
    Class used to draw shapes onto images. Methods return coordinates of
    corresponding shape on a 2d np array of shape (img_size, img_size).
    The np rng is used for enabling derministic behaviour.

    Args:
    img_size (int): draws onto 2d array of shape (img_size, img_size).
    rng (Generator): used for enabling deterministic behaviour. Example
        of valid rng: rng = np.random.default_rng(12345)

    """

    def __init__(self, img_size, rng):

        self.img_size = img_size
        self.rng = rng

        return None

    def rectangle(self):
        """
        Returns the image coordinates of a rectangle.
        """
        # Min and max rectangle height.
        a = self.rng.integers(self.img_size * 3 / 20, self.img_size * 7 / 20)
        b = self.rng.integers(self.img_size * 3 / 20, self.img_size * 7 / 20)

        # Initial coordinates of the rectangle.
        xx, yy = np.where(np.ones((a, b)) == 1)

        # Place the rectangle randomly in the image.
        cx = self.rng.integers(0 + a, self.img_size - a)
        cy = self.rng.integers(0 + b, self.img_size - b)

        rectangle = xx + cx, yy + cy

        return rectangle

    def line(self):
        """
        Returns the image coordinates of a line.
        """
        # Randomly choose the start and end coordinates.
        a, b = self.rng.integers(0, self.img_size / 3, size=2)
        c, d = self.rng.integers(self.img_size / 2, self.img_size, size=2)

        # Flip a coin to see if slope of line is + or -.
        coin_flip = self.rng.integers(low=0, high=2)
        # Use a skimage.draw method to draw the line.
        if coin_flip:
            xx, yy, _ = line_aa(a, b, c, d)
        else:
            xx, yy, _ = line_aa(a, d, c, b)

        line = xx, yy

        return line

    def donut(self):
        """
        Returns the image coordinates of an elliptical donut.
        """
        # Define a grid
        xx, yy = np.mgrid[: self.img_size, : self.img_size]
        cx = self.rng.integers(0, self.img_size)
        cy = self.rng.integers(0, self.img_size)

        # Define the width of the donut.
        width = self.rng.uniform(self.img_size / 3, self.img_size)

        # Give the donut some elliptical nature.
        e0 = self.rng.uniform(0.1, 5)
        e1 = self.rng.uniform(0.1, 5)

        # Use the forumula for an ellipse.
        ellipse = e0 * (xx - cx) ** 2 + e1 * (yy - cy) ** 2

        donut = (ellipse < (self.img_size + width)) & (
            ellipse > (self.img_size - width)
        )

        return donut


class CV_DS_Base(torch.utils.data.Dataset):
    """
    Base class for a set of PyTorch computer vision datasets. This class
    contains all of the attributes and methods common to all datasets
    in this package.
    Alone this base class has no functionality. The utility of these datasets
    is that they enable the user to test cv models with very small and
    simple images with tunable complexity. It also requires no downloading
    of images and one can scale the size of the datasets easily.

    Args:
        ds_size (int): number of images in dataset.
        img_size (int): will build images of shape (3, img_size, img_size).
        shapes_per_image (Tuple[int, int]): will produce images containing
            minimum number of shapes Tuple[0] and maximum number of shapes
            Tuple[1]. For example shapes_per_image = (2,2) would create a
            dataset where each image contains exactly two shapes.
        class_probs (Tuple[float, float, float]): relative probability of
            each shape occuring in an image. Need not sum to 1. For example
            class_probs = (1,1,0) will create a dataset with 50% class 1
            shapes, 50% class 2 shapes, 0% class 3 shapes.
        rand_seed (int): used to instantiate a numpy random number generator.
        class_map (Dict[Dict]): the class map must contain keys (0,1,2,3)
            and contain names "background", "rectangle", "line", and "donut".
            "gs_range" specifies the upper and lower bound of the
            grayscale values (0, 255) used to color the shapes.
            "target_color" can be used by visualization tools to assign
            a color to masks and boxes. Note that class 0 is reserved for
            background in most instance seg models, so one can rearrange
            the class assignments of different shapes but 0 must correspond
            to "background". The utility of this Dict is to enable the user
            to change target colors, class assignments, and shape
            intensities. A valid example:
            class_map={
            0: {"name": "background","gs_range": (200, 255),"target_color": (255, 255, 255),},
            1: {"name": "rectangle", "gs_range": (0, 100), "target_color": (255, 0, 0)},
            2: {"name": "line", "gs_range": (0, 100), "target_color": (0, 255, 0)},
            3: {"name": "donut", "gs_range": (0, 100), "target_color": (0, 0, 255)}}.
    """

    def __init__(
        self,
        ds_size,
        img_size,
        shapes_per_image,
        class_probs,
        rand_seed,
        class_map,
    ):

        if img_size <= 20:
            raise ValueError(
                "Different shapes are hard to distinguish for images of shape (3, 20, 20) or smaller."
            )

        if sorted(list(class_map.keys())) != sorted([0, 1, 2, 3]):
            raise ValueError("Dict class_map must contain keys 0,1,2,3.")

        self.ds_size = ds_size
        self.img_size = img_size
        self.rand_seed = rand_seed
        self.shapes_per_image = shapes_per_image
        self.class_probs = np.array(class_probs) / np.array(class_probs).sum()
        self.class_map = class_map

        self.rng = np.random.default_rng(self.rand_seed)

        self.draw = Draw(self.img_size, self.rng)

        self.class_ids = np.array([1, 2, 3])
        self.num_shapes_per_img = self.rng.integers(
            low=self.shapes_per_image[0],
            high=self.shapes_per_image[1] + 1,
            size=self.ds_size,
        )

        self.chosen_ids_per_img = [self.rng.choice(
                a=self.class_ids, size=num_shapes, p=self.class_probs
            ) for num_shapes in self.num_shapes_per_img]

        self.imgs, self.targets = [], []

        return None

    def __getitem__(self, idx):

        return self.imgs[idx], self.targets[idx]

    def __len__(self):

        return len(self.imgs)

    def draw_shape(self, class_id):
        """
        Draws the shape with the associated class_id.
        """

        if self.class_map[class_id]["name"] == "rectangle":
            shape = self.draw.rectangle()
        elif self.class_map[class_id]["name"] == "line":
            shape = self.draw.line()
        elif self.class_map[class_id]["name"] == "donut":
            shape = self.draw.donut()

        else:
            raise ValueError(
                "You must include rectangle, donut, and line in your class_map."
            )

        return shape

**image segmentation dataset**

In [3]:
class ImageSegmentation_DS(CV_DS_Base):
    """
    Self contained PyTorch Dataset for testing image segmentation models.


    Args:
        ds_size (int): number of images in dataset.
        img_size (int): will build images of shape (3, img_size, img_size).
        shapes_per_image (Tuple[int, int]): will produce images containing
            minimum number of shapes Tuple[0] and maximum number of shapes
            Tuple[1]. For example shapes_per_image = (2,2) would create a
            dataset where each image contains exactly two shapes.
        class_probs (Tuple[float, float, float]): relative probability of
            each shape occuring in an image. Need not sum to 1. For example
            class_probs = (1,1,0) will create a dataset with 50% class 1
            shapes, 50% class 2 shapes, 0% class 3 shapes.
        rand_seed (int): used to instantiate a numpy random number generator.
        class_map (Dict[Dict]): the class map must contain keys (0,1,2,3)
            and contain names "background", "rectangle", "line", and "donut".
            "gs_range" specifies the upper and lower bound of the
            grayscale values (0, 255) used to color the shapes.
            "target_color" can be used by visualization tools to assign
            a color to masks and boxes. Note that class 0 is reserved for
            background in most instance seg models, so one can rearrange
            the class assignments of different shapes but 0 must correspond
            to "background". The utility of this Dict is to enable the user
            to change target colors, class assignments, and shape
            intensities. A valid example:
            class_map={
            0: {"name": "background","gs_range": (200, 255),"target_color": (255, 255, 255),},
            1: {"name": "rectangle", "gs_range": (0, 100), "target_color": (255, 0, 0)},
            2: {"name": "line", "gs_range": (0, 100), "target_color": (0, 255, 0)},
            3: {"name": "donut", "gs_range": (0, 100), "target_color": (0, 0, 255)}}.
    """

    def __init__(
        self,
        ds_size=100,
        img_size=256,
        shapes_per_image=(1, 3),
        class_probs=(1, 1, 1),
        rand_seed=12345,
        class_map={
            0: {
                "name": "background",
                "gs_range": (200, 255),
                "target_color": (255, 255, 255),
            },
            1: {"name": "rectangle", "gs_range": (0, 100), "target_color": (255, 0, 0)},
            2: {"name": "line", "gs_range": (0, 100), "target_color": (0, 255, 0)},
            3: {"name": "donut", "gs_range": (0, 100), "target_color": (0, 0, 255)},
        },
        object_count=False,
    ):

        super().__init__(
            ds_size, img_size, shapes_per_image, class_probs, rand_seed, class_map
        )

        self.imgs, self.targets = self.build_imgs_and_targets()

    def build_imgs_and_targets(self):

        """
        Builds images and targets.

        Returns:
            imgs (torch.UInt8Tensor[ds_size, 3, img_size, img_size]): images
                containing different shapes. The images are gray-scale
                (each layer of the first (color) dimension is identical).
                This makes it easier to visualize targets and predictions.
            targets (torch.long[ds_size, img_size, img_size]):
                corresponding segmentation labels. Each pixel is assigned
                a class label, with class_id = 0 reserved for background.
        """
        imgs = []
        targets = []

        for idx in range(self.ds_size):

            chosen_ids = self.chosen_ids_per_img[idx]

            # Creating an empty noisy img.
            img = self.rng.integers(
                self.class_map[0]["gs_range"][0],
                self.class_map[0]["gs_range"][1],
                (self.img_size, self.img_size),
            )

            # Initially all pixels are labeled zero (background).
            target = np.zeros((self.img_size, self.img_size))

            # Filling the noisy img with shapes and building targets.
            for i, class_id in enumerate(chosen_ids):
                shape = self.draw_shape(class_id)
                gs_range = self.class_map[class_id]["gs_range"]
                img[shape] = self.rng.integers(
                    gs_range[0], gs_range[1], img[shape].shape
                )
                target[shape] = class_id

            # Convert from np to torch and assign appropriate dtypes.
            img = torch.from_numpy(img)
            img = img.unsqueeze(dim=0).repeat(3, 1, 1).type(torch.ByteTensor)

            target = torch.from_numpy(target).long()
            imgs.append(img)
            targets.append(target)

        # Turn a list of imgs with shape (3, H, W) of len ds_size to a tensor
        # of shape (ds_size, 3, H, W)
        imgs = torch.stack(imgs)
        targets = torch.stack(targets)

        return imgs, targets

## 3.&nbsp;Pytorch Lightning Data Module

In [4]:
class ImageSegmentation_DM(pl.LightningDataModule):
    """
    Self contained PyTorch Lightning DataModule for testing image
    segmentation models with PyTorch Lightning. Uses the torch dataset
    ImageSegmentation_DS.

    Args:
        train_val_size (int): total size of the training and validation
            sets combined.
        train_val_split (Tuple[float, float]): should sum to 1.0. For example
            if train_val_size = 100 and train_val_split = (0.80, 0.20)
            then the training set will contain 80 imgs and the validation
            set will contain 20 imgs.
        test_size (int): the size of the test data set.
        batch_size (int): batch size to be input to dataloaders. Applies
            for training, val, and test datasets.
        dataloader_shuffle (Dict): whether or not to shuffle for each of
            the three dataloaders. Dict must contain the keys: "train",
            "val", "test".
        img_size (int): will build images of shape (3, img_size, img_size).
        shapes_per_image (Tuple[int, int]): will produce images containing
            minimum number of shapes Tuple[0] and maximum number of shapes
            Tuple[1]. For example shapes_per_image = (2,2) would create a
            dataset where each image contains exactly two shapes.
        class_probs (Tuple[float, float, float]): relative probability of
            each shape occuring in an image. Need not sum to 1. For example
            class_probs = (1,1,0) will create a dataset with 50% class 1
            shapes, 50% class 2 shapes, 0% class 3 shapes.
        rand_seed (int): used to instantiate a numpy rng.
        class_map (Dict[Dict]): the class map must contain keys (0,1,2,3)
            and contain names "background", "rectangle", "line", and "donut".
            "gs_range" specifies the upper and lower bound of the
            grayscale values (0, 255) used to color the shapes.
            "target_color" can be used by visualization tools to assign
            a color to masks and boxes. Note that class 0 is reserved for
            background in most instance seg models, so one can rearrange
            the class assignments of different shapes but 0 must correspond
            to "background". The utility of this Dict is to enable the user
            to change target colors, class assignments, and shape
            intensities. A valid example:
            class_map={
            0: {"name": "background","gs_range": (200, 255),"target_color": (255, 255, 255),},
            1: {"name": "rectangle", "gs_range": (0, 100), "target_color": (255, 0, 0)},
            2: {"name": "line", "gs_range": (0, 100), "target_color": (0, 255, 0)},
            3: {"name": "donut", "gs_range": (0, 100), "target_color": (0, 0, 255)}}.
    """

    def __init__(
        self,
        train_val_size=100,
        train_val_split=(0.80, 0.20),
        test_size=10,
        batch_size=8,
        dataloader_shuffle={"train": True, "val": False, "test": False},
        img_size=50,
        shapes_per_image=(1, 3),
        class_probs=(1, 1, 1),
        rand_seed=12345,
        class_map={
            0: {
                "name": "background",
                "gs_range": (200, 255),
                "target_color": (255, 255, 255),
            },
            1: {"name": "rectangle", "gs_range": (0, 100), "target_color": (255, 0, 0)},
            2: {"name": "line", "gs_range": (0, 100), "target_color": (0, 255, 0)},
            3: {"name": "donut", "gs_range": (0, 100), "target_color": (0, 0, 255)},
        },
    ):

        super().__init__()

        if sorted(list(dataloader_shuffle.keys())) != sorted(["train", "val", "test"]):
            raise ValueError(
                "Dict dataloader_shuffle must contain the keys: train, val, test."
            )
        # Attributes to define datamodule.
        self.train_val_size = train_val_size
        self.train_val_split = np.array(train_val_split)
        self.train_val_sizes = np.array(
            self.train_val_size * self.train_val_split, dtype=int
        )
        self.test_size = test_size
        self.batch_size = batch_size
        self.dataloader_shuffle = dataloader_shuffle

        # Attributes to define dataset.
        self.img_size = img_size
        self.rand_seed = rand_seed
        self.shapes_per_image = shapes_per_image
        self.class_probs = class_probs
        self.class_map = class_map

    def setup(self, stage):
        if stage == "fit" or stage is None:
            print("Setting up fit stage.")

            self.train = ImageSegmentation_DS(
                ds_size=self.train_val_sizes[0],
                img_size=self.img_size,
                rand_seed=self.rand_seed,
                shapes_per_image=self.shapes_per_image,
                class_probs=self.class_probs,
                class_map=self.class_map,
            )
            self.val = ImageSegmentation_DS(
                ds_size=self.train_val_sizes[1],
                img_size=self.img_size,
                rand_seed=self.rand_seed + 111,
                shapes_per_image=self.shapes_per_image,
                class_probs=self.class_probs,
                class_map=self.class_map,
            )

        if stage == "test" or stage is None:
            print("Setting up test stage.")

            self.test = ImageSegmentation_DS(
                ds_size=self.test_size,
                img_size=self.img_size,
                rand_seed=self.rand_seed + 222,
                shapes_per_image=self.shapes_per_image,
                class_probs=self.class_probs,
                class_map=self.class_map,
            )

        return None

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size=self.batch_size,
            shuffle=self.dataloader_shuffle["train"],
        )

    def val_dataloader(self):
        return DataLoader(
            self.val,
            batch_size=self.batch_size,
            shuffle=self.dataloader_shuffle["val"],
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            batch_size=self.batch_size,
            shuffle=self.dataloader_shuffle["test"],
        )


## 4.&nbsp;Visualize Data

**show images**

In [5]:
def show(imgs, figsize=(10.0, 10.0)):
    """Displays a single image or list of images. Taken more or less from
    the pytorch docs:
    https://pytorch.org/vision/main/auto_examples/plot_visualization_utils.html#visualizing-a-grid-of-images

    Args:
        imgs (Union[List[torch.Tensor], torch.Tensor]): A list of images
            of shape (3, H, W) or a single image of shape (3, H, W).
        figsize (Tuple[float, float]): size of figure to display.

    Returns:
        None
    """

    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), figsize=figsize, squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = TF.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

    return None

**example of show function**

In [13]:
# imgs = [torch.rand((3, 5,10)),torch.rand((3, 10, 5)), torch.rand((3, 5, 10))]
# show(imgs, figsize = (5, 5) )

**labels_to_masks**

In [7]:
def labels_to_masks(labels):
    """Converts  a batch of segmentation labels into binary masks. Used
    with UNET or in other image segmentation tasks. This function works
    for both batches of labels or single (2d) image labels. The Args and
    return descriptions assume a full batch is input.

    Args:
        labels (torch.int64[batch_size, H, W]): A batch of segmentation
            labels. Each pixel is assigned a class (an integer value).

    Returns:
    binary_masks (torch.bool[batch_size, num_obj_ids, H, W]): A batch of
        corresponding binary masks. Layer i (of dim = 1) corresponds to
        a binary mask for class i. The total number of binary masks will
        be the number of unique object ids (num_obj_ids).
    """

    obj_ids = labels.unique()
    if labels.dim() == 2:
        masks = labels == obj_ids[:, None, None]

    if labels.dim() == 3:
        masks = (labels == obj_ids[:, None, None, None]).permute(1, 0, 2, 3)

    return masks

**understanding labels_to_masks**

In [14]:
# labels = torch.tensor([2,1,0,1,1]).long()
# print(f"labels:\n{labels}\n")
# print(f"labels.shape: \n{labels.shape}\n")
# obj_ids = labels.unique()
# print(f"obj_ids:\n{obj_ids}\n")
# print(f"obj_ids.shape:\n{obj_ids.shape}\n")
# print(f"obj_ids[:,None]:\n{obj_ids[:,None]}\n")
# print(f"obj_ids[:,None].shape:\n{obj_ids[:,None].shape}\n")
# masks = (labels == obj_ids[:,None])
# print(f"masks:\n{masks}\n")
# print(f"masks.shape:\n{masks.shape}\n")

**example of labels_to_masks**

In [15]:
# labels = torch.randint(high = 4, size = (32, 3, 3))
# binary_masks = labels_to_masks(labels)

# print(f"labels.shape:\n{labels.shape}\n")
# print(f"labels[0]:\n{labels[0]}\n")

# print(f"binary_masks.shape:\n{binary_masks.shape}\n")
# print(f"binary_masks[0]:\n{binary_masks[0]}\n")

**visualization functions**

In [10]:
def display_masks_unet(imgs, masks, class_map, alpha=0.4):
    """Takes a batch of images and a batch of masks of the same length and
    overlays the images with the masks using the "target_color" specified
    in the class_map.

    Args:
        imgs (List[torch.ByteTensor[batch_size, 3, H, W]]): a batch of
            images of shape (batch_size, 3, H, W).
        masks (torch.bool[batch_size, num_masks, H, W]]): a batch of
            corresponding boolean masks.
        class_map (Dict[Dict]): the class map must contain keys that
            correspond to the labels provided. Inner Dict must contain
            key "target_color". class 0 is reserved for background.
            A valid example ("name" not necessary):
            class_map={
            0: {"name": "background","target_color": (255, 255, 255),},
            1: {"name": "rectangle", "target_color": (255, 0, 0)},
            2: {"name": "line", "target_color": (0, 255, 0)},
            3: {"name": "donut", "target_color": (0, 0, 255)}}.
        alpha (float): transparnecy of masks. In range (0-1).

    Returns:
        result_imgs (List[torch.ByteTensor[3, H, W]]]): list of images
            with overlaid segmentation masks.
    """
    num_imgs = len(imgs)

    result_imgs = [
        draw_segmentation_masks(
            imgs[i],
            masks[i],
            alpha=0.4,
            colors=[class_map[j]["target_color"] for j in list(class_map.keys())],
        )
        for i in range(num_imgs)
    ]

    return result_imgs

**visualize a batch of images and targets**

In [12]:
%matplotlib inline
style = {'description_width': 'initial'}

@interact
def vizualize_label_targets(img_size = widgets.IntSlider(style= style,value=128,min=20,max=500,step=1, description = "img_size"),
                            shapes_per_img_lo = widgets.IntSlider(style=style,value=1,min=0,max=10,step=1, description = "shapes_per_img_lo"),
                            shapes_per_img_hi = widgets.IntSlider(style= style, value=4,min=0,max=10,step=1, description = "shapes_per_img_hi"),
                            class_prob_1 =  widgets.IntSlider(style= style,value=1,min=0,max=10,step=1, description = "class_prob_1"),
                            class_prob_2 =  widgets.IntSlider(style=style,value=1,min=0,max=10,step=1, description = "class_prob_2"),
                            class_prob_3 =  widgets.IntSlider(style= style,value=1,min=0,max=10,step=1, description = "class_prob_3"),
                            gs_range_0_lo = widgets.IntSlider(style= style,value=200,min=0,max=255,step=1, description = "gs_range_0_lo"),
                            gs_range_0_hi = widgets.IntSlider(style= style,value=255,min=0,max=255,step=1, description = "gs_range_0_hi"),
                            gs_range_1_lo = widgets.IntSlider(style= style,value=0,min=0,max=255,step=1, description = "gs_range_1_lo"),
                            gs_range_1_hi = widgets.IntSlider(style= style,value=100,min=0,max=255,step=1, description = "gs_range_1_hi"),
                            rand_seed = widgets.IntSlider(style= style,value=1232,min=1230,max=1239,step=1, description = "rand_seed"),
                            display_num_imgs= widgets.IntSlider(style= style,value=6,min=0,max=8,step=1, description = "display_num_imgs"),
                            display_size = widgets.IntSlider(style= style, value=30,min=5,max=50,step=1),
                            show_labels = widgets.Checkbox(style= style,value=False,description='target masks'),
                            ):


    img_seg_dm = ImageSegmentation_DM(
                                       train_val_size = 20,
                                       img_size = img_size,
                                       train_val_split = (.2,.8),
                                       shapes_per_image = (shapes_per_img_lo, shapes_per_img_hi),
                                       class_probs=(class_prob_1, class_prob_2, class_prob_3),
                                       rand_seed=rand_seed,
                                       class_map={
                                            0: {"name": "background", "gs_range": (gs_range_0_lo, gs_range_0_hi), "target_color": (255,  255,  255)},
                                            1: {"name": "rectangle", "gs_range": (gs_range_1_lo, gs_range_1_hi), "target_color": (255, 0, 0)},
                                            2: {"name": "line", "gs_range": (0,50), "target_color": (0, 255, 0)},
                                            3: {"name": "donut", "gs_range": (0,50), "target_color": (0, 0, 255)},
                                        },
                                        dataloader_shuffle={"train": False, "val": False, "test": False},
                                       )
    # TODO: Consistency here on val or training.

    # Visualize and understand some random images
    img_seg_dm.setup(stage = "fit")
    dataiter = iter(img_seg_dm.val_dataloader())

    imgs, labels = next(dataiter)
    imgs = imgs[:display_num_imgs]
    labels = labels[:display_num_imgs]
    masks = labels_to_masks(labels)


    result_images = [imgs[i] for i in range(display_num_imgs)]

    if show_labels:
        result_images = display_masks_unet(imgs, masks, img_seg_dm.class_map)

    grid = make_grid(result_images)
    show(grid, figsize = (display_size, display_size))

interactive(children=(IntSlider(value=128, description='img_size', max=500, min=20, style=SliderStyle(descript…