# Fmri to Clip Representation


## MICS

---

### Importing Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from torch.optim.lr_scheduler import StepLR
import os
import scipy.io as sio
import pickle
import glob

import numpy as np
import pandas as pd
# import matplotlib.pyplot as p
from scipy.io import loadmat

In [None]:
import glob
from pathlib import Path
from PIL import Image
import pickle
import wget
import zipfile

from scipy import stats
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV
import matplotlib.pyplot as plt

from transformers import CLIPProcessor, CLIPModel

import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.transforms import Compose, CenterCrop, Normalize, ToTensor
from torch.autograd import Variable
import torchvision.datasets as dset
import clip

In [None]:
from diffusion import generate_img

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'

---
### Setting seed

In [None]:
import random

# Set the random seed to a specific value, for example, 42
random.seed(42)
torch.manual_seed(42)

---
### Evaluation Function

In [None]:
def pairwise_evaluation(vectors_real, vectors_new):
	count = 0
	total = 0
	vectors_real_cpu = vectors_real.detach().cpu().numpy()
	vectors_new_cpu = vectors_new.detach().cpu().numpy()

	for i in range(vectors_real_cpu.shape[0]):
		for j in range(vectors_real_cpu.shape[0]):
			if j > i:
				errivsi = np.corrcoef(vectors_new_cpu[i,:], vectors_real_cpu[i,:])
				errivsj = np.corrcoef(vectors_new_cpu[i,:], vectors_real_cpu[j,:])
				errjvsi = np.corrcoef(vectors_new_cpu[j,:], vectors_real_cpu[i,:])
				errjvsj = np.corrcoef(vectors_new_cpu[j,:], vectors_real_cpu[j,:])

				if (errivsi[0,1] + errjvsj[0,1]) > (errivsj[0,1] + errjvsi[0,1]):
					count += 1
				total += 1

	accuracy = count / total
	return accuracy

In [None]:
def top_5(pred, real):
	counter = 0.0
	counterr = 0.0
	counter_ten= 0.0
	for i in range(pred.shape[0]):
		if np.argmax(pred[i,:])==np.argmax(real[i,:]):
			counter+=1
		sort = np.flip(np.argsort(pred[i,:]))
		holder = np.isin(np.argmax(real[i,:]),sort[:5])
		holder_ten = np.isin(np.argmax(real[i,:]),sort[:10])
		if holder:
			counterr+=1
		if holder_ten:
			counter_ten+=1
	accuracy = counter/pred.shape[0]
	accuracy_five = counterr/pred.shape[0]
	accuracy_ten = counter_ten/pred.shape[0]
	return accuracy, accuracy_five, accuracy_ten

In [None]:
class CosineProximityLoss(nn.Module):
	def __init__(self):
		super(CosineProximityLoss, self).__init__()

	def forward(self, input1, input2):
		# Normalize the input tensors
		input1 = F.normalize(input1, p=2, dim=-1)
		input2 = F.normalize(input2, p=2, dim=-1)

		# Compute the cosine similarity
		cosine_sim = torch.sum(input1 * input2, dim=-1)

		# Cosine proximity loss is 1 - cosine similarity
		loss = 1 - cosine_sim.mean()

		return loss

In [None]:
class MSELoss(nn.Module):
	def __init__(self):
		super(MSELoss, self).__init__()

	def forward(self, input1, input2):
		mse_loss = torch.mean((input1 - input2)**2)

		return mse_loss

In [None]:
class MeanDistanceLoss(nn.Module):
	def __init__(self, cosine_loss):
		super(MeanDistanceLoss, self).__init__()
		self.cosine_loss = cosine_loss
		self.val = 179

	def forward(self, y_true, y_pred):
		total = 0
		total_two = 0
		
		for i in range((self.val + 1)):
			if i == 0:
				total += (self.val * self.cosine_loss(y_true, y_pred))
			else:
				rolled = torch.roll(y_pred, i, dims=0)
				total_two -= self.cosine_loss(y_true, rolled)
		
		return (total_two / self.val) + (total / self.val)

---
## Loading Data from BOLD5000

### Getting the fmri data

In [None]:
def get_images_from_CSI(patient_id):
	root_path = 'Stimuli_Presentation_Lists/'
	path = os.path.join(root_path, f'CSI{patient_id}/')

	list_images_names = []
	total_images = 0

	# Get subdirectory paths
	sub_folder_paths = glob.glob(os.path.join(path, f'CSI{patient_id}_sess*/'))
	sub_folder_paths = sub_folder_paths  # Remove the last item

	# Loop through each subdirectory
	for i, folder in enumerate(sub_folder_paths):
		index = i + 1
		file_paths = glob.glob(os.path.join(folder, f'*.txt'))
		# print(file_paths)
  
		# raise Exception('Stop here')
		for file_path in file_paths:
			data_from_file = pd.read_csv(file_path, sep='\t', header=None)

			total_images += len(data_from_file)
			# get list of images 
			images = data_from_file[0].values
			# append to list of images
			list_images_names.extend(images)

	return list_images_names

list_images_names_CSI1 = get_images_from_CSI(1)

# print("Total images:", list_images_names_CSI1)
print("Length of list_images_names:", len(list_images_names_CSI1))

list_images_names = list_images_names_CSI1

In [None]:
# images file path
coco_images_path = "./Scene_Stimuli/Original_Images/COCO"

# loading all list of images names 
coco_images = glob.glob(coco_images_path + '/*')

# load images
def load_image(image_path):
	# load rgb image
	image = Image.open(image_path)
	
	# convert to numpy array
	image = np.array(image)
	
	if len(image.shape) == 2:
		image = np.stack((image,)*3, axis=-1)
	
	assert image.shape[2] == 3, "Image should have 3 channels"
	
	return image

def load_images_from_names(image_name):
	try:
		if image_name.startswith('COCO'):
			image_path = coco_images_path + '/' + image_name
			return load_image(image_path)
		else:
			raise Exception("Image not of coco dataset")
		
	except Exception as e:
		print("Error loading image:", image_name, e)
  
		return None

