In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install keras-cv

In [None]:
import contrastive_trainer

In [None]:
import keras
import keras_cv
import tensorflow as tf
import tensorflow_datasets as tfds
from matplotlib import pyplot as plt

tfds.disable_progress_bar()

Using TensorFlow backend


In [None]:
!unzip -u "/content/drive/MyDrive/hist.zip" -d "/content/cov"

In [None]:
!pip install split-folders
import splitfolders
import os
input_folder="/content/cov/BreaKHis 400X/test"
output="/content/output"
splitfolders.ratio(input_folder, output, seed=42, ratio=(.75,0,.25)) ### train 75%, val 10%, test 15%

Collecting split-folders
  Downloading split_folders-0.5.1-py3-none-any.whl (8.4 kB)
Installing collected packages: split-folders
Successfully installed split-folders-0.5.1


Copying files: 545 files [00:00, 1340.78 files/s]


In [None]:
IMAGE_SIZE = 32
IMAGE_CHANNELS = 3
NUM_CLASSES = 2

UNLABELED_BATCH_SIZE = 1024
LABELED_BATCH_SIZE = 128
TEST_BATCH_SIZE = 128
PROJECTION_WIDTH = 128
TEMPERATURE = 0.1

CONTRASTIVE_AUGMENTATION = {
    "crop_area_factor": (0.08, 1.0),
    "aspect_ratio_factor": (3 / 4, 4 / 3),
    "color_jitter_rate": 0.8,
    "brightness_factor": 0.2,
    "contrast_factor": 0.8,
    "saturation_factor": (0.3, 0.7),
    "hue_factor": 0.2,
}

CLASSIFICATION_AUGMENTATION = {
    "crop_area_factor": (0.8, 1.0),
    "aspect_ratio_factor": (3 / 4, 4 / 3),
    "color_jitter_rate": 0.05,
    "brightness_factor": 0.1,
    "contrast_factor": 0.1,
    "saturation_factor": (0.1, 0.1),
    "hue_factor": 0.2,
}

AUTOTUNE = tf.data.AUTOTUNE

In [None]:
def prepare_dataset():
  unlabeled_train_dataset = (
        tf.keras.utils.image_dataset_from_directory(
            "/content/cov/BreaKHis 400X/train/malignant",
            label_mode=None,
            image_size=(IMAGE_SIZE, IMAGE_SIZE),
            batch_size=UNLABELED_BATCH_SIZE,
            shuffle=True,
        )
        .prefetch(AUTOTUNE)
    )

  labeled_train_dataset = (
      tf.keras.utils.image_dataset_from_directory(
          "/content/output/train",
          label_mode="categorical",
          image_size=(IMAGE_SIZE, IMAGE_SIZE),
          batch_size=LABELED_BATCH_SIZE,
          shuffle=True,
      )
      .prefetch(AUTOTUNE)
  )

  test_dataset = (
      tf.keras.utils.image_dataset_from_directory('/content/output/test',
          label_mode="categorical",
          image_size=(IMAGE_SIZE, IMAGE_SIZE),
          batch_size=TEST_BATCH_SIZE,
      )
      .prefetch(AUTOTUNE)
  )

  return unlabeled_train_dataset, labeled_train_dataset, test_dataset
unlabeled_train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()

Found 1148 files belonging to 1 classes.
Found 408 files belonging to 2 classes.
Found 137 files belonging to 2 classes.


In [None]:
def get_augmenter(
    crop_area_factor,
    aspect_ratio_factor,
    color_jitter_rate,
    brightness_factor,
    contrast_factor,
    saturation_factor,
    hue_factor,
):
    return keras.Sequential(
        [
            keras.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS)),
            keras_cv.layers.Rescaling(scale=1.0 / 255),
            keras_cv.layers.RandomFlip("horizontal"),
            keras_cv.layers.RandomApply(
                keras_cv.layers.RandomColorJitter(
                    value_range=(0, 1),
                    brightness_factor=brightness_factor,
                    contrast_factor=contrast_factor,
                    saturation_factor=saturation_factor,
                    hue_factor=hue_factor,
                ),
                rate=color_jitter_rate,
            ),
        ]
    )

In [None]:
# Original Images
unlabeled_images = next(iter(unlabeled_train_dataset))
keras_cv.visualization.plot_image_gallery(
    images=unlabeled_images,
    value_range=(0, 255),
    rows=3,
    cols=3,
)

In [None]:
# Contrastive Augmentations
contrastive_augmenter = get_augmenter(**CONTRASTIVE_AUGMENTATION)
augmented_images = contrastive_augmenter(unlabeled_images)
keras_cv.visualization.plot_image_gallery(
    images=augmented_images,
    value_range=(0, 1),
    rows=3,
    cols=3,
)

