# CA3 model

This notebook is for building out the CA3 model as previous implementation doesn't work.

Formula is:
$$
min_{\vec{w}_{x_1}, \vec{w}_{x_2}, \vec{w}_a, \vec{w}_b} \left( -\vec{w}_{x_1}^T S_{xa} \vec{w}_a - \vec{w}_{x_2}^T S_{xb} \vec{w}_b + \sum_{i \in \{x_1, x_2, a, b\}} \frac{1}{2} \lambda_i \left( \vec{w}_i^T S_{ii} \vec{w}_i - 1 \right) + \theta_{rr}(\vec{w}_{x_1}, \vec{w}_{x_2}) \right)
$$


In [2]:
from gemmr.generative_model import GEMMR
from gemmr.estimators import SVDCCA
import numpy as np
from scipy.optimize import minimize
import itertools
import scipy.stats
from sklearn.cross_decomposition import CCA

## Generate some data

import numpy as np
from sklearn.cross_decomposition import CCA

np.random.seed(0)

n_samples = 200
n_features_x = 10
n_features_y = 8

# Latent variables (shared source of variance)
latent = np.random.randn(n_samples, 3)

# Mix latent variables with some noise
behavioural_data_study1 = latent @ np.random.randn(3, n_features_x) + 0.1 * np.random.randn(n_samples, n_features_x)
imging_data_study1 = latent @ np.random.randn(3, n_features_y) + 0.1 * np.random.randn(n_samples, n_features_y)


In [3]:
model_definition = GEMMR('cca', wx=2, wy=50, r_between=0.3)
behavioural_data_study1, imging_data_study1 = model_definition.generate_data(n=200)
behavioural_data_study2, imging_data_study2 = model_definition.generate_data(n=190)
study1 = (imging_data_study1, behavioural_data_study1) 
study2 = (imging_data_study2, behavioural_data_study2)

In [4]:
r2= []
for val in range(imging_data_study1.shape[1]):
    corr = scipy.stats.pearsonr(behavioural_data_study1[:, 1], imging_data_study1[:, val])[0]
    r2.append(corr)
np.mean(r2)

np.float64(-0.004693325562976217)

## Step 1: Covariance matrix

In [6]:
def mean_center(data: np.ndarray) -> np.ndarray:
    """
    Function to demean data.

    Parmeteres
    ----------
    data: np.ndarray
        data to demean

    Returns
    -------
    np.ndarray: array
        demeaned data
    """
    return data - data.mean(axis=0)

In [7]:
def cross_cov(matrix_1: np.ndarray, matrix_2: np.ndarray) -> np.ndarray:
    """
    Function to calculate 
    covariance matrix

    Parameters
    ----------
    matrix_1: np.ndarray
        A matrix tht should 
        correspond to subject by 
        features
    matrix_2: np.ndarray
        A matrix that should 
        correspond to features by
        feautres 

    Returns
    -------
    np.ndarray: array
        array of cross covariance matrix
    """
    return (matrix_1.T @ matrix_2) / matrix_1.shape[0] 

In [8]:
def data_able_to_process(study_pair: tuple, behav_data: np.ndarray, img_data: np.ndarray) -> bool:
    """
    Function to check that data
    is in correct format to be processed

    Parameters
    ----------
     study_pair: tuple, 
         tuple of behavioural data 
         and imging data
     behav_data: np.ndarray
         array of behav_data 
     img_data: np.ndarray
         array of img_data
    
    Returns
    -------
    bool: boolean
        bool of if failed or not
    """
    if not isinstance(study_pair, (tuple, list)) or len(study_pair) != 2:
        print("Given argument isn't a pair of datasets")
        return False
    if not isinstance(behav_data, np.ndarray) or not isinstance(img_data, np.ndarray):
        print("Data provided isn't a numpy array")
        return False
    if behav_data.shape[0] == 0 or img_data.shape[0] == 0 or behav_data.shape[0] != img_data.shape[0]:
        print(f"Mismatch between ({behav_data.shape[0]} and {img_data.shape[0]})")

    return True

