In [None]:
import torch

# Check if CUDA is available and set the device to GPU if it is
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
import transformers
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer
from transformers import DataCollatorForLanguageModeling

from peft import get_peft_model, LoraConfig, TaskType, PeftConfig, PeftModel

from datasets import load_dataset, DatasetDict

import numpy as np

import torch

from sklearn.model_selection import KFold, cross_val_predict, GridSearchCV

from sklearn.decomposition import PCA
from sklearn.linear_model import Ridge

from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr
from sklearn.metrics.pairwise import cosine_similarity

import time

#### Import the dataset

In [None]:
# Specify the dataset name
dataset_name = "helena-balabin/pereira_fMRI_sentences"

# Specify the path to save or load the dataset
save_path = "./data"

# Load the dataset, use the cache if available
pereira_dataset = load_dataset(dataset_name, cache_dir=save_path)

In [None]:
type(pereira_dataset)

In [None]:
# modelname = "bert-base-uncased"
modelname = "gpt2"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(modelname)

In [None]:
transformers.logging.set_verbosity_info()

##### Preprocessing Function 1 - Map the data to the tokenizer function

In [None]:
def preprocess_function(tokenizer, examples):
    return tokenizer([" ".join(x) for x in examples["sentences"]])

In [None]:
from functools import partial

partial_tokenize_function = partial(preprocess_function, tokenizer)

tokenized_pereira = pereira_dataset.map(
    partial_tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=pereira_dataset['train'].column_names,
)

In [None]:
tokenized_pereira

##### Tokenizer Function 2 - Divide the dataset into blocks of block size. Drop the remainder if the length of the dataset is not fully divisible to the block size.

In [None]:
def group_texts(examples):
    block_size = 128

    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
preprocessed_dataset = tokenized_pereira.map(group_texts, batched=True, num_proc=4)

##### Data Collator Function for (Causal) LM. This function will ensure that for each token, we have the following token respective to it as it's label/target.

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

#### Import the LoRA library from PEFT. Set it's parameters and load the model optimized using LoRA

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    inference_mode=False, 
    r=8,
    lora_alpha=32, 
    lora_dropout=0.1)

We can see the reduced number of parameters below

In [None]:
model_without_peft = AutoModelForCausalLM.from_pretrained(modelname)
# model_without_peft = DebertaV2ForMaskedLM.from_pretrained(modelname)

model = get_peft_model(model_without_peft, peft_config)

model.print_trainable_parameters()
print(next(model.parameters()).device)

If the tokenizer doesn't have a padding token by default, use End of Sequence Token. If it also doesn't have that, then we have to use a Separator or a Classification token...

In [None]:
# tokenizer.pad_token = tokenizer.cls_token
tokenizer.pad_token = tokenizer.eos_token

tokenizer.pad_token

Ensure that we are running the model on Gpu and not on Cpu

In [None]:
print(next(model.parameters()).device)

In [None]:
model.to(device)

In [None]:
print(next(model.parameters()).device)

In [None]:
import random

def train_test_split(dataset, test_size=0.2, seed=None):
    """
    Splits a Hugging Face dataset into training and testing sets.
    
    Args:
    dataset (Dataset): The dataset to split.
    test_size (float): The proportion of the dataset to include in the test split (between 0 and 1).
    seed (int, optional): A seed for random shuffling for reproducibility.

    Returns:
    tuple: Two datasets, the first being the training set and the second the testing set.
    """
    # Shuffle the dataset
    if seed is not None:
        random.seed(seed)
        shuffled_indices = random.sample(range(len(dataset)), len(dataset))
    else:
        shuffled_indices = list(range(len(dataset)))

    # Calculate the split index
    split_index = int(len(dataset) * (1 - test_size))

    # Split the dataset
    train_indices = shuffled_indices[:split_index]
    test_indices = shuffled_indices[split_index:]

    train_dataset = dataset.select(train_indices)
    test_dataset = dataset.select(test_indices)

    return train_dataset, test_dataset

In [None]:
preprocessed_dataset

In [None]:
train_set, test_set = train_test_split(preprocessed_dataset["train"], test_size=0.2, seed=42)

# Create a new DatasetDict with the new splits
final_dataset = DatasetDict({
    'train': train_set,
    'test': test_set
})

In [None]:
final_dataset

#### Set the Training Arguments

