# Imports

In [None]:
from setup import neurotransmitters, model_size, device, feat_dim, resize_size, curated_idx, few_shot_transforms,  embeddings_path, model
from setup import tqdm, torch, np, os, plt, tqdm, gc, sample
from analysis_utils import display_hdf_image_grid, resize_hdf_image, get_augmented_coordinates
from setup import cosine_similarity, euclidean_distances
from perso_utils import get_fnames, load_image, get_latents
from DINOSim import DinoSim_pipeline
from napari_dinosim.utils import get_img_processing_f

### Importing model

In [None]:
few_shot = DinoSim_pipeline(model,
                            model.patch_size,
                            device,
                            get_img_processing_f(resize_size),
                            feat_dim, 
                            dino_image_size=resize_size
                            )

files, labels = zip(*get_fnames()) 

### Resize coordinates

In [None]:
resize_factor = resize_size/130
resize = lambda x: resize_factor*x

from perso_utils import filter_f
filter = filter_f

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Extracting reference data

In [None]:
l = ['A','D','Ga','Glu','O','S']

In [None]:
indices = [
    [1,3,6,8,9],
    [1,2,4,5,6],
    [0,1,2,4,5],
    [1,2,3,6,7],
    [1,2,5,6,8],
    [0,6,10,11,14]
    ]

In [None]:
coords = [
    [(69,63.5),(68,61),(83,57),(76,62),(60,63)],
    [(66,62),(58.5,64),(64,60),(62.5,65),(64,71)],
    [(65,67),(72,60),(63,72),(60,67),(69,66.5)],
    [(65,66),(64,71),(62,58.5),(62,68),(69,55)],
    [(66,60),(60,70),(61,66.6),(58.5,63.5),(62.5,70.5)],
    [(63,73),(58,69),(60,69),(66,64),(62,71)]
    ]

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Compute Reference Embeddings

In [None]:
def compute_ref_embeddings(saved_ref_embeddings=False, 
                           embs_path=None, 
                           k=10,
                           data_aug=False):

    if saved_ref_embeddings:
        
        mean_ref = torch.load(embs_path, weights_only=False)
        return mean_ref

    else:

        if data_aug:    
            nb_transformations = len(few_shot_transforms)
            
            # Preload images and metadata once
            good_images = []
            transformed_coordinates = []

            for idx in curated_idx:
                img, coord_x, coord_y = load_image(files[idx])
                good_images.append(img.transpose(1,2,0))
                transformed_coordinates.append([(0, coord_x, coord_y)] * nb_transformations)

            transformed_images = []
            for image in good_images:
                transformed = [t(image).permute(1,2,0) for t in few_shot_transforms]
                transformed_images.extend(transformed)

            for j, img in enumerate(transformed_images):
                if img.shape != torch.Size([130, 130, 1]):
                    h, w = img.shape[:2]
                    h_diff = (130 - h) // 2
                    w_diff = (130 - w) // 2
                    padded_img = torch.zeros(130, 130, 1)
                    padded_img[h_diff:h+h_diff, w_diff:w+w_diff, :] = img
                    transformed_images[j] = padded_img
                    
            batch_size = int(len(curated_idx)/len(neurotransmitters)*nb_transformations) # nb of images in per class
            good_datasets = [transformed_images[i:i+batch_size] for i in range(0,len(transformed_images),batch_size)]
            good_datasets = np.array(good_datasets)
            
            transformed_coordinates = np.vstack(transformed_coordinates)
            good_coordinates = [transformed_coordinates[i:i+batch_size] for i in range(0,len(transformed_coordinates),batch_size)]

        else:
            ref_embs_list = []
            for i, index in tqdm(enumerate(indices)):
                dataset_slice = files[i*600:(i+1)*600]
                imgs = [resize_hdf_image(load_image(dataset_slice[k])[0]) for k in index]
                coordinates = [list(map(resize, c)) for c in coords[i]]
                dataset = list(zip(imgs, coordinates))
                class_wise_embs_list = []
                for image, reference in dataset:
                    few_shot.pre_compute_embeddings(
                        image[None,:,:,:],
                        verbose=False,
                        batch_size=1
                    )
                    few_shot.set_reference_vector(get_augmented_coordinates(reference), filter=None)
                    closest_embds = few_shot.get_k_closest_elements(k=k, return_indices=False) # list of vectors
                    class_wise_embs_list.append(torch.mean(closest_embds.cpu(), dim=0)) # list of lists of vectors
                ref_embs_list.append(class_wise_embs_list) # list of lists of lists of vectors

            ref_embs = np.array([np.mean(class_closest_embs, axis=0) for class_closest_embs in ref_embs_list])
            
            torch.save(ref_embs, os.path.join(embeddings_path, f'{model_size}_mean_ref_{resize_size}_Aug={data_aug}_k={k}'))
            
            return ref_embs

