In [None]:
!pip install model_functions

In [None]:
#Importing packages
from transformers import ViTModel, ViTImageProcessor
import torch
from torch.utils.data import DataLoader
import gc

import chromadb

#Import library code
from helper_code import dataloading 
from helper_code.model_functions import ViTEmbeddingNet

In [None]:
torch.manual_seed(1234)

In [None]:
labels_csv = "camera_data/coronado_hills_binary_10-24-2025.csv"
image_dir = "camera_data/images/"

data = dataloading.get_data(labels_csv=labels_csv, image_dir=image_dir)

In [None]:
train, val, test = dataloading.get_train_val_test(data = data, output_csvs=True)

train_dataset, val_dataset, test_dataset = dataloading.get_datasets(train, val, test)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, pin_memory=True)

In [None]:
# Importing the model and setting the device. Using a ViT model since transformer architecture is generally pretty powerful when it comes to extracting features from data.
model_name = "google/vit-base-patch16-224"
vit = ViTModel.from_pretrained(model_name, torch_dtype=torch.float32)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
encoder = ViTEmbeddingNet(vit)
encoder.to(device);

In [None]:
encoder.load_state_dict(torch.load('weights/model_weights_camera_10-27-25.pth', weights_only=True, map_location=device))
encoder.eval();

In [None]:
persist_directory = "embedding_data/" 
client = chromadb.PersistentClient(path=persist_directory)
collection = client.create_collection(name="10-27-25_model_embeddings")

In [None]:
for batch in val_dataloader:
    images = batch['pixel_values'].to(device)
    annotation_ids = batch['annotation_id']

    embedding = encoder(images)

    collection.add(
        embeddings=embedding.tolist(),
        ids = annotation_ids
    )

    del embedding
    del images
    del annotation_ids

    torch.cuda.empty_cache()
    gc.collect()



In [None]:
received_embeddings = collection.get(ids = val['annotation_id'].astype(str).tolist(), include = ['embeddings'])

In [None]:
received_embeddings['embeddings'].shape