In [None]:
%env CUDA_VISIBLE_DEVICES = '7'

In [None]:
from scripts.feature_analysis import *

In [None]:
sys.path.append('model_opts')
from feature_extraction import *
from model_options import *

In [None]:
imageset = 'oasis'
response_data = load_response_data(imageset, average = False)
image_data = load_image_data(imageset)

In [None]:
model_string = 'ViT-L/14_clip'
model_options = get_model_options()
model_name = model_options[model_string]['model_name']
train_type = model_options[model_string]['train_type']
model_call = model_options[model_string]['call']

model = eval(model_call)
model = model.eval()
if torch.cuda.is_available():
    model = model.cuda()
    
image_transforms = get_recommended_transforms(model_string)

In [None]:
stimulus_loader = DataLoader(dataset=StimulusSet(image_data.image_path, image_transforms), batch_size=64)

In [None]:
target_layers = pd.read_csv('results/superlative_layers.csv').set_index('model_string').to_dict(orient='index')
target_layer = target_layers[model_string]['model_layer']

In [None]:
stimulus_features = get_all_feature_maps(model, stimulus_loader, numpy=False,
                                         layers_to_retain = [target_layer])

In [None]:
def treves_rolls(x):
    if isinstance(x, np.ndarray):
        return ((np.sum(x / x.shape[0]))**2 / np.sum(x**2 / x.shape[0]))
    if isinstance(x, torch.Tensor):
        return ((torch.sum(x / x.shape[0]))**2 / torch.sum(x**2 / x.shape[0]))

In [None]:
#source: https://tntorch.readthedocs.io/en/latest/_modules/metrics.html

def torch_skewness(x):
    return torch.mean(((x - torch.mean(x))/torch.std(x))**3)

def torch_kurtosis(x, fisher=True):
    return torch.mean(((x-torch.mean(x))/torch.std(x))**4) - fisher*3

def torch_frobnorm(x):
    return torch.sqrt(torch.clamp(torch.dot(x,x), min=0))

In [None]:
metric_dictlist = []
for model_layer_index, model_layer in enumerate(tqdm(stimulus_features)):
    target_map = stimulus_features[model_layer]
    for target_i, target_activity in enumerate(target_map):
        image_name = image_data.image_name.iloc[target_i]
        
        mean_activity = target_activity.mean().item()
        mean_absolute = target_activity.abs().mean().item()
        max_activity = target_activity.max().item()
        min_activity = target_activity.min().item()
        var_activity = target_activity.std().item()
        var_absolute = target_activity.abs().std().item()
        sparseness = treves_rolls(target_activity).item()
        skewness = torch_skewness(target_activity.abs()).item()
        kurtosis = torch_kurtosis(target_activity.abs()).item()
        frobenius = torch_frobnorm(target_activity.abs()).item()
        activity_range = max_activity - min_activity
        
        metric_dictlist.append({
            'image': image_name, 
            'model': model_name,
            'train_type': train_type,
            'model_layer': model_layer, 
            'model_layer_index': model_layer_index,
            'mean_absolute': mean_absolute,
            'mean_activity': mean_activity,
            'var_activity': var_activity,
            'var_absolute': var_absolute,
            'max_activity': max_activity,
            'min_activity': min_activity,
            'range': activity_range,
            'sparseness': sparseness,
            'skewness': skewness,
            'kurtosis': kurtosis,
            'frobenius': frobenius,
        })
        
metric_data_raw = pd.DataFrame(metric_dictlist)

In [None]:
metric_data_raw[['mean_activity','var_activity','max_activity','min_activity',
                 'range','sparseness','kurtosis','frobenius', 'mean_absolute','var_absolute']].corr()

In [None]:
response_data = {'oasis': load_response_data('oasis')}

In [None]:
def process_metric_data(metric_data, orient='wide'):
    metric_data['dataset'] = imageset
    if 'image' in metric_data.columns:
        metric_data = metric_data.rename(columns={'image': 'image_name'})
    
    data_wide = pd.merge(metric_data, response_data[imageset], on = 'image_name')
    data_wide['model_layer_depth'] = (data_wide['model_layer_index'] / 
                                      data_wide['model_layer'].nunique())
    
    id_columns = ['dataset','image_name','image_type','model','train_type',
                  'model_layer','model_layer_index','model_layer_depth']
    measurement_columns = [col for col in data_wide.columns 
                           if col in ['arousal','beauty','valence']]
    
    analysis_columns = [col for col in data_wide.columns 
                        if col not in id_columns + measurement_columns]
    
    data_wide = data_wide[id_columns + measurement_columns + analysis_columns]
    data_wide = pd.melt(data_wide, id_vars=id_columns + analysis_columns, 
                        var_name = 'measurement', value_name='rating')
    
    data_long = pd.melt(data_wide, id_vars=id_columns + ['measurement', 'rating'], 
                        var_name = 'metric', value_name='value')
    
    if orient == 'wide':
        return(data_wide)
    if orient == 'long':
        return(data_long)
    
