## Imports

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
pip install tensorflow_addons

Collecting tensorflow_addons
  Downloading tensorflow_addons-0.15.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[?25l[K     |▎                               | 10 kB 19.8 MB/s eta 0:00:01[K     |▋                               | 20 kB 25.1 MB/s eta 0:00:01[K     |▉                               | 30 kB 28.5 MB/s eta 0:00:01[K     |█▏                              | 40 kB 20.3 MB/s eta 0:00:01[K     |█▌                              | 51 kB 15.9 MB/s eta 0:00:01[K     |█▊                              | 61 kB 14.6 MB/s eta 0:00:01[K     |██                              | 71 kB 11.2 MB/s eta 0:00:01[K     |██▍                             | 81 kB 12.3 MB/s eta 0:00:01[K     |██▋                             | 92 kB 13.4 MB/s eta 0:00:01[K     |███                             | 102 kB 12.6 MB/s eta 0:00:01[K     |███▎                            | 112 kB 12.6 MB/s eta 0:00:01[K     |███▌                            | 122 kB 12.6 MB/s eta 0:00:01[K

In [5]:
!unzip /content/drive/MyDrive/Graduation_Project/CheXpert-v1.0-small.zip > /dev/null

In [11]:
import tensorflow as tf

from keras.applications import imagenet_utils
from tensorflow.keras import layers
from tensorflow import keras
import matplotlib.pyplot as plt
#import tensorflow_datasets as tfds
import tensorflow_addons as tfa
from myowngen_v2 import DataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping
import pandas as pd
import cv2
#tfds.disable_progress_bar()

## Hyperparameters

In [12]:
# Values are from table 4.
patch_size = 4 # 2x2, for the Transformer blocks.
image_size = 256
expansion_factor = 2  # expansion factor for the MobileNetV2 blocks.

In [13]:
def conv_block(x, filters=16, kernel_size=3, strides=2):
    conv_layer = layers.Conv2D(
        filters, kernel_size, strides=strides, activation=tf.nn.swish, padding="same"
    )
    return conv_layer(x)

def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
    m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
    m = layers.BatchNormalization()(m)
    m = tf.nn.swish(m)

    if strides == 2:
        m = layers.ZeroPadding2D(padding=imagenet_utils.correct_pad(m, 3))(m)
    m = layers.DepthwiseConv2D(
        3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
    )(m)
    m = layers.BatchNormalization()(m)
    m = tf.nn.swish(m)

    m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
    m = layers.BatchNormalization()(m)

    if tf.math.equal(x.shape[-1], output_channels) and strides == 1:
        return layers.Add()([m, x])
    return m

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.swish)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, x])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]], dropout_rate=0.1,)
        # Skip connection 2.
        x = layers.Add()([x3, x2])

    return x


def mobilevit_block(x, num_blocks, projection_dim, strides=1):
    # Local projection with convolutions.
    local_features = conv_block(x, filters=projection_dim, strides=strides)
    local_features = conv_block(
        local_features, filters=projection_dim, kernel_size=1, strides=strides
    )

    # Unfold into patches and then pass through Transformers.
    num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
    non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
        local_features
    )
    global_features = transformer_block(
        non_overlapping_patches, num_blocks, projection_dim
    )

    # Fold into conv-like feature-maps.
    folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
        global_features
    )

    # Apply point-wise conv -> concatenate with the input features.
    folded_feature_map = conv_block(
        folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
    )
    local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])

    # Fuse the local and global features using a convoluion layer.
    local_global_features = conv_block(
        local_global_features, filters=projection_dim, strides=strides
    )

    return local_global_features

In [14]:
def create_mobilevit(num_classes=14):
    inputs = keras.Input((256, 256, 3))
    x = layers.Rescaling(scale=1.0 / 255)(inputs)
    # 
    x = conv_block(x, filters=8)
    x = inverted_residual_block(
    x, expanded_channels=8 * expansion_factor, output_channels=8
    )
    # Initial conv-stem -> MV2 block.
    x = conv_block(x, filters=16)
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=16
    )

    #Downsampling with MV2 block.
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
    )
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=24
    )
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=24
    )

    # First MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
    )
    x = mobilevit_block(x, num_blocks=2, projection_dim=64)

    # Second MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
    )
    x = mobilevit_block(x, num_blocks=4, projection_dim=80)
    # 

    # Third MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
    )
    x = inverted_residual_block(
    x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
    )
    x = mobilevit_block(x, num_blocks=3, projection_dim=128)
    x = conv_block(x, filters=192, kernel_size=1, strides=1)
    #
    x = mobilevit_block(x, num_blocks=3, projection_dim=192)
    x = conv_block(x, filters=320, kernel_size=1, strides=1)

    # Classification head.
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(num_classes, activation="sigmoid")(x)

    return keras.Model(inputs, outputs)


