In [None]:
## generate .rst report
!PYTHONPATH=$(pwd) python ./scripts/report.py

In [None]:
%config Completer.use_jedi = False

In [None]:
import sys
sys.path.insert(0, "/workspace8/video_toolkit/")
from VideoToolkit.tools import rescal_to_image, get_cv_resize_function
resize_func = get_cv_resize_function()

In [None]:
import torch
from torch import nn
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset


from retinanet.model.detection.retinanet import retinanet_resnet50_fpn
from retinanet.model.utils import outputs_to_logits, logits_to_preds

from retinanet.datasets.bird import BirdDetection, BirdClassification
from retinanet.datasets.transforms import *
from retinanet.datasets.utils import TransformDatasetWrapper, train_val_split

from retinanet.model.utils import load_chpt

from retinanet.utils.visualizatioin import vis_features, vis_features_CAM
from retinanet.utils import create_directory

import os
import cv2
import random
import numpy as np

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pylab import cm

%matplotlib inline

In [None]:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)
print("Torch Using device:", device)

In [None]:
data_log_dir = "/workspace8/RetinaNet/experiments/dataset"
train_transform = Compose(
    [
        ToTensor(device),
        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [None]:
train_dataset_cls = BirdClassification()
train_dataset_cls.load(data_log_dir, file_name="train_cls")
val_dataset_cls = BirdClassification()
val_dataset_cls.load(data_log_dir, file_name="validation_cls")

train_dataset_cls = TransformDatasetWrapper(train_dataset_cls, train_transform)
val_dataset_cls = TransformDatasetWrapper(val_dataset_cls, train_transform)

In [None]:
train_dataset_det = BirdDetection()
train_dataset_det.load(data_log_dir, file_name="train_detection")
val_dataset_det = BirdDetection()
val_dataset_det.load(data_log_dir, file_name="validation_detection")

train_dataset_det = TransformDatasetWrapper(train_dataset_det, train_transform)
val_dataset_det = TransformDatasetWrapper(val_dataset_det, train_transform)

In [None]:
data_dir = "../data/labeled_data/"
train_large_dataset = BirdDetection(
        images_dir=os.path.join(data_dir, "train"),
        annotations_dir=os.path.join(data_dir, "train"),
    )
train_large_dataset = TransformDatasetWrapper(train_large_dataset, train_transform)

test_large_dataset = BirdDetection(
        images_dir=os.path.join(data_dir, "test"),
        annotations_dir=os.path.join(data_dir, "test"),
    )
test_large_dataset = TransformDatasetWrapper(test_large_dataset, train_transform)

In [None]:
model = retinanet_resnet50_fpn(num_classes=2, pretrained=False, pretrained_backbone=False, extra_heads=["cls"])

model = model.to(device)
model.eval()

In [None]:
def precision_recall_curve(model, dataset, at_iou = 0.2, at_nms = 0.5):
    model.eval()
    score_list = []
    pseudo_y_true_list = []
    actual_positives = []

    with torch.no_grad():
        for step, (images, targets) in enumerate(dataset):
            ###############################################################################
            # Normal
            ###############################################################################
            images = [images]
            targets = [targets]
            detections = model(images)

            boxes, scores, labels = (
                [det["boxes"] for det in detections][0],
                [det["scores"] for det in detections][0],
                [det["labels"] for det in detections][0],
            )

            gt_boxes, gt_labels = (
                [lbl["boxes"] for lbl in targets][0],
                [lbl["labels"] for lbl in targets][0],
            )

            # remove overlapping boxes
            keep = torchvision.ops.batched_nms(
                boxes, scores, labels, at_nms
            )
            boxes, scores = boxes[keep], scores[keep]
            
            if not boxes.size()[0] > 0:
                pseudo_y_ture = torch.zeros(1).to(boxes.device)
                scores = torch.zeros(1).to(boxes.device)
            else:
                # matching predictions
                match_matrix = torchvision.ops.box_iou(boxes, gt_boxes)
                _, match_indices = torch.sort(scores, dim=0, descending=True)
                match_matrix = match_matrix[match_indices, :]
                match_scores, matched_idx = match_matrix.max(dim=1)

                # apply IOU threshold to get pseudo y_true labels
                pos_idx = torch.where(match_scores > at_iou)[0]
                pseudo_y_ture = torch.where(match_scores > at_iou, 1, 0).to(
                    match_scores.device
                )
                
                # count duplicated predistions with lower predicted score as false positives
                detected_set = set()
                for idx in pos_idx:
                    if matched_idx[idx].item() in detected_set:
                        pseudo_y_ture[idx] = 0
                    else:
                        detected_set.add(matched_idx[idx].item())

            pseudo_y_true_list.append(pseudo_y_ture)
            score_list.append(scores)
            actual_positives.append(gt_boxes.shape[0])

        pseudo_y_ture = torch.cat(pseudo_y_true_list)
        scores = torch.cat(score_list)

        # sorting pseudo labels based on predicted scores
        # makes it possoble for calculating precision and recall values
        # for different thresholds
        _, match_indices = torch.sort(scores, dim=0, descending=True)
        pseudo_y_ture = pseudo_y_ture[match_indices]

        tps = pseudo_y_ture.float().cumsum(0)
        fps = (torch.ones(1).to(pseudo_y_ture.device) - pseudo_y_ture).float().cumsum(0)

        # precision recall curve
        precision = tps / (tps + fps)
        precision[torch.isnan(precision)] = 0
        recall = tps / torch.tensor(actual_positives).sum().to(tps.device) # tps[-1]
        recall[torch.isnan(recall)] = 0


        # average precision
        ap = (torch.diff(recall) * precision[:-1]).sum()
    
    return precision.cpu().numpy(), recall.cpu().numpy(), ap.cpu().numpy()

def plot_pr_rc_curve(model, dataset, title):
    # Create figure and add axes object
    fig = plt.figure()
    ax = fig.add_axes([0, 0, 1, 1])
    ax.set_title(title)
    ax.set_xlabel('Precision', labelpad=10, fontsize=16)
    ax.set_ylabel('Recall', labelpad=10, fontsize=16)
    for at_iou in [0.2, 0.5, 0.75, 0.9, 0.95]:
        precision, recall, ap = precision_recall_curve(model, dataset, at_iou, at_nms=0.5)
        ax.plot(recall, precision, label=f"AP@IOU{at_iou:.2f}:{ap:.3f}")
    # Add legend to plot
    ax.legend(bbox_to_anchor=(1, 1), loc=1, frameon=False)
    plt.show() 

In [None]:
def get_res(model, inp):
    model.eval()
    image, label = inp

    losses, bb_pred, cls_pred = model([image], [label])

    logit = outputs_to_logits(cls_pred)
    pred = logits_to_preds(logit)
    label = label["img_cls_labels"]
    
    return logit, pred, label

# def logits_to_preds(logits):
#     return (logits > torch.min(logits)).float()

def logits_to_preds(logits):
    return (logits > 0.5).float()

def get_errors(model, dataset):
    res = []
    for i in range(len(dataset)):
        inp = dataset[i]
        logits, preds, label = get_res(model, inp)

        if not torch.eq(preds, label).all():
            print(f"\n index: {i}")
            print(f"logits: {logits}")
            print(f"preds : {preds}")
            print(f"label : {label}")
            res.append(i)
    return res

## Visualize Detection from scratch

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/best_chpt_0_1.det_scratch_test.pth")

In [None]:
idx = random.randint(0, len(val_dataset_det)-1)
img = val_dataset_det[idx][0]
gt_boxes = val_dataset_det[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device)

In [None]:
ds = test_large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.0, thickness=1, device=device, path=f"reports/det_scratch_test/{idx}.jpg")

In [None]:
exp = "0_1.det_scratch_test"
plot_pr_rc_curve(model, train_dataset_det, f"EXP: {exp}  |  DS: train_dataset_det")
plot_pr_rc_curve(model, val_dataset_det, f"EXP: {exp}  |  DS: val_dataset_det")
plot_pr_rc_curve(model, test_large_dataset, f"EXP: {exp}  |  DS: test_large_dataset")

## Visualize Detection transfer learning

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/best_chpt_0_2.det_transferlr_test.pth")

In [None]:
ds = large_dataset
# ds = train_dataset_det
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.0, device=device, vis_boxes=True)