In [None]:
#mean_refs = compute_ref_embeddings(True, os.path.join(embeddings_path, 'giant_mean_ref_518_Aug=False_k=10'))
#mean_refs = compute_ref_embeddings(False)

In [None]:
'''
def compute_ref_embeddings(saved_ref_embeddings=False, 
                           embs_path=None, 
                           k=10,
                           data_aug=True):

    if saved_ref_embeddings:
        
        mean_ref = torch.load(embs_path)

    else:

        if data_aug:    
            nb_transformations = len(few_shot_transforms)
            
            # Preload images and metadata once
            good_images = []
            transformed_coordinates = []

            for idx in curated_idx:
                img, coord_x, coord_y = load_image(files[idx])
                good_images.append(img.transpose(1,2,0))
                transformed_coordinates.append([(0, coord_x, coord_y)] * nb_transformations)

            transformed_images = []
            for image in good_images:
                transformed = [t(image).permute(1,2,0) for t in few_shot_transforms]
                transformed_images.extend(transformed)

            for j, img in enumerate(transformed_images):
                if img.shape != torch.Size([130, 130, 1]):
                    h, w = img.shape[:2]
                    h_diff = (130 - h) // 2
                    w_diff = (130 - w) // 2
                    padded_img = torch.zeros(130, 130, 1)
                    padded_img[h_diff:h+h_diff, w_diff:w+w_diff, :] = img
                    transformed_images[j] = padded_img
                    
            batch_size = int(len(curated_idx)/len(neurotransmitters)*nb_transformations) # nb of images in per class
            good_datasets = [transformed_images[i:i+batch_size] for i in range(0,len(transformed_images),batch_size)]
            good_datasets = np.array(good_datasets)
            
            transformed_coordinates = np.vstack(transformed_coordinates)
            good_coordinates = [transformed_coordinates[i:i+batch_size] for i in range(0,len(transformed_coordinates),batch_size)]

        else:

            imgs_coords = [load_image(files[idx]) for idx in curated_idx]
            imgs, xs, ys = zip(*imgs_coords)

            batch_size = int(len(curated_idx)/len(neurotransmitters))
            imgs = [imgs[i:i+batch_size] for i in range(0,len(imgs),batch_size)]
            good_datasets = np.array(imgs).transpose(0,1,3,4,2)
            
            good_coordinates = [(0, x, y) for x, y in zip(xs, ys)]
            good_coordinates = [good_coordinates[i:i+batch_size] for i in range(0,len(good_coordinates),batch_size)]
            good_coordinates = np.array(good_coordinates)


        unfiltered_ref_latents_list, filtered_latent_list, filtered_label_list = [], [], []
        for dataset, batch_label, coordinates in tqdm(zip(good_datasets, neurotransmitters, good_coordinates), desc='Iterating through neurotransmitters'):
            
            # Pre-compute embeddings
            few_shot.pre_compute_embeddings(
                dataset,  # Pass numpy array of images
                overlap=(0.5, 0.5),
                padding=(0, 0),
                crop_shape=(518, 518, 1),
                verbose=True,
                batch_size=10
            )
            
            # Set reference vectors
            few_shot.set_reference_vector(coordinates, filter=None)
            ref = few_shot.get_refs()
            
            # Get closest elements - using the correct method name
            close_embedding =  few_shot.get_k_closest_elements(k=k)
            k_labels =  [batch_label for _ in range(k)]

            
            # Convert to numpy for storing
            close_embedding_np = close_embedding.cpu().numpy() if isinstance(close_embedding, torch.Tensor) else close_embedding
            
            filtered_latent_list.append(close_embedding_np)
            filtered_label_list.append(k_labels)
            
            # Clean up to free memory
            few_shot.delete_precomputed_embeddings()
            few_shot.delete_references()

        mean_ref = torch.from_numpy(np.vstack([np.mean(l, axis=0) for l in filtered_latent_list]))
        # Stack all embeddings and labels
        ref_latents = np.vstack(filtered_latent_list)
        ref_labels = np.hstack(filtered_label_list)
        
        torch.save(mean_ref, os.path.join(dataset_path, f'{model_size}_mean_ref_{resize_size}_Aug={data_aug}_k={k}'))
        torch.save(ref_latents, os.path.join(dataset_path, f'{model_size}_ref_latents_{resize_size}_Aug={data_aug}_k={k}'))
        torch.save(ref_labels, os.path.join(dataset_path, f'{model_size}_ref_labels_{resize_size}_Aug={data_aug}_k={k}'))
'''

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Generate Ground-truth

