### Load Model

In [2]:
import wandb
import torch
from gorillatracker.model import EfficientNetV2Wrapper, SwinV2BaseWrapper
from torchvision.transforms import transforms

wandb.login()
wandb.init(mode="disabled")
api = wandb.Api()

artifact = api.artifact(
    "gorillas/Embedding-SwinV2-CXL-Open/model-8vymlbht:v3",
    type="model",
)
artifact_dir = artifact.download()
model = artifact_dir + "/model.ckpt"

# load model
checkpoint = torch.load(model, map_location=torch.device("cpu"))

model = SwinV2BaseWrapper(  # switch this with the model you want to use
    model_name_or_path="SwinV2_Base",
    from_scratch=False,
    loss_mode="offline/native",
    weight_decay=0.0001,
    lr_schedule="linear",
    warmup_mode="linear",
    warmup_epochs=0,
    max_epochs=10,
    initial_lr=0.00001,
    start_lr=0.00001,
    end_lr=0.00001,
    beta1=0.9,
    beta2=0.999,
    embedding_size=128,
)
# the following lines are necessary to load a model that was trained with arcface (the prototypes are saved in the state dict)
#model.loss_module_train.prototypes = torch.nn.Parameter(checkpoint["state_dict"]["loss_module_train.prototypes"])
#model.loss_module_val.prototypes = torch.nn.Parameter(checkpoint["state_dict"]["loss_module_val.prototypes"])

transform=transforms.Compose(  # use the transforms that were used for the model (except of course data augmentations)
        [
            transforms.ToTensor(),
            transforms.Resize((192, 192)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]), # if your model was trained with normalization, you need to normalize the images here as well
        ]
    )

