In [33]:

import pandas as pd
import numpy as np
import gdown
import os


In [34]:
!pip install gdown

import gdown


# To improve the speed of setting up, retrieve only the files for the two groups used in training
# Swap with the line below to get the entire dataset as a zip file
# gdown.download('https://drive.google.com/uc?id=1-QTtycxsVNeym17zrMBSZCAAtEZEs05p', '/content/miniimagenet.zip', quiet=False)


# Use the correct file ID from the Google Drive link
file_id = '18cAvtcJc4jMkLi_QgA7vsN1oe2z026Dp'
gdown.download(f'https://drive.google.com/uc?id={file_id}', 'miniimagenet.zip', quiet=False)

# After downloading, check if the file exists
import os
if os.path.exists('miniimagenet.zip'):
    print("File downloaded successfully.")
else:
    print("File download failed.")

images_directory = '/content/miniimagenet/images/'
if not os.path.exists(images_directory):
  os.makedirs(images_directory)
!unzip -qq /content/miniimagenet.zip -d {images_directory}




Downloading...
From (original): https://drive.google.com/uc?id=18cAvtcJc4jMkLi_QgA7vsN1oe2z026Dp
From (redirected): https://drive.google.com/uc?id=18cAvtcJc4jMkLi_QgA7vsN1oe2z026Dp&confirm=t&uuid=a6c680ea-e582-4fcd-b30c-785aadbf89ac
To: /content/miniimagenet.zip
100%|██████████| 6.74G/6.74G [00:36<00:00, 183MB/s]


File downloaded successfully.
replace /content/miniimagenet/images/n01532829/n01532829_10006.JPEG? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [35]:


# get the CSV with the list of all file names
mini_imagenet_file_list_csv = '/content/all_imagenet_file_names.csv'


# Download the CSV file from Google Drive using the file ID
file_id = '112iJze_LpZrGZbtR6XohaycpMRGCXLel'
gdown.download(f'https://drive.google.com/uc?id={file_id}', mini_imagenet_file_list_csv, quiet=False)


import pandas as pd


file_list = pd.read_csv(mini_imagenet_file_list_csv,  on_bad_lines='skip')
print(file_list.head())

selected_groups = ['n01532829', 'n01558993']
samples_miniimagenet = file_list[file_list['label'].isin(selected_groups)].groupby('label').first()


# Other options for group pairs
# n01532829, n01558993
# n02108551, n02108915

Downloading...
From: https://drive.google.com/uc?id=112iJze_LpZrGZbtR6XohaycpMRGCXLel
To: /content/all_imagenet_file_names.csv
100%|██████████| 1.82M/1.82M [00:00<00:00, 46.6MB/s]

              file_name      label
0  n07697537_49047.JPEG  n07697537
1  n07697537_33979.JPEG  n07697537
2  n07697537_10215.JPEG  n07697537
3  n07697537_15450.JPEG  n07697537
4   n07697537_5990.JPEG  n07697537





<br>
<br>
For this analysis, the first two groups above are used n01532829 and n01558993. These are two similar-looking bird species. The goal is to train the classifier to detect minor details that distinguish the two groups.

In [36]:
# show a sample from the two selected classes

from PIL import Image
from importlib import reload
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


print(samples_miniimagenet.columns)

for image_path in samples_miniimagenet['filename'].unique():
    print('Sample for group ID ', image_path[:9])
    print(images_directory, image_path )
    reload(plt)

    img=mpimg.imread(images_directory, image_path)
    #plt.close('all')
    imgplot = plt.imshow(img)
    plt.show()




Index(['file_name'], dtype='object')
Sample for group ID  n01532829
/content/miniimagenet/images/ n01532829_930.JPEG


IsADirectoryError: [Errno 21] Is a directory: '/content/miniimagenet/images'

In [37]:
# Since the imagenet files are all in the same directory, they can be used as-is
# with the maml-pytorch setup and do not need to be processed further at this point.
# However, the directory does need to be added to the Python training file

file_list[file_list['label'].isin(selected_groups)].to_csv(images_directory + '/train.csv')



In [38]:
#Get the EEG spectrograms zip file and unzip it
eeg_image_directory = '/content/eeg_sz_spectrograms'
gdown.download('https://drive.google.com/uc?id=1WZ1yIFE2bng0McnY_4UBJHqTttsXTTNX', '{}.zip'.format(eeg_image_directory), quiet=False)
!unzip -qq {eeg_image_directory}.zip -d {eeg_image_directory}

Downloading...
From (original): https://drive.google.com/uc?id=1WZ1yIFE2bng0McnY_4UBJHqTttsXTTNX
From (redirected): https://drive.google.com/uc?id=1WZ1yIFE2bng0McnY_4UBJHqTttsXTTNX&confirm=t&uuid=618b5645-37d7-4f93-9eca-4af61ac35690
To: /content/eeg_sz_spectrograms.zip
100%|██████████| 44.2M/44.2M [00:00<00:00, 48.5MB/s]