In [None]:
training_args = TrainingArguments(
    output_dir=f"mymodels/{modelname}-conference",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    push_to_hub=True,
    report_to="all",
    logging_dir='./logs',            
    logging_steps=100,
)

#### Finally create the Trainer class and train the model

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=final_dataset["train"],
    eval_dataset=final_dataset["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

In [None]:
# repo_name = "alitolga/bert-base-uncased-conference"
repo_name = "alitolga/gpt2-conference"

In [None]:
config = PeftConfig.from_pretrained(repo_name)

base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)

model = PeftModel.from_pretrained(base_model, repo_name, config=config)
# model = PeftModel.from_pretrained(base_model, repo_name)

In [None]:
model.print_trainable_parameters()

In [None]:
sentences = pereira_dataset["train"]["sentences"]
sentences = sentences[0] # 0th subject
print(len(sentences))

#### Get the sentence embeddings from the Peft model

In [None]:
def get_embeddings(sentence):
    inputs = tokenizer(sentence, return_tensors='pt', truncation=True, padding=True)
    outputs = model(**inputs, output_hidden_states=True)
    
    hidden_states = outputs.hidden_states

    embeddings = torch.mean(hidden_states[0], dim=1)

    return embeddings

In [None]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token

In [None]:
embeddings = get_embeddings(sentences)
print(embeddings.shape)

#### Do the Brain Decoding Part

In [None]:
# Get the voxels. For simplicity we start with all the brain regions
fmri_data = pereira_dataset["train"]["all"]

# fMRI data of the first subject out of 8
voxels = np.array(fmri_data[0])
print(voxels.shape)

In [None]:
# Normalize the embeddings
# embeddings_normed = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)

In [None]:
embeddings = embeddings.detach().numpy()

In [None]:
# Prepare nested CV.
# Inner CV is responsible for hyperparameter optimization;
# Outer CV is responsible for prediction.

n_folds = 5

state = int(time.time())
inner_cv = KFold(n_splits=n_folds, shuffle=True, random_state=state)
outer_cv = KFold(n_splits=n_folds, shuffle=True, random_state=state)

# Final data prep: normalize.
X = voxels - voxels.mean(axis=0)
X = X / np.linalg.norm(X, axis=1, keepdims=True)
Y = embeddings - embeddings.mean(axis=0)
Y = Y / np.linalg.norm(Y, axis=1, keepdims=True)

In [None]:
######## Run learning.

n_jobs = 4

# Candidate ridge regression regularization parameters.
ALPHAS = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e1]

# Run inner CV.
gs = GridSearchCV(Ridge(fit_intercept=False),
                {"alpha": ALPHAS}, cv=inner_cv, n_jobs=n_jobs, verbose=10)

"""
Purpose of This Line

Nested Cross-Validation:

The use of cross_val_predict with GridSearchCV (gs in this context) as the estimator 
is a part of a nested cross-validation strategy. 
The key purpose here is to evaluate the model's performance in a way that is as unbiased as possible.

Independent Data Splits:

The outer cross-validation (cv=outer_cv) splits the dataset into training and test sets multiple times 
(based on the number of folds in outer_cv). For each of these splits, 
the inner cross-validation (within GridSearchCV) finds the best alpha value. 

This process ensures that the choice of hyperparameters (alpha in this case) is not biased by the 
particular split of data used for model training and evaluation.

Generating Unbiased Predictions:

cross_val_predict does not simply fit the model but generates predictions for each point 
when it is in the test set of the outer cross-validation. 
These predictions are made by a model that has never seen the data point during training, 
thereby providing an unbiased estimate of the model's performance on unseen data.
"""

In [None]:
# Run outer CV.
decoder_predictions = cross_val_predict(gs, X, Y, cv=outer_cv)

In [None]:
print(decoder_predictions.shape)
print(Y.shape)

##### Implementation of Pairwise Accuracy Functions

In [None]:
from sklearn.base import BaseEstimator, clone
from scipy.spatial.distance import cosine