In [104]:
def calculate_covariance_matricies(*study_pairs) -> dict:
    """
    Calculates covariance matrices and auto covariance
    matricies

    Parameters
    ----------
    study_pairs: tuple
        a tuple or list containing two numpy arrays:
        (behavioural_data, imaging_data).
        Assumes data is (subjects x features).

    Returns
    -------
    covariance_results: dict
        dictionary of covariance and auto-covariance matrices

    """
    covariance_results = {}
    for idx, study_pair in enumerate(study_pairs):
        img_data, behav_data  = study_pair
        if not data_able_to_process(study_pair, behav_data, img_data):
            continue
        behav_data = mean_center(behav_data)
        img_data = mean_center(img_data)
        study_num = idx + 1
        try:
            covariance_results[f"s_behav{study_num}_behav{study_num}"] = cross_cov(behav_data, behav_data)
            covariance_results[f"s_img{study_num}_img{study_num}"] = cross_cov(img_data, img_data)
            covariance_results[f"s_img{study_num}_behav{study_num}"] = cross_cov(img_data, behav_data)

        except Exception as e:
            print(f"Error calculating covariances for Study {study_num}: {e}")
            return None
    return covariance_results

In [125]:
covariance_mat = calculate_covariance_matricies(study1, study2)

In [116]:
s_x1a = cross_cov(behavioural_data_study1, imging_data_study1)
s_x2b = cross_cov(behavioural_data_study2, imging_data_study2)
s_x1x1 = cross_cov(behavioural_data_study1, behavioural_data_study1)
s_x2x2 = cross_cov(behavioural_data_study2, behavioural_data_study2)
s_aa = cross_cov(imging_data_study1, imging_data_study1)
s_bb = cross_cov(imging_data_study2, imging_data_study2)

## Step 2. Intialization of weights 

In [12]:
def weight_intialization(*weights) -> np.ndarray:
    """
    Define a set of random starting 
    weights
    
    Parameters
    ----------
    weights: tuple(int)
        tuple of set amount
        of int values
    
    Returns
    -------
    np.ndarrray
        array of numpy values
    """
    return np.random.randn(sum(weights))

In [13]:
dim = {
  '1': behavioural_data_study1.shape[1], 
   '2': behavioural_data_study2.shape[1],
    '3': imging_data_study1.shape[1],
    '4': imging_data_study2.shape[1]
 }

In [14]:
weights_0 = weight_intialization(
    behavioural_data_study1.shape[1], 
    behavioural_data_study2.shape[1],
    imging_data_study1.shape[1],
    imging_data_study2.shape[1]
    )

## Step 3. Objective function

In [15]:
dx1_shape = s_x1x1.shape[0]
dx2_shape = s_x2x2.shape[0]
da_shape = s_aa.shape[0]
db_shape = s_bb.shape[0]
dx_shape = dx1_shape + dx2_shape
dac_shape =  dx_shape + da_shape

def get_dimensions(s_x1x1, s_x2x2, s_aa):
    dx1_shape = s_x1x1.shape[0]
    dx2_shape = s_x2x2.shape[0]
    da_shape = s_aa.shape[0]
    dx_shape =  dx1_shape + dx2_shape
    return {
        'dx1_shape': dx1_shape,
        'dx_shape' : dx_shape,
        'dac_shape':  dx_shape + da_shape
    }

def get_weights(weight_array, dx1_shape, dx_shape, dac_shape):
    return {
        "wx1": weight_array[:dx1_shape],
        "wx2": weight_array[dx1_shape:dx_shape],
        "wa": weight_array[dx_shape:dac_shape],
        "wb": weight_array[dac_shape:]
    }


In [138]:
def cross_cov_term(weight_beh, cov_mat, weight_img):
    return -weight_img.T @ (cov_mat @ weight_beh)

In [139]:
def regularization_term(weight, cov_mat):
    return 0.5 * 1.0 * (weight.T @ (cov_mat @ weight) - 1)

In [None]:
def dissimilarity_penality(theta_r, img_weight1, img_weight2):
    return theta_r * 0.5 * np.sum((img_weight1 - img_weight2) ** 2)


