# If you think this code is useful. Please cite:
Ren Z, Kong X, Zhang Y, et al. UKSSL: Underlying Knowledge based Semi-Supervised Learning for Medical Image Classification[J]. IEEE Open Journal of Engineering in Medicine and Biology, 2023

In [None]:
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import os
import pathlib
import random
import numpy as np
from PIL import Image

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import applications
from tensorflow.keras.models import Model, Sequential
import tensorflow_addons as tfa

from keras import Model

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, ConfusionMatrixDisplay, classification_report

import warnings
warnings.filterwarnings('ignore')

# HyperPara

In [None]:
project_name = 'UKSSL'
dataset_dir = '/kaggle/input/leukemia/Original'
class_num = len(next(os.walk(dataset_dir))[1])

data_source = pathlib.Path(dataset_dir)
label_names = sorted(item.name for item in data_source.glob('*/') if item.is_dir())
label_to_index = dict((name,index) for index, name in enumerate(label_names))
index_to_label = dict((index,name) for index, name in enumerate(label_names))
all_image_paths = [str(path) for path in list(data_source.glob('*/*'))]
random.shuffle(all_image_paths)
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]

dataset_size = len(all_image_paths)

print('img indices: ', label_names)
print('img num: ', dataset_size)

In [None]:
# Dataset hyperparameters
unlabeled_ratio = 0.675
labeled_ratio = 0.225
test_ratio = 0.1

unlabeled_dataset_size = int(dataset_size*unlabeled_ratio)
labeled_dataset_size = int(dataset_size*labeled_ratio)
test_dataset_size = dataset_size-unlabeled_dataset_size-labeled_dataset_size

print('unlabeled size: ',unlabeled_dataset_size,'labeled size: ',labeled_dataset_size,'test size: ',test_dataset_size)
image_size = 96
image_channels = 3

# Algorithm hyperparameters
num_epochs = 200
batch_size = 500  # Corresponds to 200 steps per epoch
width = 128
temperature = 0.1
# Stronger augmentations for contrastive, weaker ones for supervised training
contrastive_augmentation = {"min_area": 0.25, "brightness": 0.6, "jitter": 0.2}
classification_augmentation = {"min_area": 0.75, "brightness": 0.3, "jitter": 0.1}

In [None]:
# Hypare for vit
transformer_layers = 1 # ori 8
input_shape = (96,96,3)
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
num_heads = 4
projection_dim = 64
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
mlp_head_units = [2048,1024,512,256,128]
num_classes = class_num

In [None]:
# show some imgs
from random import seed

plt.figure(figsize=(10,10))
for n in range(1,5):
    image_path = random.choice(all_image_paths)
    img_data = np.asarray(Image.open(image_path))
    img_label = [pathlib.Path(image_path).parent.name]
    # display.display(display.Image(image_path))
    plt.subplot(2,2,n)
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.xlabel(img_label)
    plt.imshow(img_data)
plt.show()

# Data Processing

In [None]:
tep = []
for path in all_image_paths:
    feature=tf.io.read_file(path)
    #解码图片
    feature = tf.image.decode_jpeg(feature,channels=3)
    #重新设置大小
    feature = tf.image.resize(feature, size=[96, 96])
    tep.append(tf.reshape(feature, shape=[1,96,96,3]))
all_images = tf.concat(tep,0)

In [None]:
# 元组被解压缩到映射函数的位置参数中
def load_and_preprocess_from_path_label(path, label):
	return load_and_process_image(path), label
 
dataset_all = tf.data.Dataset.from_tensor_slices((all_images,all_image_labels)) # as_supervised=True with label. True with label.

In [None]:
# create unlabeled, labeled, test datasets
unlabeled_train_datasets = dataset_all.take(unlabeled_dataset_size)
labeled_train_datasets = dataset_all.skip(unlabeled_dataset_size).take(labeled_dataset_size)
train_datasets_ssl = dataset_all.take(unlabeled_dataset_size+labeled_dataset_size)
test_datasets = dataset_all.skip(unlabeled_dataset_size+labeled_dataset_size)
print('unlabeled_train_datasets: ',tf.data.experimental.cardinality(unlabeled_train_datasets).numpy(),
      'labeled_train_datasets: ',tf.data.experimental.cardinality(labeled_train_datasets).numpy(),
      'ssl_train_datasets: ',tf.data.experimental.cardinality(train_datasets_ssl).numpy(),
      'test_datasets: ',tf.data.experimental.cardinality(test_datasets).numpy()
     )