list_images_names_CSI1 = get_images_from_CSI(1)

# print("Total images:", list_images_names_CSI1)
print("Length of list_images_names:", len(list_images_names_CSI1))

list_images_names = list_images_names_CSI1

In [None]:
def get_mask(list_images_names):
	mask = np.zeros(len(list_images_names))
	seen_images = set()

	for i, image in enumerate(list_images_names):
		# image name is unique
		image_name = image.split('.')[0]

		if image_name.startswith('rep'):
			continue

		if image_name in seen_images:
			continue

		if not image_name.startswith('COCO'):
			continue

		mask[i] = 1
		seen_images.add(image_name)

	return mask

# Example usage:
mask = get_mask(list_images_names)
	
print("Unique images:", np.sum(mask))
print("len of mask:", len(mask))

In [None]:
# apply mask on list_images_names
unique_images = [image for i, image in enumerate(list_images_names) if mask[i] == 1]

print("Length of unique images:", len(unique_images))

# unique_images

In [None]:
# images file path
coco_images_path = "./Scene_Stimuli/Original_Images/COCO"

# loading all list of images names 
coco_images = glob.glob(coco_images_path + '/*')
# imgnet_images = glob.glob(imgnet_images_path + '/*')
# scene_images = glob.glob(scene_images_path + '/*')

# load images
def load_image(image_path):
	# load rgb image
	image = Image.open(image_path)
	
	# convert to numpy array
	image = np.array(image)
	
	if len(image.shape) == 2:
		image = np.stack((image,)*3, axis=-1)
	
	assert image.shape[2] == 3, "Image should have 3 channels"
	
	return image

def load_images_from_names(image_name):
	try:
		if image_name.startswith('COCO'):
			image_path = coco_images_path + '/' + image_name
			return load_image(image_path)
		else:
			raise Exception("Image not of coco dataset")
		
	except Exception as e:
		print("Error loading image:", image_name, e)
  
		return None

---

### Get clip features

In [None]:
# 1. Load the CLIP model
# device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
def generate_clip_image_activations(input_images):
	feature_vectors = []
	for img_path in tqdm.tqdm(input_images):
		# Load and preprocess the image
		image = load_images_from_names(img_path)
		image = Image.fromarray(np.uint8(image))
		# Preprocess the image
		image = preprocess(image).unsqueeze(0).to(device)
		
		# Generate features
		with torch.no_grad():
			# Pass the image input to the image encoder
			image_features = model.encode_image(image)
			# print shape 
			feature_vectors.append(image_features[0])
		
	return feature_vectors

# Example usage
ex_input_images = ["COCO_train2014_000000475840.jpg"]
ex_image_features = generate_clip_image_activations(ex_input_images)
print(ex_image_features[0].shape)

In [None]:
# raise Exception("Stop here")

In [None]:
# now for all the unique images
unique_image_features = generate_clip_image_activations(unique_images)

In [None]:
# stack tensor
tensor_image_features = torch.stack(unique_image_features)
tensor_image_features.shape

In [None]:
# dump variables data  
pickle.dump(unique_image_features, open("./data/unique_image_features.pkl", "wb")) 

================================================================

In [None]:
# load the vairable 
unique_images_features = pickle.load(open("./data/unique_image_features.pkl", "rb"))

In [None]:
# load the captions adn representations
# load pickle file
with open('data\caption_to_feature.pkl','rb') as f:
	caption_to_feature = pickle.load(f)
 
# load pickle file 
with open("data/img_best_caption.pkl",'rb') as f:
	img_best_caption = pickle.load(f)

In [None]:
img_name_to_caption = {}

for k,v in img_best_caption.items():
	img_name_to_caption[k.split('.')[0]] = v
 
print(list(img_name_to_caption.values())[:10])

In [None]:
unique_image_features = torch.stack(unique_images_features)

unique_captions = list(caption_to_feature.keys())
unique_captions_features = list(caption_to_feature.values())

len(unique_captions),len(unique_captions_features), len(unique_images_features)

---
### Apply Mask to FMRI Data

In [None]:
roi_data_path = 'data/CSI1/mat/CSI1_ROIs_TR34.mat'

# get the ROI data
from scipy.io import loadmat
roi_data = loadmat(roi_data_path)

print("ROI data keys:", roi_data.keys())

# delete columns that start with __
roi_data = {k: v for k, v in roi_data.items() if not k.startswith('__')}

print("ROI data keys:", roi_data.keys())
roi_data['RHLOC'].shape

# sum of voxels in each ROI
roi_data_sum = 0 

for roi in roi_data:
	roi_data_sum += roi_data[roi].shape[1]
 
print("Total number of voxels in all ROIs:", roi_data_sum)

In [None]:
# apply mask on roi_data
def apply_mask_on_roi_data(roi_data, mask):
	roi_data_masked = {}
	for roi in roi_data:
		roi_data_masked[roi] = roi_data[roi][mask == 1,:]
	
	return roi_data_masked

# apply mask on roi_data
roi_data_masked = apply_mask_on_roi_data(roi_data, mask)

# sum of voxels in each ROI
roi_data_sum = 0
for roi in roi_data_masked:
	# print shape of roi
	print(f"ROI: {roi}, Shape: {roi_data_masked[roi].shape}")
	roi_data_sum += roi_data_masked[roi].shape[1]
 
print("Total number of voxels in all ROIs after applying mask:", roi_data_sum)

In [None]:
# concat all the roi voxels
def concat_roi_voxels(roi_data):
	roi_data_concat = np.concatenate([roi_data[roi] for roi in roi_data], axis=1)
	return roi_data_concat

roi_data_concat = concat_roi_voxels(roi_data_masked)
# convert to tensor
roi_data_concat = torch.tensor(roi_data_concat, dtype=torch.float32)
print("Shape of roi_data_concat:", roi_data_concat.shape)

## Making new npy files.