In [39]:
# rename
dl_link = '/content/eeg_sz_spectrograms/gen_data_20s_70pct_overlap_-_high_nfft_all_channels_sml/'
!mv "{dl_link}/hc" {eeg_image_directory}
!mv "{dl_link}/sz" {eeg_image_directory}



In [40]:
# Use EEG of Sz for validation and testing
# Extract files from an eeg_sz spectrogram directory where files are saved by subject

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split


rand_seed = 1
# List of raw patient file IDs that should be skipped based on categorization as outliers
ignore_list = ['h09', 'h10', 's10', 's11', 's12']
hc_subject_ids = ['hc' + str(i) for i in range(14) if "h{:02}".format(i) not in ignore_list]
sz_subject_ids = ['sz' + str(i) for i in range(14) if "s{:02}".format(i) not in ignore_list]
all_subject_ids = np.concatenate([hc_subject_ids, sz_subject_ids], axis=0)
validate_hc, test_hc = train_test_split(hc_subject_ids, test_size=0.5, random_state=rand_seed)
validate_sz, test_sz = train_test_split(sz_subject_ids, test_size=0.5, random_state=rand_seed)

validation_ids = np.concatenate([validate_hc, validate_sz])
test_ids = np.concatenate([test_hc, test_sz])


print('\nSubjects assigned to groups using sklearn.model_selection.train_test_split')
print('Test group: ', ", ".join(test_ids), "\n")
print('Validation group: ', ", ".join(validation_ids), "\n")



from shutil import copyfile
import pandas as pd
import os


test_images_output_directory = 'all_test_images'
validation_images_output_directory = 'all_validation_images'

if not os.path.exists(test_images_output_directory):
    os.mkdir(test_images_output_directory)
if not os.path.exists(validation_images_output_directory):
    os.mkdir(validation_images_output_directory)



# Note: CSV is only used for MAML and Prototypical networks
def gen_csv_and_copy_sz_files(image_dir, img_output_dir, participant_ids, output_name, split_with_csv=False):
    subdir_data = []
    for group in ['hc', 'sz']: #['Healthy_Control', 'Sz_Patient']:
        for pid in os.listdir(image_dir + '/' + group): # by participant IDs
            if pid in participant_ids:
              for file in os.listdir(image_dir + '/' + group + '/' + pid):
                file_data = {'filename': file, 'label': group}
                subdir_data.append(file_data)
                destination = img_output_dir + '/' + file if split_with_csv else  '{}/{}/{}'.format(img_output_dir, group, file)
                if not os.path.exists('{}/{}'.format(img_output_dir, group)):
                  os.makedirs('{}/{}'.format(img_output_dir, group))
                copyfile(image_dir + '/' + group + '/' + pid + '/' + file,  destination )
    if split_with_csv:
      pd.DataFrame(subdir_data).to_csv(img_output_dir + '/' + output_name)
    return pd.DataFrame(subdir_data)


df = gen_csv_and_copy_sz_files(image_dir=eeg_image_directory,
                                img_output_dir=test_images_output_directory,
                                participant_ids=test_ids,
                               split_with_csv=True,
                                output_name= 'test.csv')
df = gen_csv_and_copy_sz_files(image_dir=eeg_image_directory,
                                img_output_dir=validation_images_output_directory,
                                participant_ids=validation_ids,
                               split_with_csv=True,
                                output_name= 'test.csv') #file must be name test.csv
print(df.head())




Subjects assigned to groups using sklearn.model_selection.train_test_split
Test group:  hc2, hc3, hc4, hc12, hc1, hc6, sz2, sz3, sz4, sz9, sz1, sz6 

Validation group:  hc0, hc7, hc13, hc11, hc8, hc5, sz0, sz7, sz13, sz8, sz5 

      filename label
0  hc11_48.png    hc
1  hc11_13.png    hc
2  hc11_12.png    hc
3  hc11_22.png    hc
4  hc11_73.png    hc


In [41]:
# Get the project files from github
!git clone https://github.com/MTynes/MAML-Pytorch.git maml_pytorch


