1. Collect Data

In [None]:
# ================================
# 1. Install & Import Libraries
# ================================
!pip install kaggle tensorflow tensorflow-datasets matplotlib

In [None]:


import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# ================================
# 2. Download Chest X-ray Dataset
# (positive class = X-ray)
# ================================
# Make sure you already have Kaggle API key set up (~/.kaggle/kaggle.json)

!kaggle datasets download -d paultimothymooney/chest-xray-pneumonia -p ./data
!unzip -q ./data/chest-xray-pneumonia.zip -d ./data

# We will use "chest_xray/train" folder
xray_dir = "./data/chest_xray/train"

# ================================
# 3. Get Non-X-ray Dataset
# (negative class = natural images)
# ================================
# We'll use CIFAR-10 dataset as "non-xray" proxy
cifar, info = tfds.load("cifar10", split="train", with_info=True, as_supervised=True)

def resize_and_save(dataset, outdir, limit=5000):
    os.makedirs(outdir, exist_ok=True)
    i = 0
    for img, label in tfds.as_numpy(dataset):
        if i >= limit:
            break
        img_pil = tf.keras.preprocessing.image.array_to_img(img)
        img_pil = img_pil.resize((224, 224))
        img_pil.save(os.path.join(outdir, f"nonxray_{i}.jpg"))
        i += 1

nonxray_dir = "./data/non_xray"
resize_and_save(cifar, nonxray_dir, limit=5000)

# ================================
# 4. Create Train/Validation Folders
# ================================
base_dir = "./data/xray_vs_nonxray"
train_dir = os.path.join(base_dir, "train")
val_dir = os.path.join(base_dir, "val")

for d in [train_dir, val_dir]:
    os.makedirs(os.path.join(d, "xray"), exist_ok=True)
    os.makedirs(os.path.join(d, "nonxray"), exist_ok=True)

# Copy some chest xray images into "xray"
!cp -r ./data/chest_xray/train/NORMAL/* $train_dir/xray/
!cp -r ./data/chest_xray/train/PNEUMONIA/* $train_dir/xray/
!cp -r ./data/chest_xray/val/NORMAL/* $val_dir/xray/
!cp -r ./data/chest_xray/val/PNEUMONIA/* $val_dir/xray/

# Copy CIFAR-10 images into "nonxray"
!cp ./data/non_xray/*.jpg $train_dir/nonxray/
!cp ./data/non_xray/*.jpg $val_dir/nonxray/

# ================================
# 5. Data Generators
# ================================
img_size = (224, 224)
batch_size = 32

train_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    rotation_range=20,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.2,
    horizontal_flip=True,
)

train_gen = train_datagen.flow_from_directory(
    base_dir + "/train",
    target_size=img_size,
    batch_size=batch_size,
    class_mode="binary",
    subset="training"
)

val_gen = train_datagen.flow_from_directory(
    base_dir + "/train",
    target_size=img_size,
    batch_size=batch_size,
    class_mode="binary",
    subset="validation"
)

# ================================
# 6. Build Model
# ================================
base_model = MobileNetV2(weights="imagenet", include_top=False, input_shape=(224,224,3))
base_model.trainable = False  # Freeze base

x = GlobalAveragePooling2D()(base_model.output)
x = Dense(1, activation="sigmoid")(x)

xray_filter_model = Model(inputs=base_model.input, outputs=x)

xray_filter_model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])

# ================================
# 7. Train Model
# ================================
history = xray_filter_model.fit(
    train_gen,
    validation_data=val_gen,
    epochs=5
)

# ================================
# 8. Evaluate Model
# ================================
plt.plot(history.history["accuracy"], label="train_acc")
plt.plot(history.history["val_accuracy"], label="val_acc")
plt.legend()
plt.show()

# ================================
# 9. Save Model
# ================================
os.makedirs("models", exist_ok=True)
xray_filter_model.save("models/xray_filter.h5")

print("✅ Model saved as models/xray_filter.h5")
