In [1]:
from utils import Encoder
import utils
import torch
from datasets import RoboEireanDataModule
import lightning.pytorch as pl
from PIL import ImageDraw, Image

from models import JetNet, SingleShotDetector, ObjectDetectionTask
from lightning.pytorch.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import torch.nn.functional as F
import torchvision.transforms as T

from visualize import draw_bounding_box

#### Load model checkpoint and data

In [7]:
LEARNING_RATE = 1e-1
ALPHA = 2.0
NUM_CLASSES = 1
DEFAULT_SCALINGS = torch.tensor(
    [
        [0.06549374, 0.12928654],
        [0.11965626, 0.26605093],
        [0.20708716, 0.38876095],
        [0.31018215, 0.47485098],
        [0.415882, 0.8048184],
        [0.7293086, 0.8216225],
    ]
)
encoder = Encoder(DEFAULT_SCALINGS, NUM_CLASSES)
model = JetNet(NUM_CLASSES, DEFAULT_SCALINGS.shape[0])
loss = SingleShotDetector(ALPHA)

version = 1
checkpoint_folder = f"new_logs/lightning_logs/version_{version}/checkpoints/"
checkpoint_path = "epoch=1-step=56.ckpt"


checkpoint_path = checkpoint_folder + checkpoint_path
loaded_model = ObjectDetectionTask.load_from_checkpoint(
    checkpoint_path=checkpoint_path, 
    model=model, 
    loss=loss, 
    encoder=encoder, 
    learning_rate=LEARNING_RATE)
loaded_model

# get the data we want to visualize and predict on
data_module = RoboEireanDataModule("data/raw/", encoder, 128)
data_module.setup("fit")  # TODO: inspect different stages

### Load single image prediction from validation set

In [3]:
image, _, target_bb, target_class = next(iter(data_module.val_dataloader()))
predicted_boxes, predicted_logits = loaded_model.model(image)
predicted_classes, softmax  = utils.calculate_predicted_classes(predicted_logits)
decoded_boxes = encoder.decode(predicted_boxes).squeeze()

#### Iterate over 50 batches and check for positive predictions

In [10]:
skip = 10
for i in range(50):
    image, _, target_bb, target_class = next(iter(data_module.val_dataloader()))
    
    predicted_boxes, predicted_logits = loaded_model.model(image)

    predicted_classes, softmax = utils.calculate_predicted_classes(predicted_logits)
    sorted_softmax = torch.sort(softmax[0][:,1], descending=True).indices

    decoded_boxes = encoder.decode(predicted_boxes).squeeze()


    for i in range(128):

        if 1 in predicted_classes[i]:
            print(predicted_classes)

            # object_boxes = decoded_boxes[i][predicted_classes[i] > 0]
            # if len(object_boxes) > 0:
            #     print(object_boxes)
            # else:
            #     print("No positive found")


Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/usr/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/usr/lib/python3.10/shutil.py", line 730, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/usr/lib/python3.10/shutil.py", line 728, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-obazcqxe'
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/usr/lib/python3.10/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/usr/lib/py

#### Check for duplicate prediction logits

In [None]:
equal_count = 0

for i in range(127):
    for j in range(i + 1, 128):
        print(predicted_logits[i])
        print(predicted_logits[j])
        if not torch.all(predicted_logits[i].eq(predicted_logits[j])):

            equal_count += 1

print(f'{equal_count} of the 128 predictions are not equal to each other')


In [None]:
grid_size = 28
fig = plt.figure(figsize=(9, 13))

image_list = []
for i in range(128):
    image_pil = T.ToPILImage()(image[i][0]).convert("RGBA")
    draw = ImageDraw.Draw(image_pil)
    print(predicted_classes[0])
    object_boxes = decoded_boxes[i][predicted_classes[0] > 0]
    draw = draw_bounding_box(image_pil, object_boxes)
    plt.imshow(draw)
    plt.show()