In [None]:
#one_hot_neurotransmitters = np.eye(len(neurotransmitters))
#emb_labels = np.hstack([[neuro]*int((resize_size/14)**2 * 600) for neuro in neurotransmitters]).reshape(-1,1)

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Compute Datasetwide Embeddings

In [None]:
import gc

#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
#files, _ = zip(*get_fnames()) 
#images = np.array([resize_hdf_image(load_image(file)[0]) for file in tqdm(files, desc='Loading images')])

In [None]:
def embedding_generator(batch_size=50): # TODO: if other batch_size, change in get_d_closed_elements!!!!!
        
        for i in range(0,len(images), batch_size):
                batch = images[i:i+batch_size]
                few_shot.pre_compute_embeddings(
                        batch,
                        batch_size=1, 
                        verbose=False
                        )
                torch.cuda.empty_cache()
                gc.collect()
                del batch
                yield few_shot.get_embeddings(reshape=False).cpu()
                
def compute_embeddings(batch_size=50):
        g = embedding_generator()        

        embeddings_list = []

        length = len(images)//batch_size
        for _ in tqdm(range(length)): embeddings_list.append(next(g))
        new_embeddings = torch.cat(embeddings_list)
        torch.save(embeddings_list, os.path.join(embeddings_path, f'{model_size}_dataset_embs_{resize_size}.pt'))
        return embeddings_list

In [None]:
#new_embeddings = compute_embeddings()
#new_embeddings = torch.load(os.path.join(embeddings_path, 'small_dataset_embs_518.pt')) # takes ~ 45 s

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Compute class-wise accuracies

In [None]:
'''
from analysis_utils import get_threshold

def compute_accuracies(reference_embeddings = mean_refs, 
                       embeddings = new_embeddings,
                       metric = euclidean_distances,
                       distance_threshold = 0.01
                       ):

    batch_size = int(len(embeddings)/6)

    for n, i in tqdm(enumerate(range(0, len(embeddings), batch_size))):
        batch = embeddings[i:i+batch_size]

        #embeddings = embeddings.reshape(-1, feat_dim)
        similarity_matrix = metric(reference_embeddings, batch)
        similarity_matrix_normalized = (similarity_matrix - np.min(similarity_matrix)) / (np.max(similarity_matrix) - np.min(similarity_matrix))
        threshold = get_threshold(similarity_matrix_normalized, 0.9)
        similarity_matrix_normalized_filtered = np.where(similarity_matrix_normalized <= threshold, similarity_matrix_normalized, 0)

        batch_score_list = []
        for k in range(batch_size):

            column = similarity_matrix_normalized_filtered[:,k]
            j=0
            if sum(column) == 0:
                j+=1
            else:
                patch_wise_distances_filtered = np.where(column == 0, 1, column)
                output_class = one_hot_neurotransmitters[np.argmin(patch_wise_distances_filtered)]
                gt_index = n
                ground_truth = one_hot_neurotransmitters[gt_index]
                score = np.sum(output_class*ground_truth)
                batch_score_list.append(score)
                
        yield batch_score_list

g = compute_accuracies()
score_list = []

for _ in range(6): score_list.append(next(g))

accuracies = [np.mean(scores)*100 for scores in score_list]
#print(f'{j} embeddings did not pass the threshold')
#return accuracies
'''

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

