In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os

def get_generators(base_dir, img_size=(224,224), batch_size=32):
    train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.2,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    fill_mode="nearest"
    )
    val_datagen = ImageDataGenerator(rescale=1./255)
    test_datagen = ImageDataGenerator(rescale=1./255)

    train_gen = train_datagen.flow_from_directory(
        os.path.join(base_dir,'train'),
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical'
    )
    val_gen = val_datagen.flow_from_directory(
        os.path.join(base_dir,'val'),
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical'
    )
    test_gen = test_datagen.flow_from_directory(
        os.path.join(base_dir,'test'),
        target_size=img_size,
        batch_size=batch_size,
        class_mode='categorical',
        shuffle=False
    )

    return train_gen, val_gen, test_gen


In [None]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.optimizers import Adam

def get_resnet_model(input_shape=(224,224,3), num_classes=3):
    base_model = ResNet50(weights="imagenet", include_top=False, input_shape=input_shape)
    for layer in base_model.layers:
        layer.trainable = False

    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(256, activation="relu")(x)
    x = Dropout(0.5)(x)
    preds = Dense(num_classes, activation="softmax")(x)

    model = Model(inputs=base_model.input, outputs=preds)
    model.compile(
        optimizer=Adam(learning_rate=1e-4),
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )
    return model

In [None]:
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

def train_model(model, train_gen, val_gen, epochs=20, model_path="saved_model/chestxray_resnet.h5"):
    checkpoint = ModelCheckpoint(model_path, monitor="val_accuracy", save_best_only=True, verbose=1)
    early_stop = EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True, verbose=1)

    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=epochs,
        callbacks=[checkpoint, early_stop]
    )

    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(1, len(acc)+1)

    plt.figure(figsize=(12,5))

    plt.subplot(1,2,1)
    plt.plot(epochs_range, acc, label="Train Accuracy")
    plt.plot(epochs_range, val_acc, label="Val Accuracy")
    plt.legend(loc="lower right")
    plt.title("Training vs Validation Accuracy")

    plt.subplot(1,2,2)
    plt.plot(epochs_range, loss, label="Train Loss")
    plt.plot(epochs_range, val_loss, label="Val Loss")
    plt.legend(loc="upper right")
    plt.title("Training vs Validation Loss")

    plt.savefig("saved_model/training_curves.png")
    plt.show()

    return history

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

def evaluate_model(model, test_gen):
    test_gen.reset()
    pred = model.predict(test_gen)
    pred_classes = np.argmax(pred, axis=1)
    true_classes = test_gen.classes
    class_labels = list(test_gen.class_indices.keys())

    print("Classification Report:")
    print(classification_report(true_classes, pred_classes, target_names=class_labels))

    cm = confusion_matrix(true_classes, pred_classes)
    print("\nConfusion Matrix:")
    print(cm)

    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_labels, yticklabels=class_labels)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.savefig("saved_model/confusion_matrix.png")
    plt.show()

In [None]:
from lib.data import get_generators
from lib.model import get_resnet_model
from lib.train import train_model
import os

data_dir = "data"
train_gen, val_gen, test_gen = get_generators(data_dir)

model = get_resnet_model()
history = train_model(model, train_gen, val_gen)

os.makedirs("saved_model", exist_ok=True)
model.save("saved_model/chestxray_resnet.h5")
print("Model saved at saved_model/chestxray_resnet.h5")

In [None]:
from lib.data import get_generators
from lib.evaluate import evaluate_model
from tensorflow.keras.models import load_model

data_dir = "data"
_, _, test_gen = get_generators(data_dir)

model = load_model("saved_model/chestxray_resnet.h5")
evaluate_model(model, test_gen)

In [None]:
import sys
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np

classes = ["NORMAL", "PNEUMONIA", "TUBERCULOSIS"]

def predict(img_path):
    model = load_model("saved_model/chestxray_resnet.h5")
    img = image.load_img(img_path, target_size=(224,224))
    x = image.img_to_array(img) / 255.0
    x = np.expand_dims(x, axis=0)
    pred = model.predict(x)
    print(f"Prediction: {classes[np.argmax(pred)]} ({pred[0]})")

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python -m scripts.infer <image_path>")
    else:
        print("Infer started")
        predict(sys.argv[1])