### Load Model

In [14]:
import wandb
import torch
from gorillatracker.model import EfficientNetV2Wrapper
from torchvision.transforms import transforms

wandb.login()
wandb.init(mode="disabled")
api = wandb.Api()

artifact = api.artifact(
    "gorillas/Embedding-ALL-SPAC-Open/model-3ag1c2vf:v1",  # your artifact name
    type="model",
)
artifact_dir = artifact.download()
model = artifact_dir + "/model.ckpt"

# load model
checkpoint = torch.load(model, map_location=torch.device("cpu"))

model = EfficientNetV2Wrapper(  # switch this with the model you want to use
    model_name_or_path="EfficientNetV2_Large",
    from_scratch=False,
    loss_mode="softmax/arcface",
    weight_decay=0.001,
    lr_schedule="cosine",
    warmup_mode="cosine",
    warmup_epochs=10,
    max_epochs=100,
    initial_lr=0.01,
    start_lr=0.01,
    end_lr=0.0001,
    beta1=0.9,
    beta2=0.999,
    embedding_size=128,
)
# the following lines are necessary to load a model that was trained with arcface (the prototypes are saved in the state dict)
model.loss_module_train.prototypes = torch.nn.Parameter(checkpoint["state_dict"]["loss_module_train.prototypes"])
model.loss_module_val.prototypes = torch.nn.Parameter(checkpoint["state_dict"]["loss_module_val.prototypes"])

transform=transforms.Compose(  # use the transforms that were used for the model (except of course data augmentations)
        [
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]), # if your model was trained with normalization, you need to normalize the images here as well
        ]
    )

model.load_state_dict(checkpoint["state_dict"])
model.eval()

