In [1]:
#Importing packages
from transformers import ViTModel, ViTImageProcessor
import torch
from torch.utils.data import DataLoader
import gc

import chromadb

#Import library code
import dataloading
from model_functions import ViTEmbeddingNet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(1234)

<torch._C.Generator at 0x1caf6cee050>

In [3]:
labels_csv = "camera_data/coronado_hills_binary_10-24-2025.csv"
image_dir = "camera_data/images/"

data = dataloading.get_data(labels_csv=labels_csv, image_dir=image_dir)

In [4]:
train, val, test = dataloading.get_train_val_test(data = data, output_csvs=True)

train_dataset, val_dataset, test_dataset = dataloading.get_datasets(train, val, test)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, pin_memory=True)

In [5]:
# Importing the model and setting the device. Using a ViT model since transformer architecture is generally pretty powerful when it comes to extracting features from data.
model_name = "google/vit-base-patch16-224"
vit = ViTModel.from_pretrained(model_name, torch_dtype=torch.float32)
device = "cuda" if torch.cuda.is_available() else "cpu"

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
encoder = ViTEmbeddingNet(vit)
encoder.to(device);

In [7]:
encoder.load_state_dict(torch.load('weights/varying_margin/model_with_margin_0.2.pth', weights_only=True))
encoder.eval();

In [None]:
persist_directory = "embedding_data/" 
client = chromadb.PersistentClient(path=persist_directory)
collection = client.get_or_create_collection(name="10-27-25_model_embeddings")

In [10]:
for batch in val_dataloader:
    images = batch['pixel_values'].to(device)
    annotation_ids = batch['annotation_id']

    embedding = encoder(images)

    collection.add(
        embeddings=embedding.tolist(),
        ids = annotation_ids
    )

    del embedding
    del images
    del annotation_ids

    torch.cuda.empty_cache()
    gc.collect()



In [11]:
received_embeddings = collection.get(ids = val['annotation_id'].astype(str).tolist(), include = ['embeddings'])

In [12]:
received_embeddings['embeddings'].shape

(1498, 768)