mobilevit_xxs = create_mobilevit()
mobilevit_xxs.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 rescaling (Rescaling)          (None, 256, 256, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 8)  224         ['rescaling[0][0]']              
                                                                                                  
 conv2d_1 (Conv2D)              (None, 128, 128, 16  128         ['conv2d[0][0]']             

## Dataset preparation

We will be using the
[`tf_flowers`](https://www.tensorflow.org/datasets/catalog/tf_flowers)
dataset to demonstrate the model. Unlike other Transformer-based architectures,
MobileViT uses a simple augmentation pipeline primarily because it has the properties
of a CNN.

In [15]:
sample_path  =  'sample_4.csv'
valid_path   = "val_sample_2.csv"
data_path    = '/content/'
weights_path = '/content/drive/MyDrive/Graduation_Project/Big_Vit.hdf5'

In [16]:
train = pd.read_csv(sample_path)
#train.drop(columns=['Binary'], inplace=True)
train.head()

Unnamed: 0,Path,Sex,Age,Frontal/Lateral,AP/PA,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices
0,CheXpert-v1.0-small/train/patient00057/study2/...,Female,48,Frontal,AP,1.0,,,,,,,,,0.0,,,,
1,CheXpert-v1.0-small/train/patient00060/study1/...,Female,44,Frontal,PA,1.0,0.0,,,,,,,,,0.0,,,
2,CheXpert-v1.0-small/train/patient00060/study1/...,Female,44,Lateral,,1.0,0.0,,,,,,,,,0.0,,,
3,CheXpert-v1.0-small/train/patient00066/study1/...,Male,61,Frontal,PA,1.0,0.0,,,,0.0,0.0,,,,0.0,,,
4,CheXpert-v1.0-small/train/patient00066/study1/...,Male,61,Lateral,,1.0,0.0,,,,0.0,0.0,,,,0.0,,,


In [17]:
train.loc[:, train.columns[5:]] = train.loc[:, train.columns[5:]].fillna(0)

to_take = list(set(train.columns[5:])-set(['Edema', 'Atelectasis']))
train.loc[:, to_take] = train.loc[:, to_take].replace({-1:0})

train.loc[:, ['Edema', 'Atelectasis']] = train.loc[:, ['Edema', 'Atelectasis']].replace({-1:1})

In [18]:
valid = pd.read_csv(valid_path)
valid.drop('Sex_y', axis=1, inplace=True)
valid.head()

Unnamed: 0,Path,Sex,Age,Frontal/Lateral,AP/PA,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices
0,CheXpert-v1.0-small/train/patient04947/study3/...,Male,56,Frontal,PA,0,0,0,0,0,0,0,0,0,0,0,0,0,0
1,CheXpert-v1.0-small/train/patient38193/study3/...,Male,67,Frontal,AP,0,0,0,1,0,1,0,0,0,0,0,0,0,1
2,CheXpert-v1.0-small/train/patient47458/study7/...,Male,57,Frontal,AP,0,0,0,1,0,0,0,0,0,0,0,0,1,1
3,CheXpert-v1.0-small/train/patient38830/study3/...,Male,56,Frontal,AP,0,0,0,0,0,0,0,0,1,0,0,0,0,1
4,CheXpert-v1.0-small/train/patient26417/study1/...,Female,50,Lateral,,0,0,0,0,0,0,0,0,0,0,0,0,0,1


In [19]:
train_dataset = DataGenerator(data_path, train, 14, batch_size=32, shape=(256,256, 3), shuffle=True)
val_dataset = DataGenerator(data_path, valid, 14, batch_size=32, shape=(256,256, 3), shuffle=True)

The authors use a multi-scale data sampler to help the model learn representations of
varied scales. In this example, we discard this part.

## Load and prepare the dataset

## Train a MobileViT (XXS) model

In [20]:
learning_rate = 0.001
label_smoothing_factor = 0.1
epochs = 10
batch_size = 32
#auto = tf.data.AUTOTUNE
#resize_bigger = 280
num_classes =14

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

#loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)


def run_experiment(epochs=epochs):
    mobilevit_xxs = create_mobilevit(num_classes=num_classes)
    mobilevit_xxs.load_weights(weights_path)

    # mobilevit_xxs.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=["binary_accuracy", tf.keras.metrics.AUC(multi_label=True) ])

    # checkpoint = ModelCheckpoint(weights_path, monitor='val_auc', verbose=1, save_best_only=False, mode='auto', save_freq = 'epoch')
    # early = EarlyStopping(monitor="val_auc", mode='auto', patience=5, restore_best_weights=False)
    # callbacks_list = [checkpoint, early]

    # mobilevit_xxs.fit(train_dataset, validation_data=val_dataset, epochs=epochs, batch_size=batch_size, verbose=1, callbacks=callbacks_list)
    # #mobilevit_xxs.load_weights(weights_path)
    # _, accuracy = mobilevit_xxs.evaluate(val_dataset)
    # print(f"Validation accuracy: {round(accuracy * 100, 2)}%")

    return mobilevit_xxs


mobilevit_xxs = run_experiment()

## Results and TFLite conversion

With about one million parameters, getting to ~85% top-1 accuracy on 256x256 resolution is
a strong result. This MobileViT mobile is fully compatible with TensorFlow Lite (TFLite)
and can be converted with the following code:

In [21]:
from sklearn.metrics import roc_auc_score

def custom_auc(y_true, y_pred):
    labels = ['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis', 'Pleural Effusion']

    results = pd.DataFrame(index=labels)


    scores = []
    for i in [2, 5, 6, 8, 10]:
        score = roc_auc_score(y_true[:, i], y_pred[:, i])
        scores.append(score)
        
    results['AUC'] = scores

    return results

In [22]:
val_path   = '/content/CheXpert-v1.0-small/valid.csv'
val = pd.read_csv(val_path)

In [23]:
#train_generator = DataGenerator(data_path, train, 14, batch_size=1, shape=(224,224, 3), shuffle=False)
val_generator = DataGenerator(data_path, val, 14, batch_size=1, shape=(256,256, 3), shuffle=False)

In [24]:
# actual
y_val_true   = val.iloc[:, 5:].values

# predicted
y_val_pred   = mobilevit_xxs.predict(val_generator)

results = custom_auc(y_val_true, y_val_pred)
results.AUC.mean()

0.8272416355207272

In [None]:
# # Serialize the model as a SavedModel.
# mobilevit_xxs.save("mobilevit_xxs")

# # Convert to TFLite. This form of quantization is called
# # post-training dynamic-range quantization in TFLite.
# converter = tf.lite.TFLiteConverter.from_saved_model("mobilevit_xxs")
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.target_spec.supported_ops = [
#     tf.lite.OpsSet.TFLITE_BUILTINS,  # Enable TensorFlow Lite ops.
#     tf.lite.OpsSet.SELECT_TF_OPS,  # Enable TensorFlow ops.
# ]
# tflite_model = converter.convert()
# open("mobilevit_xxs.tflite", "wb").write(tflite_model)

To learn more about different quantization recipes available in TFLite and running
inference with TFLite models, check out
[this official resource](https://www.tensorflow.org/lite/performance/post_training_quantization).