In [None]:
def prepare_dataset():
    # Labeled and unlabeled samples are loaded synchronously
    # with batch sizes selected accordingly
    steps_per_epoch = (unlabeled_dataset_size + labeled_dataset_size) // batch_size
    unlabeled_batch_size = unlabeled_dataset_size // steps_per_epoch
    labeled_batch_size = labeled_dataset_size // steps_per_epoch
    print(
        f"batch size is {unlabeled_batch_size} (unlabeled) + {labeled_batch_size} (labeled)"
    )

    unlabeled_train_dataset = (
        unlabeled_train_datasets
        .shuffle(buffer_size=10 * unlabeled_batch_size)
        .batch(unlabeled_batch_size)
    )
    labeled_train_dataset = (
        labeled_train_datasets
        .shuffle(buffer_size=10 * labeled_batch_size)
        .batch(labeled_batch_size)
    )
    
    train_datasets_ssl1 = (
        train_datasets_ssl
        .shuffle(buffer_size=10 * batch_size)
        .batch(batch_size)
    )
    test_dataset = (
        test_datasets
        .batch(batch_size)
        .prefetch(buffer_size=tf.data.AUTOTUNE)
    )

    # Labeled and unlabeled datasets are zipped together
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_dataset, labeled_train_dataset)
    ).prefetch(buffer_size=tf.data.AUTOTUNE)

    return train_dataset, labeled_train_dataset, train_datasets_ssl1, test_dataset


# Load STL10 dataset
train_dataset, labeled_train_dataset, train_datasets_ssl, test_dataset = prepare_dataset()

In [None]:
print('train_dataset: ',tf.data.experimental.cardinality(train_dataset).numpy(),
      'labeled_train_dataset: ',tf.data.experimental.cardinality(labeled_train_dataset).numpy(),
      'test_dataset: ',tf.data.experimental.cardinality(test_dataset).numpy()
     )

# Image augmentations

In [None]:
# Distorts the color distibutions of images
class RandomColorAffine(layers.Layer):
    def __init__(self, brightness=0, jitter=0, **kwargs):
        super().__init__(**kwargs)

        self.brightness = brightness
        self.jitter = jitter

    def get_config(self):
        config = super().get_config()
        config.update({"brightness": self.brightness, "jitter": self.jitter})
        return config

    def call(self, images, training=True):
        if training:
            batch_size = tf.shape(images)[0]

            # Same for all colors
            brightness_scales = 1 + tf.random.uniform(
                (batch_size, 1, 1, 1), minval=-self.brightness, maxval=self.brightness
            )
            # Different for all colors
            jitter_matrices = tf.random.uniform(
                (batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
            )

            color_transforms = (
                tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
                + jitter_matrices
            )
            images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
        return images


# Image augmentation module
def get_augmenter(min_area, brightness, jitter):
    zoom_factor = 1.0 - math.sqrt(min_area)
    return keras.Sequential(
        [
            keras.Input(shape=(image_size, image_size, image_channels)),
            layers.Rescaling(1 / 255),
            layers.RandomFlip("horizontal"),
            layers.RandomTranslation(zoom_factor / 2, zoom_factor / 2),
            layers.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),
            RandomColorAffine(brightness, jitter),
        ]
    )


def visualize_augmentations(num_images):
    # Sample a batch from a dataset
    images = next(iter(train_dataset))[0][0][:num_images]
    # Apply augmentations
    augmented_images = zip(
        images,
        get_augmenter(**classification_augmentation)(images),
        get_augmenter(**contrastive_augmentation)(images),
        get_augmenter(**contrastive_augmentation)(images),
    )
    row_titles = [
        "Original:",
        "Weakly augmented:",
        "Strongly augmented:",
        "Strongly augmented:",
    ]
    plt.figure(figsize=(num_images * 2.2, 4 * 2.2), dpi=100)
    for column, image_row in enumerate(augmented_images):
        for row, image in enumerate(image_row):
            plt.subplot(4, num_images, row * num_images + column + 1)
            plt.imshow(image)
            if column == 0:
                plt.title(row_titles[row], loc="left")
            plt.axis("off")
    plt.tight_layout()


visualize_augmentations(num_images=8)

# Encoder architecture


In [None]:
# def LeNet5():
#     model = Sequential()
#     model.add(layers.Input(shape=(96, 96, 3)))
#     model.add(layers.Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='relu', padding="same"))
#     model.add(layers.AveragePooling2D(pool_size=(2, 2), strides=(1, 1), padding='valid'))
#     model.add(layers.Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='relu', padding='valid'))
#     model.add(layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))
#     model.add(layers.Conv2D(120, kernel_size=(5, 5), strides=(1, 1), activation='relu', padding='valid'))
#     model.add(layers.Flatten())
#     model.add(layers.Dense(128,activation='relu'))
  
#     return model

# encoder = LeNet5()