# Plot results

In [None]:
'''
plt.figure(figsize=(12,7), dpi=300)
plt.bar(neurotransmitters, accuracies)
plt.xlabel('Classes')
plt.ylabel('Mean hard accuracy')
#plt.title(f'Mean hard accuracies across classes - {model_size} DINOv2 - 140x140 images - Threshold = {distance_threshold} - Data augmentation: {data_aug}')
plt.axhline(np.mean(accuracies), color='r', linestyle='--', label='Average')
plt.axhline(y=(100/6), color='b', linestyle='--', label='Randomness')
plt.legend()
ax = plt.gca()
ax.set_ylim([0,110])
plt.show()
'''

=============================================================================================================================================================================================================================

# Custom Dataloader

In [None]:
DATA = torch.load(os.path.join(embeddings_path, 'small_dataset_embs_518.pt'))
print('Done loading embeddings')

#TODO:
filtering = False

if filtering:

    LABELS = np.hstack([[neuro]*600 for neuro in neurotransmitters]).reshape(-1,1)

    REFS = compute_ref_embeddings(True, os.path.join(embeddings_path, 'small_mean_ref_518_Aug=False_k=10.pt'))

    DATASET = few_shot.get_d_closest_elements(embeddings = DATA, 
                                            reference_emb = torch.from_numpy(REFS))

else:
    
    LABELS = np.hstack([[neuro]*int((resize_size/14)**2 * 600) for neuro in neurotransmitters]).reshape(-1, 1)
    
    DATA = torch.cat(DATA)
    DATA = DATA.reshape(-1, feat_dim)
    
    DATASET = list(zip(DATA, LABELS))

DATASET = sample(DATASET, len(DATASET))

SPLIT = int(len(DATASET)*0.2)
TRAINING_SET = DATASET[SPLIT:]
TEST_SET = DATASET[:SPLIT]

one_hot_neurotransmitters = np.eye(len(neurotransmitters))

In [None]:
from torch.utils.data import Dataset
import torch.utils.data as utils

class Custom_LP_Dataset(Dataset):
    def __init__(self, 
                 set):
        if set == 'training':
            self.data = TRAINING_SET
        else:
            self.data = TEST_SET

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        embedding, label = self.data[idx]
        label_idx = neurotransmitters.index(label[0])
        return embedding, one_hot_neurotransmitters[label_idx]

In [None]:
train_batch_size, test_batch_size = 50, 50

training_dataset = Custom_LP_Dataset('training') 
test_dataset = Custom_LP_Dataset('test')

training_loader = utils.DataLoader(training_dataset, batch_size=train_batch_size, shuffle=True, pin_memory=True)
test_loader = utils.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, pin_memory=True)

=============================================================================================================================================================================================================================

# Define MLP Head

In [None]:
from torch import nn 

