## Initial setup

**Folders need to be created before running the notebook: TEM-virus-best_classify_models and TEM-virus-classify_scores**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import tensorflow as tf
print(tf.__version__)
import torch
print(torch.__version__)
import matplotlib
print(matplotlib.__version__)

In [None]:
!nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



In [None]:
# Other imports
! pip install tensorflow_addons
! pip install tensorflow_io

import os
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from keras.callbacks import Callback, EarlyStopping, ModelCheckpoint
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing.image import load_img

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, AutoMinorLocator
from imutils import paths
from tqdm import tqdm
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import tensorflow_io as tfio
import tensorflow_hub as hub
import numpy as np
import cv2
import pandas as pd
import seaborn as sns
from scipy.stats import mannwhitneyu
from sklearn.preprocessing import LabelEncoder
from sklearn.cluster import KMeans
import sklearn.manifold
from sklearn.metrics.pairwise import cosine_similarity as cos
from sympy.utilities.iterables import multiset_permutations
from sklearn.metrics import accuracy_score, f1_score,precision_score, recall_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import *
from sklearn.preprocessing import StandardScaler
from IPython.display import Image, display

import zipfile
import concurrent.futures

# Random seed fix
random_seed = 42
tf.random.set_seed(random_seed)
np.random.seed(random_seed)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow_addons
  Downloading tensorflow_addons-0.17.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 4.8 MB/s 
Installing collected packages: tensorflow-addons
Successfully installed tensorflow-addons-0.17.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow_io
  Downloading tensorflow_io-0.26.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (25.9 MB)
