In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

# Assignment: Implementing and Fitting PRF Models

---
## **Objective**
Implement a **Population Receptive Field (PRF) model** using **TensorFlow, PyTorch, or JAX**. Start with grid search, then refine with gradient descent, and compare methods.

---
## **Background**
A PRF models visual cortex responses as a 2D Gaussian:
$$
f(S_t; \theta) = A \cdot \sum_{x,y} \left[ \exp\left(-\frac{(x - x_0)^2 + (y - y_0)^2}{2\sigma^2}\right) \cdot S_t(x,y) \right] + b
$$
where $\theta = \{x_0, y_0, \sigma, A, b\}$.

---
## **Data**
You are provided with:
- `stimulus`: (n_timepoints, n_pixels)
- `convolved_stimulus`: (n_timepoints, n_pixels)
- `v1_ts`: (n_timepoints, n_voxels)
- `x_coordinates`, `y_coordinates`: (n_pixels_x, n_pixels_y)

---
## **Tasks**

### **1. Stimulus Animation (Warm-up)**
Create an animation of the raw stimulus using `matplotlib.animation.FuncAnimation`.

### **2. Implement PRF Model**
Choose ONE framework (TensorFlow/PyTorch/JAX) and implement:
- A function to compute the 2D Gaussian PRF
- Vectorized prediction of BOLD responses
- Use broadcasting for efficient computation

### **3. Fit PRF Model**
#### **A. Grid Search**
- Define parameter grids for {x₀, y₀, σ, A, b}
- Evaluate using R² score
- Find best-fitting parameters

#### **B. Gradient Descent**
- Initialize with grid search results
- Minimize negative R² using Adam optimizer
- Compare convergence speed

---
### **4. Extend to All Voxels (Bonus)**
- Fit all voxels simultaneously
- Create maps of PRF centers and R² values

### **5. Mexican Hat Model (New Bonus)**
Implement a **Mexican hat model** where:
$$
f(x,y) = A_1 \exp\left(-\frac{(x-x_0)^2 + (y-y_0)^2}{2\sigma_1^2}\right) - A_2 \exp\left(-\frac{(x-x_0)^2 + (y-y_0)^2}{2\sigma_2^2}\right) + b
$$
where $\sigma_1 < \sigma_2$ and $A_1 > A_2 > 0$

Requirements:
1. Implement the Mexican hat function
2. Fit to at least one voxel
3. Compare with standard Gaussian PRF
4. Visualize the center-surround pattern

---
### **6. Brain Data (Advanced Bonus)**
- Come get real brain data
- Apply your best model

---
## **Deliverables**
1. Jupyter notebook with your implementation
2. Stimulus animation
3. PRF visualizations and R² values
4. (Bonus) Mexican hat model comparison
5. (Advanced) Brain data results

---
## **Tips**
- Use broadcasting for efficiency
- Start with single voxel, then extend
- Plot intermediate results
- Vectorize operations


## Load in data

In [None]:
assignment_data = np.load('resources/assignment3.npz')
stimulus = assignment_data['stimulus']  # shape (n_timepoints, n_pixels)
convolved_stimulus = assignment_data['convolved_stimulus']  # shape (n_timepoints, n_pixels)
v1_ts = assignment_data['v1_ts']        # shape (n_timepoints, n_voxels)
x_coordinates = assignment_data['x_coordinates']    # shape (n_pixels_x, n_pixels_y)
y_coordinates = assignment_data['y_coordinates']    # shape (n_pixels_x, n_pixels_y)

## Checkout data

### Timeseries of BOLD data

In [None]:
v1_ts = pd.DataFrame(v1_ts, index=pd.Index(range(v1_ts.shape[0]), name='timepoint'), columns=pd.Index(range(v1_ts.shape[1]), name='voxel'))

In [None]:
good_voxels = [82, 229, 538]
v1_ts[good_voxels].plot()
sns.despine()

plt.savefig(Path('../slides/resources/prf_good_voxels.png').resolve(), transparent=False, bbox_inches='tight', dpi=300)

In [None]:
v1_ts[89].shape

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.subplot(131)
plt.imshow(stimulus[15], origin='lower')
plt.title('Stimulus')

plt.subplot(132)
im = plt.imshow(x_coordinates, origin='lower')
plt.colorbar(im, fraction=0.046)  # Smaller colorbar
plt.title('X Coordinates')

plt.subplot(133)
im = plt.imshow(y_coordinates, origin='lower')
plt.colorbar(im, fraction=0.046)  # Smaller colorbar
plt.title('Y Coordinates')

plt.tight_layout()
plt.show()


# Super simple PRF example

# Define PRF

In [None]:

dx = dy = (2*3.16) / 30
def get_prf(x_coordinates, y_coordinates, mu_x, mu_y, sigma, amplitude, baseline, dx=dx, dy=dy):
    gauss = amplitude * np.exp(-((x_coordinates - mu_x)**2 + (y_coordinates - mu_y)**2) / (2 * sigma**2)) + baseline
    gauss *= dx * dy  # Scale by pixel area
    return gauss

# Example pRF parameters
mu_x = -2.5
mu_y = -1
sigma = 1
amplitude = 1
baseline = 0

plt.imshow(get_prf(x_coordinates, y_coordinates, mu_x, mu_y, sigma, amplitude, baseline), origin='lower')
plt.title(f'Example PRF\n$\mu_x={mu_x}$, $\mu_y={mu_y}$, $\sigma={sigma}$')
plt.colorbar()
plt.savefig(Path('../slides/resources/prf_example.png').resolve(), transparent=True, bbox_inches='tight', dpi=300)

# Define time series

In [None]:
prf = get_prf(x_coordinates, y_coordinates, mu_x, mu_y, sigma, amplitude, baseline)
# predicted_ts = stimulus @ prf.flatten()
# plt.plot(predicted_ts)  # Predicted time series

# Without taking into account HRF
plt.plot((prf[np.newaxis, ...] * stimulus).sum(axis=(1, 2)), label='Prediction without HRF')  # shape (n_timepoints,))

# With taking into account HRF
plt.plot((prf[np.newaxis, ...] * convolved_stimulus).sum(axis=(1, 2)), label='Prediction with HRF')  # shape (n_timepoints,))

In [None]:
from sklearn.metrics import r2_score


x_coords = x_coordinates
y_coords = y_coordinates
voxel_idx = 2

# Example grid search
x_grid = np.linspace(-5, 5, 5)
y_grid = np.linspace(-5, 5, 5)
sigma_grid = np.linspace(1, 5, 3)
amplitude_grid = [0.05, 0.075, 0.01, 0.02]
baseline_grid = np.linspace(-0.5, 0.5, 3)

best_r2 = -np.inf
best_params = None

for mu_x in x_grid:
    for mu_y in y_grid:
        for sigma in sigma_grid:
            for amplitude in amplitude_grid:
                for baseline in baseline_grid:
                    prf = get_prf(x_coords, y_coords, mu_x, mu_y, sigma, amplitude, baseline)
                    predicted = (prf[np.newaxis, ...] * convolved_stimulus).sum(axis=(1, 2))
                    r2 = r2_score(v1_ts[voxel_idx], predicted)
                    if r2 > best_r2:
                        best_r2 = r2
                        best_params = (mu_x, mu_y, sigma, amplitude, baseline)

In [None]:
v1_ts[voxel_idx].plot(label='Actual')
plt.plot(predicted, label='Predicted')