# def get_encoder():
#     return encoder

In [None]:
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [None]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

In [None]:
def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    # augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(inputs)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dense(128,activation='gelu')(representation)
    # Add MLP.
#     features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    # logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=representation)
    return model

In [None]:
encoder = create_vit_classifier()
def get_encoder():
    return encoder

In [None]:
encoder.summary()

# Self-supervised model for contrastive pretraining


In [None]:
# Define the contrastive model with model-subclassing
class ContrastiveModel(keras.Model):
    def call(self, inputs):
        x = self.encoder(inputs)
        x = self.projection_head(x)
        x = self.linear_probe(x)
        return x
    
    def __init__(self):
        super().__init__()

        self.temperature = temperature
        self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
        self.classification_augmenter = get_augmenter(**classification_augmentation)
        self.encoder = get_encoder()
        # Non-linear MLP as projection head
        self.projection_head = keras.Sequential(
            [
                keras.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        # Single dense layer for linear probing
        self.linear_probe = keras.Sequential(
            [layers.Input(shape=(width,)), layers.Dense(class_num)], name="linear_probe"
        )

        self.encoder.summary()
        self.projection_head.summary()
        self.linear_probe.summary()

    def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)

        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer

        # self.contrastive_loss will be defined as a method
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        self.contrastive_loss_tracker = keras.metrics.Mean(name="c_loss")
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
            name="c_acc"
        )
        self.probe_loss_tracker = keras.metrics.Mean(name="p_loss")
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")

    @property
    def metrics(self):
        return [
            self.contrastive_loss_tracker,
            self.contrastive_accuracy,
            self.probe_loss_tracker,
            self.probe_accuracy,
        ]

    def contrastive_loss(self, projections_1, projections_2):
        # InfoNCE loss (information noise-contrastive estimation)
        # NT-Xent loss (normalized temperature-scaled cross entropy)

        # Cosine similarity: the dot product of the l2-normalized feature vectors
        projections_1 = tf.math.l2_normalize(projections_1, axis=1)
        projections_2 = tf.math.l2_normalize(projections_2, axis=1)
        similarities = (
            tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
        )

        # The similarity between the representations of two augmented views of the
        # same image should be higher than their similarity with other views
        batch_size = tf.shape(projections_1)[0]
        contrastive_labels = tf.range(batch_size)
        self.contrastive_accuracy.update_state(contrastive_labels, similarities)
        self.contrastive_accuracy.update_state(
            contrastive_labels, tf.transpose(similarities)
        )

        # The temperature-scaled similarities are used as logits for cross-entropy
        # a symmetrized version of the loss is used here
        loss_1_2 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, similarities, from_logits=True
        )
        loss_2_1 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, tf.transpose(similarities), from_logits=True
        )
        return (loss_1_2 + loss_2_1) / 2

    def train_step(self, data):
        (unlabeled_images, _), (labeled_images, labels) = data

        # Both labeled and unlabeled images are used, without labels
        images = tf.concat((unlabeled_images, labeled_images), axis=0)
        # Each image is augmented twice, differently
        augmented_images_1 = self.contrastive_augmenter(images, training=True)
        augmented_images_2 = self.contrastive_augmenter(images, training=True)
        with tf.GradientTape() as tape:
            features_1 = self.encoder(augmented_images_1, training=True)
            features_2 = self.encoder(augmented_images_2, training=True)
            # The representations are passed through a projection mlp
            projections_1 = self.projection_head(features_1, training=True)
            projections_2 = self.projection_head(features_2, training=True)
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        self.contrastive_loss_tracker.update_state(contrastive_loss)

        # Labels are only used in evalutation for an on-the-fly logistic regression
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=True
        )
        with tf.GradientTape() as tape:
            # the encoder is used in inference mode here to avoid regularization
            # and updating the batch normalization paramers if they are used
            features = self.encoder(preprocessed_images, training=False)
            class_logits = self.linear_probe(features, training=True)
            probe_loss = self.probe_loss(labels, class_logits)
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        labeled_images, labels = data

        # For testing the components are used with a training=False flag
        preprocessed_images = self.classification_augmenter(
            labeled_images, training=False
        )
        features = self.encoder(preprocessed_images, training=False)
        class_logits = self.linear_probe(features, training=False)
        probe_loss = self.probe_loss(labels, class_logits)
        self.probe_loss_tracker.update_state(probe_loss)
        self.probe_accuracy.update_state(labels, class_logits)

        # Only the probe metrics are logged at test time
        return {m.name: m.result() for m in self.metrics[2:]}


# Contrastive pretraining
pretraining_model = ContrastiveModel()
pretraining_model.compile(
    contrastive_optimizer=keras.optimizers.Adam(),
    probe_optimizer=keras.optimizers.Adam(),
)