[K     |████████████████████████████████| 25.9 MB 123.4 MB/s 
Installing collected packages: tensorflow-io
Successfully installed tensorflow-io-0.26.0


## Dataset gathering and preparation

In [None]:
training_batch_size = 4

BATCH_SIZE = training_batch_size

imageSize = 224

category_names = ['bundle', 'dispersed', 'network', 'singular']
color_method = ['C0', 'C1', 'C2', 'C3', 'C4']
color = ['black', 'magenta', 'cyan', 'yellow']
marker = ['o', 's', '<', '>', '^']
seaborn_palette = sns.color_palette("colorblind")

In [None]:
np.random.seed(random_seed)

peptide_morph_train_path = "/content/drive/MyDrive/TEM image datasets/2021-TEM virus/context_virus_1nm_256x256"
train_images_directory = np.array(list(paths.list_files(basePath=peptide_morph_train_path + "/augmented_train", validExts='jpg')))
validation_images_directory = np.array(list(paths.list_files(basePath=peptide_morph_train_path + "/validation", validExts='jpg')))
test_images_directory = np.array(list(paths.list_files(basePath=peptide_morph_train_path + "/test", validExts='jpg')))


print("number of training images: %i" % len(train_images_directory))
print("number of validation images: %i" % len(validation_images_directory))
print("number of test images: %i" % len(test_images_directory))

number of training images: 6553
number of validation images: 1368
number of test images: 1091


In [None]:
train_labels = []
validation_labels = []
test_labels = []
for i in range(train_images_directory.shape[0]):
  train_label = train_images_directory[i].split("/")[-2]
  train_labels.append(train_label)
le = LabelEncoder()
train_images_label = le.fit_transform(train_labels)

for i in range(validation_images_directory.shape[0]):
  validation_label = validation_images_directory[i].split("/")[-2]
  validation_labels.append(validation_label)
validation_images_label = le.transform(validation_labels)

for i in range(test_images_directory.shape[0]):
  test_label = test_images_directory[i].split("/")[-2]
  test_labels.append(test_label)
test_images_label = le.transform(test_labels)

In [None]:
# put training, validation and test data directory and labels in dataframes
train_label_storage = pd.DataFrame(np.concatenate((train_images_directory.reshape(1, -1).transpose(), train_images_label.reshape(1, -1).transpose()), axis=-1), columns=['filename', 'label'])
validation_label_storage = pd.DataFrame(np.concatenate((validation_images_directory.reshape(1, -1).transpose(), validation_images_label.reshape(1, -1).transpose()), axis=-1), columns=['filename', 'label'])
test_label_storage = pd.DataFrame(np.concatenate((test_images_directory.reshape(1, -1).transpose(), test_images_label.reshape(1, -1).transpose()), axis=-1), columns=['filename', 'label'])

In [None]:
# Image preprocessing utils
@tf.function
def parse_images(image_path):
    image_string = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image_string, channels=3)
    # image = tfio.experimental.image.decode_tiff(image_string)[:, :, :3]   # in the doc, it transforms tiff to 4 channels, with additional channel of opacity which is not needed.
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, size=[imageSize, imageSize])

    return image

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices(train_images_directory)
train_ds = (
    train_ds
    .map(parse_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # .shuffle(200)
    .batch(training_batch_size
          #  , drop_remainder=True
           )
    .prefetch(tf.data.experimental.AUTOTUNE)
)

validation_ds = tf.data.Dataset.from_tensor_slices(validation_images_directory)
validation_ds = (
    validation_ds
    .map(parse_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # .shuffle(200)
    .batch(training_batch_size
          #  , drop_remainder=True
           )
    .prefetch(tf.data.experimental.AUTOTUNE)
)

test_ds = tf.data.Dataset.from_tensor_slices(test_images_directory)
test_ds = (
    test_ds
    .map(parse_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # .shuffle(200)
    .batch(training_batch_size
          #  , drop_remainder=True
           )
    .prefetch(tf.data.experimental.AUTOTUNE)
)

## Initiate self-supervised models

In [None]:
Resnet50_transfer = tf.keras.applications.ResNet50(
    include_top=False,
    weights="imagenet",
    input_tensor=None,
    input_shape=(imageSize, imageSize, 3), 
    pooling=None,
)

Resnet50_transfer.trainable = False

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
# Resnet as backbone
def get_resnet_self_supervise_model(hidden_1, hidden_2, hidden_3):
    base_model = Resnet50_transfer
    base_model.trainable = True
    inputs = Input((imageSize, imageSize, 3))
    h = base_model(inputs, training=True)
    h = GlobalAveragePooling2D()(h)

    projection_1 = Dense(hidden_1)(h)                                        
    projection_1 = Activation("relu")(projection_1)
    projection_1 = BatchNormalization(epsilon=0.001)(projection_1)
    projection_2 = Dense(hidden_2)(projection_1)
    projection_2 = Activation("relu")(projection_2)
    projection_2 = BatchNormalization(epsilon=0.001)(projection_2)
    projection_3 = Dense(hidden_3)(projection_2)
    projection_3 = BatchNormalization(epsilon=0.001)(projection_3)

    resnet_model = Model(inputs, projection_3)
    
    return resnet_model

In [None]:
# state-of-the-arts simclr encoder trained on imagenet
saved_model_path = 'gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk0/saved_model/'
imagenet_simclr_model = tf.saved_model.load(saved_model_path)

for image_batch in train_ds.take(1):
  simclr_logits = imagenet_simclr_model(image_batch, trainable=False)

simclr_layer_shape = simclr_logits['final_avg_pool'].shape[1]

# obtain feature maps from the state-of-the-arts simclr encoder trained on imagenet
train_feature = np.zeros((len(train_images_directory), simclr_layer_shape))
validation_feature = np.zeros((len(validation_images_directory), simclr_layer_shape))
test_feature = np.zeros((len(test_images_directory), simclr_layer_shape))

counter = 0
for train_image_batch in train_ds:
  train_feature[counter * training_batch_size: (counter + 1) * training_batch_size] = imagenet_simclr_model(train_image_batch, trainable=False)['final_avg_pool']
  counter += 1

counter = 0
for validation_image_batch in validation_ds:
  validation_feature[counter * training_batch_size: (counter + 1) * training_batch_size] = imagenet_simclr_model(validation_image_batch, trainable=False)['final_avg_pool']
  counter += 1

counter = 0
for test_image_batch in test_ds:
  test_feature[counter * training_batch_size: (counter + 1) * training_batch_size] = imagenet_simclr_model(test_image_batch, trainable=False)['final_avg_pool']
  counter += 1 



## Initiate downstream classification model

In [None]:
def get_linear_model(features):                                                                                  
    linear_model = Sequential([                                                                                  
			                              Input(shape=(features,)),
                                    # Dense(1024, activation='relu'),
                                    # Dropout(0.3),
                                    # Dense(1024, activation='relu'),
                                    # Dropout(0.3),
		                                Dense(9, activation="softmax")])
    return linear_model

## Performance assessment of label-efficient training of downstream classification task in main manuscript

In [None]:
# Random seed fix
random_seed_list = np.array([42, 43, 44, 45, 46])
random_seed_for_split = np.linspace(42, 42 + 19, 20).astype(int)
training_image_size = np.array([600, 200, 120, 80, 40, 20, 10])

# list of models
models = ['barlow_tem', 'simclr_tem', 'barlow_ImageNet', 'simclr_ImageNet', 'simclr_ImageNet_sota']


earlystop_criterion = EarlyStopping(monitor='val_accuracy', patience=10, verbose=0, mode='auto', restore_best_weights=True)
adam = tf.keras.optimizers.Adam(learning_rate=0.001)
metrics = ['accuracy']

In [None]:
# for self-supervised models trained on tem images or ImageNet images (832 images) for TEM virus prediction
for m in range(0, 4):
  linear_scores = np.zeros((len(training_image_size), len(random_seed_list), len(random_seed_for_split), 4))
  fusion_matrix = np.zeros((len(training_image_size), len(random_seed_list), len(random_seed_for_split), 9, 9))
    
  for i in range(len(random_seed_list)):
    resnet_model = get_resnet_self_supervise_model(128, 64, 1024)
    if models[m] == 'barlow_tem':
      resnet_model.load_weights('barlow_resnet_batch64_project128_64_1024_seed%i.h5' % (random_seed_list[i]))
    if models[m] == 'simclr_tem':
      resnet_model.load_weights('simclr_resnet_batch64_project128_64_1024_nocrop_seed%i.h5' % (random_seed_list[i]))
    if models[m] == 'barlow_ImageNet':
      resnet_model.load_weights('%s_batch64_project128_64_1024_seed%i.h5' % (models[m], random_seed_list[i]))
    if models[m] == 'simclr_ImageNet':
      resnet_model.load_weights('%s_batch64_project128_64_1024_nocrop_seed%i.h5' % (models[m], random_seed_list[i]))


    resnet_model.layers[1].trainable = False

    feature_extraction_model = Model(resnet_model.input, resnet_model.layers[-9].output)

    # Extract train and test features
    train_feature = feature_extraction_model.predict(train_ds)
    validation_feature = feature_extraction_model.predict(validation_ds)
    test_feature = feature_extraction_model.predict(test_ds)

    for j in range(len(random_seed_for_split)):
      for n in range(len(training_image_size)):
        train_images_directory = train_label_storage.groupby('label').sample(n=training_image_size[n], random_state = random_seed_for_split[j]).sample(frac=1, random_state = 42)['filename'].to_numpy().astype(str)
        train_images_label = train_label_storage.groupby('label').sample(n=training_image_size[n], random_state = random_seed_for_split[j]).sample(frac=1, random_state = 42)['label'].to_numpy().astype(int)

        checkpoint_model_linear = ModelCheckpoint('TEM-virus-best_classify_models/%s_%iper_class_seed%i_seed%i_linear.h5' 
                                      % (models[m], training_image_size[n], random_seed_list[i], random_seed_for_split[j]),
                                      monitor='val_accuracy', mode='auto', verbose=0, save_best_only=True, save_weights_only=True)
        train_feature_map = train_feature[train_label_storage.groupby('label').sample(training_image_size[n], random_state = random_seed_for_split[j]).sample(frac=1, random_state = 42).index.to_numpy()]
        # train linear classifier model
        linear_model = get_linear_model(train_feature.shape[1])
        linear_model.compile(loss="sparse_categorical_crossentropy", metrics=metrics, optimizer=adam)
        linear_history = linear_model.fit(train_feature_map, train_images_label, validation_data=(validation_feature, validation_images_label), 
                                              batch_size=training_batch_size, epochs=300, workers=8, use_multiprocessing=True, 
                                              verbose=1, callbacks=[earlystop_criterion, checkpoint_model_linear])
        # log best classification model performance
        linear_model.load_weights('TEM-virus-best_classify_models/%s_%iper_class_seed%i_seed%i_linear.h5' 
                                    % (models[m], training_image_size[n], random_seed_list[i], random_seed_for_split[j]))
        y_pred_linear = np.argmax(linear_model.predict(test_feature), axis=-1)

        linear_scores[n, i, j] = np.array([accuracy_score(y_pred_linear, test_images_label), 
                                                precision_score(y_pred_linear, test_images_label, average='weighted'), 
                                                recall_score(y_pred_linear, test_images_label, average='weighted'),
                                                f1_score(y_pred_linear, test_images_label, average='weighted')])
        fusion_matrix[n, i, j] = confusion_matrix(y_pred_linear, test_images_label)
    
          
  np.savez_compressed('TEM-virus-classify_scores/%s_classification_result.npz' %(models[m]), scores=linear_scores, fusion_matrix=fusion_matrix)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 33/300
Epoch 34/300
Epoch 35/300
Epoch 36/300
Epoch 37/300
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11

In [None]:
# Random seed fix
random_seed_list = np.array([42, 43, 44, 45, 46])
random_seed_for_split = np.linspace(42, 42 + 99, 100).astype(int)
training_image_size = np.array([600, 200, 120, 80, 40, 20, 10])

# list of models
models = ['barlow_tem', 'simclr_tem', 'barlow_ImageNet', 'simclr_ImageNet', 'simclr_ImageNet_sota']


earlystop_criterion = EarlyStopping(monitor='val_accuracy', patience=10, verbose=0, mode='auto', restore_best_weights=True)
adam = tf.keras.optimizers.Adam(learning_rate=0.001)
metrics = ['accuracy']

In [None]:
# for sota self-supervised encoder obtained from SimCLR official implementation

linear_scores = np.zeros((len(training_image_size), len(random_seed_for_split), 4))
fusion_matrix = np.zeros((len(training_image_size), len(random_seed_for_split), 9, 9))

for j in range(len(random_seed_for_split)):
  for n in range(len(training_image_size)):
    train_images_directory = train_label_storage.groupby('label').sample(n=training_image_size[n], random_state = random_seed_for_split[j]).sample(frac=1, random_state = 42)['filename'].to_numpy().astype(str)
    train_images_label = train_label_storage.groupby('label').sample(n=training_image_size[n], random_state = random_seed_for_split[j]).sample(frac=1, random_state = 42)['label'].to_numpy().astype(int)

    checkpoint_model_linear = ModelCheckpoint('TEM-virus-best_classify_models/%s_%iper_class_seed%i_linear.h5' 
                                  % (models[-1], training_image_size[n], random_seed_for_split[j]),
                                  monitor='val_accuracy', mode='auto', verbose=0, save_best_only=True, save_weights_only=True)
    train_feature_map = train_feature[train_label_storage.groupby('label').sample(training_image_size[n], random_state = random_seed_for_split[j]).sample(frac=1, random_state = 42).index.to_numpy()]
    # train linear classifier model
    linear_model = get_linear_model(train_feature.shape[1])
    linear_model.compile(loss="sparse_categorical_crossentropy", metrics=metrics, optimizer=adam)
    linear_history = linear_model.fit(train_feature_map, train_images_label, validation_data=(validation_feature, validation_images_label), 
                                          batch_size=training_batch_size, epochs=300, workers=8, use_multiprocessing=True, 
                                          verbose=1, callbacks=[earlystop_criterion, checkpoint_model_linear])
    # log best classification model performance
    linear_model.load_weights('TEM-virus-best_classify_models/%s_%iper_class_seed%i_linear.h5' 
                                % (models[-1], training_image_size[n], random_seed_for_split[j]))
    y_pred_linear = np.argmax(linear_model.predict(test_feature), axis=-1)

    linear_scores[n, j] = np.array([accuracy_score(y_pred_linear, test_images_label), 
                                            precision_score(y_pred_linear, test_images_label, average='weighted'), 
                                            recall_score(y_pred_linear, test_images_label, average='weighted'),
                                            f1_score(y_pred_linear, test_images_label, average='weighted')])

  np.savez_compressed('TEM-virus-classify_scores/%s_classification_result.npz' %(models[-1]), scores=linear_scores, fusion_matrix=fusion_matrix)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 31/300
Epoch 32/300
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300
Epoch 26/300
Epoch 27/300
Epoch 28/300
Epoch 29/300
Epoch 30/300
Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
