### Setup

In [1]:
from typing import Union
import warnings
from random_word import RandomWords
from tqdm import tqdm
from pathlib import Path
import random
from datetime import datetime, timezone
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import MetricCollection

# Disable warnings for low Pearson correlation
warnings.filterwarnings(
    action='ignore',
    category=UserWarning,
    message="The variance of predictions or target "
            "is close to zero. This can cause instability"
            " in Pearson correlationcoefficient, leading"
            " to wrong results.")

# Set random seed (before model initialization)
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Enable TensorFloat32 tensor cores for float32 matrix multiplication
torch.set_float32_matmul_precision('high')

### Standalone Metric

In [2]:
"""
Adapted from `torchmetrics.PearsonCorrCoef <https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/regression/pearson.py>`__.
"""
from typing import Optional, Union, Any, Tuple
from jaxtyping import Float, Float64, Int, Int64, Bool
import plotly.graph_objects as go
import torch
from torchmetrics import Metric
from torchmetrics.regression.pearson import _final_aggregation as _pearson_corrcoef_final_aggregation
from torchmetrics.functional.regression.pearson import _pearson_corrcoef_compute
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE


def safe_divide(numerator: torch.Tensor, denominator: torch.Tensor):
    return torch.where(denominator != 0, numerator / denominator, torch.tensor(0.0))


def _pearson_corrcoef_update(
    preds: Float64[torch.Tensor, "n_samples n_neurons"],
    target: Float64[torch.Tensor, "n_samples n_neurons"],
    mask: Bool[torch.Tensor, "n_samples n_neurons"],
    mean_x: Float64[torch.Tensor, "n_neurons"],
    mean_y: Float64[torch.Tensor, "n_neurons"],
    var_x: Float64[torch.Tensor, "n_neurons"],
    var_y: Float64[torch.Tensor, "n_neurons"],
    corr_xy: Float64[torch.Tensor, "n_neurons"],
    n_total: Int64[torch.Tensor, "n_neurons"],
) -> Tuple[
    Float64[torch.Tensor, "n_neurons"],
    Float64[torch.Tensor, "n_neurons"],
    Float64[torch.Tensor, "n_neurons"],
    Float64[torch.Tensor, "n_neurons"],
    Float64[torch.Tensor, "n_neurons"],
    Int64[torch.Tensor, "n_neurons"]
]:
    """Update and returns variables required to compute Pearson Correlation Coefficient,
        for subset of population neurons

    Args:
        preds: estimated scores
        target: ground truth scores
        mask: binary mask indicating which samples to include in update
        mean_x: current mean estimate of x tensor
        mean_y: current mean estimate of y tensor
        var_x: current variance estimate of x tensor
        var_y: current variance estimate of y tensor
        corr_xy: current covariance estimate between x and y tensor
        n_total: current number of observed observations
    
    NOTE: Adapted (for population and masking) from :func:`torchmetrics.functional.regression.pearson._pearson_corrcoef_update`
    """
    # Count obs
    num_obs = mask.sum(0)
    # Running mean updaes
    mx_new = safe_divide(n_total * mean_x + (preds * mask).sum(0), n_total + num_obs)
    my_new = safe_divide(n_total * mean_y + (target * mask).sum(0), n_total + num_obs)
    # Update obs counts
    n_total += num_obs
    # Running variance updates
    var_x += ((preds - mx_new) * (preds - mean_x) * mask).sum(0)
    var_y += ((target - my_new) * (target - mean_y) * mask).sum(0)
    corr_xy += ((preds - mx_new) * (target - my_new) * mask).sum(0)

    return mx_new, my_new, var_x, var_y, corr_xy, n_total


