# CA3 model

This notebook is for building out the CA3 model as previous implementation doesn't work.

Formula is:
$$
min_{\vec{w}_{x_1}, \vec{w}_{x_2}, \vec{w}_a, \vec{w}_b} \left( -\vec{w}_{x_1}^T S_{xa} \vec{w}_a - \vec{w}_{x_2}^T S_{xb} \vec{w}_b + \sum_{i \in \{x_1, x_2, a, b\}} \frac{1}{2} \lambda_i \left( \vec{w}_i^T S_{ii} \vec{w}_i - 1 \right) + \theta_{rr}(\vec{w}_{x_1}, \vec{w}_{x_2}) \right)
$$


In [2]:
from gemmr.generative_model import GEMMR
import numpy as np

## Generate some data

In [127]:
model_definition = GEMMR('cca', wx=2, wy=50, r_between=0.3)
behavioural_data_study1, imging_data_study1 = model_definition.generate_data(n=200)
behavioural_data_study2, imging_data_study2 = model_definition.generate_data(n=190)
study1 = (behavioural_data_study1, imging_data_study1) 
study2 = (behavioural_data_study2, imging_data_study2)

## Step 1: Covariance matrix

In [86]:
def mean_center(data: np.ndarray) -> np.ndarray:
    """
    Function to demean data.

    Parmeteres
    ----------
    data: np.ndarray
        data to demean

    Returns
    -------
    np.ndarray: array
        demeaned data
    """
    return data - data.mean(axis=0)

In [None]:
def cross_cov(matrix_1: np.ndarray, matrix_2: np.ndarray) -> np.ndarray:
    """
    Function to calculate the cross 
    covariance matrix

    Parameters
    ----------
    matrix_1: np.ndarray
        A matrix tht should 
        correspond to subject by 
        features (behavioural data)
    matrix_2: np.ndarray
        A matrix that should 
        correspond to features by
        feautres (imging data)

    Returns
    -------
    np.ndarray: array
        array of cross covariance matrix
    """
    return (matrix_1.T @ matrix_2) / matrix_1.shape[0] 

In [121]:
def data_able_to_process(study_pair: tuple, behav_data: np.ndarray, img_data: np.ndarray) -> bool:
    """
    Function to check that data
    is in correct format to be processed

    Parameters
    ----------
     study_pair: tuple, 
         tuple of behavioural data 
         and imging data
     behav_data: np.ndarray
         array of behav_data 
     img_data: np.ndarray
         array of img_data
    
    Returns
    -------
    bool: boolean
        bool of if failed or not
    """
    if not isinstance(study_pair, (tuple, list)) or len(study_pair) != 2:
        print("Given argument isn't a pair of datasets")
        return False
    if not isinstance(behav_data, np.ndarray) or not isinstance(img_data, np.ndarray):
        print("Data provided isn't a numpy array")
        return False
    if behav_data.shape[0] == 0 or img_data.shape[0] == 0 or behav_data.shape[0] != img_data.shape[0]:
        print(f"Mismatch between ({behav_data.shape[0]} and {img_data.shape[0]})")

    return True

In [132]:
def calculate_covariance_matricies(*study_pairs):
    """
    Calculates within-study covariance matrices (behav-behav, img-img, behav-img, img-behav)
     for a given set of study pairs.

    Parameters
    ----------
    study_pairs: tuple
        a tuple or list containing two numpy arrays:
        (behavioural_data, imaging_data).
        Assumes data is (subjects x features).

    Returns
    -------
    covariance_results: dict
        dictionary of 
        A dictionary where keys indicate the pair of matrices involved for each study
        (e.g., 's_behav1_img1', 's_behav1_behav1', 's_img1_img1', 's_img1_behav1') and values
        are the corresponding covariance matrices.
        Returns an empty dictionary if no valid pairs are provided or calculated.

    """
    covariance_results = {}
    for idx, study_pair in enumerate(study_pairs):
        behav_data, img_data = study_pair
        if not data_able_to_process(study_pair, behav_data, img_data):
            continue
        behav_data = mean_center(behav_data)
        img_data = mean_center(img_data)
        study_num = idx + 1
        try:
            covariance_results[f"s_behav{study_num}_behav{study_num}"] = cross_cov(behav_data, behav_data)
            covariance_results[f"s_img{study_num}_img{study_num}"] = cross_cov(img_data, img_data)
            covariance_results[f"s_img{study_num}_behav{study_num}"] = cross_cov(img_data, behav_data)

        except Exception as e:
            print(f"Error calculating covariances for Study {study_num}: {e}")
            return None
    return covariance_results

In [133]:
covariance_mat = calculate_covariance_matricies(study1, study2)

In [135]:
s_x1a = cross_cov(behavioural_data_study1, imging_data_study1)
s_x2b = cross_cov(behavioural_data_study2, imging_data_study2)
s_x1x1 = cross_cov(behavioural_data_study1, behavioural_data_study1)
s_x2x2 = cross_cov(behavioural_data_study2, behavioural_data_study2)
s_aa = cross_cov(imging_data_study1, imging_data_study1)
s_bb = cross_cov(imging_data_study2, imging_data_study2)

## Step 2. Intialization of weights 

In [184]:
def weight_intialization(*weights) -> np.ndarray:
    """
    Define a set of random starting 
    weights
    
    Parameters
    ----------
    weights: tuple(int)
        tuple of set amount
        of int values
    
    Returns
    -------
    np.ndarrray
        array of numpy values
    """
    return np.random.randn(sum(weights))

In [185]:
weights_0 = weight_intialization(
    behavioural_data_study1.shape[1], 
    behavioural_data_study2.shape[1],
    imging_data_study1.shape[1],
    imging_data_study2.shape[1]
    )

## Step 3. Objective function

In [191]:
dx1_shape = s_x1x1.shape[0]
dx2_shape = s_x2x2.shape[0]
da_shape = s_aa.shape[0]
db_shape = s_bb.shape[0]
dx_shape = dx1_shape + dx2_shape
dac_shape =  dx_shape + da_shape


def get_weights(weight_array, dx1_shape, dx_shape, dac_shape):
    return {
        "wx1": weight_array[:dx1_shape],
        "wx2": weight_array[dx1_shape:dx_shape],
        "wa": weight_array[dx_shape:dac_shape],
        "wb": weight_array[dac_shape:]
    }
wx1 = weights_0[:dx1_shape]
wx2 = weights_0[dx1_shape:dx_shape]
wa = weights_0[dx_shape:dac_shape]
wb = weights_0[dac_shape:]

In [190]:
def cross_cov_term(weight_beh, cov_mat, weight_img):
    return -weight_beh.T @ cov_mat @ weight_img

In [None]:
def regularization_term(weight, cov_mat):
    return 0.5 * 1.0 * (weight.T @ cov_mat @ weight - 1)