In [None]:
sizess = []
for roi in roi_data_masked:
	sizess.append(roi_data_masked[roi].shape[1])
print(sizess)
np.save('./data/look_ups/sizes_clip.npy', sizess)
sizess = np.array(sizess)
sizess = np.round(sizess/2).astype(int)
print(sizess)
np.save('./data/look_ups/reduced_sizes_clip.npy', sizess)

---
## Data loaders

In [None]:
# # split data set into train and test sklearn
# from sklearn.model_selection import train_test_split

# x_data = roi_data_concat
# y_data = tensor_image_features

# # print("Shape of x_data:", x_data.shape)
# # print("Shape of y_data:", y_data.shape)

# # tensor split data
# x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=42)

# # print shape of train and test data
# print("Shape of x_train:", x_train.shape)
# print("Shape of y_train:", y_train.shape)
# print("Shape of x_test:", x_test.shape)
# print("Shape of y_test:", y_test.shape)

In [None]:
# setting train to be full and test to be 0.1 size of the data
glob_roi_fmri_data = roi_data_concat.to(device)
glob_img_clip_reps = unique_image_features.to(device)
glob_img_names = unique_images
glob_captions = unique_captions
glob_captions_reps = torch.stack(unique_captions_features).to(device)

In [None]:
x_train = roi_data_concat.to(device)
y_train = unique_image_features.to(device)

_x_train = roi_data_concat[200:].to(device)
_y_train = unique_image_features[200:].to(device)

img_names_train = unique_images

x_test = roi_data_concat[:200].to(device)
y_test = unique_image_features[:200].to(device)

img_test_names = unique_images[:200]

# print shape of train and test data
print("Shape of x_train:", x_train.shape)
print("Shape of y_train:", y_train.shape)
print("len of img_names_train:", len(img_names_train))
print()
print("Shape of x_test:", x_test.shape)
print("Shape of y_test:", y_test.shape)
print("len of img_test_names:", len(img_test_names))

In [None]:
x_train[:2]

In [None]:
# get index 
ex_image_name = glob_img_names[0:50]
ex_image_rep = glob_img_clip_reps[0:50].to(device)
ex_fmri_data = x_test[0:50].to(device)

indexes = [27,6,7,8,14,13]  # Example indices

---
## Evalaution Function

In [None]:
def evaluate_model(model, test_loader, criterion_accuracy, criterion_glove, get_ranks, get_mean_ranks, img_test_names,model_type=None):
	plot_ranks = []
	accuracy_test = 0
	rank = 0
	loss_glove = 0
	cnt = 0

	with torch.no_grad():
		for inputs, targets_glove in test_loader:
			outputs_glove = None
			if model_type == None: 
				out, outputs_glove, outputs_class = model(inputs)
			elif model_type == "small":
				outputs_glove = model(inputs)
			else:
				raise Exception("Invalid model type")
			plot_ranks.extend(get_ranks(outputs_glove[:10], img_test_names[:10]))
			accuracy_test += criterion_accuracy(outputs_glove[:8], targets_glove[:8])
			rank += get_mean_ranks(outputs_glove[:10], img_test_names[:10])
			loss_glove += criterion_glove(outputs_glove, targets_glove)
			cnt += 1
			
	avg_accuracy_test = accuracy_test / cnt
	avg_loss_glove = loss_glove / cnt
	avg_rank = rank / cnt

	return avg_accuracy_test, avg_loss_glove, avg_rank, plot_ranks

def plot_ranks_distribution(plot_ranks):
	plt.figure(figsize=(10, 6))
	plt.hist(plot_ranks, bins=50, color='skyblue', edgecolor='black')
	plt.title('Distribution of Ranks')
	plt.xlabel('Rank')
	plt.ylabel('Frequency')
	plt.xlim(0, 2000)  # Limit x-axis from 0 to 2000
	plt.grid(True)
	plt.show()

In [None]:
def get_dic_sorted_by_cosine_similarity(query,glob_names=glob_img_names,glob_rep=glob_img_clip_reps):
	cosine_similarities = {}
	for img_name, feature in zip(glob_names, glob_rep):	
		feature.to(device)
		assert query.shape == torch.Size([512]), "Query should be 512"
		assert feature.shape == torch.Size([512]), "Feature should be 512"
		cosine_similarities[img_name] = torch.nn.functional.cosine_similarity(query, feature, dim=0)
		cosine_similarities[img_name] = cosine_similarities[img_name].item()
  
	# get top 5 representations along with names
	cosine_similarities = dict(sorted(cosine_similarities.items(), key=lambda item: item[1], reverse=True))
  
	return cosine_similarities

def find_top5_img(query):
	cosine_similarities = get_dic_sorted_by_cosine_similarity(query)
	top5_image = list(cosine_similarities.items())[:5]
	ret_dic = {}
 
	# get image representations from top 5 image 
	# print("Top 5 images:")
	for img_name, sim in top5_image:
		# get index of image from image names
		index = glob_img_names.index(img_name)
		# get image feature
		# print("sim: ",sim)
		image_feature = glob_img_clip_reps[index]
		ret_dic[img_name] = image_feature
  
	return ret_dic

def evaluate_gt_feature_in_top5(y_preds, y_gts):
	# get top 5 images for each query
	correct_pred_top5 = 0
	for y_pred, y_gt in zip(y_preds, y_gts):
		assert y_pred.shape == torch.Size([512]), "y_pred should be 512"
		top5_images_reps_dic = find_top5_img(y_pred)
  
		found = False
		for img_name, img_rep in top5_images_reps_dic.items():
			# check if y_gt is in top 5 images
			if 0.98 <= torch.nn.functional.cosine_similarity(y_gt, img_rep, dim=0) <= 1.1:
				found = True
				break

		if found:
			correct_pred_top5 += 1
   
	accuracy = correct_pred_top5 *100 / len(y_preds) 
	# print(f"correct_pred_top5: {correct_pred_top5}, Total: {len(y_preds)}")
	# print("Accuracy:", accuracy)
	return accuracy