class MaskedPopulationPearsonCorrCoef(Metric):
    is_differentiable: bool = True
    higher_is_better: Optional[bool] = True
    full_state_update: bool = True
    plot_lower_bound: float = -1.0
    plot_upper_bound: float = 1.0

    mean_x: Float64[torch.Tensor, "population_size"]
    mean_y: Float64[torch.Tensor, "population_size"]
    var_x: Float64[torch.Tensor, "population_size"]
    var_y: Float64[torch.Tensor, "population_size"]
    corr_xy: Float64[torch.Tensor, "population_size"]
    n_total: Int64[torch.Tensor, "population_size"]

    def __init__(
        self,
        population_size: int,
        masked: Optional[bool] = None,
        # override default with `False` to enable plotting distribution
        compute_with_cache: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(compute_with_cache=compute_with_cache, **kwargs)
        # Initialize states
        self.add_state("mean_x", default=torch.zeros(population_size, dtype=torch.float64), dist_reduce_fx=None)
        self.add_state("mean_y", default=torch.zeros(population_size, dtype=torch.float64), dist_reduce_fx=None)
        self.add_state("var_x", default=torch.zeros(population_size, dtype=torch.float64), dist_reduce_fx=None)
        self.add_state("var_y", default=torch.zeros(population_size, dtype=torch.float64), dist_reduce_fx=None)
        self.add_state("corr_xy", default=torch.zeros(population_size, dtype=torch.float64), dist_reduce_fx=None)
        self.add_state("n_total", default=torch.zeros(population_size, dtype=torch.int64), dist_reduce_fx=None)
        # Masking behavior
        self.ignore_mask = (masked is None)
        self.invert_mask = (masked == True)


    def update(
        self,
        preds: Float[torch.Tensor, "batch n_neurons n_samples"],
        target: Float[torch.Tensor, "batch n_neurons n_samples"],
        neurons: Int[torch.Tensor, "batch n_neurons"],
        mask: Float[torch.Tensor, "batch n_neurons n_samples"],
    ) -> None:
        # Invert mask if metric is inverted
        mask = torch.ones_like(mask) if self.ignore_mask else mask.logical_not() if self.invert_mask else mask
        # Increase precision of inputs to avoid overflow
        preds, target = preds.to(torch.float64), target.to(torch.float64)
        # Iterate over samples in batch (NOTE: batches might use overlapping neurons so can't parallelize)
        for preds_block, target_block, neurons_block, mask_block in zip(preds, target, neurons, mask):
            # Calculate new state values for neurons in block
            mean_x, mean_y, var_x, var_y, corr_xy, n_total = _pearson_corrcoef_update(
                preds_block.T, target_block.T, mask_block.T,
                self.mean_x[neurons_block],
                self.mean_y[neurons_block],
                self.var_x[neurons_block],
                self.var_y[neurons_block],
                self.corr_xy[neurons_block],
                self.n_total[neurons_block],
            )
            # Update population state values for subset of neurons
            self.mean_x.scatter_(dim=0, index=neurons_block, src=mean_x)
            self.mean_y.scatter_(dim=0, index=neurons_block, src=mean_y)
            self.var_x.scatter_(dim=0, index=neurons_block, src=var_x)
            self.var_y.scatter_(dim=0, index=neurons_block, src=var_y)
            self.corr_xy.scatter_(dim=0, index=neurons_block, src=corr_xy)
            self.n_total.scatter_(dim=0, index=neurons_block, src=n_total)

    def compute(
        self,
        reduce: bool = True,
        observed_only: bool = False,
        nonnan: bool = False,
    ) -> Union[float, Float[torch.Tensor, "selected_length"]]:
        var_x, var_y, corr_xy, n_total = self.var_x, self.var_y, self.corr_xy, self.n_total
        # Multiple devices, need further reduction
        if self.mean_x.ndim > 1:
            _, _, var_x, var_y, corr_xy, n_total = _pearson_corrcoef_final_aggregation(
                self.mean_x, self.mean_y, self.var_x, self.var_y, self.corr_xy, self.n_total
            )
        # Compute Pearson Correlation Coefficient for entire population
        ret = _pearson_corrcoef_compute(var_x, var_y, corr_xy, n_total)
        # Initialize selection of valid neurons
        selection = torch.ones_like(ret, dtype=torch.bool)
        # Ignore neurons with NAN
        if nonnan or reduce:
            selection = torch.logical_and(selection, ~ret.isnan())
        # Ignore neurons unseen in dataset (i.e. `n_total` updates is zero)
        if observed_only or reduce:
            selection = torch.logical_and(selection, (n_total > 0))
        # Select neurons with mask
        ret = ret[selection]
        if not reduce:
            return ret
        # If reduction specified, take average weighted by number of
        #  observations in the dataset
        valid_n_total = n_total[selection].float()
        weights = valid_n_total / valid_n_total.sum()
        return torch.sum(weights * ret)
    
    def forward(self, *args: Any, **kwargs: Any) -> Any:
        raise NotImplementedError
        
    def plot(self, val: Any, ax: Any,) -> _PLOT_OUT_TYPE:
        """
        Plot's distribution of Pearson Correlation Coefficient values 
         across neurons in population

        TODO: Handle plotting pre-computed values (i.e. `val`)
        """
        # Compute unreduced metric value for population, and prepare for plotting
        val = self.compute(
            reduce=False, observed_only=True, nonnan=True
        ).detach().cpu()
        #   Create figure
        fig = go.Figure()
        #   Create violin plot
        fig.add_trace(go.Violin(
            y=val, 
            box_visible=True, 
            meanline_visible=True, 
            line_color='blue', 
            name='Distribution'
        ))
        # Add lines at metric bounds and annotate optimal bound
        fig.add_shape(
            type='line', 
            x0=0, x1=1, 
            y0=self.plot_lower_bound, 
            y1=self.plot_lower_bound, 
            xref='paper', yref='y', 
            line=dict(dash='dash', color='black',)
        )
        fig.add_shape(
            type='line', 
            x0=0, x1=1, 
            y0=self.plot_upper_bound, 
            y1=self.plot_upper_bound, 
            xref='paper', 
            yref='y', 
            line=dict(dash='dash', color='black')
        )
        fig.add_annotation(
            x=0.2, 
            y=self.plot_upper_bound + .1, 
            text="Optimal \n value", 
            showarrow=False, 
            xanchor='center', 
            yanchor='middle'
        )
        # Calculate y-axis range
        pad = 0.1 * (self.plot_upper_bound - self.plot_lower_bound)
        yaxis_range = [self.plot_lower_bound - pad, self.plot_upper_bound + pad]
        # Update plot layout
        fig.update_layout(
            title=f'Population Distribution of Pearson Correlation Coefficient for '
            f'{"All" if self.ignore_mask else "Masked" if self.invert_mask else "Unmasked"} Samples',
            yaxis_title='Pearson Correlation Coefficient', 
            yaxis=dict(showgrid=True, range=yaxis_range), 
            xaxis=dict(showgrid=True), 
            width=1000, height=1000
        )
        return fig, None

### Setup dataloader and model

In [3]:

from foundation_models.modeling import build_model
from foundation_models.modeling.mtm import MTMPerceiverArgs
from foundation_models.datasets.masked_neuro import MaskedNeuroDataset, MaskedNeuroDatapoint, MaskedNeuroSampler, MaskedNeuroSampler, SubsetSequentialSampler

model_args = MTMPerceiverArgs(
    neural_population_size=8000,
    num_sessions=1,
    num_samples_per_token=4,
    dim_embedding=128,
    num_latent_groups=32,
    latent_group_size=8,
    partial_rot=True,
    context_window_len_s=4.1,
    temporal_precision_s=1.e-3,
    num_heads=8,
    dim_head=32,
    num_blocks=1,
    num_self_attends_per_block=2,
    ffn_hidden_mult=1,
    ffn_multiple_of=128,
    reconstruct_masked_only=False,
    init_std=0.02,
    depth_init=False
)

model = build_model(model_args)

dataset = MaskedNeuroDataset(
    data_root="/mnt/scratch09/foundational_model_data",
    session="0",
    num_samples_per_block=32,
    train_indice_step=4,
    validation_fraction=0.2,
)
dataset.setup()

train_dataloader = DataLoader(
    dataset,
    batch_size=1,
    sampler=MaskedNeuroSampler(dataset.train_indices),
    num_workers=0,
    drop_last=True,
    pin_memory=True,
)

val_dataloader = DataLoader(
    dataset,
    batch_size=1,
    sampler=SubsetSequentialSampler(dataset.validation_indices),
    num_workers=0,
    drop_last=True,
    pin_memory=True,
)

def transfer_batch_to_device(
    batch: MaskedNeuroDatapoint,
    device: Union[str, torch.device] = "cuda",
):
    return MaskedNeuroDatapoint(*(x.to(device) for x in batch))

log_dir = Path('logs') / 'mtm_simple'
log_dir.mkdir(parents=True, exist_ok=True)



### Training Loop

In [None]:
# Initialize metric
train_xcorr = MaskedPopulationPearsonCorrCoef(
    population_size=model_args.neural_population_size,
    masked=None,
)
# Move metric to GPU
train_xcorr.to('cuda')
# Move model to GPU
model.to('cuda')
# Set the model to training mode
model.train()
# Compile model
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
# Configure Optimization
optimizer = optim.Adam(model.parameters(), lr=1.e-3)
# Configure Logger
r = RandomWords()
# Name of general project. TODO: Change to your own!
project_name = 'foundation_models'
run_name = f'{r.get_random_word()}-{r.get_random_word()}-{datetime.now(timezone.utc).strftime("%y%m%d%H%M")}' 
log_dir = log_dir / project_name / run_name
log_dir.mkdir(parents=True, exist_ok=True)
print(f'TensorBoard logging training to {log_dir} ...')
writer = SummaryWriter(log_dir)
# Training loop
epoch_loss, epoch_xcorr = 0., -1.
epoch_progress = tqdm(range(100), desc=f"Training MTM Model")
for epoch in epoch_progress:
    # Initialize the running loss
    running_loss = 0.0
    # Reset metric
    train_xcorr.reset()
    # Epoch loop
    for i,batch in enumerate(train_dataloader):
        # Move data to GPU 
        batch = transfer_batch_to_device(batch, 'cuda')
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        loss, logits = compiled_model(
            batch.responses,
            batch.timestamps,
            batch.neurons,
            batch.mask,
            batch.session,
        ) 
        # Backward pass
        loss.backward()
        # Update parameters
        optimizer.step()  
        # Update loss tracker
        running_loss += loss.item() * batch.responses.size(0)
        # Update metric states
        train_xcorr.update(logits, batch.responses, batch.neurons, batch.mask)
        # Update tqdm description with epoch loss
        epoch_progress.set_postfix(batch=f'{i+1}/{len(train_dataloader)}', loss=epoch_loss, masked_xcorr=epoch_xcorr)

    # Caluclate average loss across epoch
    epoch_loss = running_loss / len(train_dataloader.sampler)
    # Caluclate per-neuron correlation over course of the 
    #  epoch, and average over the population
    epoch_xcorr = train_xcorr.compute()
    # Log the loss and correlation to TensorBoard
    writer.add_scalar('train/loss', epoch_loss, epoch)
    writer.add_scalar('train/masked_xcorr', epoch_xcorr, epoch)

print("Training complete!")
# Cleanup Tensorboard
writer.close()


### Validation Loop

In [None]:
# NOTE: Expects you just ran training above. Use Lightning if you want to load checkpoints etc...
# Initialize metric
val_xcorr = MaskedPopulationPearsonCorrCoef(
    population_size=dataset.num_neurons,
    masked=None,
)
# Move metric to GPU
val_xcorr.to('cuda')
# Move model to GPU
model.to('cuda')
# Set the model to eval mode
model.eval()
# Compile model
compiled_model = torch.compile(model, fullgraph=False, dynamic=True)
# Initialize the running loss
validation_loss = 0.0
# Reset metric
val_xcorr.reset()
# Re-open Logger
print(f'TensorBoard logging validation to {log_dir} ...')
writer = SummaryWriter(f'{log_dir}')
# Disable gradient tracking
with torch.no_grad():
    # Validation loop
    batch_progress = tqdm(val_dataloader, desc=f"Validation Loop")
    for i, batch in enumerate(batch_progress):
        # Move data to GPU 
        batch = transfer_batch_to_device(batch, 'cuda')
        # Forward pass
        loss, logits = compiled_model(
            batch.responses,
            batch.timestamps,
            batch.neurons,
            batch.mask,
            batch.session,
        )
        # Update loss tracker
        validation_loss += loss.item() * batch.responses.size(0)
        # Update metric states
        val_xcorr.update(logits, batch.responses, batch.neurons, batch.mask)
        # Update tqdm description with epoch loss
        batch_progress.set_postfix(loss=loss.item())
# Caluclate average loss across epoch
epoch_loss = validation_loss / len(val_dataloader.dataset)
# Caluclate per-neuron correlation over course of the 
#  epoch, and average over the population
epoch_xcorr = val_xcorr.compute()
# Log the loss and correlation to TensorBoard
writer.add_scalar('validation/loss', epoch_loss, 0)
writer.add_scalar('validation/masked_xcorr', epoch_xcorr, 0)
# Let 'em know!
print(f"Validation complete: Loss = {epoch_loss} | Masked XCorr = {epoch_xcorr}")
# Cleanup Tensorboard
writer.close()
