In [None]:
%env CUDA_VISIBLE_DEVICES = 7

In [None]:
from warnings import filterwarnings
filterwarnings("ignore")

from dataset import *
from feature_analysis import *

import seaborn as sns
from siuba import *
from plotnine import *
from plotnine import options
options.figure_size = (10,5)

In [None]:
model_string = 'swin_base_patch4_window7_224_imagenet'
model_options = get_model_options()
image_transforms = get_recommended_transforms(model_string)

model_name = model_options[model_string]['model_name']
train_type = model_options[model_string]['train_type']
model_call = model_options[model_string]['call']

In [None]:
target_imageset = 'oasis'
image_data = load_image_data(target_imageset)

In [None]:
stimulus_loader = DataLoader(dataset=StimulusSet(image_data.image_path, image_transforms), batch_size=64)

In [None]:
superlative_layers = pd.read_csv('../results/superlative_layers.csv').set_index('model_string').to_dict(orient='index')
target_layer = superlative_layers[model_string]['model_layer']

In [None]:
feature_maps = get_all_feature_maps(model_string, stimulus_loader, numpy=False, layers_to_retain = [target_layer])

In [None]:
stimulus_features = get_feature_map_srps(feature_maps)

In [None]:
subject_data = load_response_data('oasis', average = False)

In [None]:
ratings_tally = (subject_data >> gather('measurement', 'rating', _.arousal, _.valence, _.beauty) >>
                 filter(~_.rating.isna()) >> group_by(_.subject, _.measurement, _.image_type) >> count())

In [None]:
ratings_tally.groupby('measurement').nunique('subject')

In [None]:
ratings_tally >> arrange(_.n)

In [None]:
ratings_tally >> group_by(_.measurement, _.subject) >> summarize(n = np.sum(_.n)) >> arrange(_.n)

In [None]:
ratings_tally >> filter(_.measurement == 'beauty') >> arrange(_.n) >> group_by(_.image_type) >> distinct(_.n) 

In [None]:
image_tally = (subject_data >> gather('measurement', 'rating', _.arousal, _.valence, _.beauty) >>
                 filter(~_.rating.isna()) >> group_by(_.image_name, _.measurement, _.image_type) >> count())

In [None]:
image_tally >> group_by(_.measurement) >> summarize(count = np.sum(_.n))

In [None]:
image_tally >> arrange(_.n)

In [None]:
oracle_data = (pd.read_csv('response/oasis_oracle_data.csv')
               .rename(columns={'item_count': 'image_count', 'category': 'image_type'}))

In [None]:
oracle_data.groupby(['measurement','image_type'])['oracle_corr'].mean().reset_index()

In [None]:
(ggplot(oracle_data, aes('image_type', 'oracle_corr')) + geom_jitter(width = 0.3, height = 0) + 
 geom_boxplot(outlier_alpha = 0) + facet_wrap('~measurement'))

In [None]:
output_dir = 'incoming/subject_regs/{}'.format(target_imageset)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

output_file = os.path.join(output_dir, model_string + '.csv')

if os.path.exists(output_file):
    subject_reg_data = pd.read_csv(output_file)

