# CNN Extractor + Regressor

### Constant

In [1]:
DATA_PATH = "../../data.nosync/subj01"
EXTRACTOR = "resnet50-imagenet1k-v2"
LAYER = "avgpool"

## Load data

In [2]:
from sklearn.model_selection import train_test_split
from src import dataset

feat, l_frmi, r_frmi = dataset.get_dataset(
    DATA_PATH, EXTRACTOR, LAYER, True)

# normalize
feat = (feat - feat.mean()) / feat.std()

X_train, X_test, l_fmri_train, l_fmri_test, r_fmri_train, r_fmri_test = train_test_split(feat, l_frmi, r_frmi, train_size=0.8)

print("X_train shape: {}".format(X_train.shape))
print("l_fmri_train shape: {}".format(l_fmri_train.shape))
print("r_fmri_train shape: {}".format(r_fmri_train.shape))

print()

print("X_test shape: {}".format(X_test.shape))
print("l_fmri_test shape: {}".format(l_fmri_test.shape))
print("r_fmri_test shape: {}".format(r_fmri_test.shape))


X_train shape: (7872, 2048)
l_fmri_train shape: (7872, 19004)
r_fmri_train shape: (7872, 20544)

X_test shape: (1969, 2048)
l_fmri_test shape: (1969, 19004)
r_fmri_test shape: (1969, 20544)


## Modelling

In [3]:
from sklearn.linear_model import Lasso

model_l = Lasso().fit(X=X_train, y=l_fmri_train)
y_pred_l = model_l.predict(X_test)

In [4]:
model_r = Lasso().fit(X=X_train, y=r_fmri_train)
y_pred_r = model_r.predict(X_test)

In [None]:
import numpy as np
from scipy.stats import pearsonr

def compute_perason(pred, target):

    corrcoef = list()
    for pred, target in zip(pred.T, target.T):

        s, _ = pearsonr(x=pred, y=target)
        corrcoef.append(s)

    return np.array(corrcoef)


lh_correlation = compute_perason(y_pred_l, l_fmri_test)
rh_correlation = compute_perason(y_pred_r, r_fmri_test)



## Visualization

In [6]:
from src.visualize import histogram, box_plot

hist = histogram(DATA_PATH, lh_correlation, rh_correlation, "Lasso, ResNet50-Imagenet1k-AvgPooling", "./img/lasso_baseline")
hist

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [7]:
box = box_plot(DATA_PATH, lh_correlation, rh_correlation, "Lasso, ResNet50-Imagenet1k-AvgPooling", "./img/lasso_baseline")
box