In [None]:
from setup import neurotransmitters, model, model_size, device, feat_dim, resize_size, dataset_path, curated_idx, few_shot_transforms
from setup import tqdm, torch, np, os, h5py, sns, plt, tqdm, Trans, Image, chain
from setup import cosine_similarity, euclidean_distances
from perso_utils import get_fnames, load_image, get_latents
from DINOSim import DinoSim_pipeline, diplay_features
from napari_dinosim.utils import get_img_processing_f

# Create an instance of the pipeline (not just assigning the class)
few_shot = DinoSim_pipeline(model,
                            model.patch_size,
                            device,
                            get_img_processing_f(resize_size),
                            feat_dim, 
                            dino_image_size=resize_size
                            )

In [None]:
k = 10
data_aug = True

files, labels = zip(*get_fnames()) 


if data_aug:    
    nb_transformations = len(few_shot_transforms)
    nb_trans_images = nb_transformations * len(curated_idx)

    good_datasets = []
    good_images = np.vstack([[load_image(files[idx])[0]]*nb_transformations for idx in curated_idx]).transpose(0,2,3,1)
    good_labels = np.hstack([[labels[idx]]*nb_transformations for idx in curated_idx])
    coordinates = np.vstack([[(0, load_image(files[idx])[1], load_image(files[idx])[2])]*nb_transformations for idx in curated_idx])
    batch = []
    for k in range(0, nb_trans_images, nb_transformations):
        for image in good_images[k:k+nb_transformations]:
            batch.append([few_shot_transforms[j](image) for j in range(nb_transformations)])
            if len(batch) == 10:
                good_datasets.append(list(chain.from_iterable(batch)))
                batch = []        

else:
    good_datasets = np.array([load_image(files[idx])[0] for idx in curated_idx]).transpose(0,2,3,1)
    good_labels = np.array([labels[idx] for idx in curated_idx])
    good_coordinates = np.array([(0, load_image(files[idx])[1], load_image(files[idx])[2]) for idx in curated_idx])


latent_list, label_list = [], []
for dataset, batch_label in tqdm(zip(good_datasets, good_labels), desc='Iterating through neurotransmitters'):

    # Pre-compute embeddings
    few_shot.pre_compute_embeddings(
        dataset,  # Pass numpy array of images
        overlap=(0.5, 0.5),
        padding=(0, 0),
        crop_shape=(140, 140, 1),
        verbose=True,
        batch_size=10
    )
    
    # Set reference vectors
    few_shot.set_reference_vector(good_coordinates, filter=None)
    ref = few_shot.get_refs()
    
    # Get closest elements - using the correct method name
    
    close_embedding =  few_shot.get_k_closest_elements(k=k)
    #k_labels = [l for _ in range(k)]
    k_labels =  [batch_label[0] for _ in range(k)]

    
    # Convert to numpy for storing
    close_embedding_np = close_embedding.cpu().numpy() if isinstance(close_embedding, torch.Tensor) else close_embedding
    
    latent_list.append(close_embedding_np)
    label_list.append(k_labels)
    
    # Clean up to free memory
    few_shot.delete_precomputed_embeddings()
    few_shot.delete_references()
    
mean_ref = torch.from_numpy(np.vstack([np.mean(list, axis=0) for list in latent_list]))
# Stack all embeddings and labels
ref_latents = np.vstack(latent_list)
ref_labels = np.hstack(label_list)