In [None]:
import torch 
from torchvision import transforms
from data_load import get_subj_dataset
import numpy as np 
from torch.utils.data import DataLoader
from image_processing import ImageDataset, fit_pca, extract_features_with_pca
from torchvision.models.feature_extraction import create_feature_extractor
from text_processing import get_embeddings, get_coco_info
from transformers import BertTokenizer, BertForSequenceClassification
from evaluation import calculate_corr, get_roi_corr, plot_corr_all_rois
from sklearn.linear_model import LinearRegression

In [None]:
subj = 1
batch_size = 100
n_components = 100
rand_seed = 5 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nsd_stim_info_file_path = 'nsd_stim_info_merged.csv'
coco_annotation_path = 'annotations'
data_dir = 'algonauts_2023_challenge_data'

In [None]:
subj_paths, subj_data = get_subj_dataset(subj, data_dir)

num_train = int(np.round(len(subj_paths.img_list) * 0.9))

idxs = np.arange(len(subj_paths.img_list))
np.random.shuffle(idxs)

idxs_train, idxs_val = idxs[:num_train], idxs[num_train:]
idxs_test = np.arange(len(subj_paths.img_list_test))

print('Training stimulus images: ' + format(len(idxs_train)))
print('\nValidation stimulus images: ' + format(len(idxs_val)))
print('\nTest stimulus images: ' + format(len(idxs_test)))

lh_fmri_train = subj_data.lh_fmri[idxs_train]
rh_fmri_train = subj_data.rh_fmri[idxs_train]
lh_fmri_val = subj_data.lh_fmri[idxs_val]
rh_fmri_val = subj_data.rh_fmri[idxs_val]

In [None]:
# Image features 

transform = transforms.Compose([
    transforms.Resize((224,224)), 
    transforms.ToTensor(), 
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
])

train_imgs_dataloader = DataLoader(
    ImageDataset(subj_paths.img_dir_list, idxs_train, transform), 
    batch_size=batch_size
)
val_imgs_dataloader = DataLoader(
    ImageDataset(subj_paths.img_dir_list, idxs_val, transform), 
    batch_size=batch_size
)
test_imgs_dataloader = DataLoader(
    ImageDataset(subj_paths.img_dir_list_test, idxs_test, transform), 
    batch_size=batch_size
)

model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet')
model.to(device) 
model.eval()
model_layer = "features.5"
feature_extractor = create_feature_extractor(model, return_nodes=[model_layer])
pca = fit_pca(feature_extractor, train_imgs_dataloader)
features_train = extract_features_with_pca(feature_extractor, train_imgs_dataloader, pca)
features_val = extract_features_with_pca(feature_extractor, val_imgs_dataloader, pca)
features_test = extract_features_with_pca(feature_extractor, test_imgs_dataloader, pca)
del model, pca 

In [None]:
# Text features 
model_checkpoint = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_checkpoint)
model = BertForSequenceClassification.from_pretrained(model_checkpoint)
nsd_to_coco, coco_id_to_description = get_coco_info()
train_emb = get_embeddings(subj_paths.img_dir_list, idxs_train, model, tokenizer)
val_emb = get_embeddings(subj_paths.img_dir_list, idxs_val, model, tokenizer) 
test_emb = get_embeddings(subj_paths.img_dir_list_test, idxs_test, model, tokenizer) 
del model, tokenizer 

In [None]:
# Combine features in one model 

X_train = np.hstack([train_emb, features_train])
X_val = np.hstack([val_emb, features_val])

reg_lh = LinearRegression().fit(X_train, lh_fmri_train)
reg_rh = LinearRegression().fit(X_train, rh_fmri_train)

lh_fmri_val_pred = reg_lh.predict(X_val)
rh_fmri_val_pred = reg_rh.predict(X_val)
lh_corr = calculate_corr(lh_fmri_val_pred, lh_fmri_val)
rh_corr = calculate_corr(rh_fmri_val_pred, rh_fmri_val)
plot_corr_all_rois(lh_corr, rh_corr, f'1 model with text and image features')
get_roi_corr(lh_corr, rh_corr)