In [None]:
def get_k_closest_names(query, k=5, glob_names=glob_img_names, glob_rep=glob_img_clip_reps.to(device)):
	cosine_similarities = get_dic_sorted_by_cosine_similarity(query, glob_names, glob_rep)
 
	out = list(cosine_similarities.keys())
	out = out[:k]
 
	return out
 
	return out

def get_rank_query(query, search):
	cosine_similarities = get_dic_sorted_by_cosine_similarity(query)
 
	rank = 0
	for img_name in cosine_similarities.keys():
		rank += 1
		if img_name == search:
			break
 
	return rank

def get_ranks(y_preds, gt_names):
	ranks = []
	for y_pred, gt_name in zip(y_preds, gt_names):
		assert y_pred.shape == torch.Size([512]), "y_pred should be 512"
		rank = get_rank_query(y_pred, gt_name)
		ranks.append(rank)
  
	return np.array(ranks)

def get_mean_ranks(y_preds, gt_names):
	return np.mean(get_ranks(y_preds, gt_names))

In [None]:
def show_original_and_closest_image(indexes, model, glob_img_names, glob_img_clip_reps, x_test, k=1, ex_fmri_data=None, ex_image_rep=None):
    # Create dataset and data loader for all images
    test_dataset = TensorDataset(ex_fmri_data.to(device), ex_image_rep.to(device))
    ex_test_loader = DataLoader(test_dataset, batch_size=len(x_test))
    
    # Run the model
    for roi_data, img_clip_reps in ex_test_loader:
        out, outputs_glove, _ = model(roi_data)
        break
    
    # Iterate over each index
    for idx in indexes:
        # Get the closest image names for this index
        closest_names = get_k_closest_names(outputs_glove[idx], k)
        
        # Load original image
        ex_image = load_images_from_names(glob_img_names[idx])
        
        # Plot original and closest images
        fig, axes = plt.subplots(1, k+1, figsize=(15, 5))
        
        # Plot original image
        axes[0].imshow(ex_image)
        axes[0].set_title('Original Image')
        
        # Plot closest images
        for i, name in enumerate(closest_names):
            closest_img = load_images_from_names(name)
            axes[i+1].imshow(closest_img)
            axes[i+1].set_title(f'Closest Image Rank {i+1}')
        
        plt.show()
        
        
        
def print_predictions(model, indexes, x_test, 
                      glob_img_names=glob_img_names,
                      glob_img_clip_reps=glob_img_clip_reps,
					  glob_captions=glob_captions,
					  glob_caption_reps=glob_captions_reps,
					  ex_fmri_data=None, 
					  ex_image_rep=None,
					  model_type=None):
	# Create dataset and data loader for all images
	test_dataset = TensorDataset(ex_fmri_data.to(device), ex_image_rep.to(device))
	ex_test_loader = DataLoader(test_dataset, batch_size=len(x_test))
	model.to(device)
	model.eval()
	
	# Run the model
	for roi_data, img_clip_reps in ex_test_loader:
		if model_type == None:
			out, outputs_glove, _ = model(roi_data)
		elif model_type == "small":
			outputs_glove = model(roi_data)
		break
	
	# Iterate over each index
	for idx in indexes:
		# Get the closest image names for this index
		closest_name = get_k_closest_names(outputs_glove[idx], 1)[0]
		closet_captions = get_k_closest_names(outputs_glove[idx], 2, glob_captions, glob_caption_reps)[:2]
		
		# Get the actual image name
		actual_name = glob_img_names[idx]
		
		# Print the predictions
		print(f"Actual Image: {actual_name}")
		plt.imshow(load_images_from_names(actual_name))
		plt.show()
		print("Original Caption:")
		print("\"| ",img_best_caption[actual_name]," |\"")
		print("\nPredictions:")
		print(f"Closest Image: {closest_name}")
		plt.imshow(load_images_from_names(closest_name))
		plt.show()
		print("Caption:")
		print("\"| ",closet_captions[0]," |\"")
		print("Gen image using rep:",end=" ")
		generate_img(closet_captions[1])
		print("=="*50,"\n\n"*4)

		# ==========================================
  
def print_predictions_one_rep(model, index, x_test, 
                      glob_img_names=glob_img_names,
                      glob_img_clip_reps=glob_img_clip_reps,
					  glob_captions=glob_captions,
					  glob_caption_reps=glob_captions_reps,
					  ex_fmri_data=None, 
					  ex_image_rep=None,
					  model_type=None,k=2):
	# Create dataset and data loader for all images
	test_dataset = TensorDataset(ex_fmri_data.to(device), ex_image_rep.to(device))
	ex_test_loader = DataLoader(test_dataset, batch_size=len(x_test))
	
	# Run the model
	for roi_data, img_clip_reps in ex_test_loader:
		if model_type == None:
			out, outputs_glove, _ = model(roi_data)
		elif model_type == "small":
			outputs_glove = model(roi_data)
			outputs_glove.to('cpu')
		break

	# Get the closest image names for this index
	closest_names = get_k_closest_names(outputs_glove[index].cpu(), k)
	closet_captions = get_k_closest_names(outputs_glove[index].cpu(), k, glob_captions, glob_caption_reps)
	
 	# Get the actual image name
	actual_name = glob_img_names[index]
	
	# Print the predictions
	print(f"Actual Image: {actual_name}")
	plt.imshow(load_images_from_names(actual_name))
	plt.show()
	# print("Original Caption:",end=" ")
	# print("\"| ",img_best_caption[actual_name]," |\"")
	print("=="*50)
	
	# Iterate over each index
	for i in range(k):
		print("=="*50)
		print("\nPredictions:")
		print(f"Closest Image: {closest_names[i]}")
		plt.imshow(load_images_from_names(closest_names[i]))
		plt.title(f"Caption: {img_best_caption[closest_names[i]]}")
		plt.show()
	
	for i in range(k):
		print("Caption:",end=" ")
		print("\"| ",closet_captions[i]," |\"")
  
	# TODO
	for i in range(k):
		print("Gen image using rep:",end=" ")
		generate_img(img_best_caption[closest_names[i]])
		# print(ret_out)
		# plt.imshow(np.array())
		# print(f"Caption: {img_best_caption[closest_names[i]]}")
		# plt.show()
	print("=="*50,"\n\n"*4)