pretraining_history = pretraining_model.fit(
    train_dataset, epochs=num_epochs, validation_data=test_dataset
)
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(pretraining_history.history["val_p_acc"]) * 100
    )
)

# Supervised finetuning of the pretrained encoder


In [None]:
# v4
def ft_model(pretraining_model):
    input = layers.Input(shape=(image_size, image_size, image_channels))
    x = get_augmenter(**classification_augmentation)(input)
    x = pretraining_model.encoder(x)
    x = layers.Dense(256,activation='relu')(x)
    x = layers.Dense(256,activation='relu')(x)
    x = layers.Dense(256,activation='relu')(x)
    
    x = layers.Dense(512,activation='relu')(x)
    x = layers.Dense(512,activation='relu')(x)

    x = layers.Dense(1024,activation='relu')(x)
    x = layers.Dense(1024,activation='relu')(x)
    
    x = layers.Dense(512,activation='relu')(x)
    x = layers.Dense(512,activation='relu')(x)

    x = layers.Dense(256,activation='relu')(x)
    x = layers.Dense(256,activation='relu')(x)
    x = layers.Dense(256,activation='relu')(x)


    output = layers.Dense(class_num)(x)
    finetuning_model = Model(inputs=input, outputs = output,name="finetuning_{}".format(pretraining_model.encoder.name))
    
    return finetuning_model

In [None]:
finetuning_model = ft_model(pretraining_model)
finetuning_model.summary()

In [None]:
finetuning_model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)
# labeled_train_dataset train_datasets_ssl
finetuning_history = finetuning_model.fit(
    labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)
print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(finetuning_history.history["val_acc"]) * 100
    )
)

# Print the results

In [None]:
# get all labels of test dataset.
_, true_labels = tuple(zip(*test_dataset))
# true_labels = np.array(true_labels).flatten()
true_labels = np.hstack(true_labels)

## Supervised Fine-tuned model

In [None]:
preds_f = finetuning_model.predict(test_dataset)
pred_results_f = np.argmax(preds_f, axis=1)
pre_f, rec_f, f1s_f, _ = precision_recall_fscore_support(true_labels, pred_results_f, average='macro')
print('Fine-tuned model performance: ')
print(classification_report(true_labels, pred_results_f, digits=3))

50%
Fine-tuned model performance: 
              precision    recall  f1-score   support

           0      1.000     0.731     0.844        52
           1      0.879     1.000     0.936       109
           2      1.000     0.989     0.994        87
           3      1.000     1.000     1.000        78

    accuracy                          0.954       326
   macro avg      0.970     0.930     0.944       326
weighted avg      0.960     0.954     0.952       326

25%

Fine-tuned model performance: 
              precision    recall  f1-score   support

           0      0.604     0.558     0.580        52
           1      0.756     0.826     0.789       109
           2      0.976     0.943     0.960        88
           3      1.000     0.962     0.980        78

    accuracy                          0.847       327
   macro avg      0.834     0.822     0.827       327
weighted avg      0.849     0.847     0.847       327

10%

Fine-tuned model performance: 
              precision    recall  f1-score   support

           0      0.600     0.577     0.588        52
           1      0.721     0.807     0.762       109
           2      1.000     0.851     0.919        87
           3      0.975     1.000     0.987        78

    accuracy                          0.828       326
   macro avg      0.824     0.809     0.814       326
weighted avg      0.837     0.828     0.830       326

1%

Fine-tuned model performance: 
              precision    recall  f1-score   support

           0      0.000     0.000     0.000        52
           1      0.671     0.881     0.762       109
           2      0.743     0.862     0.798        87
           3      0.951     1.000     0.975        78

    accuracy                          0.764       326
   macro avg      0.591     0.686     0.634       326
weighted avg      0.650     0.764     0.701       326

# Comparison against the baseline


In [None]:
# The classification accuracies of the baseline and the pretraining + finetuning process:
def plot_training_curves(pretraining_history, finetuning_history):
    for metric_key, metric_name in zip(["acc", "loss"], ["accuracy", "loss"]):
        plt.figure(figsize=(8, 5), dpi=100)
        plt.plot(
            pretraining_history.history[f"val_p_{metric_key}"],
            label="self-supervised pretraining",
        )
        plt.plot(
            finetuning_history.history[f"val_{metric_key}"],
            label="supervised finetuning",
        )
        plt.legend()
        plt.title(f"Classification {metric_name} during training")
        plt.xlabel("epochs")
        plt.ylabel(f"validation {metric_name}")


plot_training_curves(pretraining_history, finetuning_history)