In [19]:
def objective_function(weights, s_x1a, s_x2b, s_x1x1, s_x2x2, s_aa, s_bb, theta_r):
    dimensions = get_dimensions(s_x1x1, s_x2x2, s_aa)
    weights = get_weights(
            weights, 
            dimensions['dx1_shape'], 
            dimensions['dx_shape'], 
            dimensions['dac_shape'])
    term1 = cross_cov_term(weights['wx1'],s_x1a, weights['wa'])
    term2 = cross_cov_term(weights['wx2'],s_x2b, weights['wb'])
    reg_x1 = regularization_term(weights['wx1'], s_x1x1)
    reg_x2 = regularization_term(weights['wx2'], s_x2x2)
    reg_a = regularization_term(weights['wa'], s_aa)
    reg_b = regularization_term(weights['wb'], s_bb)
    theta_r = dissimilarity_penality(theta_r, weights['wa'], weights['wb'])
    return term1 + term2 + reg_x1 + reg_x2 + reg_a + reg_b + theta_r


## Step 4: Minimise

In [20]:
best_loss = float('inf')
optimal_theta_r = None
optimium_model = None

for theta_r in np.logspace(-3, 2, 10):
    weights_0 = weight_intialization(
        behavioural_data_study1.shape[1], 
        behavioural_data_study2.shape[1],
        imging_data_study1.shape[1],
        imging_data_study2.shape[1])  # re-init each time
    res = minimize(
        objective_function,
        weights_0,
        args=(s_x1a, s_x2b, s_x1x1, s_x2x2, s_aa, s_bb, theta_r),
        method='L-BFGS-B'
    )
    if res.status !=0:
        print(res.status)
        continue


    if res.fun < best_loss:
        best_loss = res.fun
        best_theta_r = theta_r
        optimium_model = res

print(f"Best θ_r: {best_theta_r}")
print(f"Best loss: {best_loss}")

1
1
1
Best θ_r: 0.046415888336127795
Best loss: -1.9999999921047418


In [22]:
dimensions = get_dimensions(s_x1x1, s_x2x2, s_aa)
weights = get_weights(optimium_model.x,
            dimensions['dx1_shape'], 
            dimensions['dx_shape'], 
            dimensions['dac_shape'])

In [23]:
# Get projections (scores)
scores_x1 = behavioural_data_study1 @ weights['wx1']
scores_x2 = behavioural_data_study2 @ weights['wx2']
scores_a  = imging_data_study1 @ weights['wa']
scores_b  = imging_data_study2 @ weights['wb']


In [24]:
display(np.corrcoef(scores_x1, scores_a)[0, 1])
display(np.corrcoef(scores_x2, scores_b)[0, 1])
display(np.linalg.norm(weights['wa'] - weights['wb']))

np.float64(0.11538425261898907)

np.float64(0.07889484464506155)

np.float64(0.00014508777456870848)

## Putting it all together

