## 1.&nbsp;Imports


In [1]:
%%capture
!pip install torch
!pip install torchtext
!pip install torchvision
!pip install pytorch-lightning
!pip install pytorch-lightning-bolts
!pip install torchmetrics
!pip install matplotlib
!pip install numpy
!pip install ipywidgets

In [2]:
import torch
torch.__version__

'2.0.1'

In [3]:
# Deep learning imports.
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset

import torchvision
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks, make_grid
from torchvision.ops import masks_to_boxes
import torchvision.transforms.functional as TF
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

import torchmetrics

# Standard imports.
from typing import List, Union
import gc
import matplotlib.pyplot as plt
import numpy as np
import cv2

# Necessary for creating our images.
from skimage.draw import line_aa

# Interactive widgets for data viz.
import ipywidgets as widgets
from ipywidgets import interact

  "class": algorithms.Blowfish,


## 2.&nbsp;Pytorch Dataset

**base classes (identical in each tutorial)**

In [4]:
class Draw:

    def __init__(self, img_size, rng):

        self.img_size = img_size
        self.rng = rng

        return None

    def rectangle(self):

        # Min and max rectangle height.
        a = self.rng.integers(self.img_size * 3 / 20, self.img_size * 7 / 20)
        b = self.rng.integers(self.img_size * 3 / 20, self.img_size * 7 / 20)

        # Initial coordinates of the rectangle.
        xx, yy = np.where(np.ones((a, b)) == 1)

        # Place the rectangle randomly in the image.
        cx = self.rng.integers(0 + a, self.img_size - a)
        cy = self.rng.integers(0 + b, self.img_size - b)

        rectangle = xx + cx, yy + cy

        return rectangle

    def line(self):

        # Randomly choose the start and end coordinates.
        a, b = self.rng.integers(0, self.img_size / 3, size=2)
        c, d = self.rng.integers(self.img_size / 2, self.img_size, size=2)

        # Flip a coin to see if slope of line is + or -.
        coin_flip = self.rng.integers(low=0, high=2)
        # Use a skimage.draw method to draw the line.
        if coin_flip:
            xx, yy, _ = line_aa(a, b, c, d)
        else:
            xx, yy, _ = line_aa(a, d, c, b)

        line = xx, yy

        return line

    def donut(self):
        """
        Returns the image coordinates of an elliptical donut.
        """
        # Define a grid
        xx, yy = np.mgrid[: self.img_size, : self.img_size]
        cx = self.rng.integers(0, self.img_size)
        cy = self.rng.integers(0, self.img_size)

        # Define the width of the donut.
        width = self.rng.uniform(self.img_size / 3, self.img_size)

        # Give the donut some elliptical nature.
        e0 = self.rng.uniform(0.1, 5)
        e1 = self.rng.uniform(0.1, 5)

        # Use the forumula for an ellipse.
        ellipse = e0 * (xx - cx) ** 2 + e1 * (yy - cy) ** 2

        donut = (ellipse < (self.img_size + width)) & (
            ellipse > (self.img_size - width)
        )

        return donut

class CV_DS_Base(torch.utils.data.Dataset):

    def __init__(
        self,
        ds_size,
        img_size,
        shapes_per_image,
        class_probs,
        rand_seed,
        class_map,
    ):

        if img_size <= 20:
            raise ValueError(
                "Different shapes are hard to distinguish for images of shape (3, 20, 20) or smaller."
            )

        if sorted(list(class_map.keys())) != sorted([0, 1, 2, 3]):
            raise ValueError("Dict class_map must contain keys 0,1,2,3.")

        self.ds_size = ds_size
        self.img_size = img_size
        self.rand_seed = rand_seed
        self.shapes_per_image = shapes_per_image
        self.class_probs = np.array(class_probs) / np.array(class_probs).sum()
        self.class_map = class_map

        self.rng = np.random.default_rng(self.rand_seed)

        self.draw = Draw(self.img_size, self.rng)

        self.class_ids = np.array([1, 2, 3])
        self.num_shapes_per_img = self.rng.integers(
            low=self.shapes_per_image[0],
            high=self.shapes_per_image[1] + 1,
            size=self.ds_size,
        )

        self.chosen_ids_per_img = [self.rng.choice(
                a=self.class_ids, size=num_shapes, p=self.class_probs
            ) for num_shapes in self.num_shapes_per_img]

        self.imgs, self.targets = [], []

        return None

    def __getitem__(self, idx):

        return self.imgs[idx], self.targets[idx]

    def __len__(self):

        return len(self.imgs)

    def draw_shape(self, class_id):
        """
        Draws the shape with the associated class_id.
        """

        if self.class_map[class_id]["name"] == "rectangle":
            shape = self.draw.rectangle()
        elif self.class_map[class_id]["name"] == "line":
            shape = self.draw.line()
        elif self.class_map[class_id]["name"] == "donut":
            shape = self.draw.donut()

        else:
            raise ValueError(
                "You must include rectangle, donut, and line in your class_map."
            )

        return shape

