In [None]:
import torch
import torchvision
import torch.nn.functional as F


from retinanet.model.detection.retinanet import retinanet_resnet50_fpn
from retinanet.model.detection.transform import GeneralizedRCNNTransform

from retinanet.datasets.bird import BirdDetection
from retinanet.datasets.transforms import *
from retinanet.datasets.utils import TransformDatasetWrapper

import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
import sys
sys.path.insert(0, "/workspace8/video_toolkit/")
from VideoToolkit.tools import rescal_to_image, get_cv_resize_function
resize_func = get_cv_resize_function()

In [None]:
def get_features(model, images, device=None):
    transform = GeneralizedRCNNTransform(800, 1333, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    images, _ = transform(images, None)
    
    # get the features from the backbone
    features = model.backbone(images.tensors.to(device))
    
    if isinstance(features, torch.Tensor):
        features = OrderedDict([("0", features)])

    features = list(features.values())
    features = [feat.mean(1) for feat in features]
    return features

In [None]:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)
print("Torch Using device:", device)

transform = Compose(
    [
        ToTensor(device),
        Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

dataset = BirdDetection(image_dir="../dataset/data", annotations_dir="../dataset/ann")
dataset = TransformDatasetWrapper(dataset, transform)
model = retinanet_resnet50_fpn(num_classes=2, pretrained=False, pretrained_backbone=False)

model = model.to(device)
model.eval()

## Train Detection from scratch

In [None]:
!PYTHONPATH=$(pwd) python ./scripts/retinanet_train.py \
                            --lr 3e-5 \
                            --lr_delta 1e-5 \
                            --max_epoch 100 \
                            --batch_size 4 \
                            --tag 0_1_det_scratch \
                            --train_percent .7 \
                            --use_p_of_data 1 \
                            --data_dir $(pwd)/../dataset \
                            --log_dir $(pwd)/experiments

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_0_1_det_scratch.pth"))

In [None]:
idx = random.randint(0, len(dataset)-1)
img = dataset[idx][0]

model.eval()
#get features
features = get_features(model, [img], device)

imact = [feat.squeeze().cpu().detach().numpy() for feat in features]

# get predictions
predicted = model([img])
keep = torchvision.ops.nms(predicted[0]["boxes"], predicted[0]["scores"], 0.1)
keep = keep.cpu().numpy()
boxes = list(np.floor(predicted[0]["boxes"].cpu().detach().numpy()[keep]))
scores = list(predicted[0]["scores"].cpu().detach().numpy()[keep])

# Visualize
fig, axarr = plt.subplots(2, 3, figsize=(15,10))

# print(img.shape)
# for feat in imact:
#     print(feat.shape)
#     print(resize_func(feat, img.shape[1:]).shape)

img = (img.cpu().permute((1, 2, 0)).numpy().copy() * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))

# visualize boxes
print(len(boxes))
for box, score in zip(boxes, scores):
    if score > 0.2:
        cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2)

axarr[0, 0].imshow(img)
# visualize features
for j in range(1, 6):
    axarr[j//3, j%3].imshow(resize_func(imact[j-1], img.shape[:2]))

## Train Detection transfer learning

In [None]:
!PYTHONPATH=$(pwd) python ./scripts/retinanet_train.py \
                            --lr 3e-5 \
                            --lr_delta 1e-5 \
                            --max_epoch 100 \
                            --batch_size 4 \
                            --tag 0_2_det_transferlr \
                            --pretrained_backend \
                            --train_percent .7 \
                            --use_p_of_data 1 \
                            --data_dir $(pwd)/../dataset \
                            --log_dir $(pwd)/experiments

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_0_2_det_transferlr.pth"))

In [None]:
idx = random.randint(0, len(dataset)-1)
img = dataset[idx][0]

model.eval()
#get features
features = get_features(model, [img], device)

imact = [feat.squeeze().cpu().detach().numpy() for feat in features]

# get predictions
predicted = model([img])
keep = torchvision.ops.nms(predicted[0]["boxes"], predicted[0]["scores"], 0.1)
keep = keep.cpu().numpy()
boxes = list(np.floor(predicted[0]["boxes"].cpu().detach().numpy()[keep]))
scores = list(predicted[0]["scores"].cpu().detach().numpy()[keep])

# Visualize
fig, axarr = plt.subplots(2, 3, figsize=(15,10))

# print(img.shape)
# for feat in imact:
#     print(feat.shape)
#     print(resize_func(feat, img.shape[1:]).shape)

img = (img.cpu().permute((1, 2, 0)).numpy().copy() * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))

# visualize boxes
print(len(boxes))
for box, score in zip(boxes, scores):
    if score > 0.2:
        cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2)

