In [None]:
# Import necessary libraries
import os
import glob
import numpy as np
import pandas as pd
import cv2
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from datetime import datetime
import tensorflow_hub as hub
import json
import xml.etree.ElementTree as ET
import time

# ---------------------------
# Configuration & Constants
# ---------------------------

MODEL_NAME = 'FasterRCNN_BreastCancer_OD'
BASE_DATASET_PATH = './k_CBIS-DDSM/'
CALC_METADATA_CSV_PATH = os.path.join(BASE_DATASET_PATH, 'calc_case(with_jpg_img).csv')
MASS_METADATA_CSV_PATH = os.path.join(BASE_DATASET_PATH, 'mass_case(with_jpg_img).csv')
ACTUAL_IMAGE_FILES_BASE_DIR = os.path.join(BASE_DATASET_PATH, 'jpg_img')

IMG_WIDTH, IMG_HEIGHT = 224, 224
BATCH_SIZE = 8
EPOCHS = 50
LEARNING_RATE = 1e-4
OUTPUT_DIR = os.path.join('./', f"run_{MODEL_NAME}_{IMG_WIDTH}_{BATCH_SIZE}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Assume each row in CSV has bounding box info: x_min, y_min, x_max, y_max
# Add those columns to your CSVs if not already present
BOX_COLUMNS = ['x_min', 'y_min', 'x_max', 'y_max']
CLASS_COLUMN = 'pathology'

# Label map
LABEL_MAP = {"BENIGN": 1, "MALIGNANT": 2}
NUM_CLASSES = len(LABEL_MAP)  # Background is class 0

# ---------------------------
# Load and Preprocess Data
# ---------------------------

def load_metadata():
    df_list = []
    for csv_path in [CALC_METADATA_CSV_PATH, MASS_METADATA_CSV_PATH]:
        if os.path.exists(csv_path):
            df = pd.read_csv(csv_path)
            df['case_type'] = os.path.basename(csv_path).split('_')[0].lower()
            df_list.append(df)
    return pd.concat(df_list, ignore_index=True)

def heuristic_find_image_path(row, base_dir):
    # Same logic as before to find actual image path
    pass  # You already have this function from your original code

def prepare_dataset(df):
    records = []
    for _, row in df.iterrows():
        img_path = heuristic_find_image_path(row, ACTUAL_IMAGE_FILES_BASE_DIR)
        if img_path is None:
            continue
        # Ensure bounding box values are available
        if any(row[col] is None for col in BOX_COLUMNS):
            continue
        records.append({
            'image_path': img_path,
            'bboxes': [row[c] for c in BOX_COLUMNS],
            'label': LABEL_MAP[row[CLASS_COLUMN]]
        })
    return records

print("Loading metadata...")
metadata_df = load_metadata()
metadata_df.dropna(subset=BOX_COLUMNS, inplace=True)
dataset_records = prepare_dataset(metadata_df)

train_records, test_records = train_test_split(dataset_records, test_size=0.2, random_state=42)
train_records, val_records = train_test_split(train_records, test_size=0.1, random_state=42)

print(f"Total samples: {len(dataset_records)}")
print(f"Train: {len(train_records)}, Val: {len(val_records)}, Test: {len(test_records)}")

# ---------------------------
# Data Loading and Augmentation
# ---------------------------

def parse_example(record):
    image = tf.io.read_file(record['image_path'])
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (IMG_HEIGHT, IMG_WIDTH))
    image /= 255.0  # Normalize to [0, 1]

    bboxes = tf.constant(record['bboxes'], dtype=tf.float32)
    bboxes /= tf.constant([image.shape[1], image.shape[0], image.shape[1], image.shape[0]], dtype=tf.float32)
    classes = tf.constant([record['label']], dtype=tf.int64)

    return image, {
        'boxes': bboxes[tf.newaxis, :],
        'classes': classes
    }

def create_dataset(records):
    def gen():
        for r in records:
            yield parse_example(r)

    dataset = tf.data.Dataset.from_generator(
        gen,
        output_types=(tf.float32, {'boxes': tf.float32, 'classes': tf.int64}),
        output_shapes=(
            (IMG_HEIGHT, IMG_WIDTH, 3),
            {'boxes': (1, 4), 'classes': (1,)}
        )
    )
    return dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

train_dataset = create_dataset(train_records)
val_dataset = create_dataset(val_records)
test_dataset = create_dataset(test_records)

# ---------------------------
# Load Pretrained Detection Model
# ---------------------------

print("Loading Faster R-CNN model from TF Hub...")
hub_url = "https://tfhub.dev/tensorflow/faster_rcnn/resnet50_v1_800x1333/1" 
model = hub.load(hub_url)

class DetectionModel(Model):
    def __init__(self, model_url, num_classes, **kwargs):
        super().__init__(**kwargs)
        self.model = hub.load(model_url)

    @tf.function
    def call(self, inputs, training=None, mask=None):
        return self.model(inputs)

detection_model = DetectionModel(hub_url, NUM_CLASSES + 1)  # + background

# ---------------------------
# Training Loop
# ---------------------------

optimizer = Adam(learning_rate=LEARNING_RATE)

@tf.function
def train_step(images, targets):
    with tf.GradientTape() as tape:
        outputs = detection_model(images, training=True)
        loss_dict = outputs['losses']
        total_loss = sum(loss_dict.values())
    gradients = tape.gradient(total_loss, detection_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, detection_model.trainable_variables))
    return loss_dict

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    for step, (images, targets) in enumerate(train_dataset):
        loss_dict = train_step(images, targets)
        if step % 10 == 0:
            print(f"Step {step}: Losses: {loss_dict}")

# ---------------------------
# Evaluation on Test Set
# ---------------------------

def evaluate(model, dataset):
    all_detections = []
    for images, _ in dataset:
        outputs = model(images, training=False)
        all_detections.extend(outputs)
    return all_detections

print("\nEvaluating model...")
detections = evaluate(detection_model, test_dataset)

# Save model
model_save_path = os.path.join(OUTPUT_DIR, "faster_rcnn_breast_cancer_od.keras")
tf.saved_model.save(detection_model, model_save_path)
print(f"Model saved to {model_save_path}")