if not os.path.exists(output_file):
    target_features = stimulus_features[target_layer]
    if isinstance(target_features, torch.Tensor):
        target_features = target_features.numpy()
    
    model_layer = target_layer
    model_layer_index = 0

    data_i = copy(subject_data).merge(image_data, on = ['image_name'])

    score_dictlist = []
    for measurement in tqdm(['arousal','valence','beauty']):
        data_i_sub1 = data_i[['subject', 'image_name', 'image_type', measurement]]
        for image_type in tqdm(data_i['image_type'].unique().tolist() + ['Combo'], leave = False):
            if image_type != 'Combo':
                data_i_subset = data_i_sub1[data_i_sub1['image_type'] == image_type]
            if image_type == 'Combo':
                data_i_subset = data_i_sub1
            for subject in tqdm(data_i_subset['subject'].unique(), leave = False):
                group_data_i = (data_i_subset[data_i_subset['subject'] != subject].groupby('image_name')[measurement]
                                .mean().reset_index()[measurement]).to_numpy()
                subject_data_i = data_i_subset[data_i_subset['subject'] == subject][measurement].to_numpy()
                item_indices = np.argwhere(~np.isnan(subject_data_i)).flatten()
                if len(item_indices) > 10:
                    y, y_group = subject_data_i[item_indices], group_data_i[item_indices]

                    X = scale(target_features[item_indices,:])

                    alpha_values = [1000]
                    regression = RidgeCV(alphas=alpha_values, store_cv_values=True,
                                         scoring='explained_variance').fit(X,y)

                    ridge_gcv_score, ridge_gcv_alpha = regression.best_score_, regression.alpha_
                    y_pred = regression.cv_values_[:, alpha_values.index(ridge_gcv_alpha)]

                    for alpha_value in alpha_values:
                        y_pred = regression.cv_values_[:, alpha_values.index(alpha_value)]

                        for score_type in scoring_metrics:
                            ridge_gcv_score = scoring_metrics[score_type](y, y_pred)

                            score_dictlist.append({'model': model_name, 'train_type': train_type, 
                                                   'model_layer_index': model_layer_index+1,
                                                   'model_layer': model_layer,
                                                   'subject': subject, 
                                                   'measurement': measurement,
                                                   'image_type': image_type,
                                                   'image_count': len(item_indices),
                                                   'score_type': score_type,
                                                   'score': ridge_gcv_score, 
                                                   'alpha': regression.alpha_})

    subject_reg_data = pd.DataFrame(score_dictlist)
    subject_reg_data.to_csv(output_file, index = None)

In [None]:
oracle_reg_data = subject_reg_data.merge(oracle_data, on = ['measurement','image_type','image_count','subject'])

In [None]:
(subject_reg_data[subject_reg_data['score_type'] == 'pearson_r']
 .groupby(['measurement', 'image_type'])['score'].mean().reset_index())

In [None]:
(oracle_reg_data.groupby(['score_type','measurement','image_type'])['score','oracle_corr']
 .corr().iloc[0::2,-1].reset_index().drop('level_3', axis = 1))

In [None]:
plot_data = subject_reg_data[subject_reg_data['score_type'] == 'pearson_r']
(ggplot(plot_data, aes('image_type', 'score')) + 
 geom_jitter(aes(color = 'image_count'), width = 0.3, height = 0) + 
 geom_boxplot(outlier_alpha = 0) + facet_wrap('~measurement'))

In [None]:
plot_data = subject_reg_data[subject_reg_data['score_type'] == 'pearson_r'].reset_index()
plot_data = plot_data[plot_data['image_type'] != 'Combo'].reset_index()
(ggplot(plot_data, aes(x = 'image_count', y = 'score')) + facet_wrap('~measurement') +
 geom_point(aes(color = 'image_type')) + geom_smooth(method = 'lm'))

In [None]:
for measurement in ['beauty', 'arousal', 'valence']:
    subject_reg_data_ = subject_reg_data[subject_reg_data['score_type'] == 'pearson_r'].reset_index()
    oracle_regs = (pd.concat([oracle_data, subject_reg_data_[['score']]], axis = 1) >> 
                   filter(_.image_type == 'Combo', _.measurement == measurement))
    x = oracle_regs['oracle_corr']
    y = oracle_regs['score']
    nas = np.logical_or(np.isnan(x), np.isnan(y))
    corr = np.round(pearsonr(x[~nas], y[~nas]),5)
    print('{} (n = {}), r = {}, p = {}'.format(measurement, len(oracle_regs), corr[0], corr[1]))

