### Load Model

In [8]:
import wandb
import torch
from gorillatracker.model import EfficientNetV2Wrapper, SwinV2BaseWrapper
from torchvision.transforms import transforms

wandb.login()
wandb.init(mode="disabled")
api = wandb.Api()

artifact = api.artifact(
    "gorillas/Embedding-SwinV2-CXL-Open/model-8vymlbht:v3",
    type="model",
)
artifact_dir = artifact.download()
model = artifact_dir + "/model.ckpt"

# load model
checkpoint = torch.load(model, map_location=torch.device("cpu"))

model = SwinV2BaseWrapper(  # switch this with the model you want to use
    model_name_or_path="SwinV2_Base",
    from_scratch=False,
    loss_mode="offline/native",
    weight_decay=0.0001,
    lr_schedule="linear",
    warmup_mode="linear",
    warmup_epochs=0,
    max_epochs=10,
    initial_lr=0.00001,
    start_lr=0.00001,
    end_lr=0.00001,
    beta1=0.9,
    beta2=0.999,
    embedding_size=128,
)
# the following lines are necessary to load a model that was trained with arcface (the prototypes are saved in the state dict)
#model.loss_module_train.prototypes = torch.nn.Parameter(checkpoint["state_dict"]["loss_module_train.prototypes"])
#model.loss_module_val.prototypes = torch.nn.Parameter(checkpoint["state_dict"]["loss_module_val.prototypes"])

transform=transforms.Compose(  # use the transforms that were used for the model (except of course data augmentations)
        [
            transforms.ToTensor(),
            transforms.Resize((192, 192)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]), # if your model was trained with normalization, you need to normalize the images here as well
        ]
    )

model.load_state_dict(checkpoint["state_dict"])
model.eval()
print("Model loaded successfully")

