# Testing the McGurk effect on Machine Learning models

### Defining the McGurk effect experiments

In [None]:
from experiments import McGurkExperiment

# Instantiate the list of experiments
experiments = [
    McGurkExperiment("ba", "ga", "da"), # ba (auditory) + ga (visual) = da  (fusioned sound)
    McGurkExperiment("ba", "fa", "va"), # ba (auditory) + fa (visual) = va  (fusioned sound)
    McGurkExperiment("ga", "ba", "bga") # ga (auditory) + ba (visual) = bga (combined sound)
]

### Ensuring reproducible experiment results

In [None]:
# Set the seeds for the experiments to ensure reproducible results
import torch
torch.manual_seed(42)
import random
random.seed(0)
import numpy as np
np.random.seed(0)

## Testing the effect on pretrained PerceiverIO models with regression mapping

### Training the models

In [None]:
from models import McGurkPerceiver

perceiver_models = []
for experiment in experiments:
    # Instantiate a Perceiver model for the given experiment
    model = McGurkPerceiver(experiment)
    perceiver_models.append(model)

for model in perceiver_models:
    print(model.name())
    # Train the models
    _, _, _, _ = model.train(epochs=100000, learning_rate=0.003, train_with_masks=True)

### Generating the predictions

In [None]:
model_predictions = []

# Test the models on McGurk effect videos
for model in perceiver_models:
    print(model.name())
    predictions = model.test()
    model_predictions.append(predictions)
    # Print the results
    print(predictions)

### Results

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# TODO: Plot the average confidence scores for each normal sample of each experiment -> If possible, with test set
# TODO: Plot also for McGurk samples
for prediction in model_predictions:
    # prediction will be a (n_samples, 3) sized tensor
    average_confidence = torch.mean(prediction, dim=0)
    auditory_confidence, visual_confidence, mcgurk_confidence = [float(c) for c in average_confidence]
    print(f"average auditory : {auditory_confidence}")
    print(f"average visual : {visual_confidence}")
    print(f"average mcgurk : {mcgurk_confidence}")
    
    

# also maybe TODO: plot the confidence increase from normal samples to mcgurk samples for the mcgurk syllable (if it's interesting) -> maybe on a logscale
# and #TODO at home with PC, test my shiny aggregate function, and the masked pipeline, and many steps and aggresive learning rate 