In [None]:
import torch
torch.cuda.is_available()

In [None]:
from analysis_utils import get_threshold, resize_tiff_image, resize_hdf_image, compute_similarity_matrix, get_augmented_coordinates, display_tiff_image_grid, get_resized_volume, get_resized_ground_truth
from setup import neurotransmitters, model_size, device, feat_dim, resize_size, curated_idx, few_shot_transforms, model, embeddings_path, save_path
from setup import np, plt, torch, tqdm, sns, os, h5py, tifffile
from perso_utils import get_fnames, load_image
from DINOSim import DinoSim_pipeline
from napari_dinosim.utils import get_img_processing_f

In [None]:
few_shot = DinoSim_pipeline(model,
                            model.patch_size,
                            device,
                            get_img_processing_f(resize_size),
                            feat_dim, 
                            dino_image_size=resize_size
                            )

files, labels = zip(*get_fnames()) 


In [None]:
resize_factor = resize_size/130
resize = lambda x: resize_factor*x

emb_labels = np.hstack([[neuro]*feat_dim*600 for neuro in neurotransmitters]).reshape(-1) # FIXME: LOOKS WEIRD

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Extracting reference data

In [None]:
l = ['A','D','Ga','Glu','O','S']

indices = [
    [1,3,6,8,9],
    [1,2,4,5,6],
    [0,1,2,4,5],
    [1,2,3,6,7],
    [1,2,5,6,8],
    [0,6,10,11,14]
    ]

coords = [
    [(69,63.5),(68,61),(83,57),(76,62),(60,63)],
    [(66,62),(58.5,64),(64,60),(62.5,65),(64,71)],
    [(65,67),(72,60),(63,72),(60,67),(69,66.5)],
    [(65,66),(64,71),(62,58.5),(62,68),(69,55)],
    [(66,60),(60,70),(61,66.6),(58.5,63.5),(62.5,70.5)],
    [(63,73),(58,69),(60,69),(66,64),(62,71)]
    ]

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Compute Reference Embeddings

In [None]:
def compute_ref_embeddings(saved_ref_embeddings=False, 
                           embs_path=None, 
                           k=10,
                           data_aug=False):

    if saved_ref_embeddings:
        
        mean_ref = torch.load(embs_path, weights_only=False)
        return mean_ref

    else:

        if data_aug:    
            nb_transformations = len(few_shot_transforms)
            
            # Preload images and metadata once
            good_images = []
            transformed_coordinates = []

            for idx in curated_idx:
                img, coord_x, coord_y = load_image(files[idx])
                good_images.append(img.transpose(1,2,0))
                transformed_coordinates.append([(0, coord_x, coord_y)] * nb_transformations)

            transformed_images = []
            for image in good_images:
                transformed = [t(image).permute(1,2,0) for t in few_shot_transforms]
                transformed_images.extend(transformed)

            for j, img in enumerate(transformed_images):
                if img.shape != torch.Size([130, 130, 1]):
                    h, w = img.shape[:2]
                    h_diff = (130 - h) // 2
                    w_diff = (130 - w) // 2
                    padded_img = torch.zeros(130, 130, 1)
                    padded_img[h_diff:h+h_diff, w_diff:w+w_diff, :] = img
                    transformed_images[j] = padded_img
                    
            batch_size = int(len(curated_idx)/len(neurotransmitters)*nb_transformations) # nb of images in per class
            good_datasets = [transformed_images[i:i+batch_size] for i in range(0,len(transformed_images),batch_size)]
            good_datasets = np.array(good_datasets)
            
            transformed_coordinates = np.vstack(transformed_coordinates)
            good_coordinates = [transformed_coordinates[i:i+batch_size] for i in range(0,len(transformed_coordinates),batch_size)]

        else:
            ref_embs_list = []
            for i, index in tqdm(enumerate(indices)):
                dataset_slice = files[i*600:(i+1)*600]
                imgs = [resize_hdf_image(load_image(dataset_slice[k])[0]) for k in index]
                coordinates = [list(map(resize, c)) for c in coords[i]]
                dataset = list(zip(imgs, coordinates))
                class_wise_embs_list = []
                for image, reference in dataset:
                    few_shot.pre_compute_embeddings(
                        image[None,:,:,:],
                        verbose=False,
                        batch_size=1
                    )
                    few_shot.set_reference_vector(get_augmented_coordinates(reference), filter=None)
                    closest_embds = few_shot.get_k_closest_elements(k=k, return_indices=False) # list of vectors
                    class_wise_embs_list.append(torch.mean(closest_embds.cpu(), dim=0)) # list of lists of vectors
                ref_embs_list.append(class_wise_embs_list) # list of lists of lists of vectors

            ref_embs = np.array([np.mean(class_closest_embs, axis=0) for class_closest_embs in ref_embs_list])
            
            torch.save(ref_embs, os.path.join(embeddings_path, f'{model_size}_mean_ref_{resize_size}_Aug={data_aug}_k={k}'))
            
            return ref_embs