**image classification and object counting dataset**

In [5]:
class ObjectCounting_DS(CV_DS_Base):


    def __init__(
        self,
        ds_size=100,
        img_size=256,
        shapes_per_image=(1, 3),
        class_probs=(1, 1, 1),
        rand_seed=12345,
        class_map={
            0: {
                "name": "background",
                "gs_range": (200, 255),
                "target_color": (255, 255, 255),
            },
            1: {"name": "rectangle", "gs_range": (0, 100), "target_color": (255, 0, 0)},
            2: {"name": "line", "gs_range": (0, 100), "target_color": (0, 255, 0)},
            3: {"name": "donut", "gs_range": (0, 100), "target_color": (0, 0, 255)},
        },
        object_count=False,
    ):

        super().__init__(
            ds_size, img_size, shapes_per_image, class_probs, rand_seed, class_map
        )

        self.object_count = object_count
        self.num_classes = np.max(np.nonzero(self.class_probs)) + 1

        self.imgs, self.targets = self.build_imgs_and_targets()

    def build_imgs_and_targets(self):

        imgs = []
        targets = []

        for idx in range(self.ds_size):

            chosen_ids = self.chosen_ids_per_img[idx]

            # Creating an empty noisy img.
            img = self.rng.integers(
                self.class_map[0]["gs_range"][0],
                self.class_map[0]["gs_range"][1],
                (self.img_size, self.img_size),
            )

            target = np.zeros(self.num_classes)

            # Filling the noisy img with shapes and building targets.
            for i, class_id in enumerate(chosen_ids):
                shape = self.draw_shape(class_id)
                gs_range = self.class_map[class_id]["gs_range"]
                img[shape] = self.rng.integers(
                    gs_range[0], gs_range[1], img[shape].shape
                )
                target[class_id - 1] += 1

            # Convert from np to torch and assign appropriate dtypes.
            img = torch.from_numpy(img)
            img = img.unsqueeze(dim=0).repeat(3, 1, 1).type(torch.ByteTensor)

            target = torch.from_numpy(target).long()
            if not self.object_count:
                target = torch.clamp(target, max=1)
            imgs.append(img)
            targets.append(target)

        # Turn a list of imgs with shape (3, H, W) of len ds_size to a tensor
        # of shape (ds_size, 3, H, W)
        imgs = torch.stack(imgs)
        targets = torch.stack(targets)

        return imgs, targets

**look at the dataset**

In [6]:
# obj_counting_ds = ObjectCounting_DS(ds_size = 3, shapes_per_image = (0, 9), object_count = True)

# for key, value in obj_counting_ds.__dict__.items():
#     print(key, ' : ', value)

## 3.&nbsp;Pytorch Lightning Data Module

