In [None]:
import urllib
import numpy as np
import tensorflow as tf
from rcnn_sat import preprocess_image
import os
import json
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.decomposition import PCA as RandomizedPCA
from sklearn.preprocessing import MinMaxScaler
from scipy.stats import entropy
from scipy import stats
import scipy.io as sio

# Transfer learning: train the classification layer on scenes

## Building the model

In [None]:
model_type = 'b' #b, bl, b_d

In [None]:
input_layer = tf.keras.layers.Input((128, 128, 3))
if model_type=='b':
    from rcnn_sat import b_net
    model = b_net(input_layer, classes_scenes=2)
elif model_type=='b_d':
    from rcnn_sat import b_d_net
    model = b_d_net(input_layer, classes_scenes=2)
elif model_type=='bl':
    from rcnn_sat import bl_net
    model = bl_net(input_layer, classes_scenes=2, cumulative_readout=False)

## Load the weights of the frozen layers into the model

In [None]:
weights_name='{}_ecoset.h5'.format(model_type)
#download from OSF if not done already
if not os.path.isfile(weights_name): 
    _, msg = urllib.request.urlretrieve(
        'https://osf.io/9td5p/download', '{}_ecoset.h5'.format(model_type))
    print(msg)
model.load_weights(weights_name,by_name=True)

In [None]:
model.summary()

## Define trainable layers (for fine-tuning: all)

In [None]:
for layer in model.layers:
    layer.trainable = True

## Define the train and validation data

In [None]:
#load and preprocess training & validation data

#training
with open('2400_selected_scenes_places365_train_standard.json') as json_file:
    subset_scenes_dict_train = json.load(json_file)  
train_set_path = '/scratch/agnek95/PDM/places_365_256_train_val/data_256/'
train_image_paths = list(subset_scenes_dict_train.keys())
train_imgs_prep = np.ones([len(train_image_paths),128,128,3])
train_imgs_prep[:] = np.nan
for idx,image_path in enumerate(train_image_paths):
        image = load_img(train_set_path+image_path, target_size=(128, 128)) 
        image = img_to_array(image)
        image = np.uint8(image)
        image = preprocess_image(image)
        train_imgs_prep[idx,:,:,:] = image
        
train_images_paths = [train_set_path+image_path for image_path in train_image_paths]

#validation    
with open('1200_selected_scenes_places365_val_standard.json') as json_file:
    subset_scenes_dict_val = json.load(json_file)  
val_set_path = '/scratch/agnek95/PDM/places_365_256_train_val/val_256/' 
val_image_paths = list(subset_scenes_dict_val.keys())
val_imgs_prep = np.ones([len(val_image_paths),128,128,3])
val_imgs_prep[:] = np.nan
for idx,image_path in enumerate(val_image_paths):
        image = load_img(val_set_path+image_path, target_size=(128, 128)) 
        image = img_to_array(image)
        image = np.uint8(image)
        image = preprocess_image(image)
        val_imgs_prep[idx,:,:,:] = image



In [None]:
#labels: man-made: 0, natural: 1
y_train = np.array([label for label in range(2) for reps in range(int(train_imgs_prep.shape[0]/2))]) 
y_val = np.array([label for label in range(2) for reps in range(int(val_imgs_prep.shape[0]/2))])

x_train =  train_imgs_prep
x_val = val_imgs_prep


## Train the classifier

In [None]:
base_learning_rate = 0.0001 #in case you want to use optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate)
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])
if model.name == 'bl_net_edl':

    print(x_train.shape)
    print(y_train.shape)
    print(x_val.shape)
    print(y_val.shape)
        
history = model.fit(
    x_train,
    y_train,
    shuffle=True,
    batch_size=10,
    epochs=20, #5?
    validation_data=(x_val, y_val),
)

# y_pred = model.predict(x_val, batch_size = 10)


In [None]:
y_pred = model.predict(x_val, batch_size = 10)
print(y_pred)

## Save the weights

In [None]:
model_name='{}_net_scenes'.format(model_type)
model.save_weights(model_name+'_weights.h5')

# If you want to get the RTs rightaway, follow the next steps, otherwise use the separate script (rnn_dth_collect_activations.ipynb)