In [None]:
# Classification Augmentations
classification_augmenter = get_augmenter(**CLASSIFICATION_AUGMENTATION)
augmented_images = classification_augmenter(unlabeled_images)
keras_cv.visualization.plot_image_gallery(
    images=augmented_images,
    value_range=(0, 1),
    rows=3,
    cols=3,
)

In [None]:
def get_encoder():
    return keras.Sequential(
        [
            keras.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS)),
            keras_cv.models.ResNet18Backbone(include_rescaling=False),
            keras.layers.GlobalAveragePooling2D(name="max_pooling"),
        ],
        name="encoder",
    )

In [None]:
labeled_train_dataset = (
    tf.keras.utils.image_dataset_from_directory(
        "/content/output/train",
        label_mode="int",
        image_size=(IMAGE_SIZE, IMAGE_SIZE),
        batch_size=LABELED_BATCH_SIZE,
        shuffle=True,
    )
    .prefetch(AUTOTUNE)
)

Found 408 files belonging to 2 classes.


In [None]:
test_dataset = (
    tf.keras.utils.image_dataset_from_directory(
        "/content/output/test",
        label_mode="int",
        image_size=(IMAGE_SIZE, IMAGE_SIZE),
        batch_size=LABELED_BATCH_SIZE,
        shuffle=True,
    )
    .prefetch(AUTOTUNE)
)

Found 137 files belonging to 2 classes.


In [None]:
baseline_model = keras.Sequential(
    [
        keras.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS)),
        get_augmenter(**CLASSIFICATION_AUGMENTATION),
        get_encoder(),
        keras.layers.Dense(NUM_CLASSES),
    ],
    name="baseline_model",
)
baseline_model.compile(
    optimizer=keras.optimizers.Nadam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

baseline_history = baseline_model.fit(
    labeled_train_dataset, epochs=100, validation_data=test_dataset
)

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(baseline_history.history["val_acc"]) * 100
    )
)

In [None]:
test_loss, test_accuracy = baseline_model.evaluate(test_dataset)

print("Test Loss: {:.4f}".format(test_loss))
print("Test Accuracy: {:.2f}%".format(test_accuracy * 100))

Test Loss: 0.1428
Test Accuracy: 94.33%


In [None]:
from contrastive_trainer import ContrastiveTrainer

In [None]:
class SimCLRTrainer(ContrastiveTrainer):
    def __init__(self, encoder, augmenter, projector, probe=None, **kwargs):
        super().__init__(
            encoder=encoder,
            augmenter=augmenter,
            projector=projector,
            probe=probe,
            **kwargs,
        )


simclr_model = SimCLRTrainer(
    encoder=get_encoder(),
    augmenter=get_augmenter(**CONTRASTIVE_AUGMENTATION),
    projector=keras.Sequential(
        [
            keras.layers.Dense(PROJECTION_WIDTH, activation="elu"),
            keras.layers.Dense(PROJECTION_WIDTH),
            keras.layers.BatchNormalization(),
        ],
        name="projector",
    ),
)

simclr_model.compile(
    encoder_optimizer=keras.optimizers.Adam(),
    encoder_loss=keras_cv.losses.SimCLRLoss(
        temperature=TEMPERATURE,
    ),
)

simclr_history = simclr_model.fit(
    unlabeled_train_dataset,
    epochs=50,
)

In [None]:
finetune_model = keras.Sequential(
    [
        keras.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS)),
        get_augmenter(**CLASSIFICATION_AUGMENTATION),
        simclr_model.encoder,
        keras.layers.Dense(NUM_CLASSES),
    ],
    name="finetuning_model",
)

#custom_learning_rate = 0.001  # You can adjust this value
#optimizer = keras.optimizers.Adam(learning_rate=custom_learning_rate)

finetune_model.compile(
    optimizer=keras.optimizers.Nadam(),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
)

finetune_history = finetune_model.fit(
    labeled_train_dataset, epochs=100, validation_data=test_dataset
)

print(
    "Maximal validation accuracy: {:.2f}%".format(
        max(finetune_history.history["val_acc"]) * 100
    )
)

In [None]:
# Evaluate the model on the test dataset to get the overall test accuracy
test_loss, test_accuracy = finetune_model.evaluate(test_dataset)

print("Test Loss: {:.4f}".format(test_loss))
print("Test Accuracy: {:.2f}%".format(test_accuracy * 100))

Test Loss: 0.6931
Test Accuracy: 98.18%