In [None]:
ds = test_large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.0, thickness=1, device=device, path=f"reports/det_transferlr_test/{idx}.jpg")

In [None]:
exp = "0_2.det_transferlr_test"
plot_pr_rc_curve(model, train_dataset_det, f"EXP: {exp}  |  DS: train_dataset_det")
plot_pr_rc_curve(model, val_dataset_det, f"EXP: {exp}  |  DS: val_dataset_det")
plot_pr_rc_curve(model, test_large_dataset, f"EXP: {exp}  |  DS: test_large_dataset")

## Visualize Image Level Classifier from scratch

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/best_chpt_1_1.img_cls_scratch_test.pth")

In [None]:
err_indices = get_errors(model, train_dataset_cls)

In [None]:
len(err_indices)

In [None]:
idx = random.randint(0, len(train_dataset_cls))
img = train_dataset_cls[idx][0]

In [None]:
idx = random.randint(0, len(err_indices))
img = train_dataset_cls[err_indices[idx]][0]

In [None]:
fig = vis_features(model, img, threshold=0.2, device=device)

In [None]:
fig = vis_features_CAM(model, img, threshold=0.2, device=device)

### Visualize Finetuned Detection Task (on from scratch)

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/best_chpt_1_2.ft_det_scratch_test.pth")