In [None]:
def get_RTs(test_images_path,batch_size,entropy_thresh):
    num_images_all = len(test_images_path)
    num_batches = int(num_images_all / batch_size)
    num_timepoints = 8
    num_classes = 2
    all_batches_activ = np.ones([num_batches, batch_size, 128, 128, 3])
    all_batches_activ[:] = np.nan
    pred = np.ones([num_batches,num_timepoints,batch_size,num_classes])
    pred[:] = np.nan

    for batch, img_idx in enumerate(range(0, num_images_all, batch_size)):
        batch_paths = test_images_path[img_idx:img_idx + batch_size] 
        batch_images = np.zeros((batch_size,128,128,3)) 
        for i, image_path in enumerate(batch_paths):
            image = load_img(image_path, target_size=(128, 128)) 
            image = img_to_array(image)
            image = np.uint8(image)
            image = preprocess_image(image)
            batch_images[i,:,:,:] = image

        #predictions
        pred[batch,:,:,:] = model(batch_images) #shape: num_timepoints x batch_size x classes

    #reshape: all images from all batches in one dimension
    pred_reshaped =  np.transpose(pred,(0,2,1,3)).reshape(num_batches*batch_size,num_timepoints,num_classes)

    #get entropies for each image & each timepoint
    entropies_pred = np.ones([num_images_all,num_timepoints])
    entropies_pred[:] = np.nan

    for image in range(num_images_all):
        for tp in range(num_timepoints):
            entropies_pred[image,tp] = entropy(pred_reshaped[image,tp])

    # #for each image, determine the timepoint when entropy reaches threshold
    rt_thresh = np.ones(num_images_all)
    rt_thresh[:] = np.nan
    for image in range(num_images_all):
        for tp in range(num_timepoints):
            if entropies_pred[image,tp] <= entropy_thresh:
                rt_thresh[image]=tp
                break          

    #if it never reaches the threshold (nan in the array), replace by 8
    rt_thresh[np.isnan(rt_thresh)] = 8
   
    return rt_thresh

In [None]:
# from scipy import stats

#load RTs
rts_eeg_dict = sio.loadmat('RT_all_subjects_5_35_categorization.mat')
rts_eeg = rts_eeg_dict.get('RTs')

#define some variables
num_subjects = rts_eeg.shape[0]
entropies = np.arange(0.01,0.1,0.01)
best_entropy = np.ones([num_subjects])
best_entropy[:] = np.nan
correlation_test = np.ones([num_subjects,3]) #all,artificial,natural
correlation_test[:] = np.nan
num_scenes = len(test_images_paths)

#get RNN RTs for every entropy threshold and correlate with humans
rts_rnn = np.ones([len(entropies),len(test_images_paths)])
rts_rnn[:] = np.nan
for idx,e in enumerate(entropies):
    rts_rnn[idx,:] = get_RTs(test_images_paths,20,e)
    
#for each fold, fit the entropy threshold on 29 subjects
for s in range(num_subjects): 
    artificial_idx = np.arange(30)
    natural_idx = np.arange(30,60)

    test_sub = rts_eeg[s,:]
    fit_sub = np.nanmean(rts_eeg[np.arange(num_subjects)!=s,:],0)
    correlation_fit = np.ones([len(entropies),2])
    correlation_fit[:] = np.nan
    corr_diff = np.ones([len(entropies)])
    corr_diff[:] = np.nan
    
    for idx,e in enumerate(entropies):
        correlation_fit[idx,0] = stats.pearsonr(np.squeeze(rts_rnn[idx,artificial_idx]),fit_sub[artificial_idx])[0] #artificial
        correlation_fit[idx,1] = stats.pearsonr(np.squeeze(rts_rnn[idx,natural_idx]),fit_sub[natural_idx])[0] #natural
        corr_diff[idx] = np.abs(correlation_fit[idx,0]-correlation_fit[idx,1])
        
    #select the entropy with highest correlation but lowest art/nat RNN-human difference   
    best_entropy[s] = round(entropies[np.argmin(corr_diff)],2)
    print(correlation_fit)
    print(corr_diff)
    
    #remove scene if there's no RT for it 
    selected_rnn_rts = rts_rnn[np.argmin(corr_diff),:]
    if np.argwhere(np.isnan(test_sub)).size:
        print(s)
        removed_scene = np.argwhere(np.isnan(test_sub))[0][0]
        if removed_scene in natural_idx:
            natural_idx = np.delete(natural_idx,removed_scene-30)
        elif removed_scene in artificial_idx:
            artificial_idx = np.delete(artificial_idx,removed_scene)

    #correlate with leftout subject        
    correlation_test[s,0] = stats.pearsonr(selected_rnn_rts[np.concatenate((artificial_idx,natural_idx))],\
                                           test_sub[np.concatenate((artificial_idx,natural_idx))])[0]        
    correlation_test[s,1] = stats.pearsonr(selected_rnn_rts[artificial_idx],test_sub[artificial_idx])[0]
    correlation_test[s,2] = stats.pearsonr(selected_rnn_rts[natural_idx],test_sub[natural_idx])[0]
    
print(best_entropy)
print(correlation_test)
RT_entropy = stats.mode(best_entropy)[0][0]
RT_RNN_final = rts_rnn[np.argwhere(entropies==RT_entropy)[0][0],:]

np.save('correlation_RT_human_{}_net_cross-validated'.format(model_type),correlation_test)
np.save('{}_net_RTs_entropy_threshold_{}'.format(model_type,RT_entropy),RT_RNN_final)