def pairwise_accuracy(
    estimator: BaseEstimator = None,
    X: torch.Tensor = None,  # noqa
    y: torch.Tensor = None,
    topic_ids: torch.Tensor = None,
    scoring_variation: str = None,  # type: ignore
) -> float:
    """Calculate the average pairwise accuracy of all pairs of true and predicted vectors.

    Based on the pairwise accuracy as defined in Oota et al. 2022, Sun et al. 2021, Pereira et al. 2018.

    :param estimator: Estimator object (e.g., a Ridge regression)
    :type estimator: BaseEstimator
    :param X: Sentence embeddings used as a basis to predict MRI vectors with the estimator
    :type X: torch.Tensor
    :param y: True MRI vectors
    :type y: torch.Tensor
    :param topic_ids: Topic IDs for each paragraph
    :type topic_ids: torch.Tensor
    :param scoring_variation: Variation of the scoring function, defaults to None
    :type scoring_variation: str
    :return: Average pairwise accuracy from all possible sentence pairs
    :rtype: float
    """
    pred = estimator.predict(X)  # noqa

    if scoring_variation == "same-topic":
        # Calculate pairwise accuracy for same-topic sentences
        res = [
            cosine(pred[i], y[i]) + cosine(pred[j], y[j]) < cosine(pred[i], y[j]) + cosine(pred[j], y[i])
            for i in range(len(X))
            for j in range(i + 1, len(X)) if topic_ids[i] == topic_ids[j]
        ]
    elif scoring_variation == "different-topic":
        # Calculate pairwise accuracy for different-topic sentences
        res = [
            cosine(pred[i], y[i]) + cosine(pred[j], y[j]) < cosine(pred[i], y[j]) + cosine(pred[j], y[i])
            for i in range(len(X))
            for j in range(i + 1, len(X)) if topic_ids[i] != topic_ids[j]
        ]
    else:
        # See for all possible sentence pairings: Is the distance between the correct matches of predicted and X
        # sentences smaller than the distance between pairings of X and predicted vectors from different sentences?
        res = [
            cosine(pred[i], y[i]) + cosine(pred[j], y[j]) < cosine(pred[i], y[j]) + cosine(pred[j], y[i])
            for i in range(len(X))
            for j in range(i + 1, len(X))
        ]

    # Return the fraction of instances for which the condition holds versus all possible pairs
    return sum(res) / len(res)


def pearson_scoring(
    estimator: BaseEstimator = None,
    X: torch.Tensor = None,  # noqa
    y: torch.Tensor = None,
) -> float:
    """Calculate the average pearson correlation for the given set of true and predicted MRI vectors.

    :param estimator: Estimator object (e.g., a Ridge regression)
    :type estimator: BaseEstimator
    :param X: Sentence embeddings used as a basis to predict MRI vectors with the estimator
    :type X: torch.Tensor
    :param y: True MRI vectors
    :type y: torch.Tensor
    :return: Average pearson correlation from all pairs of predicted and true MRI vectors
    :rtype: float
    """
    pred = estimator.predict(X)  # noqa

    # See for all possible sentence pairings: Is the distance between the correct matches of predicted and X
    # sentences smaller than the distance between pairings of X and predicted vectors from different sentences?
    res = [pearsonr(t, p).statistic for t, p in zip(y, pred)]

    # Return the fraction of instances for which the condition holds versus all possible pairs
    return np.mean(res)  # noqa

In [None]:
len(X)

In [None]:
Y[0].shape

In [None]:
######### Evaluate.

Y_flatten = Y.flatten()
pred_flatten = decoder_predictions.flatten()

# Evaluate the performance (e.g., using mean squared error)
mse = mean_squared_error(Y, decoder_predictions)
print(f"Mean Squared Error: {mse}")

r2 = r2_score(Y, decoder_predictions)
print(f"R-squared (R2) Score: {r2}")

# Pearson Correlation Coefficient
res = [pearsonr(t, p).statistic for t, p in zip(Y, decoder_predictions)]
pearson_corr = np.mean(res)
print(f"Pearson Correlation Coefficient: {pearson_corr}")

# Cosine Similarity
cosine_sim = np.mean(cosine_similarity(decoder_predictions, Y))
print(f"Cosine Similarity: {cosine_sim}")

# Pairwise Accuracy
res = [ cosine(decoder_predictions[i], Y[i]) + cosine(decoder_predictions[j], Y[j]) < cosine(decoder_predictions[i], Y[j]) + cosine(decoder_predictions[j], Y[i])
        for i in range(len(X))
        for j in range(i + 1, len(X))
    ]
pairwise_acc = sum(res) / len(res)
print(f"Pairwise Accuracy: {pairwise_acc}")