In [None]:
idx = random.randint(0, len(val_dataset_det)-1)
img = val_dataset_det[idx][0]
gt_boxes = val_dataset_det[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device)

In [None]:
fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device)

In [None]:
ds = test_large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0.2, thickness=1, device=device, path=f"reports/ft_det_scratch_test/{idx}.jpg")

In [None]:
exp = "1_2.ft_det_scratch_test"
plot_pr_rc_curve(model, train_dataset_det, f"EXP: {exp}  |  DS: train_dataset_det")
plot_pr_rc_curve(model, val_dataset_det, f"EXP: {exp}  |  DS: val_dataset_det")
plot_pr_rc_curve(model, test_large_dataset, f"EXP: {exp}  |  DS: test_large_dataset")

## Visualize Image Level Classifier with transfer learning

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/best_chpt_2_1_img_cls_transferlr.pth")

In [None]:
err_indices = get_errors(model, train_dataset_cls)

In [None]:
idx = random.randint(0, len(train_dataset_cls))
img = train_dataset_cls[idx][0]

In [None]:
idx = random.randint(0, len(err_indices))
img = train_dataset_cls[12][0]

In [None]:
fig = vis_features(model, img, threshold=0.2, device=device)

In [None]:
fig = vis_features_CAM(model, img, threshold=0.2, device=device)

### Visualize Finetuned Detection Task (on with transfer learning)

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/best_chpt_2_2.ft_det_transferlr_test.pth")

In [None]:
ds.dataset.files_name[idx]

In [None]:
ds = val_dataset_det
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0.0, device=device)

In [None]:
fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device)

In [None]:
ds = test_large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0.0, thickness=1, device=device, path=f"reports/ft_det_transferlr_test/{idx}.jpg")

In [None]:
exp = "2_2.ft_det_transferlr_test"
plot_pr_rc_curve(model, train_dataset_det, f"EXP: {exp}  |  DS: train_dataset_det")
plot_pr_rc_curve(model, val_dataset_det, f"EXP: {exp}  |  DS: val_dataset_det")
plot_pr_rc_curve(model, test_large_dataset, f"EXP: {exp}  |  DS: test_large_dataset")

## Visualize intermittent training of image level cls and detection

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/chpt_3_0.intermitent_cls_det_scratch_test3.pth")

In [None]:
ds = large_dataset
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.0, device=device)

In [None]:
fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device)

In [None]:
ds = large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device, path=f"reports/cls_det_scratch/{idx}.jpg")