---
### Sanity check of functions

In [None]:
# top 5 sanity check
evaluate_gt_feature_in_top5(y_test[:5], y_test[:5])

In [None]:
# rank sanity check
ranks = get_ranks(y_test[1:10], img_test_names[0:9])

print("Ranks:", ranks)

In [None]:
# sanity check for rank
# select a random image and get its representation
index = 1349
index_name = unique_images[index]

# laod and print img
img = load_images_from_names(index_name)
plt.imshow(img)
plt.show()

y_pred = unique_image_features[index]
print("Shape of y_pred:", y_pred.shape)

print(img_test_names)
print(y_test.shape)

# get rank of the image from test 
closest_names = get_k_closest_names(y_pred, k=5)

for i, name in enumerate(closest_names):
    img = load_images_from_names(name)
    plt.imshow(img)
    plt.show()
    print(f"Rank: {i+1}, Name: {name}")

---
## Models

### Model with Pretraining (Reconstruction and Classification Separately)

#### Model class

In [None]:
class Autoencoder(nn.Module):
	def __init__(self, mean = False):
		super(Autoencoder, self).__init__()

		rate = 0.4
		dense_size = 512
		glove_size = 512
		reduced_size = 843

		sizes = np.load('./data/look_ups/sizes_clip.npy')
		reduced = np.load('./data/look_ups/reduced_sizes_clip.npy')

		self.dense_layers = nn.ModuleList([nn.Linear(sizes[i], reduced[i]) for i in range(len(sizes))])
		self.batch_norm = nn.ModuleList([nn.BatchNorm1d(reduced[i]) for i in range(len(sizes))])
		self.dropout = nn.Dropout(rate)

		self.batch_norm11 = nn.BatchNorm1d(reduced_size)
		self.dropout11 = nn.Dropout(rate)

		self.dense1 = nn.Linear(reduced_size, dense_size)
		self.leaky_relu = nn.LeakyReLU(0.3)
		self.batch_norm1 = nn.BatchNorm1d(dense_size)
		self.dropout1 = nn.Dropout(rate)

		self.dense2 = nn.Linear(dense_size, glove_size)
		self.batch_norm2 = nn.BatchNorm1d(glove_size)
		self.dropout2 = nn.Dropout(rate)

		self.dense3 = nn.Linear(glove_size, 180)
		self.softmax = nn.Softmax(dim=1)

		self.dense4 = nn.Linear(dense_size, reduced_size)
		self.batch_norm3 = nn.BatchNorm1d(reduced_size)
		self.dropout3 = nn.Dropout(rate)

		self.dense_layers_transpose = nn.ModuleList([nn.Linear(reduced[i], sizes[i]) for i in range(len(sizes))])
		self.batch_norm4 = nn.ModuleList([nn.BatchNorm1d(size) for size in sizes])
		# self.dense_transpose_layer = DenseTranspose(nn.Linear(dense_size, gordon_areas * reduced_size))
		# self.dense_transpose_layers = nn.ModuleList([DenseTranspose(dense) for dense in self.dense_layers])
		# self.dense_transpose_layers = nn.ModuleList([nn.Linear(reduced_size, size) for size in sizes])

		self.mean = mean

	def forward(self, x):
		# print(1)
		branch_outputs = []
		index = 0
		gordon_areas = 10
		sizes = np.load('./data/look_ups/sizes_clip.npy')
		reduced = np.load('./data/look_ups/reduced_sizes_clip.npy')
		for i in range(gordon_areas):
			new_index = index + sizes[i]
			small_input = x[:, index:new_index]
			dense_out = self.leaky_relu(self.dense_layers[i](small_input))
			dense_out = self.batch_norm[i](dense_out)
			dense_out = self.dropout(dense_out)
			branch_outputs.append(dense_out)
			index = new_index

		# tensor_branch_outputs = torch.cat(branch_outputs, dim=1)

		# print(tensor_branch_outputs.size())

		concat = torch.cat(branch_outputs, dim=1)
		# print(concat.size())
		dense1_out = self.batch_norm11(concat)
		dense1_out = self.dropout11(dense1_out)
		# print((dense1_out).size())
		out_further = self.leaky_relu(self.dense1(dense1_out))
		out_further = self.batch_norm1(out_further)
		out_further = self.dropout1(out_further)
		# print((out_further).size())
		out_glove = self.leaky_relu(self.dense2(out_further))
		out_glove = self.batch_norm2(out_glove)
		out_glove = self.dropout2(out_glove)
		# print((out_glove).size())
		out_class = self.softmax(self.dense3(out_glove))
		# print((out_class).size())
		dense4_out = self.leaky_relu(self.dense4(out_further))
		# print("hiiiiii")
		dense4_out = self.batch_norm3(dense4_out)
		dense4_out = self.dropout3(dense4_out)
		# print((dense4_out).size())
		branch_outputs1 = []
		index1 = 0
		for j in range(gordon_areas):
			new_index1 = index1 + reduced[j]
			small_input = dense4_out[:, index1:new_index1]
			# print("hi")
			# print(small_input.size())
			dense_out = self.leaky_relu(self.dense_layers_transpose[j](small_input))
			# print("bye")
			dense_out = self.batch_norm4[j](dense_out)
			branch_outputs1.append(dense_out)
			index1 = new_index1
		out = torch.cat(branch_outputs1, dim=1)
		# print((out).size())
		if not self.mean:
			# print("yayyyyy!!!")
			return out, out_glove, out_class
		else:
			# print("fuckkkkkkk")
			return out, out_glove, out_class, concat, dense1_out

---
#### Training

In [None]:
model = Autoencoder()