def init_model(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

class MLP_Head(nn.Module):
    def __init__(self, device, feat_dim):
        super().__init__()
        self.device = device
        self.nb_outputs = 6
        self.feat_dim = feat_dim
        self.hidden_dims = self.feat_dim*np.array([3/4, 1/2, 1/4])
        self.hidden_dims = self.hidden_dims.astype(int)

        self.stack = nn.Sequential(nn.Linear(self.feat_dim,
                                             self.hidden_dims[0]),
                                   nn.ReLU(),
                                   nn.Linear(self.hidden_dims[0],
                                             self.hidden_dims[1]),
                                   nn.ReLU(),
                                   nn.Linear(self.hidden_dims[1],
                                             self.hidden_dims[2]),
                                   nn.ReLU(),
                                   nn.Linear(self.hidden_dims[2],
                                             self.nb_outputs),
                                   nn.Sigmoid())

        self.apply(init_model)
        self.to(self.device)

    def forward(self, x):
        return self.stack(x)

## Call MLP Head

In [None]:
head = MLP_Head(device=device, feat_dim=feat_dim)
head.to(device)
optimizer = torch.optim.Adam(head.parameters(), lr=3e-4)
loss_fn = nn.BCELoss()

# Plot MLP Head

In [None]:
import torchvision
from torchview import draw_graph

model_graph = draw_graph(head, input_size=(1,feat_dim), expand_nested=True)
model_graph.visual_graph

# Training

In [None]:
def use(epochs):
    head.train()
    loss_list = []
    prediction_list = []
    test_accuracies = []
    for _ in tqdm(range(epochs), desc=f'Epoch:'):
        epoch_loss_list = []
        proportion_list = []
        for embeddings, one_hot_gts in tqdm(training_loader, desc='Training', leave=False):
            embeddings = embeddings.to(device)
            output = head(embeddings).to(torch.float64)
            
            gt = one_hot_gts
            gt = gt.to(device)
            loss=0
            for out, true in zip(output,gt):
                loss += loss_fn(out,true)
                
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            epoch_loss_list.append(loss.detach().cpu().numpy())
            
            proportion_list.append(one_hot_gts)
                
        loss_list.append(np.mean(epoch_loss_list))

        head.eval()
        with torch.no_grad():
            score = 0
            total = 0
            for embeddings, one_hot_gts in tqdm(test_loader, desc='Testing', leave=False):
                embeddings = embeddings.to(device)
                outputs = head(embeddings) # shape (batch_size, nb_classes)
                
                for output, one_hot_gt in zip(outputs, one_hot_gts):
                    predicted_idx = torch.argmax(output).item()
                    true_idx = torch.argmax(one_hot_gt).item()
                    prediction_list.append([predicted_idx, true_idx])
                    
                    if predicted_idx == true_idx:
                        score += 1
                    total += 1
                batch_score = 100*score/total
                test_accuracies.append(batch_score)

    return loss_list, proportion_list, prediction_list, test_accuracies

In [None]:
epochs = 1
loss_list, proportion_list, prediction_list, test_accuracies = use(epochs) # (pred, truth)

# Class proportions during training:

In [None]:
gts = []
for e in proportion_list: gts.extend(e)
gts = np.array(gts)

vectors, counts = np.unique(gts, axis=0, return_counts=True)
positions = [np.where(np.all(one_hot_neurotransmitters == v, axis=1)) for v in vectors]

proportions = np.zeros((len(neurotransmitters),1))
for count, position in zip(counts, positions): proportions[position] = count

proportions = 100*proportions/int(np.sum(proportions))

fig, ax = plt.subplots(figsize=(10,4), dpi=150)

img = ax.imshow(proportions.T, cmap='RdYlGn')

for k, prop in enumerate(proportions):
    text = ax.text(x=k, y=0, s=f'{round(proportions[k].item(), ndigits=2)}%\n({int(counts[k])})', ha="center", va="center", color="black")

ax.set_xticks(range(len(neurotransmitters)), labels=neurotransmitters, rotation=-45, ha="right", rotation_mode="anchor")
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

ax.text(5.7, 0, f'({int(np.sum(counts))})', va='center', ha='left', color='black')

ax.set_title('Class proportions')
ax.set_yticklabels([])

fig.colorbar(img, ax=ax, orientation='horizontal', label='Proportion')

plt.show()

# Confusion matrix:

In [None]:
confusion_matrix = np.zeros((len(neurotransmitters), len(neurotransmitters)))
for pred in prediction_list:
    truth = pred[1]
    prediction = pred[0]
    confusion_matrix[truth, prediction] += 1
    
initial_confusion_matrix = confusion_matrix.copy()
    
total_list = []
for row in confusion_matrix:
    total = sum(row)
    total_list.append(row)
    row /= total
confusion_matrix=100*confusion_matrix

fig, ax = plt.subplots(figsize=(7,7), dpi=150)

import seaborn as sns
im = ax.imshow(confusion_matrix, cmap='YlGn')

ax.set_yticks(range(len(neurotransmitters)), labels=neurotransmitters)
ax.set_xticks(range(len(neurotransmitters)), labels=neurotransmitters, rotation=-45, ha="right", rotation_mode="anchor")
ax.tick_params(top=True, bottom=False,
                   labeltop=True, labelbottom=False)
for i in range(len(neurotransmitters)):
    for j in range(len(neurotransmitters)):
        text = ax.text(j, i, s=f'{round(confusion_matrix[i, j], ndigits=2)}%\n({int(initial_confusion_matrix[i,j])})', ha="center", va="center", color="black")

for i, row in enumerate(initial_confusion_matrix):
    ax.text(5.75, i, f'({int(sum(row))})',
            va='center', ha='left', color='black')
    
for j, row in enumerate(initial_confusion_matrix.T):
    ax.text(j, 5.75, f'({int(sum(row))})',
            va='center', ha='center', color='black')

ax.text(5.75, 5.75, f'({int(np.sum(initial_confusion_matrix))})',
            va='center', ha='left', color='black')

fig.tight_layout()
ax.set_title(f'Confusion matrix for filtered class-wise predictions - 20% Dataset - Epochs={epochs}')

#fig.colorbar(im, ax=ax, orientation='vertical', label='Accuracy')

plt.show()

# General results:

In [None]:
x = [i for i in range(epochs)]
fig, ax1 = plt.subplots(figsize=(5,5), dpi=150)
ax2 = ax1.twinx()
lns1 = ax1.plot(x, loss_list, label='Train loss')
ax1.set_ylim(0,max(loss_list)*1.05)
lns2 = ax2.plot(x, test_accuracies, label='Test accuracy', color='red')
ax2.set_ylim(0,105)

lns = lns1 + lns2
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc=0)