[34m[1mwandb[0m: Downloading large artifact model-8vymlbht:v3, 996.45MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.5


Model loaded successfully


### Load Video (and look at all the gorilla ids which contain face images)

In [9]:
import cv2
from gorillatracker.utils.video_models import VideoClip, _parse_tracked_video_clip
from datetime import datetime
from PIL import Image

json_path = "/workspaces/gorillatracker/data/derived_data/spac_gorillas_converted_labels_tracked/M002_20220529_031_tracked.json"
mp4_path = "/workspaces/gorillatracker/video_data" + json_path.split("spac_gorillas_converted_labels_tracked")[1].replace("_tracked.json", ".mp4")
print(mp4_path)
v = VideoClip(video_id="", camera_id="", start_time=datetime.now())
v = _parse_tracked_video_clip(v, json_path)
video = cv2.VideoCapture(mp4_path)

for i, _ in enumerate(v.trackings):
    gorilla = v.trackings[i]
    if len(gorilla.bounding_boxes_face) > 0:
        print(i)
        print(gorilla.bounding_boxes_face)

/workspaces/gorillatracker/video_data/M002_20220529_031.mp4
0
[TrackedFrame(f=13, bb=((1755, 608), (1813, 689)), c=0.5903230905532837), TrackedFrame(f=14, bb=((1755, 606), (1814, 692)), c=0.6314746737480164), TrackedFrame(f=15, bb=((1766, 610), (1824, 694)), c=0.5695499181747437), TrackedFrame(f=16, bb=((1772, 610), (1827, 693)), c=0.7277531027793884), TrackedFrame(f=17, bb=((1775, 609), (1831, 696)), c=0.709379255771637), TrackedFrame(f=18, bb=((1778, 610), (1839, 695)), c=0.7195687890052795), TrackedFrame(f=19, bb=((1778, 609), (1848, 694)), c=0.6800578832626343), TrackedFrame(f=22, bb=((1793, 612), (1865, 695)), c=0.627274751663208), TrackedFrame(f=23, bb=((1800, 612), (1871, 696)), c=0.7736772298812866), TrackedFrame(f=24, bb=((1803, 613), (1874, 697)), c=0.7796608209609985), TrackedFrame(f=25, bb=((1805, 613), (1878, 694)), c=0.7743455171585083), TrackedFrame(f=26, bb=((1808, 617), (1884, 701)), c=0.7462272644042969), TrackedFrame(f=27, bb=((1813, 619), (1890, 699)), c=0.742549359

### Extract face images and embeddings

In [10]:
gorilla = v.trackings[0] # change this to the gorilla you want to extract the embeddings from

faces = []
embeddings = []
for frame in gorilla.bounding_boxes_face:
    video.set(cv2.CAP_PROP_POS_FRAMES, frame.f)
    ret, img = video.read()
    cropped_img = Image.fromarray(img).crop(frame.bb[0] + frame.bb[1])
    faces.append(cropped_img)
    img = transform(cropped_img)
    img = img.unsqueeze(0)
    embedding = model(img)
    embeddings.append(embedding)

### Define Embedding Projector

In [11]:
from sklearn.manifold import Isomap, LocallyLinearEmbedding, MDS, SpectralEmbedding, TSNE
from sklearn.decomposition import PCA
import umap.umap_ as umap

class EmbeddingProjector:
    def __init__(self):
        self.algorithms = {
            "tsne": TSNE(n_components=2),
            "isomap": Isomap(n_components=2),
            "lle": LocallyLinearEmbedding(n_components=2),
            "mds": MDS(n_components=2),
            "spectral": SpectralEmbedding(n_components=2),
            "pca": PCA(n_components=2),
            "umap": umap.UMAP(),
        }

    def reduce_dimensions(self, embeddings, method="tsne"):
        algorithm = self.algorithms.get(method, TSNE(n_components=2))
        return algorithm.fit_transform(embeddings)

### Plot each embedding with a slider

In [32]:
import matplotlib.pyplot as plt
import ipywidgets as widgets
from io import BytesIO

low_dim_embeddings = EmbeddingProjector().reduce_dimensions(torch.cat(embeddings).detach().numpy(),
                                                            method="pca")
low_dim_embeddings = low_dim_embeddings.tolist()
display(low_dim_embeddings)

x_axis, y_axis = zip(*low_dim_embeddings)
plot_list = []
for embedding in low_dim_embeddings:
    plt.xlim(min(x_axis) - 1, max(x_axis) + 1)
    plt.ylim(min(y_axis) - 1, max(y_axis) + 1)
    plt.plot(embedding[0], embedding[1], marker='+', linestyle='None', markersize=10, color='blue')
    plt.grid(True)
    
    buffer = BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    pil_image = Image.open(buffer)
    plot_list.append(pil_image)
    plt.close()

images_per_page = 2

def display_images(page):
    start = page
    print(start)
    fig, axs = plt.subplots(1, images_per_page, figsize=(15, 5))  # Create subplots
    for i in range(images_per_page):
        if start + i < len(faces) + len(plot_list):
            if i == 0:
                axs[i].imshow(faces[start])
            else:
                axs[i].imshow(plot_list[start])
            axs[i].axis('off')
        else:
            axs[i].axis('off')  # Hide axes for empty subplots
    plt.tight_layout()
    plt.show()


page_selector = widgets.IntSlider(min=0, max=(len(faces) + len(plot_list) - 1) // images_per_page, description='Page:')
widgets.interact(display_images, page=page_selector)

[[6.069356441497803, -2.606229543685913],
 [5.707740306854248, -0.13445377349853516],
 [5.233815670013428, -2.5183310508728027],
 [4.776072978973389, -3.5770421028137207],
 [4.561035633087158, -3.3034815788269043],
 [2.3291049003601074, -3.09199595451355],
 [3.0421857833862305, -1.7251754999160767],
 [2.613830804824829, -2.098525285720825],
 [1.4418076276779175, -1.5222442150115967],
 [0.9208231568336487, -1.4479016065597534],
 [0.02792559191584587, -2.535956382751465],
 [0.042257774621248245, -2.443587064743042],
 [-0.5683120489120483, -1.1450175046920776],
 [-0.4032926857471466, -2.3702445030212402],
 [-1.182234287261963, -3.036203622817993],
 [-4.032674789428711, -1.3732041120529175],
 [-3.8200595378875732, -1.6110939979553223],
 [-3.826831579208374, -2.099912166595459],
 [-3.755432367324829, -0.9303005337715149],
 [-3.696626663208008, -1.093754529953003],
 [-3.7225329875946045, -1.333591103553772],
 [-3.424424409866333, -1.491300106048584],
 [-3.858016014099121, -0.698914647102356]

interactive(children=(IntSlider(value=0, description='Page:', max=36), Output()), _dom_classes=('widget-intera…

<function __main__.display_images(page)>