model.load_state_dict(checkpoint["state_dict"])
model.eval()
print("Model loaded successfully")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33memirhan404[0m ([33mgorillas[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Downloading large artifact model-8vymlbht:v3, 996.45MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:3.4
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Model loaded successfully


### Load Video (and look at all the gorilla ids which contain face images)

In [54]:
import cv2
from gorillatracker.utils.video_models import VideoClip, _parse_tracked_video_clip
from datetime import datetime
from PIL import Image

json_path = "/workspaces/gorillatracker/data/derived_data/spac_gorillas_converted_labels_tracked/M002_20220725_012_tracked.json"
mp4_path = "/workspaces/gorillatracker/video_data" + json_path.split("spac_gorillas_converted_labels_tracked")[1].replace("_tracked.json", ".mp4")
print(mp4_path)
v = VideoClip(video_id="", camera_id="", start_time=datetime.now())
v = _parse_tracked_video_clip(v, json_path)
video = cv2.VideoCapture(mp4_path)

counter = 0
frame_dict = {}
for gorilla in v.trackings:
    if len(gorilla.bounding_boxes_face) > 0:
        for frame in gorilla.bounding_boxes_face:
            try:
                frame_dict[frame.f].append((frame, gorilla.individual_id))
            except KeyError:
                frame_dict[frame.f] = [(frame, gorilla.individual_id)]
        counter+=1
print(f"Number of gorillas with tracked faces: {counter}")

/workspaces/gorillatracker/video_data/M002_20220725_012.mp4
Number of gorillas with tracked faces: 7


### Define Embedding Projector

In [4]:
from sklearn.manifold import Isomap, LocallyLinearEmbedding, MDS, SpectralEmbedding, TSNE
from sklearn.decomposition import PCA
import umap.umap_ as umap

class EmbeddingProjector:
    def __init__(self):
        self.algorithms = {
            "tsne": TSNE(n_components=2),
            "isomap": Isomap(n_components=2),
            "lle": LocallyLinearEmbedding(n_components=2),
            "mds": MDS(n_components=2),
            "spectral": SpectralEmbedding(n_components=2),
            "pca": PCA(n_components=2),
            "umap": umap.UMAP(),
        }

    def reduce_dimensions(self, embeddings, method="tsne"):
        algorithm = self.algorithms.get(method, TSNE(n_components=2))
        return algorithm.fit_transform(embeddings)

### Extract face images and embeddings

In [55]:
import colorcet as cc

def hex_to_rgb(hex):
    hex = hex.replace("#", "")
    return tuple(int(hex[i:i+2], 16) for i in (0, 2, 4))

colors = cc.glasbey
embedding_dict = {}
images = []

for f in frame_dict.keys():
    video.set(cv2.CAP_PROP_POS_FRAMES, f)
    ret, img = video.read()
    imgcopy = img.copy()
    for frame_gorilla in frame_dict[f]:
        cropped_img = Image.fromarray(imgcopy).crop(frame_gorilla[0].bb[0] + frame_gorilla[0].bb[1])
        cropped_img = transform(cropped_img).unsqueeze(0)
        embedding = model(cropped_img).detach()
        try:
            embedding_dict[f].append((embedding, frame_gorilla[1]))
        except KeyError:
            embedding_dict[f] = [(embedding, frame_gorilla[1])]
        img = cv2.rectangle(img, frame_gorilla[0].bb[0], frame_gorilla[0].bb[1], hex_to_rgb(colors[frame_gorilla[1]]), 2)
    images.append(img)

### Plot each embedding with a slider

In [56]:
import matplotlib.pyplot as plt
import ipywidgets as widgets
import copy
from io import BytesIO

all_embeddings = []
for f in embedding_dict.keys():
    embeddings = [pair[0] for pair in embedding_dict[f]]
    all_embeddings.append(*embeddings)

low_dim_embeddings = EmbeddingProjector().reduce_dimensions(torch.cat(all_embeddings).detach().numpy(),
                                                            method="pca").tolist()
min_x = min([pair[0] for pair in low_dim_embeddings])
max_x = max([pair[0] for pair in low_dim_embeddings])
min_y = min([pair[1] for pair in low_dim_embeddings])
max_y = max([pair[1] for pair in low_dim_embeddings])

low_dim_embedding_dict = copy.deepcopy(embedding_dict)
for f in embedding_dict.keys():
    for i, pair in enumerate(embedding_dict[f]):
        low_dim_embedding_dict[f][i] = (low_dim_embeddings.pop(0), pair[1])

plot_list = []

for f in low_dim_embedding_dict.keys():
    plt.xlim(min_x - 1, max_x + 1)
    plt.ylim(min_y - 1, max_y + 1)
    plt.grid(True)
    for pair in low_dim_embedding_dict[f]:
        plt.plot(pair[0][0], pair[0][1], marker='+', linestyle='None', markersize=10, color=colors[pair[1]])

    plt.title(f"Frame {f}")
    buffer = BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    pil_image = Image.open(buffer)
    plot_list.append(pil_image)
    plt.close()

images_per_page = 2

def display_images(page):
    start = page
    print(start)
    fig, axs = plt.subplots(1, images_per_page, figsize=(15, 5))  # Create subplots
    for i in range(images_per_page):
        if start + i < len(images) + len(plot_list):
            if i == 0:
                axs[i].imshow(images[start])
            else:
                axs[i].imshow(plot_list[start])
            axs[i].axis('off')
        else:
            axs[i].axis('off')  # Hide axes for empty subplots
    plt.tight_layout()
    plt.show()


page_selector = widgets.IntSlider(min=0, max=(len(images) + len(plot_list) - 1) // images_per_page, description='Page:')
widgets.interact(display_images, page=page_selector)

TypeError: list.append() takes exactly one argument (2 given)

In [5]:
import matplotlib.pyplot as plt
import ipywidgets as widgets
from io import BytesIO

low_dim_embeddings = EmbeddingProjector().reduce_dimensions(torch.cat(embeddings).detach().numpy(),
                                                            method="pca")
low_dim_embeddings = low_dim_embeddings.tolist()

x_axis, y_axis = zip(*low_dim_embeddings)
plot_list = []
for embedding in low_dim_embeddings:
    plt.xlim(min(x_axis) - 1, max(x_axis) + 1)
    plt.ylim(min(y_axis) - 1, max(y_axis) + 1)
    plt.plot(embedding[0], embedding[1], marker='+', linestyle='None', markersize=10, color='blue')
    plt.grid(True)
    
    buffer = BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    pil_image = Image.open(buffer)
    plot_list.append(pil_image)
    plt.close()

images_per_page = 2

def display_images(page):
    start = page
    print(start)
    fig, axs = plt.subplots(1, images_per_page, figsize=(15, 5))  # Create subplots
    for i in range(images_per_page):
        if start + i < len(faces) + len(plot_list):
            if i == 0:
                axs[i].imshow(faces[start])
            else:
                axs[i].imshow(plot_list[start])
            axs[i].axis('off')
        else:
            axs[i].axis('off')  # Hide axes for empty subplots
    plt.tight_layout()
    plt.show()


page_selector = widgets.IntSlider(min=0, max=(len(faces) + len(plot_list) - 1) // images_per_page, description='Page:')
widgets.interact(display_images, page=page_selector)

interactive(children=(IntSlider(value=0, description='Page:', max=36), Output()), _dom_classes=('widget-intera…

<function __main__.display_images(page)>