In [None]:
for measurement in ['beauty', 'arousal', 'valence']:
    for category in ['Animal', 'Object', 'Person', 'Scene', 'Combo']:
        subject_reg_data_ = subject_reg_data[subject_reg_data['score_type'] == 'pearson_r'].reset_index()
        oracle_regs = (pd.concat([oracle_data, subject_reg_data_[['score']]], axis = 1) >> 
                       filter(_.image_type == category, _.measurement == measurement))
        x = oracle_regs['oracle_corr']
        y = oracle_regs['score']
        nas = np.logical_or(np.isnan(x), np.isnan(y))
        corr = np.round(pearsonr(x[~nas], y[~nas]),5)
        print('{}, {} (n = {}), r = {}, p = {}'.format(measurement, category, len(oracle_regs), corr[0], corr[1]))

In [None]:
subject_reg_data_ = subject_reg_data[subject_reg_data['score_type'] == 'pearson_r'].reset_index()
oracle_regs = (pd.concat([oracle_data, subject_reg_data_[['score']]], axis = 1) >> 
               filter(_.image_type == 'Combo', _.measurement == 'beauty'))
sns.lmplot(x = 'score', y = 'oracle_corr', data = oracle_regs);

In [None]:
splithalf_oracle_beauty = []
for i in tqdm(range(1000)):
    subject_reg_data_ = subject_reg_data[subject_reg_data['score_type'] == 'pearson_r'].reset_index()
    oracle_regs = (pd.concat([oracle_data, subject_reg_data_[['score']]], axis = 1) >> 
                   filter(_.image_type == 'Combo', _.measurement == 'beauty'))
    oracle_regs = oracle_regs.sample(n = 400)
    x = oracle_regs['oracle_corr']
    y = oracle_regs['score']
    nas = np.logical_or(np.isnan(x), np.isnan(y))
    corr = np.round(pearsonr(x[~nas], y[~nas]),5)
    splithalf_oracle_beauty.append({'measurement': 'beauty', 'n': len(oracle_regs),
                                    'r': corr[0], 'p': corr[1]})
    #print('{} (n = {}), r = {}, p = {}'.format(measurement, len(oracle_regs), corr[0], corr[1]))
splithalf_oracle_beauty = pd.DataFrame(splithalf_oracle_beauty)

In [None]:
splithalf_oracle_beauty['r'].mean()

In [None]:
sns.distplot(splithalf_oracle_beauty['r']);

In [None]:
splithalf_valence_beauty = []
for i in tqdm(range(1000)):
    subject_reg_data_ = subject_reg_data[subject_reg_data['score_type'] == 'pearson_r'].reset_index()
    oracle_regs = (pd.concat([oracle_data, subject_reg_data_[['score']]], axis = 1) >> 
                   filter(_.image_type == 'Combo', _.measurement == 'valence'))
    oracle_regs = oracle_regs.sample(n = 400)
    x = oracle_regs['oracle_corr']
    y = oracle_regs['score']
    nas = np.logical_or(np.isnan(x), np.isnan(y))
    corr = np.round(pearsonr(x[~nas], y[~nas]),5)
    splithalf_valence_beauty.append({'measurement': 'beauty', 'n': len(oracle_regs),
                                    'r': corr[0], 'p': corr[1]})
    #print('{} (n = {}), r = {}, p = {}'.format(measurement, len(oracle_regs), corr[0], corr[1]))
splithalf_valence_beauty = pd.DataFrame(splithalf_valence_beauty)

In [None]:
splithalf_valence_beauty['r'].mean()

In [None]:
sns.distplot(splithalf_valence_beauty['r']);

### Vessel Dataset

In [None]:
subject_data = load_response_data('vessel', average = False)
image_data = load_image_data('vessel')

In [None]:
stimulus_loader = DataLoader(dataset=StimulusSet(image_data.image_path, image_transforms), batch_size=64)

In [None]:
feature_maps = get_all_feature_maps(model_string, stimulus_loader, numpy=False, layers_to_retain = [target_layer])

In [None]:
stimulus_features = get_feature_map_srps(feature_maps)