# Create DataLoader
batch_size_train = 64
train_dataset = TensorDataset(x_train.to(device), y_train.to(device))
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
test_dataset = TensorDataset(x_test.to(device), y_test.to(device))
test_loader = DataLoader(test_dataset, batch_size=50)

model = Autoencoder().to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

optimizer_specific = torch.optim.Adam([
	{'params': model.dense2.parameters()},
	{'params': model.batch_norm2.parameters()}
], lr=0.001)

loss_fn_glove = CosineProximityLoss()
criterion_glove = CosineProximityLoss()
loss_fn_autoencoder = CosineProximityLoss()
criterion_accuracy = evaluate_gt_feature_in_top5

scheduler = StepLR(optimizer, step_size=5, gamma=0.25)

In [None]:
epochs = 8

for epoch in range(epochs):
	model.train()
	running_loss = 0.0
	running_loss_glove = 0.0
	running_loss_autoencoder = 0.0
	cnt2 = 0
	for batch_data, batch_glove in tqdm.tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
		optimizer.zero_grad()
		out, output_glove, output_class = model(batch_data)
		loss_autoencoder = loss_fn_autoencoder(batch_data, out)
		loss = loss_autoencoder
		loss.backward()
		loss.dtype
		optimizer.step()
		loss_glove = loss_fn_glove(output_glove, batch_glove)
		running_loss_glove += loss_glove.item() * batch_data.size(0)
		running_loss_autoencoder += loss_autoencoder.item() * batch_data.size(0)
		cnt2+=1
	epoch_loss_glove = running_loss_glove / len(train_loader.dataset)
	epoch_loss_autoencoder = running_loss_autoencoder / len(train_loader.dataset)
	
	with torch.no_grad():
		for inputs, targets_glove in test_loader:
			out, outputs_glove, outputs_class = model(inputs)
			loss_glove = criterion_glove(outputs_glove, targets_glove)

	running_loss = epoch_loss_glove + epoch_loss_autoencoder
	average_loss = running_loss / cnt2
	print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {average_loss:.4f}, Glove Loss: {loss_glove:.4f}')
	scheduler.step()

In [None]:
epochs = 15

for epoch in range(epochs):
	model.train()
	running_loss = 0.0
	running_loss_glove = 0.0
	running_loss_autoencoder = 0.0
	cnt2 = 0
	for batch_data, batch_glove in tqdm.tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
		optimizer_specific.zero_grad()
		out, output_glove, output_class = model(batch_data)
		loss_glove = loss_fn_glove(output_glove, batch_glove)
		loss = loss_glove
		loss.backward()
		loss.dtype
		optimizer_specific.step()

		running_loss_glove += loss_glove.item() * batch_data.size(0)
		running_loss_autoencoder += loss_autoencoder.item() * batch_data.size(0)
		cnt2+=1
	
	epoch_loss_glove = running_loss_glove / len(train_loader.dataset)
	epoch_loss_autoencoder = running_loss_autoencoder / len(train_loader.dataset)
	
	with torch.no_grad():
		for inputs, targets_glove in test_loader:
			out, outputs_glove, outputs_class = model(inputs)
			loss_glove = criterion_glove(outputs_glove, targets_glove)

	running_loss = epoch_loss_glove + epoch_loss_autoencoder
	average_loss = running_loss / cnt2
	print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {average_loss:.4f}, Glove Loss: {loss_glove:.4f}')
	scheduler.step()

---
#### Results

In [None]:
avg_accuracy_test, avg_loss_glove, avg_rank, plot_ranks = evaluate_model(model, test_loader, criterion_accuracy, criterion_glove, get_ranks, get_mean_ranks, img_test_names)

# print(accuracy_test/cnt, loss_glove/cnt, rank/cnt)
print("Top 5 Accuracy:", avg_accuracy_test)
print("Glove Loss:", avg_loss_glove)
print("Mean Rank:", avg_rank)

# plot the ranks
plot_ranks_distribution(plot_ranks)

In [None]:
# Example usage:
print_predictions(model, indexes, x_test, 
				  glob_img_names=glob_img_names,
				  glob_img_clip_reps=glob_img_clip_reps,
				  glob_captions=glob_captions,
				  glob_caption_reps=glob_captions_reps,
				  ex_fmri_data=ex_fmri_data, 
				  ex_image_rep=ex_image_rep)

In [None]:
print_predictions_one_rep(model, 0, x_test, 
				  glob_img_names=glob_img_names,
				  glob_img_clip_reps=glob_img_clip_reps,
				  glob_captions=glob_captions,
				  glob_caption_reps=glob_captions_reps,
				  ex_fmri_data=ex_fmri_data, 
				  ex_image_rep=ex_image_rep,k=5)

---
### Model with Pretraining (Reconstruction and Classification Together)

#### Training

In [None]:
model = Autoencoder()

# Create DataLoader
batch_size_train = 16
train_dataset = TensorDataset(x_train.to(device), y_train.to(device))
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
test_dataset = TensorDataset(x_test.to(device), y_test.to(device))
test_loader = DataLoader(test_dataset, batch_size=50)

model = Autoencoder().to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

optimizer_specific = torch.optim.Adam([
	{'params': model.dense2.parameters()},
	{'params': model.batch_norm2.parameters()}
], lr=0.001)

loss_fn_glove = CosineProximityLoss()
criterion_glove = CosineProximityLoss()
loss_fn_autoencoder = CosineProximityLoss()
criterion_accuracy = evaluate_gt_feature_in_top5

scheduler = StepLR(optimizer, step_size=5, gamma=0.25)

In [None]:
epochs = 8