axarr[0, 0].imshow(img)
# visualize features
for j in range(1, 6):
    axarr[j//3, j%3].imshow(resize_func(imact[j-1], img.shape[:2]))

## Train Image Level Classifier from scratch

In [None]:
!PYTHONPATH=$(pwd) python ./scripts/image_cls_train.py \
                        --opt sgd \
                        --lr 3e-4 \
                        --lr_delta 1e-6 \
                        --weight_decay 0.01 \
                        --max_epoch 100 \
                        --batch_size 12 \
                        --accumulation_steps 3 \
                        --tag 1_1_img_cls_scratch \
                        --train_percent .90 \
                        --use_p_of_data 1 \
                        --data_dir $(pwd)/../data/train \
                        --log_dir $(pwd)/experiments

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_1_1_img_cls_scratch.pth"))

In [None]:
idx = random.randint(0, len(dataset)-1)
img = dataset[idx][0]

features = get_features(model, [img], device)

imact = [feat.squeeze().cpu().detach().numpy() for feat in features]
fig, axarr = plt.subplots(2, 3, figsize=(15,10))

# print(img.shape)
# for feat in imact:
#     print(feat.shape)
#     print(resize_func(feat, img.shape[1:]).shape)

axarr[0, 0].imshow((img.permute((1, 2, 0))*np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])))
for j in range(1, 6):
    axarr[j//3, j%3].imshow(resize_func(imact[j-1], img.shape[1:]))

### Finetune on Detection Task
#### (from scratch)

In [None]:
!PYTHONPATH=$(pwd) python ./scripts/retinanet_train.py \
                            --lr 3e-5 \
                            --lr_delta 1e-5 \
                            --max_epoch 100 \
                            --batch_size 4 \
                            --tag 1_2_ft_det_scratch \
                            --pretrained $(pwd)/experiments/checkpoints/best_chpt_1_1_img_cls_scratch.pth \
                            --train_percent .7 \
                            --use_p_of_data 1 \
                            --data_dir $(pwd)/../dataset \
                            --log_dir $(pwd)/experiments

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_1_2_ft_det_scratch.pth"))

In [None]:
idx = random.randint(0, len(dataset)-1)
img = dataset[idx][0]

features = get_features(model, [img], device)

imact = [feat.squeeze().cpu().detach().numpy() for feat in features]
fig, axarr = plt.subplots(2, 3, figsize=(15,10))

# print(img.shape)
# for feat in imact:
#     print(feat.shape)
#     print(resize_func(feat, img.shape[1:]).shape)

axarr[0, 0].imshow((img.permute((1, 2, 0))*np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])))
for j in range(1, 6):
    axarr[j//3, j%3].imshow(resize_func(imact[j-1], img.shape[1:]))

## Train Image Level Classifier with transfer learning

In [None]:
!PYTHONPATH=$(pwd) python ./scripts/image_cls_train.py \
                            --opt sgd \
                            --lr 3e-4 \
                            --lr_delta 1e-6 \
                            --weight_decay 1e-4 \
                            --max_epoch 100 \
                            --batch_size 12 \
                            --accumulation_steps 3 \
                            --tag 2_1_img_cls_transferlr \
                            --pretrained_backend \
                            --train_percent .90 \
                            --use_p_of_data 1 \
                            --data_dir $(pwd)/../data/train \
                            --log_dir $(pwd)/experiments

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_2_1_img_cls_transferlr.pth"))

In [None]:
idx = random.randint(0, len(dataset)-1)
img = dataset[idx][0]

features = get_features(model, [img], device)

imact = [feat.squeeze().cpu().detach().numpy() for feat in features]
fig, axarr = plt.subplots(2, 3, figsize=(15,10))

# print(img.shape)
# for feat in imact:
#     print(feat.shape)
#     print(resize_func(feat, img.shape[1:]).shape)

axarr[0, 0].imshow((img.permute((1, 2, 0))*np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])))
for j in range(1, 6):
    axarr[j//3, j%3].imshow(resize_func(imact[j-1], img.shape[1:]))

### Finetune on Detection Task
#### (transfer learning)

In [None]:
!PYTHONPATH=$(pwd) python ./scripts/retinanet_train.py \
                            --lr 3e-5 \
                            --lr_delta 1e-5 \
                            --max_epoch 100 \
                            --batch_size 4 \
                            --tag 2_2_ft_det_transferlr \
                            --pretrained $(pwd)/experiments/checkpoints/best_chpt_2_1_img_cls_transferlr.pth \
                            --train_percent .7 \
                            --use_p_of_data 1 \
                            --data_dir $(pwd)/../dataset \
                            --log_dir $(pwd)/experiments

In [None]:
model.load_state_dict(torch.load("/workspace8/RetinaNet/experiments/checkpoints/best_chpt_2_2_ft_det_transferlr.pth"))

In [None]:
idx = random.randint(0, len(dataset)-1)
img = dataset[idx][0]

#get features
features = get_features(model, [img], device)

imact = [feat.squeeze().cpu().detach().numpy() for feat in features]

# get predictions
predicted = model([img])
keep = torchvision.ops.nms(predicted[0]["boxes"], predicted[0]["scores"], 0.1)
keep = keep.cpu().numpy()
boxes = list(np.floor(predicted[0]["boxes"].cpu().detach().numpy()[keep]))
scores = list(predicted[0]["scores"].cpu().detach().numpy()[keep])

# Visualize
fig, axarr = plt.subplots(2, 3, figsize=(15,10))

# print(img.shape)
# for feat in imact:
#     print(feat.shape)
#     print(resize_func(feat, img.shape[1:]).shape)

img = (img.cpu().permute((1, 2, 0)).numpy().copy() * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]))

# visualize boxes
print(len(boxes))
for box, score in zip(boxes, scores):
    if score > 0.3:
        cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2)

axarr[0, 0].imshow(img)
# visualize features
for j in range(1, 6):
    axarr[j//3, j%3].imshow(resize_func(imact[j-1], img.shape[:2]))