[34m[1mwandb[0m: Downloading large artifact model-3ag1c2vf:v1, 1346.85MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:2.0


EfficientNetV2Wrapper(
  (loss_module_train): ArcFaceLoss(
    (ce): CrossEntropyLoss()
  )
  (loss_module_val): ArcFaceLoss(
    (ce): CrossEntropyLoss()
  )
  (model): EfficientNet(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): FusedMBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): FusedMBConv(
          (block): Sequential(
            (0): Conv2d

In [None]:
import cv2
from gorillatracker.utils.video_models import VideoClip, _parse_tracked_video_clip
from datetime import datetime
from PIL import Image

json_path = "/workspaces/gorillatracker/data/derived_data/spac_gorillas_converted_labels_tracked/M002_20220529_031_tracked.json"
mp4_path = "/workspaces/gorillatracker/video_data" + json_path.split("spac_gorillas_converted_labels_tracked")[1].replace("_tracked.json", ".mp4")
print(mp4_path)
v = VideoClip(video_id="", camera_id="", start_time=datetime.now())
v = _parse_tracked_video_clip(v, json_path)
video = cv2.VideoCapture(mp4_path)

for i, _ in enumerate(v.trackings):
    gorilla = v.trackings[i]
    if len(gorilla.bounding_boxes_face) > 0:
        print(i)
        print(gorilla.bounding_boxes_face)

In [23]:
gorilla = v.trackings[0]

faces = []
embeddings = []
for frame in gorilla.bounding_boxes_face:
    video.set(cv2.CAP_PROP_POS_FRAMES, frame.f)
    ret, img = video.read()
    cropped_img = Image.fromarray(img).crop(frame.bb[0] + frame.bb[1])
    faces.append(cropped_img)
    img = transform(cropped_img)
    img = img.unsqueeze(0)
    embedding = model(img)
    embeddings.append(embedding)

/workspaces/gorillatracker/video_data/M002_20220529_031.mp4


0
[TrackedFrame(f=13, bb=((1755, 608), (1813, 689)), c=0.5903230905532837), TrackedFrame(f=14, bb=((1755, 606), (1814, 692)), c=0.6314746737480164), TrackedFrame(f=15, bb=((1766, 610), (1824, 694)), c=0.5695499181747437), TrackedFrame(f=16, bb=((1772, 610), (1827, 693)), c=0.7277531027793884), TrackedFrame(f=17, bb=((1775, 609), (1831, 696)), c=0.709379255771637), TrackedFrame(f=18, bb=((1778, 610), (1839, 695)), c=0.7195687890052795), TrackedFrame(f=19, bb=((1778, 609), (1848, 694)), c=0.6800578832626343), TrackedFrame(f=22, bb=((1793, 612), (1865, 695)), c=0.627274751663208), TrackedFrame(f=23, bb=((1800, 612), (1871, 696)), c=0.7736772298812866), TrackedFrame(f=24, bb=((1803, 613), (1874, 697)), c=0.7796608209609985), TrackedFrame(f=25, bb=((1805, 613), (1878, 694)), c=0.7743455171585083), TrackedFrame(f=26, bb=((1808, 617), (1884, 701)), c=0.7462272644042969), TrackedFrame(f=27, bb=((1813, 619), (1890, 699)), c=0.7425493597984314), TrackedFrame(f=28, bb=((1816, 619), (1891, 703)), 

In [16]:
from sklearn.manifold import Isomap, LocallyLinearEmbedding, MDS, SpectralEmbedding, TSNE
from sklearn.decomposition import PCA
import umap.umap_ as umap

class EmbeddingProjector:
    def __init__(self):
        self.algorithms = {
            "tsne": TSNE(n_components=2),
            "isomap": Isomap(n_components=2),
            "lle": LocallyLinearEmbedding(n_components=2),
            "mds": MDS(n_components=2),
            "spectral": SpectralEmbedding(n_components=2),
            "pca": PCA(n_components=2),
            "umap": umap.UMAP(),
        }

    def reduce_dimensions(self, embeddings, method="tsne"):
        algorithm = self.algorithms.get(method, TSNE(n_components=2))
        return algorithm.fit_transform(embeddings)

In [34]:
import matplotlib.pyplot as plt
import ipywidgets as widgets
from io import BytesIO

low_dim_embeddings = EmbeddingProjector().reduce_dimensions(torch.cat(embeddings).detach().numpy(), method="tsne")
low_dim_embeddings = low_dim_embeddings.tolist()

x_axis, y_axis = zip(*low_dim_embeddings)
plot_list = []
for embedding in low_dim_embeddings:
    plt.plot(embedding, marker='+', linestyle='None', markersize=10, color='blue')
    
    plt.xlim(min(x_axis) - 1, max(y_axis) + 1)
    plt.ylim(min(y_axis) - 1, max(y_axis) + 1)
    
    plt.grid(True)
    
    buffer = BytesIO()
    plt.savefig(buffer, format='png')
    buffer.seek(0)
    pil_image = Image.open(buffer)
    plot_list.append(pil_image)
    plt.close()

images_per_page = 2

def display_images(page):
    start = page
    print(start)
    fig, axs = plt.subplots(1, images_per_page, figsize=(15, 5))  # Create subplots
    for i in range(images_per_page):
        if start + i < len(faces) + len(plot_list):
            if i == 0:
                axs[i].imshow(faces[start])
            else:
                axs[i].imshow(plot_list[start])
            axs[i].axis('off')
        else:
            axs[i].axis('off')  # Hide axes for empty subplots
    plt.tight_layout()
    plt.show()


page_selector = widgets.IntSlider(min=0, max=(len(faces) + len(plot_list) - 1) // images_per_page, description='Page:')
widgets.interact(display_images, page=page_selector)

[[-4.594196796417236, 3.5996522903442383], [-4.751836776733398, 3.549981117248535], [-5.030726432800293, 3.63887095451355], [-5.187254905700684, 3.4718680381774902], [-5.154721736907959, 3.462038278579712], [-5.162607669830322, 3.5103681087493896], [-5.095883846282959, 3.5854573249816895], [-5.92625617980957, 3.237489938735962], [-5.905942916870117, 2.999976873397827], [-5.8959150314331055, 2.8569793701171875], [-5.839818000793457, 2.8563294410705566], [-6.041816711425781, 2.4889419078826904], [-5.800403594970703, 2.6677162647247314], [-5.743133068084717, 2.525880813598633], [-5.573644638061523, 2.4899508953094482], [-5.793516635894775, 1.9225950241088867], [-5.9566850662231445, 1.8136204481124878], [-5.867406845092773, 1.6885408163070679], [-5.864697456359863, 1.6284581422805786], [-5.764241695404053, 1.6217055320739746], [-5.668868064880371, 1.636839747428894], [-5.524453163146973, 1.3691431283950806], [-5.416896343231201, 1.231851577758789], [-5.39024019241333, 1.4424231052398682], 

interactive(children=(IntSlider(value=0, description='Page:', max=36), Output()), _dom_classes=('widget-intera…

<function __main__.display_images(page)>