# Demo

## Importing required libraries

In [101]:
import torch
from PIL import Image
from torchvision.transforms import transforms 
import pandas as pd
from torchvision.models import resnet50
import torch.nn as nn

In [102]:
torch.manual_seed(2023)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
detector_path = "data/models/bird_only.pth"
img_path = 'data/tiled_small_subset/DJI_20210517104803_0029_0_0.jpg'
num_class = 23

## Detection

In [103]:
# Load trained bird detector
detector = torch.load(detector_path, map_location=device)
detector.eval()

# Upload and transform image 
transformer = transforms.Compose([transforms.PILToTensor(),
                                  transforms.ConvertImageDtype(torch.float)])

image = Image.open(img_path).convert('RGB')
image_tensor = transformer(image)
image_tensor = image_tensor.unsqueeze_(0)  # So the image is treated as a batch 

# Detect birds in image
boxes = detector(image_tensor)
print(boxes)

# Have a table of coornidates of the bounding boxes
boxes_array = boxes[0]['boxes'].detach().numpy()
boxes_df = pd.DataFrame(boxes_array, columns=['x1', 'y1', 'x2', 'y2'])

# Extract the bounding boxes from the image
bird_images = []
for box in boxes_array:
    x1, y1, x2, y2 = box
    bird_images.append(image.crop((x1, y1, x2, y2)))

[{'boxes': tensor([[376.5843, 513.5974, 437.9249, 573.2222],
        [  0.0000, 292.9141,  23.3197, 377.2767],
        [ 68.1124, 465.9175, 143.2446, 495.8057]], grad_fn=<StackBackward0>), 'labels': tensor([1, 1, 1]), 'scores': tensor([0.9873, 0.1312, 0.0755], grad_fn=<IndexBackward0>)}]


## Classification

In [108]:
# Load bird classifier
resnet = resnet50(weights='DEFAULT')
resnet.fc = nn.Linear(resnet.fc.in_features, num_class)
resnet.eval()

# Transform images to be compatible with resnet
resnet_transormer = transforms.Compose([transforms.Resize((80, 80)),
                                        transforms.PILToTensor(),
                                        transforms.ConvertImageDtype(torch.float)])
                                            
bird_tensors = torch.stack([resnet_transormer(bird_image) for bird_image in bird_images])

# Classify birds
label_scores = resnet(bird_tensors)
labels = []
scores = []
for label_score in label_scores:
    label = label_score.argmax().item()
    score = label_score.max().item()
    labels.append(label)
    scores.append(score)
    # print(label, score)

# Add labels and scores to the table
boxes_df['label'] = labels
boxes_df['score'] = scores
print(boxes_df)
print(f"Number of birds detected: {len(boxes_df)}")

           x1          y1          x2          y2  label     score
0  376.584320  513.597412  437.924866  573.222229      1  0.584811
1    0.000000  292.914093   23.319675  377.276703     21  0.668366
2   68.112427  465.917542  143.244629  495.805725     13  0.711072
Number of birds detected: 3