plt.show()

=============================================================================================================================================================================================================================

# Freeze MLP Head:

In [None]:
for param in head.parameters():
    param.requires_grad = False

# Load DINOv2 Model

In [None]:
model = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_size[0]}14_reg')
model.to(device)
model.eval()

# Freeze weights

In [None]:
for param in model.parameters():
    param.requires_grad = False

# Plot DINOv2 Model

In [None]:
model_graph = draw_graph(model, input_size=(1,3,518,518), expand_nested=True, graph_dir='TB', strict=False, depth=10)
model_graph.visual_graph

# Define adapter

In [None]:
from torch import nn

class AdaptMLP(nn.Module):
    def __init__(self, device, original_mlp, in_dim, mid_dim, dropout=0.0, s=0.1):
        super().__init__()
        
        self.device = device
        self.original_mlp = original_mlp # original MLP block
        
        # down --> non linear --> up
        self.down_proj = nn.Linear(in_dim, mid_dim)
        self.act = nn.ReLU()
        self.up_proj = nn.Linear(mid_dim, in_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = s # scaling factor
        
        # initialization
        nn.init.kaiming_uniform_(self.down_proj.weight)
        nn.init.zeros_(self.up_proj.weight)
        nn.init.zeros_(self.down_proj.bias)
        nn.init.zeros_(self.up_proj.bias)
        
        # freeze original MLP
        for p in self.original_mlp.parameters():
            p.requires_grad = False
        
        self.to(self.device)

    def forward(self, x):

        down = self.down_proj(x)
        down = self.act(down)
        down = self.dropout(down)
        up = self.up_proj(down)

        output = self.original_mlp(x) + up * self.scale

        return output

# Change MLP blocks

In [None]:
for k in range(len(list(model.blocks))):

    mlp = nn.Sequential(model.blocks[k].norm2,
                        model.blocks[k].mlp,
                        model.blocks[k].ls2)
    in_dim = model.blocks[k].norm2.normalized_shape[0]
    mid_dim = int(model.blocks[k].norm2.normalized_shape[0]/4)
    
    adapter = AdaptMLP(device, mlp, in_dim, mid_dim)

    model.blocks[k].mlp = adapter

# Plot Adapter

In [None]:
model_graph = draw_graph(adapter, input_size=(1,feat_dim), expand_nested=True)
model_graph.visual_graph

# Add MLP Head to modified DINOv2

In [None]:
augmented_model = nn.Sequential(model, head)

# Plot Augmented Model + MLP Head

In [None]:
model_graph = draw_graph(augmented_model, input_size=(1,3,518,518), expand_nested=True, depth=10)
model_graph.visual_graph

# Number of trainable parameters:

In [None]:
trainable_params = [p for p in augmented_model.parameters() if p.requires_grad]
params = sum([np.prod(p.size()) for p in trainable_params])

frozen_params_list = [p for p in augmented_model.parameters() if not p.requires_grad]
frozen_params = sum([np.prod(p.size()) for p in frozen_params_list])

total_params = params + frozen_params

print(f'Proportion of trainable parameters: {params / total_params * 100:.2f}%')

# Dataset for fine-tuning:

In [None]:
files, _ = zip(*get_fnames())

IMAGES = []
for file in tqdm(files, desc='Loading images'):
    im = resize_hdf_image(load_image(file)[0])
    stack = np.concatenate([im, im, im], axis=2)
    IMAGES.append(stack)
IMAGES = np.array(IMAGES).transpose(0,3,1,2)

FT_LABELS = np.hstack([[neuro]*600 for neuro in neurotransmitters]).reshape(-1,1)

FT_DATASET = list(zip(IMAGES, LABELS))

FT_DATASET = sample(FT_DATASET, len(FT_DATASET))

FT_SPLIT = int(len(FT_DATASET)*0.2)
FT_TRAINING_SET = FT_DATASET[FT_SPLIT:]
FT_TEST_SET = FT_DATASET[:FT_SPLIT]

In [None]:
class Custom_FT_Dataset(Dataset):
    def __init__(self, 
                 set):
        if set == 'training':
            self.data = FT_TRAINING_SET
        else:
            self.data = FT_TEST_SET

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx]
        label_idx = neurotransmitters.index(label[0])
        return image, one_hot_neurotransmitters[label_idx]