In [None]:
subject_data.groupby(['image_type','subject']).beauty.mean().reset_index().groupby('image_type').beauty.count()

In [None]:
oracle_data = (pd.read_csv('response/vessel_oracle_data.csv')
               .rename(columns={'item_count':'image_count', 'category': 'image_type'}))

In [None]:
subject_data

In [None]:
score_dictlist = []
data_i = copy(subject_data).merge(image_data, on = ['image_name'])
for model_layer_index, model_layer in enumerate(tqdm([target_layer], desc = 'Regression (Layer)')):
    target_features = stimulus_features[model_layer]
    if isinstance(stimulus_features[model_layer], torch.Tensor):
        target_features = target_features.numpy()

    for measurement in [col for col in subject_data.columns if col in ['arousal','beauty','valence']]:
        data_i_sub1 = data_i[['subject', 'image_name', 'image_type', measurement]]
        for image_type in tqdm(data_i['image_type'].unique().tolist() + ['Combo'], leave = False):
            if image_type != 'Combo':
                data_i_subset = data_i_sub1[data_i_sub1['image_type'] == image_type]
            if image_type == 'Combo':
                data_i_subset = data_i_sub1
            for subject in tqdm(data_i_subset['subject'].unique(), leave = False):
                group_data_i = (data_i_subset[data_i_subset['subject'] != subject].groupby('image_name')[measurement]
                                .mean().reset_index()[measurement]).to_numpy()
                subject_data_i = data_i_subset[data_i_subset['subject'] == subject][measurement].to_numpy()
                item_indices = np.argwhere(~np.isnan(subject_data_i)).flatten()
                if len(item_indices) > 10:
                    y, y_group = subject_data_i[item_indices], group_data_i[item_indices]
                    
                    X = scale(target_features[item_indices,:])
                    
                    alpha_values = [1000]
                    regression = RidgeCV(alphas=alpha_values, store_cv_values=True,
                                         scoring='explained_variance').fit(X,y)

                    ridge_gcv_score, ridge_gcv_alpha = regression.best_score_, regression.alpha_
                    y_pred = regression.cv_values_[:, alpha_values.index(ridge_gcv_alpha)]

                    for alpha_value in alpha_values:
                        y_pred = regression.cv_values_[:, alpha_values.index(alpha_value)]

                        for score_type in scoring_metrics:
                            ridge_gcv_score = scoring_metrics[score_type](y, y_pred)

                            score_dictlist.append({'model': model_name, 'train_type': train_type, 
                                                 'model_layer_index': model_layer_index+1,
                                                 'model_layer': model_layer,
                                                 'subject': subject, 
                                                 'measurement': measurement,
                                                 'image_type': image_type,
                                                 'image_count': len(item_indices),
                                                 'score_type': score_type,
                                                 'score': ridge_gcv_score, 
                                                 'alpha': regression.alpha_})

subject_reg_data = pd.DataFrame(score_dictlist)

In [None]:
#subject_reg_data.to_csv('vessel_subject_regressions.csv', index = None)

In [None]:
oracle_reg_data = subject_reg_data.merge(oracle_data, on = ['image_type','subject'])

In [None]:
oracle_reg_data.image_type.unique()

In [None]:
(oracle_reg_data.groupby(['score_type','measurement','image_type'])['score','oracle_corr']
 .corr().iloc[0::2,-1].reset_index().drop('level_3', axis = 1))

In [None]:
for image_type in oracle_reg_data.image_type.unique():
    oracle_reg_data_i = oracle_reg_data >> filter(_.image_type == image_type)
    corr = np.round(pearsonr(oracle_reg_data_i['score'], oracle_reg_data_i['oracle_corr']), 5)
    print('{} (n = {}), r = {}, p = {}'.format(image_type, len(oracle_reg_data_i), corr[0], corr[1]))
    sns.lmplot(x = 'score', y = 'oracle_corr', data = oracle_reg_data_i);