In [None]:
exp = "3_0.intermitent_cls_det_scratch_test3"
plot_pr_rc_curve(model, train_dataset_det, f"EXP: {exp}  |  DS: train_dataset_det")
plot_pr_rc_curve(model, val_dataset_det, f"EXP: {exp}  |  DS: val_dataset_det")
plot_pr_rc_curve(model, test_large_dataset, f"EXP: {exp}  |  DS: test_large_dataset")

#### transferlr

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/chpt_3_0.intermitent_cls_det_transferlr_test.pth")

In [None]:
ds = large_dataset
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features(model, img, gt_boxes=gt_boxes, threshold=0.0, device=device)

In [None]:
fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device)

In [None]:
ds = large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device, path=f"reports/cls_det_scratch/{idx}.jpg")

In [None]:
exp = "3_0.intermitent_cls_det_transferlr_test"
plot_pr_rc_curve(model, train_dataset_det, f"EXP: {exp}  |  DS: train_dataset_det")
plot_pr_rc_curve(model, val_dataset_det, f"EXP: {exp}  |  DS: val_dataset_det")
plot_pr_rc_curve(model, test_large_dataset, f"EXP: {exp}  |  DS: test_large_dataset")

## Visualize training with autoencoder and image level cls and detection

#### from scratch

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/chpt_4_1.img_cls_regen_scratch.pth")

In [None]:
ds = large_dataset
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device)

In [None]:
exp = "3_0.intermitent_cls_det_transferlr_test"
plot_pr_rc_curve(model, train_dataset_det, f"EXP: {exp}  |  DS: train_dataset_det")
plot_pr_rc_curve(model, val_dataset_det, f"EXP: {exp}  |  DS: val_dataset_det")
plot_pr_rc_curve(model, test_large_dataset, f"EXP: {exp}  |  DS: test_large_dataset")

##### detection

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/chpt_4_2.ft_det_scratch.pth")

In [None]:
ds = large_dataset
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0.1, device=device)

In [None]:
ds = large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device, path=f"reports/cls_regen_det_scratch/{idx}.jpg")

In [None]:
exp = "4_1.img_cls_regen_scratch"
plot_pr_rc_curve(model, train_dataset_det, f"EXP: {exp}  |  DS: train_dataset_det")
plot_pr_rc_curve(model, val_dataset_det, f"EXP: {exp}  |  DS: val_dataset_det")
plot_pr_rc_curve(model, test_large_dataset, f"EXP: {exp}  |  DS: test_large_dataset")

#### transfer learning

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/chpt_5_1.img_cls_regen_transferlr.pth")

In [None]:
weight = model.head.extra_heads.image_classification_head.fc.weight

In [None]:
ds = large_dataset
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0, device=device)

##### detection

In [None]:
model = load_chpt(model, "/workspace8/RetinaNet/experiments/checkpoints/chpt_5_2.ft_det_transferlr.pth")

In [None]:
# check if the classification head is loaded correctly
#    the classification head is requaired for CAM visualizations
torch.all(weight == model.head.extra_heads.image_classification_head.fc.weight)

In [None]:
ds = large_dataset
idx = random.randint(0, len(ds)-1)
img = ds[idx][0]
gt_boxes = ds[idx][1]["boxes"]

fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0.0, device=device)

In [None]:
ds = large_dataset
for idx, d in enumerate(ds):
    img = ds[idx][0]
    gt_boxes = ds[idx][1]["boxes"]
    fig = vis_features_CAM(model, img, gt_boxes=gt_boxes, threshold=0.2, device=device, path=f"reports/cls_regen_det_transferlr/{idx}.jpg")

In [None]:
exp = "4_1.img_cls_regen_scratch"
plot_pr_rc_curve(model, train_dataset_det, f"EXP: {exp}  |  DS: train_dataset_det")
plot_pr_rc_curve(model, val_dataset_det, f"EXP: {exp}  |  DS: val_dataset_det")
plot_pr_rc_curve(model, test_large_dataset, f"EXP: {exp}  |  DS: test_large_dataset")