def process_corr_data(data_wide, include_combo = True, orient='long'):
    model_layers = data_wide['model_layer'].unique().tolist()
    
    id_columns = ['model','train_type','dataset','image_type','model_layer',
                  'model_layer_index','model_layer_depth', 'measurement']
    
    corr_data_wide = (data_wide.groupby(id_columns).corrwith(data_wide['rating'], numeric_only=True)
                      .reset_index().drop('rating',axis = 1))
    
    if include_combo:
        
        id_columns_ = [col for col in id_columns if col != 'image_type']
        
        corr_data_wide_ = (data_wide.groupby(id_columns_).corrwith(data_wide['rating'], numeric_only=True)
                           .reset_index().drop('rating',axis = 1))
        corr_data_wide_['image_type'] = 'Combo'
        
        corr_data_wide = pd.concat([corr_data_wide, corr_data_wide_])
        
    
    corr_data_long = pd.melt(corr_data_wide, id_vars = id_columns, 
                             var_name = 'metric', value_name='corr')
        
    if orient == 'wide':
        return(corr_data_wide)
    if orient == 'long':
        return(corr_data_long)


In [None]:
metric_data = process_metric_data(metric_data_raw)

In [None]:
corr_data = process_corr_data(metric_data)
corr_data['corr_abs'] = abs(corr_data['corr'])

In [None]:
corr_data = corr_data[corr_data['model_layer'] == 'LayerNorm-50']

In [None]:
corr_data.groupby(['metric'])['corr'].mean().reset_index()

In [None]:
(corr_data[(corr_data['measurement'] == 'beauty')]
 .groupby(['metric'])['corr_abs'].mean().reset_index().sort_values(by='corr_abs'))

In [None]:
(corr_data[(corr_data['measurement'] == 'beauty') & (corr_data['image_type'] == 'Scene')]
 .groupby(['metric'])['corr_abs'].mean().reset_index().sort_values(by='corr_abs'))

In [None]:
max_transform(corr_data, group_vars = ['measurement', 'image_type', 'metric'],
              measure_var = 'corr').groupby(['metric'])['corr'].mean().reset_index()

In [None]:
max_transform(corr_data[(corr_data['measurement'] == 'beauty') & (corr_data['image_type'] == 'Scene')],
              group_vars = ['metric'], measure_var = 'corr').groupby(['metric'])['corr'].mean().reset_index()

In [None]:
import numba

NAN = float("nan")

@numba.njit(nogil=True)
def _any_nans(a):
    for x in a:
        if np.isnan(x): return True
    return False

@numba.jit
def any_nans(a):
    if not a.dtype.kind=='f': return False
    return _any_nans(a.flat)

import pingouin as pg

In [None]:
target_metrics = ['mean_activity', 'mean_absolute', 'var_activity', 'var_absolute', 'max_activity', 'min_activity',
                  'range', 'sparseness', 'skewness', 'kurtosis', 'frobenius']

results_dictlist = []
data_wide = metric_data
model_layers = data_wide['model_layer'].unique()
for measurement in data_wide['measurement'].unique():
        for image_type in data_wide['image_type'].unique():
            for metric in target_metrics:
                data_i = data_wide[(data_wide['image_type'] == image_type) & 
                                   (data_wide['measurement'] == measurement)]
                y = data_i[(data_i['model_layer']==model_layers[0])]['rating'].to_numpy()
                X = np.stack([data_i[(data_i['model_layer']==model_layer)][metric].to_numpy() 
                              for model_layer in model_layers], axis = 1)

                actual_max = max([abs(pearsonr(x, y)[0]) for x in X.transpose()
                                  if not any_nans(x)])

                permuted_max_corrs = []
                for i in range(1000):
                    permuted_corrs = [abs(pearsonr(np.random.permutation(x), y)[0]) 
                                      for x in X.transpose() if not any_nans(x)]
                    permuted_max_corrs.append(max(permuted_corrs))

                permuted_lqt = np.quantile(permuted_max_corrs, 0.025)
                permuted_uqt = np.quantile(permuted_max_corrs, 0.975)
                permuted_pvalue = (len([corr for corr in permuted_max_corrs if corr >= actual_max])) / 1000

                results_dictlist.append({'model': model_name, 'train_type': train_type, 
                                         'dataset': 'oasis', 'image_type': image_type, 
                                         'metric': metric, 'measurement': measurement,
                                         'model_depth': len(model_layers),
                                         'corr_max_score': actual_max,
                                         'corr_lower_ci': permuted_lqt,
                                         'corr_upper_ci': permuted_uqt,
                                         'corr_p_value': permuted_pvalue})


metric_permutations = pd.DataFrame(results_dictlist)

In [None]:
metric_permutations['corr_p_adj'] = pg.multicomp(metric_permutations['corr_p_value'].to_numpy(), 
                                                 alpha = 0.05, method = 'fdr')[1]

In [None]:
metric_permutations[['measurement','image_type']].value_counts()

In [None]:
metric_permutations.query('corr_p_value < 0.05')[['measurement','image_type']].value_counts()

In [None]:
metric_permutations.query('corr_p_value < 0.05')[['metric']].value_counts()

In [None]:
metric_permutations[(metric_permutations['measurement'] == 'beauty') & 
                    (metric_permutations['image_type'] == 'Scene')] 