for epoch in range(epochs):
	model.train()
	running_loss = 0.0
	running_loss_glove = 0.0
	running_loss_autoencoder = 0.0
	cnt2 = 0
	for batch_data, batch_glove in tqdm.tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
		optimizer.zero_grad()
		out, output_glove, output_class = model(batch_data)
		loss_autoencoder = loss_fn_autoencoder(batch_data, out)
		loss = loss_autoencoder
		loss.backward()
		loss.dtype
		optimizer.step()

		running_loss_glove += loss_glove.item() * batch_data.size(0)
		running_loss_autoencoder += loss_autoencoder.item() * batch_data.size(0)
		cnt2+=1
	
	epoch_loss_glove = running_loss_glove / len(train_loader.dataset)
	epoch_loss_autoencoder = running_loss_autoencoder / len(train_loader.dataset)
	
	with torch.no_grad():
		for inputs, targets_glove in test_loader:
			out, outputs_glove, outputs_class = model(inputs)
			loss_glove = criterion_glove(outputs_glove, targets_glove)

	running_loss = epoch_loss_glove + epoch_loss_autoencoder
	average_loss = running_loss / cnt2
	print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {average_loss:.4f}, Glove Loss: {loss_glove:.4f}')
	scheduler.step()

In [None]:
epochs = 15

for epoch in range(epochs):
	model.train()
	running_loss = 0.0
	running_loss_glove = 0.0
	running_loss_autoencoder = 0.0
	cnt2 = 0
	for batch_data, batch_glove in tqdm.tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
		optimizer.zero_grad()
		out, output_glove, output_class = model(batch_data)
		loss_glove = loss_fn_glove(output_glove, batch_glove)
		loss = loss_glove
		loss.backward()
		loss.dtype
		optimizer.step()

		running_loss_glove += loss_glove.item() * batch_data.size(0)
		running_loss_autoencoder += loss_autoencoder.item() * batch_data.size(0)
		cnt2+=1
	
	epoch_loss_glove = running_loss_glove / len(train_loader.dataset)
	epoch_loss_autoencoder = running_loss_autoencoder / len(train_loader.dataset)
	
	with torch.no_grad():
		for inputs, targets_glove in test_loader:
			out, outputs_glove, outputs_class = model(inputs)
			loss_glove = criterion_glove(outputs_glove, targets_glove)

	running_loss = epoch_loss_glove + epoch_loss_autoencoder
	average_loss = running_loss / cnt2
	print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {average_loss:.4f}, Glove Loss: {loss_glove:.4f}')
	scheduler.step()

---
#### Results 

In [None]:
avg_accuracy_test, avg_loss_glove, avg_rank, plot_ranks = evaluate_model(model, test_loader, criterion_accuracy, criterion_glove, get_ranks, get_mean_ranks, img_test_names)

# print(accuracy_test/cnt, loss_glove/cnt, rank/cnt)
print("Top 5 Accuracy:", avg_accuracy_test)
print("Glove Loss:", avg_loss_glove)
print("Mean Rank:", avg_rank)

# plot the ranks
plot_ranks_distribution(plot_ranks)

In [None]:
print_predictions(model, indexes, x_test, 
				  glob_img_names=glob_img_names,
				  glob_img_clip_reps=glob_img_clip_reps,
				  glob_captions=glob_captions,
				  glob_caption_reps=glob_captions_reps,
				  ex_fmri_data=ex_fmri_data, 
				  ex_image_rep=ex_image_rep)

In [None]:
print_predictions_one_rep(model, 0, x_test, 
				  glob_img_names=glob_img_names,
				  glob_img_clip_reps=glob_img_clip_reps,
				  glob_captions=glob_captions,
				  glob_caption_reps=glob_captions_reps,
				  ex_fmri_data=ex_fmri_data, 
				  ex_image_rep=ex_image_rep,k=5)

---
### Model without Pretraining

#### Training

In [None]:
# Create DataLoader
batch_size_train = 128
train_dataset = TensorDataset(x_train.to(device), y_train.to(device))
train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
batch_size_test = 100
test_dataset = TensorDataset(x_test.to(device), y_test.to(device))
test_loader = DataLoader(test_dataset, batch_size=batch_size_test)

model = Autoencoder().to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

loss_fn_glove = CosineProximityLoss()
criterion_glove = CosineProximityLoss()
loss_fn_autoencoder = CosineProximityLoss()
criterion_accuracy = evaluate_gt_feature_in_top5

epochs = 15
scheduler = StepLR(optimizer, step_size=5, gamma=0.25)

In [None]:
with torch.no_grad():
	accuracy_test = 0
	loss_glove = 0
	cnt = 0
	for inputs, targets_glove in test_loader:
		out, outputs_glove, outputs_class = model(inputs)
		loss_glove += criterion_glove(outputs_glove, targets_glove)
		cnt += 1
print(loss_glove/cnt)

In [None]:
for epoch in range(epochs):
	model.train()
	running_loss = 0.0
	running_loss_glove = 0.0
	running_loss_autoencoder = 0.0
	cnt2 = 0
	for batch_data, batch_glove in tqdm.tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
		optimizer.zero_grad()
		out, output_glove, output_class = model(batch_data)
		loss_glove = loss_fn_glove(output_glove, batch_glove)
		loss_autoencoder = loss_fn_autoencoder(batch_data, out)
		loss = loss_autoencoder + loss_glove
		loss.backward()
		loss.dtype
		optimizer.step()

		running_loss_glove += loss_glove.item() * batch_data.size(0)
		running_loss_autoencoder += loss_autoencoder.item() * batch_data.size(0)
		cnt2+=1
	
	epoch_loss_glove = running_loss_glove / len(train_loader.dataset)
	epoch_loss_autoencoder = running_loss_autoencoder / len(train_loader.dataset)
	
	with torch.no_grad():
		for inputs, targets_glove in test_loader:
			out, outputs_glove, outputs_class = model(inputs)
			loss_glove = criterion_glove(outputs_glove, targets_glove)

	running_loss = epoch_loss_glove + epoch_loss_autoencoder
	average_loss = running_loss / cnt2
	print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {average_loss:.4f}, Glove Loss: {loss_glove:.4f}')
	scheduler.step()

---
#### Results

In [None]:
avg_accuracy_test, avg_loss_glove, avg_rank, plot_ranks = evaluate_model(model, test_loader, criterion_accuracy, criterion_glove, get_ranks, get_mean_ranks, img_test_names)