Cloning into 'maml_pytorch'...
remote: Enumerating objects: 226, done.[K
remote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (14/14), done.[K
remote: Total 226 (delta 8), reused 7 (delta 3), pack-reused 209 (from 1)[K
Receiving objects: 100% (226/226), 672.94 KiB | 2.35 MiB/s, done.
Resolving deltas: 100% (129/129), done.


In [42]:
!python /content/maml_pytorch/train_custom_dataset.py --help


  if name is 'conv2d':
  elif name is 'convt2d':
  elif name is 'linear':
  elif name is 'bn':
  if name is 'conv2d':
  elif name is 'convt2d':
  elif name is 'linear':
  elif name is 'leakyrelu':
  elif name is 'avg_pool2d':
  elif name is 'max_pool2d':
  if name is 'conv2d':
  elif name is 'convt2d':
  elif name is 'linear':
  elif name is 'bn':
  elif name is 'flatten':
  elif name is 'reshape':
  elif name is 'relu':
  elif name is 'leakyrelu':
  elif name is 'tanh':
  elif name is 'sigmoid':
  elif name is 'upsample':
  elif name is 'max_pool2d':
  elif name is 'avg_pool2d':
usage: train_custom_dataset.py [-h] [--train_dir TRAIN_DIR]
                               [--further_training_dir FURTHER_TRAINING_DIR]
                               [--validation_dir VALIDATION_DIR] [--test_dir TEST_DIR]
                               [--run_further_training RUN_FURTHER_TRAINING] [--epochs EPOCHS]
                               [--further_training_epochs FURTHER_TRAINING_EPOCHS] [--n_way N_

In [44]:
# run the training file

import timeit

start = timeit.default_timer()

n_epochs = 400 * 10000 # must be a multiple of 10000

train_dir = '/content/miniimagenet/images'
!python /content/maml_pytorch/train_custom_dataset.py  --epochs {n_epochs} --run_further_training 'false'



stop = timeit.default_timer()
print('MAML execution time: {} hrs'.format((stop - start)/60/60) )






Namespace(train_dir='/content/miniimagenet/images', further_training_dir='/content/all_further_training_images', validation_dir='/content/all_validation_images', test_dir='/content/all_test_images', run_further_training=False, epochs=4000000, further_training_epochs=2000000, n_way=2, k_spt=1, k_qry=5, imgsz=84, imgc=3, task_num=4, meta_lr=0.001, update_lr=0.01, update_step=5, update_step_test=10, accuracy_log_file='/content/mean_test_accuracy.txt')
Meta(
  (net): Learner(
    conv2d:(ch_in:3, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:1, pad

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(20,10))


print('Mean Validation Accuracy over Epochs')

text_file = open('/content/mean_test_accuracy.txt', "r")
mean_accs = text_file.read().split('\n')
mean_accs = [(float(ma) * 100) for ma in mean_accs]

axes = plt.gca()
axes.set_ylim([np.int(min(mean_accs)-2), np.int(max(mean_accs)) +2])

plt.plot(mean_accs)
plt.xlabel('Epochs')
plt.ylabel('Mean Accuracy')
plt.title('Mean Validation Accuracy over Epochs')
plt.show();

In [None]:
metrics = pd.read_csv('mean_metrics.csv')
metrics.head()


metrics[['train_loss', 'val_loss']].plot(figsize=(10,5), title='Train and Validation Loss over Epochs')

In [None]:
metrics[['train_accuracy', 'val_accuracy']].plot(figsize=(10,5), title='Train and Validation Accuracy over Epochs')

In [None]:
import matplotlib.pyplot as plt
import itertools

# modified from main.py https://github.com/zhangrong1722/CheXNet-Pytorch

def plt_roc(test_y, probas_y, plot_micro=False, plot_macro=False):
    assert isinstance(test_y, list) and isinstance(probas_y, list), 'the type of input must be list'
    skplt.metrics.plot_roc(test_y, probas_y, plot_micro=plot_micro, plot_macro=plot_macro)
    plt.savefig('roc_auc_curve.png')
    plt.show()
    plt.close()


###########################################
# Define confusion matrix and ROC visualization functions
# from https://colab.research.google.com/drive/1ISfhxFDntfOos7cOeT7swduSqzLEqyFn#scrollTo=UiKRYOWPfhJs

def plot_confusion_matrix(cm, classes=None,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues,
                          cv=10):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("\nNormalized confusion matrix")
    else:
        print('\nConfusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    if classes:
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 1.5
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.locator_params(nbins=2)

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

In [None]:
pred_df = pd.read_csv('test_predictions_and_labels.csv')

In [None]:
from IPython.display import display

# Get the predicted labels for metric calculations.
#
# Since prototypical networks always sets the ground truth to 0,
# infer the predicted class label from true label and the Boolean value for 'correct'
# To do this, set hc to be -1 and sz to 1. This allows the opposite class to be selected
# for rows where correct is False by multiplying the true_label by -1
pred_df['true_label'] = pred_df.apply(lambda x: -1 if x['true_label'] == 0 else 1, axis=1)
pred_df['correct'] = pred_df.apply(lambda x: 1 if x['correct'] == True else 0, axis=1)
pred_df['prediction'] = pred_df.apply(lambda x: x['true_label']
  if x['correct'] == 1 else x['true_label'] * -1, axis=1)
pred_df.replace(-1, 0, inplace=True)
# display(pred_df.head())
# pred_df.tail()


In [None]:
from sklearn.metrics import confusion_matrix


pred_y = pred_df['prediction'].values
truth_y = pred_df['true_label'].values
#probas_y = [s.replace('[', '').replace(']', '').split(', ') for s in best_model_preds['probas_y'].values]
#probas_y = [[float(t[0]), float(t[1])] for t in probas_y]


confusion = confusion_matrix(pred_y, truth_y)
plot_confusion_matrix(confusion,
                      classes=['hc', 'sz'],
                      title='Confusion Matrix')

In [None]:
pd.read_csv('metrics_summary.csv')