In [7]:
class ObjectCounting_DM(pl.LightningDataModule):


    def __init__(
        self,
        train_val_size=100,
        train_val_split=(0.80, 0.20),
        test_size=10,
        batch_size=8,
        dataloader_shuffle={"train": True, "val": False, "test": False},
        img_size=50,
        shapes_per_image=(1, 3),
        class_probs=(1, 1, 1),
        rand_seed=12345,
        class_map={
            0: {
                "name": "background",
                "gs_range": (200, 255),
                "target_color": (255, 255, 255),
            },
            1: {"name": "rectangle", "gs_range": (0, 100), "target_color": (255, 0, 0)},
            2: {"name": "line", "gs_range": (0, 100), "target_color": (0, 255, 0)},
            3: {"name": "donut", "gs_range": (0, 100), "target_color": (0, 0, 255)},
        },
        object_count=True,
    ):

        super().__init__()

        if sorted(list(dataloader_shuffle.keys())) != sorted(["train", "val", "test"]):
            raise ValueError(
                "Dict dataloader_shuffle must contain the keys: train, val, test."
            )
        # Attributes to define datamodule.
        self.train_val_size = train_val_size
        self.train_val_split = np.array(train_val_split)
        self.train_val_sizes = np.array(
            self.train_val_size * self.train_val_split, dtype=int
        )
        self.test_size = test_size
        self.batch_size = batch_size
        self.dataloader_shuffle = dataloader_shuffle

        # Attributes to define dataset.
        self.img_size = img_size
        self.rand_seed = rand_seed
        self.shapes_per_image = shapes_per_image
        self.class_probs = class_probs
        self.class_map = class_map
        self.object_count = object_count

    def setup(self, stage):
        if stage == "fit" or stage is None:
            print("Setting up fit stage.")

            self.train = ObjectCounting_DS(
                ds_size=self.train_val_sizes[0],
                img_size=self.img_size,
                rand_seed=self.rand_seed,
                shapes_per_image=self.shapes_per_image,
                class_probs=self.class_probs,
                class_map=self.class_map,
                object_count=self.object_count,
            )
            self.val = ObjectCounting_DS(
                ds_size=self.train_val_sizes[1],
                img_size=self.img_size,
                rand_seed=self.rand_seed + 111,
                shapes_per_image=self.shapes_per_image,
                class_probs=self.class_probs,
                class_map=self.class_map,
                object_count=self.object_count,
            )

        if stage == "test" or stage is None:
            print("Setting up test stage.")

            self.test = ObjectCounting_DS(
                ds_size=self.test_size,
                img_size=self.img_size,
                rand_seed=self.rand_seed + 222,
                shapes_per_image=self.shapes_per_image,
                class_probs=self.class_probs,
                class_map=self.class_map,
                object_count=self.object_count,
            )

        return None

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size=self.batch_size,
            shuffle=self.dataloader_shuffle["train"],
        )

    def val_dataloader(self):
        return DataLoader(
            self.val,
            batch_size=self.batch_size,
            shuffle=self.dataloader_shuffle["val"],
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            batch_size=self.batch_size,
            shuffle=self.dataloader_shuffle["test"],
        )


## 4.&nbsp;Visualize Data

**show images**

In [8]:
def show(imgs, figsize=(10.0, 10.0)):

    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), figsize=figsize, squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = TF.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

    return None

**add labels to images**

In [11]:
def add_labels(img, label, class_map, pred=False, object_count=False):

    img_size = img.shape[-1]
    img = img.permute(1, 2, 0).cpu().numpy().astype(np.uint8).copy()

    if label.sum() == 0:

        nonzero_classes = [0]
        label_colors = [class_map[0]["target_color"]]
        img_labels = ["background"]

    else:

        nonzero_classes = label.cpu().numpy().nonzero()[0] + 1
        label_colors = [class_map[indx]["target_color"] for indx in nonzero_classes]

        if object_count:
            img_labels = [
                class_map[indx]["name"] + ": {}".format(label[indx - 1])
                for indx in nonzero_classes
            ]
        else:
            img_labels = [class_map[indx]["name"] for indx in nonzero_classes]

    scaling_ratio = img_size / 512

    font = cv2.FONT_HERSHEY_SIMPLEX
    fontScale = 1 * scaling_ratio
    thickness = 1
    lineType = 1

    y0, x0, dy = 27 * scaling_ratio, 10 * scaling_ratio, 27 * scaling_ratio
    if pred:
        y0, x0 = 400 * scaling_ratio, 315 * scaling_ratio
        thickness = 2

    for i, (img_label, label_color, c) in enumerate(
        zip(img_labels, label_colors, nonzero_classes)
    ):
        y = y0 + c * dy
        fontColor = label_color

        cv2.putText(
            img,
            img_label,
            (int(x0), int(y)),
            font,
            fontScale,
            fontColor,
            thickness,
            lineType,
        )

    img = torch.from_numpy(img).permute(2, 0, 1)

    return img

