In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import random
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os

# 1. Dataset Paths 
train_dir = 'chest_xray//train'
test_dir = 'chest_xray//test'

# 2. Image Preprocessing
img_height, img_width = 150, 150
batch_size = 32

data_generator = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

train_data = data_generator.flow_from_directory(
    train_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='binary'
)

test_data = data_generator.flow_from_directory(
    test_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='binary'
)

original_labels = train_data.classes.copy() #Create a copy of the original labels.

# 3. Label-Flip Poisoning
def apply_label_flip(labels, flip_rate=0.05):
    poisoned_labels = labels.copy()
    num_flips = int(len(labels) * flip_rate)
    flip_indices = random.sample(range(len(labels)), num_flips)
    for idx in flip_indices:
        poisoned_labels[idx] = 1 - poisoned_labels[idx]
    return poisoned_labels, flip_indices

poisoned_labels, flipped_indices = apply_label_flip(original_labels)

# 4. Baseline Model Construction and Training
def build_cnn(input_shape=(img_height, img_width, 3)):
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D(2, 2),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D(2, 2),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

baseline_model = build_cnn()

train_data.classes = poisoned_labels #Overwrite the generator labels.

baseline_model.fit(
    train_data,
    epochs=10,
    validation_data=test_data
)

# 5. Anomaly Detection and Refinement.
def identify_label_discrepancies(original, poisoned):
    discrepancies = [i for i, (o, p) in enumerate(zip(original, poisoned)) if o != p]
    return discrepancies

detected_anomalies = identify_label_discrepancies(original_labels, poisoned_labels)

# 6. Robust Model Training.
train_data.classes = original_labels #Restore the original labels.
refined_indices = [i for i in range(len(original_labels)) if i not in detected_anomalies]

refined_data_generator = ImageDataGenerator(rescale=1./255)

def refined_generator(directory, target_size, batch_size, class_mode, indices):
    generator = refined_data_generator.flow_from_directory(
        directory,
        target_size=target_size,
        batch_size=batch_size,
        class_mode=class_mode,
        shuffle=False
    )
    for i in range(len(generator.filenames)):
        if i not in indices:
            generator.filenames[i] = None
    return generator

refined_train = refined_generator(train_dir, (img_height, img_width), batch_size, 'binary', refined_indices)

robust_model = build_cnn()
robust_model.fit(refined_train, epochs=10, validation_data=test_data)

# 7. Evaluation
_, baseline_accuracy = baseline_model.evaluate(test_data)
_, robust_accuracy = robust_model.evaluate(test_data)

print(f"Baseline Accuracy (Poisoned): {baseline_accuracy}")
print(f"Robust Model Accuracy: {robust_accuracy}")

Found 5216 images belonging to 2 classes.
Found 624 images belonging to 2 classes.


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  self._warn_if_super_not_called()


Epoch 1/10
[1m163/163[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m84s[0m 506ms/step - accuracy: 0.6856 - loss: 0.6674 - val_accuracy: 0.6346 - val_loss: 0.6022
Epoch 2/10
[1m163/163[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 568ms/step - accuracy: 0.7617 - loss: 0.4886 - val_accuracy: 0.7404 - val_loss: 0.5041
Epoch 3/10
[1m163/163[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m106s[0m 652ms/step - accuracy: 0.8227 - loss: 0.4303 - val_accuracy: 0.6683 - val_loss: 0.6220
Epoch 4/10
[1m163/163[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 641ms/step - accuracy: 0.8325 - loss: 0.3999 - val_accuracy: 0.7676 - val_loss: 0.5068
Epoch 5/10
[1m163/163[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m101s[0m 621ms/step - accuracy: 0.8450 - loss: 0.3899 - val_accuracy: 0.7901 - val_loss: 0.4297
Epoch 6/10
[1m163/163[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m99s[0m 610ms/step - accuracy: 0.8550 - loss: 0.3710 - val_accuracy: 0.8494 - val_loss: 0.3776
Epoch 7