In [None]:
class CA3:
    def __init__(self, theta: float=None, random_seed: int=None):
        self.theta_ = np.logspace(-3, 2, 10) if theta is None else theta
        self.intial_weights_ = None
        self.dims_ = []
        self.best_loss = float('inf')
        self.optimal_theta_r = None
        self.weights_ = None
        self.covariances_ = {}
        random_seed = 42 if random_seed is None else random_seed
        self.rng = np.random.RandomState(random_seed) 

    def fit(self, *data_sets):
        self._calculate_covariance_matricies(*data_sets)
        self._get_dimensions(*data_sets)
        self._weight_intialization(*list(itertools.chain(*self.dims_)))
        self._optimise()

    def transform(self, *data_sets):
        assert self.weights_ is not None, "Model must be fitted before transfomed can be called."
        assert len(data_sets) == len(self.dims_), "Model fitted with different number of datasets."
        
        scores = {}
        correlations = {}
        count = 0
        for (img_data, beh_data), (wx, wb) in zip(data_sets, self.weights_):
            imging_projections = img_data @ wx
            beh_projections = beh_data @ wb
            scores[f'study{count}'] = [imging_projections, beh_projections]
            corr = np.array([np.corrcoef(imging_projections, beh_projections)[0, 1]])
            correlations[f'study{count}'] = corr
            count += 1
    
        return {
              "correlations": correlations, 
              "projections": scores
        }
    
    def fit_transform(self, *data_sets):
        self.fit(*data_sets)
        return self.transform(*data_sets)
    
    def _weight_intialization(self, *weights) -> np.ndarray:
        """
        Define a set of random starting 
        weights
        
        Parameters
        ----------
        weights: tuple(int)
            tuple of set amount
            of int values
        
        Returns
        -------
        np.ndarrray
            array of numpy values
        """ 
        self.intial_weights_ = self.rng.randn(sum(weights))  # Assigning the result to the instance attribute

    def _calculate_covariance_matricies(self, *data_sets) -> dict:
        """
        Calculates covariance matrices and auto covariance
        matricies
    
        Parameters
        ----------
        study_pairs: tuple
            a tuple or list containing two numpy arrays:
            (behavioural_data, imaging_data).
            Assumes data is (subjects x features).
    
        Returns
        -------
        covariance_results: dict
            dictionary of covariance and auto-covariance matrices
    
        """
        for idx, study_pair in enumerate(data_sets):
            img_data, behav_data  = study_pair
            if not self._data_able_to_process(study_pair, behav_data, img_data):
                continue
            behav_data = self._mean_center(behav_data)
            img_data = self._mean_center(img_data)
            study_num = idx + 1
            try:
                self.covariances_[f"s_behav{study_num}_behav{study_num}"] = self._create_covariance_amtrix(behav_data, behav_data)
                self.covariances_[f"s_img{study_num}_img{study_num}"] = self._create_covariance_amtrix(img_data, img_data)
                self.covariances_[f"s_img{study_num}_behav{study_num}"] = self._create_covariance_amtrix(img_data, behav_data)
    
            except Exception as e:
                print(f"Error calculating covariances for Study {study_num}: {e}")

    def _data_able_to_process(self, study_pair: tuple, behav_data: np.ndarray, img_data: np.ndarray) -> bool:
        """
        Function to check that data
        is in correct format to be processed
    
        Parameters
        ----------
         study_pair: tuple, 
             tuple of behavioural data 
             and imging data
         behav_data: np.ndarray
             array of behav_data 
         img_data: np.ndarray
             array of img_data
        
        Returns
        -------
        bool: boolean
            bool of if failed or not
        """
        if not isinstance(study_pair, (tuple, list)) or len(study_pair) != 2:
            print("Given argument isn't a pair of datasets")
            return False
        if not isinstance(behav_data, np.ndarray) or not isinstance(img_data, np.ndarray):
            print("Data provided isn't a numpy array")
            return False
        if behav_data.shape[0] == 0 or img_data.shape[0] == 0 or behav_data.shape[0] != img_data.shape[0]:
            print(f"Mismatch between ({behav_data.shape[0]} and {img_data.shape[0]})")
    
        return True
            
    def _optimise(self):
        if isinstance(self.theta_, (list, np.ndarray)):
            for theta in self.theta_:
                model = self._optimising_model(theta)
                if model.status !=0:
                    continue
                if model.fun < self.best_loss:
                    self.best_loss = model.fun
                    self.optimal_theta_r = theta
                    self.weights_ = self._split_weights(model.x)
        else:
            model = self._optimising_model(theta)
            self.best_loss = model.fun
            self.optimal_theta_r = theta
            self.weights_ = self._split_weights(model.x)


    def _optimising_model(self, theta):
        return minimize(
            self._objective_function,
            self.intial_weights_,
            args=(self.covariances_, theta),
            method='L-BFGS-B'
        )
    
    def _get_dimensions(self, *data_sets):
        self.dims_ = [(behav.shape[1], img.shape[1]) for behav, img in data_sets]
    
    def _split_weights(self, w):
        """
        Splits the flat weight vector w into individual vectors
        for each behavioural and imaging dataset.
        """
        offset = 0
        weights = []
        for img_dim, behav_dim in self.dims_:
            wx = w[offset:offset + img_dim]
            offset +=img_dim  
            wb = w[offset:offset + behav_dim]
            offset += behav_dim
            weights.append((wx, wb))
        return weights 
    
    def _objective_function(self, weights, covariances, theta):
        total_loss = 0
        weights_ = self._split_weights(weights)
        for idx, (wx, wb) in enumerate(weights_):
           
            s_xb = covariances[f"s_img{idx+1}_behav{idx+1}"]
            s_xx = covariances[f"s_img{idx+1}_img{idx+1}"]
            s_bb = covariances[f"s_behav{idx+1}_behav{idx+1}"]
            total_loss += self._cross_cov_term(wb, s_xb, wx) 
            total_loss += self._regularization_term(wx, s_xx)
            total_loss += self._regularization_term(wb, s_bb)
    
        # Similarity penalty across imaging weights
        if theta > 0 and len(weights_) > 1:
            for img_data in range(len(weights_)):
                for next_img_data in range(img_data + 1, len(weights_)):
                    total_loss += dissimilarity_penality(theta, weights_[img_data][0], weights_[next_img_data][0])
    
        return total_loss
    
    def _create_covariance_amtrix(self, matrix_1: np.ndarray, matrix_2: np.ndarray) -> np.ndarray:
        """
        Function to calculate 
        covariance matrix
    
        Parameters
        ----------
        matrix_1: np.ndarray
            A matrix tht should 
            correspond to subject by 
            features
        matrix_2: np.ndarray
            A matrix that should 
            correspond to features by
            feautres 
    
        Returns
        -------
        np.ndarray: array
            array of cross covariance matrix
        """
        return (matrix_1.T @ matrix_2) / matrix_1.shape[0] 
    
    def _mean_center(self, data: np.ndarray) -> np.ndarray:
        """
        Function to demean data.
    
        Parmeteres
        ----------
        data: np.ndarray
            data to demean
    
        Returns
        -------
        np.ndarray: array
            demeaned data
        """
        return data - data.mean(axis=0)

    def _cross_cov_term(self, weight_beh, cov_mat, weight_img):
        return -weight_img.T @ (cov_mat @ weight_beh)

    def _regularization_term(self, weight, cov_mat):
        return 0.5 * 1.0 * (weight.T @ (cov_mat @ weight) - 1)

    def _dissimilarity_penality(self, theta_r, img_weight1, img_weight2):
        return theta_r * 0.5 * np.sum((img_weight1 - img_weight2) ** 2)