In [None]:
ft_train_batch_size, ft_test_batch_size = 1, 1

ft_training_dataset = Custom_FT_Dataset('training') 
ft_test_dataset = Custom_FT_Dataset('test')

ft_training_loader = utils.DataLoader(ft_training_dataset, batch_size=ft_train_batch_size, shuffle=True, pin_memory=True)
ft_test_loader = utils.DataLoader(ft_test_dataset, batch_size=ft_test_batch_size, shuffle=True, pin_memory=True)

# Training augmented model with frozen MLP Head (Fine-Tuning)

# UMAP Before:

In [None]:
REFS = compute_ref_embeddings(True, os.path.join(embeddings_path, 'small_mean_ref_518_Aug=False_k=10.pt'))
EMBEDDINGS = torch.load(os.path.join(embeddings_path, 'small_dataset_embs_518.pt'))
EMBEDDINGS = torch.cat(EMBEDDINGS)

In [None]:
import umap
reducer = umap.UMAP(random_state=42)

REFS = compute_ref_embeddings(True, os.path.join(embeddings_path, 'small_mean_ref_518_Aug=False_k=10.pt'))
EMBEDDINGS = torch.load(os.path.join(embeddings_path, 'small_dataset_embs_518.pt'))
EMBEDDINGS = torch.cat(EMBEDDINGS)

