In [1]:
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F


def plot(imgs, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0])
    _, axs = plt.subplots(nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        for col_idx, img in enumerate(row):
            boxes = None
            masks = None
            if isinstance(img, tuple):
                img, target = img
                if isinstance(target, dict):
                    boxes = target.get("boxes")
                    masks = target.get("masks")
                elif isinstance(target, tv_tensors.BoundingBoxes):
                    boxes = target
                else:
                    raise ValueError(f"Unexpected target type: {type(target)}")
            img = F.to_image(img)
            if img.dtype.is_floating_point and img.min() < 0:
                # Poor man's re-normalization for the colors to be OK-ish. This
                # is useful for images coming out of Normalize()
                img -= img.min()
                img /= img.max()

            img = F.to_dtype(img, torch.uint8, scale=True)
            if boxes is not None:
                img = draw_bounding_boxes(img, boxes, colors="yellow", width=3)
            if masks is not None:
                img = draw_segmentation_masks(img, masks.to(torch.bool), colors=["green"] * masks.shape[0], alpha=.65)

            ax = axs[row_idx, col_idx]
            ax.imshow(img.permute(1, 2, 0).numpy(), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()

In [2]:
from pathlib import Path

import torch
import torch.nn as nn

import torchvision.transforms as v1
from torchvision.io import read_image

plt.rcParams["savefig.bbox"] = 'tight'
torch.manual_seed(1)

# If you're trying to run that on collab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
import sys
sys.path += ["../transforms"]
ASSETS_PATH = Path('assets')

In [3]:
dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))

transforms = torch.nn.Sequential(
    v1.RandomCrop(224),
    v1.RandomHorizontalFlip(p=0.3),
)

scripted_transforms = torch.jit.script(transforms)

plot([dog1, scripted_transforms(dog1), dog2, scripted_transforms(dog2)])

In [4]:
from torchvision.models import resnet18, ResNet18_Weights, ResNet


class Predictor(nn.Module):

    def __init__(self):
        super().__init__()
        weights = ResNet18_Weights.DEFAULT
        self.resnet18 = resnet18(weights=weights, progress=False).eval()
        self.transforms = weights.transforms(antialias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x = self.transforms(x)
            y_pred = self.resnet18(x)
            return y_pred.argmax(dim=1)

In [5]:
from torch.utils.tensorboard import SummaryWriter
from nn import TopologyObserver
from topology import Persistence, IntrinsicDimension

predictor = Predictor()
writer = SummaryWriter('runs/resnet_experiment_1')

Filtration = Persistence()
Dimension = IntrinsicDimension()

net = predictor.resnet18
observer = TopologyObserver(
    net, writer=writer,
    post_topology=[
        (net.conv1, [
            (Dimension, {'label': 'Dimension Analysis'})
        ]),
        (net.layer1, [
            (Dimension, {'label': 'Dimension Analysis'})
        ]),
        (net.layer4, [
            (Dimension, {'label': 'Dimension Analysis'})
        ])
    ]
)

In [6]:
batch = torch.stack([dog1, dog2])

res = predictor(batch)

In [6]:
predictor.resnet18.conv1.register_forward_hook(lambda s, a, r: print(r.shape))
predictor.resnet18.layer1.register_forward_hook(lambda s, a, r: print(r.shape))
predictor.resnet18.layer2.register_forward_hook(lambda s, a, r: print(r.shape))
predictor.resnet18.layer3.register_forward_hook(lambda s, a, r: print(r.shape))
predictor.resnet18.layer4.register_forward_hook(lambda s, a, r: print(r.shape))

In [11]:
predictor.resnet18.conv1.register_forward_hook(lambda s, a, r: print(r.shape[2] * r.shape[3]))
predictor.resnet18.layer1.register_forward_hook(lambda s, a, r: print(r.shape[2] * r.shape[3]))
predictor.resnet18.layer2.register_forward_hook(lambda s, a, r: print(r.shape[2] * r.shape[3]))
predictor.resnet18.layer3.register_forward_hook(lambda s, a, r: print(r.shape[2] * r.shape[3]))
predictor.resnet18.layer4.register_forward_hook(lambda s, a, r: print(r.shape[2] * r.shape[3]))

In [None]:
import json

res = predictor(batch)
res_scripted = scripted_predictor(batch)

with open(Path('assets') / 'imagenet_class_index.json') as labels_file:
    labels = json.load(labels_file)

for i, (pred, pred_scripted) in enumerate(zip(res, res_scripted)):
    assert pred == pred_scripted
    print(f"Prediction for Dog {i + 1}: {labels[str(pred.item())]}")

In [13]:
dog1.shape

In [15]:
predictor(dog1)

In [8]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/resnet_experiment_2')

In [9]:
from nn import TopologyObserver
from topology import Persistence, IntrinsicDimension

Filtration = Persistence()
Dimension = IntrinsicDimension()

net = predictor.resnet18
observer = TopologyObserver(
    net, writer=writer, pre_topology=[
        (net.conv1, [
            (Filtration, {'label': 'Input', 'distances': False, 'batches': True}),
            (Dimension, {'label': 'Dimension Analysis 2', 'distances': False, 'batches': True})
        ])
    ],
    post_topology=[
        (net.layer1, [
            (Filtration, {'label': 'Hidden 1', 'distances': False, 'batches': True}),
            (Dimension, {'label': 'Dimension Analysis', 'distances': False, 'batches': True})
        ]),
        (net.layer4, [
            (Filtration, {'label': 'Hidden 4', 'distances': False, 'batches': True}),
            (Dimension, {'label': 'Dimension Analysis', 'distances': False, 'batches': True})
        ])
    ]
)

In [None]:
predictor(batch)