In [61]:
ca3 = CA3()
ca3.fit(study1, study2)


[(50, 2), (50, 2)]
[ 0.49671415 -0.1382643   0.64768854  1.52302986 -0.23415337 -0.23413696
  1.57921282  0.76743473 -0.46947439  0.54256004 -0.46341769 -0.46572975
  0.24196227 -1.91328024 -1.72491783 -0.56228753 -1.01283112  0.31424733
 -0.90802408 -1.4123037   1.46564877 -0.2257763   0.0675282  -1.42474819
 -0.54438272  0.11092259 -1.15099358  0.37569802 -0.60063869 -0.29169375
 -0.60170661  1.85227818 -0.01349722 -1.05771093  0.82254491 -1.22084365
  0.2088636  -1.95967012 -1.32818605  0.19686124  0.73846658  0.17136828
 -0.11564828 -0.3011037  -1.47852199 -0.71984421 -0.46063877  1.05712223
  0.34361829 -1.76304016  0.32408397 -0.38508228 -0.676922    0.61167629
  1.03099952  0.93128012 -0.83921752 -0.30921238  0.33126343  0.97554513
 -0.47917424 -0.18565898 -1.10633497 -1.19620662  0.81252582  1.35624003
 -0.07201012  1.0035329   0.36163603 -0.64511975  0.36139561  1.53803657
 -0.03582604  1.56464366 -2.6197451   0.8219025   0.08704707 -0.29900735
  0.09176078 -1.98756891 -0.2196

In [48]:
ca3.intial_weights_

array([], dtype=float64)