idx = 1

EX_EMBEDDINGS = EMBEDDINGS[idx].reshape(-1, feat_dim)
EX_REF = torch.tensor(REFS[0])

embeddings_and_ref = np.vstack([EX_REF, EX_EMBEDDINGS])

N = nb_patches_per_dim = int((resize_size/14))

ref_coords = list(map(resize, coords[0][0]))

center = (ref_coords[1]//14+1,ref_coords[0]//14+1)
row, col = np.ogrid[:N, :N]

distance_matrix = np.abs(N - np.maximum(np.abs(row - center[0]), np.abs(col - center[1])) - nb_patches_per_dim)

distances = []
for i in range(nb_patches_per_dim):
    for j in range(nb_patches_per_dim):
        distances.append(distance_matrix[i,j])

umap_embeddings = reducer.fit_transform(embeddings_and_ref)

from analysis_utils import compute_similarity_matrix

semantic_distances = compute_similarity_matrix(EX_REF, EX_EMBEDDINGS)

plt.scatter(umap_embeddings[1:,0], umap_embeddings[1:,1], c=semantic_distances.ravel(), s=2, cmap='bwr')
plt.scatter(umap_embeddings[0,0], umap_embeddings[0,1], c='red', marker='o')
plt.colorbar()
plt.show()

In [None]:
ft_optimizer = torch.optim.Adam(augmented_model.parameters(), lr=3e-4)
augmented_model.to(device)
ft_loss_fn = nn.BCELoss()

In [None]:
def fine_tuning(epochs):
    augmented_model.train()
    loss_list = []
    prediction_list = []
    test_accuracies = []
    for _ in tqdm(range(epochs), desc=f'Epoch:'):
        epoch_loss_list = []
        for images, one_hot_gts in tqdm(ft_training_loader, desc='Training', leave=False):
            images = images.to(torch.float32).to(device)
            output = augmented_model(images).to(torch.float64)
            
            gt = one_hot_gts
            gt = gt.to(device)
            loss=0
            for out, true in zip(output,gt):
                loss += ft_loss_fn(out,true)
                
            loss.backward()
            ft_optimizer.step()
            ft_optimizer.zero_grad()
            
            epoch_loss_list.append(loss.detach().cpu().numpy())

        loss_list.append(np.mean(epoch_loss_list))

        augmented_model.eval()
        with torch.no_grad():
            score = 0
            total = 0
            for images, one_hot_gts in tqdm(ft_test_loader, desc='Testing', leave=False):
                
                images = images.to(torch.float32).to(device)
                outputs = augmented_model(images) # shape (batch_size, nb_classes)
                
                for output, one_hot_gt in zip(outputs, one_hot_gts):
                    predicted_idx = torch.argmax(output).item()
                    true_idx = torch.argmax(one_hot_gt).item()
                    prediction_list.append([predicted_idx, true_idx])
                    
                    if predicted_idx == true_idx:
                        score += 1
                    total += 1
                batch_score = 100*score/total
            test_accuracies.append(batch_score)

    return loss_list, prediction_list, test_accuracies

In [None]:
nb_epochs = 1
loss_list, prediction_list, test_accuracies = fine_tuning(nb_epochs) # (pred, truth)

# UMAP After:

In [None]:
#TODO: Create a hook to get representation

In [None]:
x = [i for i in range(epochs)]
fig, ax1 = plt.subplots(figsize=(5,5), dpi=150)
ax2 = ax1.twinx()
lns1 = ax1.plot(x, loss_list, label='Train loss')
ax1.set_ylim(0,max(loss_list)*1.05)
lns2 = ax2.plot(x, test_accuracies, label='Test accuracy', color='red')
ax2.set_ylim(0,105)

lns = lns1 + lns2
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc=0)

plt.show()