# First test to see if pretrained models are already good enough to classify our DarWild animals

In [None]:
import cv2
import numpy as np
from PIL import Image

## Check the resolution of the clips

In [None]:
## doesn't work in VSCode interactive
# def display_frame(frame):
#     cv2.imshow("Frame", frame)
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()

def extract_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []

    while True:
        ret, frame = cap.read()

        if not ret:
            break

        frames.append(frame)

    cap.release()
    return frames


def crop_frames(frame, pixel_frombottom=100):
    return frame[:-pixel_frombottom, :, :]

In [None]:
def get_frames(video_path):
    frames = extract_frames(video_path)
    print(f'The clip contains {len(frames)} frames.')
    return frames

In [None]:
video_path = '../images/DSCF0935.MP4'

# Extract frames from the video
frames = extract_frames(video_path)
print(f'The clip contains {len(frames)} frames.')

# for i, frame in enumerate(frames):
#     # Print resolution
#     resolution = (frame.shape[1], frame.shape[0])  # (width, height)
#     print(f"Frame {i + 1} resolution: {resolution}")

#     # Print RGB channels
#     channels = frame.shape[2]  # Number of channels (3 for RGB)
#     print(f"Frame {i + 1} channels: {channels}")

#     # Convert frame to numpy array with desired shape
#     frame_array = np.array(frame.transpose(2, 0, 1))  # (channels, height, width)
#     print(f"Frame {i + 1} array shape: {frame_array.shape}")

#     # Save frame_array to file if needed
#     # np.save(f"frame_{i + 1}.npy", frame_array)

## Load a pretrained model
Pretrained models can be found here: https://github.com/huggingface/pytorch-image-models
They've all been trained on the same collection of images (IMAGENET) with 1000 labels.

Available models and their performance can be found here: https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv

According to Artem we want:
1. high top1/5 score
2. small param_count (too many parameters would make the model very large and not easy to run on my laptop)

Quickstart guide: https://huggingface.co/docs/timm/quickstart

In [None]:
import timm
from PIL import Image
import torch

import requests
from collections import Counter

In [None]:
def numpy_to_pil(numpy_image):
    # opencv loads frames as BGR rather than RGB but PIL expects RGB
    pil_image = Image.fromarray(cv2.cvtColor(numpy_image, cv2.COLOR_BGR2RGB))
    return pil_image

In [None]:
class WildClip:
    def __init__(self, filepath, animal_exp):
        self.filepath = filepath
        self.expectedanimal = animal_exp

    def get_frames(self):
        self.frames = extract_frames(self.filepath)
        print(f'The clip contains {len(self.frames)} frames.')

    def classify_image(self, frame, transform, model):
        # remove the time stamp banner
        frame_crop = crop_frames(frame, pixel_frombottom=110)
        pil_image = numpy_to_pil(frame_crop)
        image_tensor = transform(pil_image)

        # We use unsqueeze(0) in this case, as the model is expecting a batch dimension.
        # don't keep track of the gradients to not run out of memory
        with torch.no_grad():
            output = model(image_tensor.unsqueeze(0))

        # To get the predicted probabilities, we apply softmax to the output. This leaves us with a tensor of shape (num_classes,).
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        # get the top5 predictions
        values, indices = torch.topk(probabilities, 5)

        return(values, indices)

    def find_animals_clip(self, transform, model, IMAGENET_1k_LABELS):
        predictions = []
        for i, frame in enumerate(self.frames):
            if i % 20 == 0:
                print(i)
            prob, val = self.classify_image(frame, transform, model)
            predictions += [(prob,val)]
        majority_vote = Counter([j for i in predictions for j in i[1].numpy()])
        indeces = sorted([(i, j) for i, j in majority_vote.items()], key=lambda x: x[1], reverse=True)

        # If we check the imagenet labels for the top index, we can see what the model predicted…
        print(f'For {self.expectedanimal} model finds following classes:')
        print([{'label': IMAGENET_1k_LABELS[idx[0]], 'number of frames:': idx[1]} for idx in indeces])

        self.predictions=predictions
        

In [None]:
# get imagenet labels
IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')

In [None]:
model = timm.create_model('tf_efficientnet_b5.ns_jft_in1k', pretrained=True)
# Note: The returned PyTorch model is set to train mode by default, so you must call .eval() on it if you plan to use it for inference.
model.eval()


In [None]:
# figure out what  transforms where applied for model inputs
# important parameters: resolution and normalisation
# model.pretrained_cfg
timm.data.resolve_data_config(model.pretrained_cfg)

In [None]:
# create a transform that transforms images into the right input format
data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
transform = timm.data.create_transform(**data_cfg)
transform

In [None]:
Ducks = WildClip('../images/DSCF0935.MP4', animal_exp='ducks')
Ducks.get_frames()
numpy_to_pil(Ducks.frames[4]).save('../images/Ducks.png')
# Ducks.find_animals_clip(transform, model, IMAGENET_1k_LABELS)

In [None]:
Otter = WildClip('../images/DSCF0005.MP4', animal_exp='otter')
Otter.get_frames()
numpy_to_pil(Otter.frames[1]).save('../images/Otter.png')

# Otter.find_animals_clip(transform, model, IMAGENET_1k_LABELS)

In [None]:
Squirrel = WildClip('../images/DSCF0006.MP4', animal_exp='squirrel')
Squirrel.get_frames()
numpy_to_pil(Squirrel.frames[50]).save('../images/Squirrel.png')

# Squirrel.find_animals_clip(transform, model, IMAGENET_1k_LABELS)

In [None]:
Mouse = WildClip('../images/DSCF0115.MP4', animal_exp='mouse')
Mouse.get_frames()
numpy_to_pil(Mouse.frames[100]).save('../images/Mouse.png')

# Mouse.find_animals_clip(transform, model, IMAGENET_1k_LABELS)

In [None]:
Pigeon = WildClip('../images/DSCF0025.MP4', animal_exp='pigeon')
Pigeon.get_frames()
numpy_to_pil(Pigeon.frames[200]).save('../images/Pigeon.png')

# Pigeon.find_animals_clip(transform, model, IMAGENET_1k_LABELS)