# print(accuracy_test/cnt, loss_glove/cnt, rank/cnt)
print("Top 5 Accuracy:", avg_accuracy_test)
print("Glove Loss:", avg_loss_glove)
print("Mean Rank:", avg_rank)

# plot the ranks
plot_ranks_distribution(plot_ranks)

In [None]:
print_predictions(model, indexes, x_test, 
				  glob_img_names=glob_img_names,
				  glob_img_clip_reps=glob_img_clip_reps,
				  glob_captions=glob_captions,
				  glob_caption_reps=glob_captions_reps,
				  ex_fmri_data=ex_fmri_data, 
				  ex_image_rep=ex_image_rep)

In [None]:
print_predictions_one_rep(model, 0, x_test, 
				  glob_img_names=glob_img_names,
				  glob_img_clip_reps=glob_img_clip_reps,
				  glob_captions=glob_captions,
				  glob_caption_reps=glob_captions_reps,
				  ex_fmri_data=ex_fmri_data, 
				  ex_image_rep=ex_image_rep,k=5)

---
### Small Model

#### Model class

In [None]:
FMRI_size = roi_data_sum

In [None]:
class EncDecSmallModel(nn.Module):
	def __init__(self, rate=0.3, glove_size=512, dense_size=1000, fMRI_size=FMRI_size):
		super(EncDecSmallModel, self).__init__()
		
		self.dense_first = nn.Sequential(nn.Linear(fMRI_size, dense_size),nn.LeakyReLU(0.3),nn.BatchNorm1d(dense_size),nn.Dropout(rate))
		self.out_glove = nn.Linear(dense_size, glove_size)
		
	def forward(self, x):
		x = self.dense_first(x)
		out_glove = self.out_glove(x)
		return out_glove

In [None]:
# Create DataLoader
batch_size = 32
train_dataset = TensorDataset(_x_train, _y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(x_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=128)

# Instantiate the model
model = EncDecSmallModel().to(device)

# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0005)

# Define loss function
loss_fn_glove = CosineProximityLoss()
criterion_glove = evaluate_gt_feature_in_top5

# Define number of epochs
epochs = 25

# Define learning rate scheduler
scheduler = StepLR(optimizer, step_size=5, gamma=0.25)

# Training loop
for epoch in range(epochs):
	model.train()
	running_loss = 0.0
	running_loss_glove = 0.0
	running_loss_class = 0.0
	cnt2 = 0
	for batch_data, batch_glove in tqdm.tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
		optimizer.zero_grad()
		output_glove = model(batch_data)
		loss_glove = loss_fn_glove(output_glove, batch_glove)
		loss = loss_glove
		loss.backward()
		loss.dtype
		optimizer.step()
		
		running_loss_glove += loss_glove.item() * batch_data.size(0)
		cnt2+=1
	
	epoch_loss_glove = running_loss_glove / len(train_loader.dataset)

	cnt = 0
	pair = 0
	run_acc = 0
	avg_cos_dist = 0
	
	with torch.no_grad():
		for inputs, targets_glove in test_loader:
			outputs_glove = model(inputs)
			# Calculate accuracy for classification
			# print(inputs.shape,' ',targets_class.shape,' ',outputs_class.shape)
			# _, predicted_class = torch.max(outputs_class, 1)

			# validation loss
			# print(outputs_glove.shape,' ',targets_glove.shape)
			accuracy_test = criterion_glove(outputs_glove[:8],targets_glove[:8])
			cos_dist = 0
			for i in range(outputs_glove.shape[0]):
				cos_dist += torch.nn.functional.cosine_similarity(outputs_glove[i],targets_glove[i],dim=0)
			
			cnt+=1
			pair += loss_glove
			run_acc += accuracy_test / outputs_glove.shape[0]
			avg_cos_dist += cos_dist / outputs_glove.shape[0]

	# Calculate average accuracy
	# average_accuracy = total_accuracy / len(validation_dataset)
	pair = pair /cnt
	running_loss = epoch_loss_glove
	average_loss = running_loss / cnt2
	acc = run_acc / cnt
	avg_cos_dist = avg_cos_dist / cnt
	# print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {average_loss:.4f}, Average Top 1 Accuracy: {average_a1:.4f}, Average Top 5 Accuracy: {average_a5:.4f}, Average Top 10 Accuracy: {average_a10:.4f}')
	# print(f'Pairwise Loss: {pair}')
	print(f"Epoch {epoch+1}/{epochs}, Glove Loss: {epoch_loss_glove}, Pairwise Loss: {pair}, Top5 Accuracy: {acc}, cosine distance from gt on avg: {avg_cos_dist}")
	scheduler.step()

---
#### Results

In [None]:
avg_accuracy_test, avg_loss_glove, avg_rank, plot_ranks = evaluate_model(model, test_loader, criterion_accuracy, criterion_glove, get_ranks, get_mean_ranks, img_test_names,model_type="small")

# print(accuracy_test/cnt, loss_glove/cnt, rank/cnt)
print("Top 5 Accuracy:", avg_accuracy_test)
print("Glove Loss:", avg_loss_glove)
print("Mean Rank:", avg_rank)

# plot the ranks
plot_ranks_distribution(plot_ranks)

In [None]:
print_predictions(model, indexes, x_test, 
				  glob_img_names=glob_img_names,
				  glob_img_clip_reps=glob_img_clip_reps,
				  glob_captions=glob_captions,
				  glob_caption_reps=glob_captions_reps,
				  ex_fmri_data=ex_fmri_data, 
				  ex_image_rep=ex_image_rep,
				  model_type="small")

In [None]:
print_predictions_one_rep(model, 0, x_test, 
				  glob_img_names=glob_img_names,
				  glob_img_clip_reps=glob_img_clip_reps,
				  glob_captions=glob_captions,
				  glob_caption_reps=glob_captions_reps,
				  ex_fmri_data=ex_fmri_data, 
				  ex_image_rep=ex_image_rep,k=5,model_type="small")