**visualize a batch of images and targets**

In [13]:
%matplotlib inline
style = {'description_width': 'initial'}

@interact
def vizualize_label_targets(img_size = widgets.IntSlider(style= style,value=256,min=128,max=1000,step=1, description = "img_size"),
                            shapes_per_img_lo = widgets.IntSlider(style=style,value=3,min=0,max=10,step=1, description = "shapes_per_img_lo"),
                            shapes_per_img_hi = widgets.IntSlider(style= style, value=6,min=0,max=10,step=1, description = "shapes_per_img_hi"),
                            class_prob_1 =  widgets.IntSlider(style= style,value=1,min=0,max=10,step=1, description = "class_prob_1"),
                            class_prob_2 =  widgets.IntSlider(style=style,value=1,min=0,max=10,step=1, description = "class_prob_2"),
                            class_prob_3 =  widgets.IntSlider(style= style,value=1,min=0,max=10,step=1, description = "class_prob_3"),
                            gs_range_0_lo = widgets.IntSlider(style= style,value=200,min=0,max=255,step=1, description = "gs_range_0_lo"),
                            gs_range_0_hi = widgets.IntSlider(style= style,value=255,min=0,max=255,step=1, description = "gs_range_0_hi"),
                            gs_range_1_lo = widgets.IntSlider(style= style,value=0,min=0,max=255,step=1, description = "gs_range_1_lo"),
                            gs_range_1_hi = widgets.IntSlider(style= style,value=100,min=0,max=255,step=1, description = "gs_range_1_hi"),
                            rand_seed = widgets.IntSlider(style= style,value=1232,min=1230,max=1239,step=1, description = "rand_seed"),
                            display_size = widgets.IntSlider(style= style, value=30,min=5,max=50,step=1),
                            show_labels = widgets.Checkbox(style= style,value=False,description='target labels'),
                            object_count = widgets.Checkbox(style= style,value=True,description='object_count')
                            ):


    obj_counting_dm = ObjectCounting_DM(
                                    train_val_size = 10,
                                    img_size = img_size,
                                    train_val_split = (.6, .4),
                                    shapes_per_image = (shapes_per_img_lo, shapes_per_img_hi),
                                    class_probs=(class_prob_1, class_prob_2, class_prob_3),
                                    rand_seed=rand_seed,
                                    class_map={
                                        0: {"name": "background", "gs_range": (gs_range_0_lo, gs_range_0_hi), "target_color": (255,255,255)},
                                        1: {"name": "rectangle", "gs_range": (gs_range_1_lo, gs_range_1_hi), "target_color": (255, 0, 0)},
                                        2: {"name": "line", "gs_range": (0, 100), "target_color": (0, 255, 0)},
                                        3: {"name": "donut", "gs_range": (0, 100), "target_color": (0, 0, 255)},
                                    },
                                    dataloader_shuffle={"train": False, "val": False, "test": False},
                                    object_count = object_count,
                                       )

    # Visualize and understand some random images
    obj_counting_dm.setup(stage = "fit")
    dataiter = iter(obj_counting_dm.train_dataloader())

    imgs, labels = next(dataiter)

    result_images = [imgs[i] for i in range(len(imgs))]

    if show_labels:
        result_images = [add_labels(img, label, obj_counting_dm.class_map, object_count = object_count, pred = False) for img, label in zip(imgs, labels)]

    grid = make_grid(result_images)
    show(grid, figsize = (display_size, display_size))

interactive(children=(IntSlider(value=256, description='img_size', max=1000, min=128, style=SliderStyle(descri…