In [None]:
#mean_refs = compute_ref_embeddings(True, '/home/tomwelch/Cambridge/Embeddings/giant_mean_ref_518_Aug=False_k=10')
mean_refs = compute_ref_embeddings(False)

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Compute Datasetwide Embeddings


## w/o Generator

In [None]:
def compute_embeddings(saved_embeddings=False,
                       embs_path=None):

        if saved_embeddings:
                dataset_embeddings = torch.load(embs_path)

        else:
                files, _ = zip(*get_fnames()) 
                dataset_embeddings = []
                for file in files[:5]:
                        with h5py.File(file) as f:
                                volume = f['volumes/raw'][:][np.newaxis,:,:,:]
                                volume_embeddings = []
                                for z in range(volume.shape[-1]):
                                        resized_slice = resize_hdf_image(volume[:,:,:,z])
                                        few_shot.pre_compute_embeddings(
                                                resized_slice,
                                                batch_size=1
                                                )
                                        slice_embedding = few_shot.get_embeddings(reshape=False)
                                        volume_embeddings.append(slice_embedding)
                        dataset_embeddings.append(np.array(volume_embeddings))

                dataset_embeddings = np.array(dataset_embeddings)
                torch.save(dataset_embeddings, os.path.join(embeddings_path, f'{model_size}_dataset_embeddings_{resize_size}'))
        
        return dataset_embeddings

In [None]:
#dataset_embeddings = compute_embeddings()

In [None]:
def compute_predictions(embeddings: np.array, reference_embeddings, proportion):
    
    flattened_embeddings = embeddings.ravel()
    assert emb_labels.shape == flattened_embeddings.shape
    
    similarity_list = []
    for i, embedding in tqdm(enumerate(flattened_embeddings)):
        neuro_class = emb_labels[i]
        reference_index = neurotransmitters.index(neuro_class)
        reference_vector = reference_embeddings[reference_index]
        cosine_similarity = np.dot(reference_vector, embedding)/(np.linalg.norm(reference_vector)*np.linalg.norm(embedding))
        similarity_list.append(cosine_similarity)
    
    similarity_tensor = np.array(similarity_list).reshape(embeddings.shape)
    
    pred_tensor = []
    for volume in similarity_tensor:
        threshold = get_threshold(volume, proportion)
        pred = volume > threshold
        pred = pred.astype(np.uint8)
        pred_tensor.append(pred)
    
    pred_tensor
    
    for embedding in embeddings:
        flattened_volume = embedding.reshape(-1)
        similarity_list = [np.dot(reference, patch)/(np.linalg.norm(reference)*np.linalg.norm(patch)) for patch in flattened_volume]
        similarity_matrix = np.array(similarity_list).reshape(embedding.shape)
        threshold = get_threshold(similarity_matrix, proportion)
        pred = similarity_matrix > threshold
        pred = pred.astype(np.uint8)
        pred_tensor.append(pred)
    return pred_tensor

## w/ Generator

In [None]:
def get_prediction(saved_embeddings=False,
                    embs_path=None,
                    reference_embeddings=None,
                    proportion=0.99):

        if saved_embeddings:
                dataset_embeddings = torch.load(embs_path)

        else:

                files, _ = zip(*get_fnames()) 
                for n, file in enumerate(files):
                        with h5py.File(file) as f:
                                volume = f['volumes/raw'][:]
                                '''
                                volume_embeddings = []
                                for z in tqdm(range(volume.shape[-1])):
                                        resized_slice = resize_hdf_image(volume[None,:,:,z])
                                        
                                        few_shot.pre_compute_embeddings(
                                                resized_slice[None,...],
                                                batch_size=1,
                                                verbose=False
                                                )
                                        slice_embedding = few_shot.get_embeddings(reshape=False)
                                        volume_embeddings.append(slice_embedding)
                                '''
                                few_shot.pre_compute_embeddings(
                                                volume[None,...].transpose(3,1,2,0),
                                                batch_size=65,
                                                verbose=True
                                                )
                                volume_embeddings = few_shot.get_embeddings(reshape=False)
                                volume_embeddings = np.array(volume_embeddings.cpu())

                                flattened_volume_embeddings = volume_embeddings.reshape(-1, feat_dim)
                                path = os.path.normpath(file)
                                neuro_class = path.split(os.sep)[-2] 

                                reference_index = neurotransmitters.index(neuro_class)
                                reference_vector = reference_embeddings[reference_index]
                                cosine_similarities = [np.dot(reference_vector, embedding)/(np.linalg.norm(reference_vector)*np.linalg.norm(embedding)) for embedding in flattened_volume_embeddings]
                                similarity_volume = np.array(cosine_similarities).reshape(volume_embeddings.shape[:3])

                                threshold = get_threshold(similarity_volume, proportion)
                                pred = similarity_volume > threshold
                                pred = pred.astype(np.uint8)
                                
                                number = n % 600
                                
                                torch.cuda.empty_cache()

                                
                                yield pred, volume, neuro_class, number

