In [None]:
%config Completer.use_jedi = False

In [None]:
## generate .rst report
!PYTHONPATH=$(pwd) python ./scripts/report.py

In [None]:
import sys
sys.path.insert(0, "/workspace8/video_toolkit/")
from VideoToolkit.tools import rescal_to_image, get_cv_resize_function
resize_func = get_cv_resize_function()

In [None]:
import torch
from torch import nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset


from retinanet.model.detection.retinanet import retinanet_resnet50_fpn
from retinanet.model.utils import outputs_to_logits, logits_to_preds

from retinanet.datasets.bird import BirdDetection, BirdClassification
from retinanet.datasets.transforms import *
from retinanet.datasets.utils import TransformDatasetWrapper, train_val_split

from retinanet.utils.visualizatioin import vis_features, vis_features_CAM

import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt

# %matplotlib inline

In [None]:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)
print("Torch Using device:", device)

In [None]:
data_log_dir = "/workspace8/RetinaNet/experiments/dataset"
train_transform = Compose(
    [
        ToTensor(device),
        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [None]:
train_dataset_cls = BirdClassification()
train_dataset_cls.load(data_log_dir, file_name="train_cls")
val_dataset_cls = BirdClassification()
val_dataset_cls.load(data_log_dir, file_name="validation_cls")

train_dataset_cls = TransformDatasetWrapper(train_dataset_cls, train_transform)
val_dataset_cls = TransformDatasetWrapper(val_dataset_cls, train_transform)

In [None]:
train_dataset_det = BirdDetection()
train_dataset_det.load(data_log_dir, file_name="train_detection")
val_dataset_det = BirdDetection()
val_dataset_det.load(data_log_dir, file_name="validation_detection")

train_dataset_det = TransformDatasetWrapper(train_dataset_det, train_transform)
val_dataset_det = TransformDatasetWrapper(val_dataset_det, train_transform)

In [None]:
data_dir = "../dataset/abc"
large_dataset = BirdDetection(
        images_dir=os.path.join(data_dir, "JPEGImages"),
        annotations_dir=os.path.join(data_dir, "Annotations"),
    )
large_dataset = TransformDatasetWrapper(large_dataset, train_transform)

In [None]:
model = retinanet_resnet50_fpn(num_classes=2, pretrained=False, pretrained_backbone=False)

model = model.to(device)
model.eval()

In [None]:
def get_res(model, inp):
    model.eval()
    image, label = inp

    losses, bb_pred, cls_pred = model([image], [label])

    logit = outputs_to_logits(cls_pred)
    pred = logits_to_preds(logit)
    label = label["img_cls_labels"]
    
    return logit, pred, label

# def logits_to_preds(logits):
#     return (logits > torch.min(logits)).float()

def logits_to_preds(logits):
    return (logits > 0.5).float()

def get_errors(model, dataset):
    res = []
    for i in range(len(dataset)):
        inp = dataset[i]
        logits, preds, label = get_res(model, inp)

        if not torch.eq(preds, label).all():
            print(f"\n index: {i}")
            print(f"logits: {logits}")
            print(f"preds : {preds}")
            print(f"label : {label}")
            res.append(i)
    return res

## Visualize Detection from scratch

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_0_1_det_scratch.pth"))

In [None]:
idx = random.randint(0, len(val_dataset_det)-1)
img = val_dataset_det[idx][0]
gt_boxes = val_dataset_det[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device)

In [None]:
ds = large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device, path=f"reports/det_scratch/{idx}.jpg")

## Visualize Detection transfer learning

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_0_2_det_transferlr.pth"))

In [None]:
ds = train_dataset_det
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device)

In [None]:
ds = large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device, path=f"reports/det_transferlr/{idx}.jpg")

## Visualize Image Level Classifier from scratch

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_1_1_img_cls_scratch.pth"))

In [None]:
err_indices = get_errors(model, train_dataset_cls)

In [None]:
len(err_indices)

In [None]:
idx = random.randint(0, len(train_dataset_cls))
img = train_dataset_cls[idx][0]

In [None]:
idx = random.randint(0, len(err_indices))
img = train_dataset_cls[err_indices[idx]][0]

In [None]:
fig = vis_features(model, img, threshold=0.2, device=device)

In [None]:
fig = vis_features_CAM(model, img, threshold=0.2, device=device)

### Visualize Finetuned Detection Task (on from scratch)

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_1_2_ft_det_scratch.pth"))

In [None]:
idx = random.randint(0, len(val_dataset_det)-1)
img = val_dataset_det[idx][0]
gt_boxes = val_dataset_det[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device)

In [None]:
fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device)

In [None]:
ds = large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device, path=f"reports/ft_det_scratch/{idx}.jpg")

## Visualize Image Level Classifier with transfer learning

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_avg_2_1_img_cls_transferlr.pth"))

In [None]:
err_indices = get_errors(model, train_dataset_cls)

In [None]:
idx = random.randint(0, len(train_dataset_cls))
img = train_dataset_cls[idx][0]

In [None]:
idx = random.randint(0, len(err_indices))
img = train_dataset_cls[12][0]

In [None]:
fig = vis_features(model, img, threshold=0.2, device=device)

In [None]:
fig = vis_features_CAM(model, img, threshold=0.2, device=device)

### Visualize Finetuned Detection Task (on with transfer learning)

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_2_2_ft_det_transferlr.pth"))

In [None]:
ds = large_dataset
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.0, device=device)

In [None]:
fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device)

In [None]:
ds = large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device, path=f"reports/ft_det_transferlr/{idx}.jpg")

## Visualize intermittent training of image level cls and detection

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/chpt_3_0.cls_det_scratch.pth"))

In [None]:
ds = large_dataset
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.0, device=device)

In [None]:
fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device)

In [None]:
ds = large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device, path=f"reports/cls_det_scratch/{idx}.jpg")