In [None]:
prediction_generator = get_prediction(reference_embeddings=mean_refs)

In [None]:
for _ in tqdm(range(len(files))):
    prediction, volume, neuro_class, number = next(prediction_generator)
    np.save(file=os.path.join(save_path, 'Predictions', f'{neuro_class}_prediction_{number:04d}.npy'), arr=prediction) # FIXME: Check axis 
    tifffile.imwrite(os.path.join(save_path, 'Volumes', f'{neuro_class}_volume_{number:04d}.tiff'), volume.transpose(2,0,1))

In [None]:
preds= np.load('/home/tomwelch/Cambridge/Output/Predictions/acetylcholine_prediction_0000.npy').transpose(0,2,1) # -> zxy to zyx
pred = preds[75]
from tifffile import imread
gt = imread('/home/tomwelch/Cambridge/Output/Volumes/acetylcholine_volume_0000.tiff')[75]

h, w = pred.shape

fig, ax = plt.subplots(figsize=(5, 5), dpi=150)

# Show the background image
ax.imshow(gt, cmap='gray', extent=[0, h, w, 0])

# Overlay the heatmap
sns.heatmap(
    pred,
    cmap='coolwarm',
    alpha=0.5,             # Make the heatmap semi-transparent
    ax=ax,
    cbar=True,
    xticklabels=False,
    yticklabels=False
)

plt.title("Patch Similarities")
plt.tight_layout()
plt.show()

In [None]:
volume.shape # yxz -> zyx

In [None]:
import tifffile

tifffile.imwrite('test.tif', volume.transpose(2,0,1))

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
from tifffile import imread
img = imread('/Users/tomw/Documents/MVA/Internship/Cambridge/Datasets/EM_Data_copy/original/test/x/frarobo_conn_18925456_images.tif')[6]
resized_image = resize_tiff_image(img)

resize_factor = np.multiply([resize_size, resize_size], 1/np.array(img.shape))
coordinates = (resize_factor[0] * 153, resize_factor[1] * 131.5)
aug_coordinates = get_augmented_coordinates(coordinates)

few_shot.pre_compute_embeddings(
    dataset=resized_image[None,:,:,:], 
    batch_size=1
    )
embeddings = few_shot.get_embeddings(reshape=False)

few_shot.set_reference_vector(aug_coordinates)
reference = few_shot.get_reference_embedding()

# Generate ground-truth

In [None]:
import os
path = '/Users/tomw/Documents/MVA/Internship/Cambridge/Datasets/EM_Data_copy/original/test'
x = sorted(os.listdir(os.path.join(path, 'x')), key=os.path.basename)
y = sorted(os.listdir(os.path.join(path, 'y')), key=os.path.basename)

In [None]:
from tifffile import imread
test_volume = get_resized_volume('/Users/tomw/Documents/MVA/Internship/Cambridge/Datasets/EM_Data_copy/original/test/x/frarobo_conn_18925456_images.tif')

nb_slices = test_volume.shape[0]

few_shot.pre_compute_embeddings(
    dataset=test_volume, 
    batch_size=nb_slices
    )
new_embeddings = few_shot.get_embeddings(reshape=False)

similarity_tensor = []
for i in tqdm(range(nb_slices)):
    similarity = compute_similarity_matrix(reference, new_embeddings[i])
    similarity_tensor.append(similarity)
    
similarity_tensor = np.array(similarity_tensor)

threshold = get_threshold(similarity_tensor[6], 0.95)
pred = similarity_tensor > threshold
pred = pred.astype(np.uint8)

In [None]:
gt = get_resized_ground_truth('/Users/tomw/Documents/MVA/Internship/Cambridge/Datasets/EM_Data_copy/original/test/y/frarobo_conn_18925456_masks.tif').squeeze()

In [None]:
from sklearn.metrics import jaccard_score

score = jaccard_score(gt.ravel(), pred.ravel())
print(score)