In [None]:
# -*- coding: utf-8 -*-
"""
AI-Driven Doctor Stress Detection and Shift Optimization Platform

This notebook implements an end-to-end system for:
1. Detecting doctor stress levels from facial expressions using deep learning.
2. Optimizing doctor shifts based on stress levels and operational constraints.

Author: Gemini (Acting as Senior Software Engineer & AI Expert)
Date: May 3, 2025
"""

# @title << 0. Setup: Install Dependencies and Configure APIs >>
# Install necessary libraries
!pip install -q kaggle pandas numpy scikit-learn tensorflow tf2onnx onnxruntime-gpu fastapi uvicorn pyngrok nest-asyncio pydantic[email] Pillow requests matplotlib seaborn opencv-python-headless mediapipe google-generativeai python-dotenv ortools albumentations ray[tune] optuna sqlalchemy pyodbc

import os
import sys
import json
import zipfile
import random
import shutil
import time
import datetime
import asyncio
import threading
from pathlib import Path
from getpass import getpass # For safer key handling in Colab

import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder, MinMaxScaler

# Deep Learning / ML
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, GRU, LSTM, TimeDistributed, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50, EfficientNetB0, MobileNetV3Small
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import albumentations as A
import mediapipe as mp
# import optuna # Or Ray Tune - Placeholder for HPO
# import tf2onnx
# import onnxruntime as ort
# import tensorrt as trt # Placeholder - Requires specific environment

# Shift Optimization
from ortools.sat.python import cp_model

# Backend & API
import fastapi
import uvicorn
from pydantic import BaseModel, EmailStr, Field
from pyngrok import ngrok # To expose Colab API endpoint
import nest_asyncio

# Gemini API
import google.generativeai as genai

# Database (conceptual schema)
import sqlalchemy as db
from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, DateTime, Float, ForeignKey

# Plotting and Utils
from IPython.display import display, Javascript, Image
from google.colab.output import eval_js
from base64 import b64decode, b64encode
import PIL.Image

print("TensorFlow Version:", tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
# Ensure GPU is utilized if available
if len(tf.config.list_physical_devices('GPU')) > 0:
    print("Using GPU")
    # Optional: Set memory growth to avoid allocating all GPU memory at once
    physical_devices = tf.config.list_physical_devices('GPU')
    try:
        for gpu in physical_devices:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)
else:
    print("Using CPU")

# --- Configuration ---

# !! IMPORTANT SECURITY NOTE !!
# Hardcoding API keys directly in code (as provided in the prompt) is highly insecure.
# In a real application, use environment variables, secret managers (like GCP Secret Manager or AWS Secrets Manager),
# or secure configuration files. For Colab, using Colab Secrets or getpass is better.

# Using placeholders - replace with your actual credentials or use Colab Secrets
# KAGGLE_USERNAME = "YOUR_KAGGLE_USERNAME" # Replace or use Colab secrets
# KAGGLE_KEY = "YOUR_KAGGLE_API_KEY"      # Replace or use Colab secrets
# GEMINI_API_KEY = "YOUR_GEMINI_API_KEY"  # Replace or use Colab secrets

# Using getpass for slightly better security in demo:
if 'KAGGLE_USERNAME' not in os.environ:
  os.environ['KAGGLE_USERNAME'] = getpass('Enter Kaggle Username: ')
if 'KAGGLE_KEY' not in os.environ:
  os.environ['KAGGLE_KEY'] = getpass('Enter Kaggle Key: ')
if 'GEMINI_API_KEY' not in os.environ:
  # Securely store the Gemini API key
  try:
      from google.colab import userdata
      GEMINI_API_KEY = userdata.get('GEMINI_API_KEY')
      if not GEMINI_API_KEY:
          GEMINI_API_KEY = getpass('Enter Gemini API Key: ')
          print("Gemini Key entered via getpass.")
      else:
          print("Gemini Key loaded from Colab Secrets.")
  except ImportError:
      GEMINI_API_KEY = getpass('Enter Gemini API Key: ')
      print("Gemini Key entered via getpass (Colab Secrets not available).")
  genai.configure(api_key=GEMINI_API_KEY)


# --- Kaggle API Setup ---
KAGGLE_CONFIG_DIR = os.path.join(Path.home(), '.kaggle')
os.makedirs(KAGGLE_CONFIG_DIR, exist_ok=True)
KAGGLE_JSON_PATH = os.path.join(KAGGLE_CONFIG_DIR, 'kaggle.json')

# Write kaggle.json if credentials were provided via getpass or env
if 'KAGGLE_USERNAME' in os.environ and 'KAGGLE_KEY' in os.environ:
    kaggle_credentials = {
        "username": os.environ['KAGGLE_USERNAME'],
        "key": os.environ['KAGGLE_KEY']
    }
    with open(KAGGLE_JSON_PATH, 'w') as f:
        json.dump(kaggle_credentials, f)
    os.chmod(KAGGLE_JSON_PATH, 600) # Set permissions
    print("Kaggle API configured.")
else:
    print("Kaggle credentials not found. Please provide them or upload kaggle.json manually.")

# --- Constants ---
DATA_DIR = Path("./data")
MODEL_DIR = Path("./models")
ONNX_MODEL_DIR = Path("./onnx_models")
LOG_DIR = Path("./logs")
IMG_SIZE = (224, 224) # Input size for models like ResNet, EfficientNet
BATCH_SIZE = 32
EPOCHS = 15 # Reduced for demo purposes; increase for real training (e.g., 50-100)
SEED = 42
STRESS_THRESHOLD = 0.6 # Example threshold for triggering alerts/optimization

# Create directories
DATA_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)
ONNX_MODEL_DIR.mkdir(exist_ok=True)
LOG_DIR.mkdir(exist_ok=True)

# Set random seeds for reproducibility
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
# @title << 1. Data Collection & Preprocessing >>

# --- 1.1 Dataset Download ---
# Define datasets to use. We'll use FER2013 as a base for emotions,
# and simulate or find a smaller stress-specific dataset.
# NOTE: Finding large, public, ethically sourced facial *stress* datasets is challenging.
# We will use FER2013 for emotion recognition which can be a *proxy* or feature for stress,
# and supplement with simulated data or a smaller dataset if available.

# Dataset Links (as requested):
# FER2013: https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge/data (Requires joining competition)
# AffectNet: http://mohammadmahoor.com/affectnet/ (Requires license agreement)
# RAF-DB: http://www.whdeng.cn/RAF/model1.html (Requires license agreement)
# Simulated Stress/Custom Data: Often needed for specific stress detection.

# Example: Download FER2013
# You might need to accept competition rules on Kaggle first.
print("Downloading FER2013 dataset...")
!kaggle competitions download -c challenges-in-representation-learning-facial-expression-recognition-challenge -p {DATA_DIR} --force

fer2013_zip = DATA_DIR / "challenges-in-representation-learning-facial-expression-recognition-challenge.zip"
fer2013_dir = DATA_DIR / "fer2013"

if fer2013_zip.exists():
    print("Extracting FER2013...")
    with zipfile.ZipFile(fer2013_zip, 'r') as zip_ref:
        zip_ref.extractall(fer2013_dir)
    # Clean up zip file
    # os.remove(fer2013_zip) # Keep the zip for record if needed
    print("FER2013 extracted.")
    # Check for the actual data file (often fer2013.csv)
    if not (fer2013_dir / 'fer2013.csv').exists():
         print(f"Warning: Expected 'fer2013.csv' not found in {fer2013_dir}. Check extraction.")
         # Attempt to find it in subdirectories if necessary
         potential_csv = list(fer2013_dir.glob('**/fer2013.csv'))
         if potential_csv:
             # Move csv to the main fer2013_dir for consistency
             shutil.move(str(potential_csv[0]), str(fer2013_dir / 'fer2013.csv'))
             print(f"Moved {potential_csv[0]} to {fer2013_dir}")
         else:
             print("Could not locate fer2013.csv. Manual check required.")

else:
    print("FER2013 download failed. Ensure you have joined the competition on Kaggle and API key is correct.")
    # As a fallback, we'll create dummy data structure later if download fails

# Placeholder for other datasets (AffectNet, RAF-DB, Stress Faces)
# These often require manual download and agreement forms.
# Example structure if downloaded manually:
# (DATA_DIR / "AffectNet").mkdir(exist_ok=True)
# (DATA_DIR / "RAF-DB").mkdir(exist_ok=True)
# (DATA_DIR / "StressFaces").mkdir(exist_ok=True)
print("Placeholder directories created for AffectNet, RAF-DB, StressFaces.")
print("Please download these manually if needed, respecting their licenses.")


# --- 1.2 Data Loading and Initial Exploration (FER2013 Example) ---
fer_csv_path = fer2013_dir / 'fer2013.csv'
if fer_csv_path.exists():
    print(f"Loading {fer_csv_path}...")
    try:
        fer_df = pd.read_csv(fer_csv_path)
        print("FER2013 Dataframe Info:")
        fer_df.info()
        print("\nFER2013 Emotion Distribution:")
        print(fer_df['emotion'].value_counts())
        # Emotion mapping: 0=Angry, 1=Disgust, 2=Fear, 3=Happy, 4=Sad, 5=Surprise, 6=Neutral
        emotion_map = {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 6: 'Neutral'}
        fer_df['emotion_label'] = fer_df['emotion'].map(emotion_map)

        # Display some sample images
        plt.figure(figsize=(12, 6))
        for i in range(5):
            idx = random.randint(0, len(fer_df) - 1)
            pixels = np.array(fer_df['pixels'][idx].split(), dtype='uint8')
            img = pixels.reshape(48, 48)
            plt.subplot(1, 5, i + 1)
            plt.imshow(img, cmap='gray')
            plt.title(f"Emotion: {fer_df['emotion_label'][idx]}")
            plt.axis('off')
        plt.suptitle("Sample FER2013 Images")
        plt.show()

    except Exception as e:
        print(f"Error loading or processing FER2013 CSV: {e}")
        fer_df = None # Indicate failure
else:
    print(f"FER2013 CSV not found at {fer_csv_path}. Cannot load data.")
    fer_df = None


# --- 1.3 Preprocessing Functions ---

# Face Detection (using MediaPipe)
mp_face_detection = mp.solutions.face_detection
mp_drawing = mp.solutions.drawing_utils
face_detector = mp_face_detection.FaceDetection(model_selection=1, min_detection_confidence=0.5)

def detect_face(image_bgr):
    """Detects the largest face in an image using MediaPipe."""
    results = face_detector.process(cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB))
    if not results.detections:
        return None, None # No face detected

    # Find the detection with the largest bounding box area
    best_detection = None
    max_area = 0
    ih, iw, _ = image_bgr.shape

    for detection in results.detections:
        bboxC = detection.location_data.relative_bounding_box
        bbox = int(bboxC.xmin * iw), int(bboxC.ymin * ih), \
               int(bboxC.width * iw), int(bboxC.height * ih)
        area = bbox[2] * bbox[3]
        if area > max_area:
            max_area = area
            best_detection = bbox

    x, y, w, h = best_detection
    # Ensure bbox coordinates are within image bounds and valid
    x = max(0, x)
    y = max(0, y)
    w = min(iw - x, w)
    h = min(ih - y, h)
    if w <= 0 or h <= 0:
        return None, None # Invalid bbox dimensions

    face_img = image_bgr[y:y+h, x:x+w]
    return face_img, best_detection

# Landmark Extraction (using MediaPipe Face Mesh) - Optional but useful feature
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, min_detection_confidence=0.5)

def extract_landmarks(image_rgb):
    """Extracts facial landmarks using MediaPipe Face Mesh."""
    results = face_mesh.process(image_rgb)
    if not results.multi_face_landmarks:
        return None
    # Assuming only one face
    landmarks = results.multi_face_landmarks[0].landmark
    # Convert landmarks to numpy array (x, y, z) - z is depth estimate
    landmark_points = np.array([[lm.x, lm.y, lm.z] for lm in landmarks])
    return landmark_points

# rPPG Estimation (Placeholder - Very Complex)
# Real rPPG requires analyzing subtle color changes in skin over time (video).
# Requires advanced signal processing (e.g., ICA, CHROM, POS methods).
# Here, we'll just have a placeholder function.
def estimate_rppg(face_video_frames):
    """Placeholder for rPPG estimation. Returns a simulated heart rate variability metric."""
    # In reality: process frame sequence, ROI selection (forehead/cheeks),
    # color channel analysis, filtering, FFT/peak detection.
    print("Warning: rPPG estimation is complex and simulated here.")
    # Simulate some HRV metric based on number of frames or random noise
    simulated_sdnn = np.random.uniform(30, 80) # Example: SDNN (standard deviation of NN intervals)
    return simulated_sdnn

# Normalization and Resizing
def preprocess_image(image, target_size):
    """Resizes and normalizes image."""
    img_resized = cv2.resize(image, target_size)
    img_normalized = img_resized / 255.0 # Scale pixels to [0, 1]
    return img_normalized.astype(np.float32)


# --- 1.4 Data Augmentation ---
# Using Albumentations for flexibility
# Geometric + Photometric
transform_train = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=15, p=0.5, border_mode=cv2.BORDER_CONSTANT),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.6),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.4),
    A.MotionBlur(blur_limit=7, p=0.3),
    # A.CoarseDropout(max_holes=8, max_height=8, max_width=8, min_holes=1, min_height=4, min_width=4, p=0.3), # Cutout
    # Add more augmentations as needed
])

# Basic transform for validation/test (only resizing/normalization is done in preprocess_image)
transform_val = A.Compose([]) # Usually no augmentation for validation/test

# Temporal Smoothing (Conceptual - Applied during sequence processing)
# Could involve averaging features/predictions over a short time window.

# GAN-based Oversampling (Advanced - Placeholder)
# Requires training a GAN (e.g., StyleGAN, Diffusers) on specific stress expressions
# or using augmentation techniques like MixAugment.
def apply_mixaugment(images, labels):
    """Placeholder for GAN-based/advanced oversampling."""
    print("Note: GAN-based oversampling (MixAugment/Diffusion) requires separate complex setup.")
    # Could involve generating synthetic stressed faces or interpolating between samples.
    return images, labels

# --- 1.5 Prepare Data for Models ---
# Process FER2013 (if loaded) into images and labels suitable for TF/Keras
# We will map FER emotions to a simplified Stress/NoStress binary classification for this demo.
# Mapping: Angry, Fear, Sad -> Stress (1); Happy, Neutral, Surprise -> NoStress (0); Disgust -> Exclude (often low samples)

def prepare_fer_data(df, target_size, stress_map, use_face_detection=True):
    images = []
    labels = []
    skipped_count = 0

    if df is None:
      print("FER DataFrame is None. Skipping preparation.")
      return np.array([]), np.array([]) # Return empty arrays

    print(f"Preparing data with target size {target_size}...")
    for index, row in df.iterrows():
        emotion = row['emotion']
        if emotion == 1: # Skip 'Disgust'
             skipped_count += 1
             continue

        target_label = stress_map.get(emotion)
        if target_label is None: # Should not happen with the map defined below
            skipped_count += 1
            continue

        pixels = np.array(row['pixels'].split(), dtype='uint8')
        img = pixels.reshape(48, 48)
        img_bgr = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) # Convert to 3 channels for models like ResNet

        face = img_bgr # Default to original if face detection fails or is off
        if use_face_detection:
            detected_face, _ = detect_face(img_bgr)
            if detected_face is not None and detected_face.size > 0:
                face = detected_face
            else:
                # Optional: Skip if face not detected, or use original
                # skipped_count += 1
                # continue
                pass # Use original 48x48 converted image if face detection fails

        # Preprocess: Resize and normalize
        processed_face = preprocess_image(face, target_size)
        images.append(processed_face)
        labels.append(target_label)

        if (index + 1) % 5000 == 0:
            print(f"Processed {index + 1} / {len(df)} images...")

    print(f"Finished processing. Skipped {skipped_count} images.")
    return np.array(images), np.array(labels)

# Define the stress mapping
# 0=Angry, 2=Fear, 4=Sad -> Stress (1)
# 3=Happy, 5=Surprise, 6=Neutral -> NoStress (0)
stress_mapping = {0: 1, 2: 1, 4: 1, 3: 0, 5: 0, 6: 0}

# Prepare the data (Run this only if fer_df loaded successfully)
if fer_df is not None:
    X, y = prepare_fer_data(fer_df, IMG_SIZE, stress_mapping, use_face_detection=False) # Turn off face detection for FER as it's already cropped
    print(f"Data shapes: X={X.shape}, y={y.shape}")

    if X.size > 0 and y.size > 0:
      # Split data: 70% Train, 15% Validation, 15% Test
      X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=SEED, stratify=y)
      X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=SEED, stratify=y_temp) # 0.5 * 0.3 = 0.15

      print(f"Train set: X={X_train.shape}, y={y_train.shape}")
      print(f"Validation set: X={X_val.shape}, y={y_val.shape}")
      print(f"Test set: X={X_test.shape}, y={y_test.shape}")

      # Check distribution
      print("\nTrain label distribution:", np.bincount(y_train))
      print("Validation label distribution:", np.bincount(y_val))
      print("Test label distribution:", np.bincount(y_test))

      # Create TensorFlow Datasets for efficiency (optional but recommended)
      # We can apply augmentations using tf.data or a Keras layer/ImageDataGenerator
      # For simplicity with Albumentations, we'll use a custom generator later if needed,
      # or apply augmentations directly to the numpy arrays (can be memory intensive).

      # Simple Keras ImageDataGenerator for augmentation
      train_datagen = ImageDataGenerator(
          # Use Albumentations via lambda func or preprocess_input
          # preprocessing_function=lambda x: transform_train(image=x)['image'], # Can be slow
          # Or use Keras's built-in ones (less variety than Albumentations)
          rotation_range=15,
          width_shift_range=0.1,
          height_shift_range=0.1,
          shear_range=0.1,
          zoom_range=0.1,
          horizontal_flip=True,
          fill_mode='nearest'
      )
      # No augmentation for validation/test data, only rescaling (already done)
      val_datagen = ImageDataGenerator() # Rescaling done in preprocess_image

      train_generator = train_datagen.flow(X_train, y_train, batch_size=BATCH_SIZE, seed=SEED)
      validation_generator = val_datagen.flow(X_val, y_val, batch_size=BATCH_SIZE, seed=SEED)
      test_generator = val_datagen.flow(X_test, y_test, batch_size=BATCH_SIZE, shuffle=False) # No shuffle for test evaluation

    else:
        print("Data preparation resulted in empty arrays. Cannot proceed with training.")
        # Create dummy data for code structure continuation
        X_train, y_train = np.random.rand(100, IMG_SIZE[0], IMG_SIZE[1], 3), np.random.randint(0, 2, 100)
        X_val, y_val = np.random.rand(20, IMG_SIZE[0], IMG_SIZE[1], 3), np.random.randint(0, 2, 20)
        X_test, y_test = np.random.rand(20, IMG_SIZE[0], IMG_SIZE[1], 3), np.random.randint(0, 2, 20)
        print("Using dummy data for demonstration.")
        # Need dummy generators too
        train_generator = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
        validation_generator = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
        test_generator = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(BATCH_SIZE) # No shuffle for test evaluation


else:
    print("FER DataFrame not loaded. Cannot prepare data.")
    # Create dummy data for code structure continuation
    X_train, y_train = np.random.rand(100, IMG_SIZE[0], IMG_SIZE[1], 3), np.random.randint(0, 2, 100)
    X_val, y_val = np.random.rand(20, IMG_SIZE[0], IMG_SIZE[1], 3), np.random.randint(0, 2, 20)
    X_test, y_test = np.random.rand(20, IMG_SIZE[0], IMG_SIZE[1], 3), np.random.randint(0, 2, 20)
    print("Using dummy data for demonstration.")
    # Need dummy generators too
    train_generator = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    validation_generator = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    test_generator = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(BATCH_SIZE) # No shuffle for test evaluation


print("Data preparation and generator setup complete.")

In [None]:
# @title << 2. Model Development – Facial Stress Detection >>

# --- 2.1 Model Definitions ---

def build_resnet_gru_model(input_shape, num_classes, gru_units=128, freeze_base=True):
    """Builds a ResNet50 + GRU model for sequence/temporal analysis (if needed)
       or standard classification."""
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)

    if freeze_base:
        base_model.trainable = False
        print("ResNet50 base frozen.")
    else:
        # Optionally unfreeze some layers later for fine-tuning
        base_model.trainable = True
        print("ResNet50 base trainable (fine-tuning).")
        # Example: Fine-tune from conv5_block1_out onwards
        # fine_tune_at = 143 # Index of 'conv5_block1_out' layer
        # for layer in base_model.layers[:fine_tune_at]:
        #   layer.trainable = False


    inputs = Input(shape=input_shape)
    # Optional: Add TimeDistributed wrapper if input is sequence of images
    # x = TimeDistributed(base_model)(inputs)
    # x = TimeDistributed(GlobalAveragePooling2D())(x)
    # x = GRU(gru_units, return_sequences=False)(x) # Or LSTM

    # For single image classification:
    x = base_model(inputs, training=not freeze_base) # Use training=False if base is frozen and has BN layers
    x = GlobalAveragePooling2D()(x)
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)
    x = Dense(256, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='sigmoid' if num_classes == 1 else 'softmax')(x) # Sigmoid for binary

    model = Model(inputs=inputs, outputs=outputs)

    # Compile the model
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4) # Lower LR for fine-tuning if base is unfrozen
    loss = 'binary_crossentropy' if num_classes == 1 else 'sparse_categorical_crossentropy'
    metrics = ['accuracy',
               tf.keras.metrics.AUC(name='roc_auc'),
               tf.keras.metrics.Precision(name='precision'),
               tf.keras.metrics.Recall(name='recall')]
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

    print("ResNet50 + Dense model built.")
    model.summary()
    return model

def build_efficientnet_model(input_shape, num_classes, base_model_name='B0', freeze_base=True):
    """Builds an EfficientNet (B0-B7) or MobileNetV3 model."""
    if 'EfficientNet' in base_model_name:
        base_model_cls = getattr(tf.keras.applications, f"EfficientNet{base_model_name}")
        print(f"Using EfficientNet{base_model_name}")
    elif 'MobileNetV3' in base_model_name:
        base_model_cls = getattr(tf.keras.applications, base_model_name) # e.g., MobileNetV3Small
        print(f"Using {base_model_name}")
    else:
        raise ValueError("Unsupported base_model_name")

    base_model = base_model_cls(weights='imagenet', include_top=False, input_shape=input_shape)

    if freeze_base:
        base_model.trainable = False
        print(f"{base_model_name} base frozen.")
    else:
        base_model.trainable = True
        print(f"{base_model_name} base trainable (fine-tuning).")

    inputs = Input(shape=input_shape)
    x = base_model(inputs, training=not freeze_base)
    x = GlobalAveragePooling2D()(x)
    # Rebuild top based on recommendations for EfficientNet/MobileNet
    x = Dropout(0.3)(x) # Increased dropout often recommended
    x = Dense(128, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    outputs = Dense(num_classes, activation='sigmoid' if num_classes == 1 else 'softmax')(x)

    model = Model(inputs=inputs, outputs=outputs)

    # Compile
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3) # Can often start higher for these models
    loss = 'binary_crossentropy' if num_classes == 1 else 'sparse_categorical_crossentropy'
    metrics = ['accuracy',
               tf.keras.metrics.AUC(name='roc_auc'),
               tf.keras.metrics.Precision(name='precision'),
               tf.keras.metrics.Recall(name='recall')]

    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    print(f"{base_model_name} + Dense model built.")
    model.summary()
    return model


# --- 2.2 Model Training ---
NUM_CLASSES = 1 # Binary: Stress (1) vs NoStress (0)
INPUT_SHAPE = (IMG_SIZE[0], IMG_SIZE[1], 3)

# Callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)
# Reduce LR on plateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6, verbose=1)
# Model checkpointing (optional but good practice)
resnet_checkpoint_path = str(MODEL_DIR / "resnet_stress_best.keras") # Use .keras format
efficientnet_checkpoint_path = str(MODEL_DIR / "efficientnet_stress_best.keras")

resnet_checkpoint = ModelCheckpoint(resnet_checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max', verbose=1)
efficientnet_checkpoint = ModelCheckpoint(efficientnet_checkpoint_path, monitor='val_accuracy', save_best_only=True, mode='max', verbose=1)


# --- Train Model 1: ResNet50 + Dense ---
print("\n--- Training ResNet50 Model ---")
# Start with frozen base
model_resnet = build_resnet_gru_model(INPUT_SHAPE, NUM_CLASSES, freeze_base=True)

# Calculate steps per epoch if using generators
steps_per_epoch_train = None
validation_steps = None
if isinstance(train_generator, tf.keras.preprocessing.image.DirectoryIterator) or isinstance(train_generator, tf.keras.preprocessing.image.NumpyArrayIterator):
    steps_per_epoch_train = len(train_generator)
    validation_steps = len(validation_generator)
elif isinstance(train_generator, tf.data.Dataset):
    # For tf.data, steps_per_epoch might not be needed if dataset is not repeating indefinitely
    # However, Keras expects it for certain callbacks or progress bars.
    # Estimate based on data size / batch size if possible
    if hasattr(X_train, 'shape'):
        steps_per_epoch_train = int(np.ceil(X_train.shape[0] / BATCH_SIZE))
        validation_steps = int(np.ceil(X_val.shape[0] / BATCH_SIZE))


# Check if generators are valid before training
if train_generator is None or validation_generator is None:
    print("Error: Data generators are not initialized. Cannot train.")
else:
    history_resnet = model_resnet.fit(
        train_generator,
        epochs=EPOCHS,
        validation_data=validation_generator,
        steps_per_epoch=steps_per_epoch_train,
        validation_steps=validation_steps,
        callbacks=[early_stopping, reduce_lr, resnet_checkpoint],
        verbose=1
    )
    print("ResNet Training Finished.")
    # Optional: Fine-tuning phase
    # model_resnet.trainable = True # Unfreeze base
    # ... recompile with lower LR ...
    # ... train for few more epochs ...


# --- Train Model 2: EfficientNetB0 ---
print("\n--- Training EfficientNetB0 Model ---")
model_efficientnet = build_efficientnet_model(INPUT_SHAPE, NUM_CLASSES, base_model_name='B0', freeze_base=True)

if train_generator is None or validation_generator is None:
     print("Error: Data generators are not initialized. Cannot train.")
else:
    history_efficientnet = model_efficientnet.fit(
        train_generator,
        epochs=EPOCHS,
        validation_data=validation_generator,
        steps_per_epoch=steps_per_epoch_train,
        validation_steps=validation_steps,
        callbacks=[early_stopping, reduce_lr, efficientnet_checkpoint],
        verbose=1
    )
    print("EfficientNet Training Finished.")


# --- 2.3 Model Evaluation ---
def evaluate_model(model, model_name, test_data, history):
    print(f"\n--- Evaluating {model_name} ---")

    # Plot training history
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Val Accuracy')
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title(f'{model_name} Training History')
    plt.xlabel('Epoch')
    plt.legend()

    plt.subplot(1, 2, 2)
    if 'roc_auc' in history.history:
      plt.plot(history.history['roc_auc'], label='Train ROC AUC')
      plt.plot(history.history['val_roc_auc'], label='Val ROC AUC')
      plt.title(f'{model_name} ROC AUC')
      plt.xlabel('Epoch')
      plt.legend()

    plt.tight_layout()
    plt.show()

    # Evaluate on Test Set
    print("Evaluating on Test Set...")
    if isinstance(test_data, tf.data.Dataset):
      results = model.evaluate(test_data, verbose=1)
      test_loss = results[0]
      test_accuracy = results[1]
      test_roc_auc = results[2] # Assuming AUC is the 3rd metric compiled
      # Getting F1 requires predictions
      y_pred_prob = model.predict(test_data)
      y_pred = (y_pred_prob > 0.5).astype(int).flatten()
      # Extracting true labels from tf.data.Dataset is a bit tricky
      # It's often easier to evaluate using numpy arrays if possible, or iterate through dataset
      y_true = np.concatenate([y for x, y in test_data], axis=0)

    elif isinstance(test_data, tf.keras.preprocessing.image.NumpyArrayIterator):
        results = model.evaluate(test_data, verbose=1)
        test_loss = results[0]
        test_accuracy = results[1]
        test_roc_auc = results[2] # Assuming AUC is the 3rd metric compiled
        y_pred_prob = model.predict(test_data)
        y_pred = (y_pred_prob > 0.5).astype(int).flatten()
        y_true = test_data.labels[:len(y_pred)] # Get true labels from generator

    else: # Assume test_data is (X_test, y_test) numpy arrays
        test_loss, test_accuracy, test_roc_auc, test_precision, test_recall = model.evaluate(test_data[0], test_data[1], verbose=1)
        y_pred_prob = model.predict(test_data[0])
        y_pred = (y_pred_prob > 0.5).astype(int).flatten()
        y_true = test_data[1]


    test_f1 = f1_score(y_true, y_pred)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Test ROC AUC: {test_roc_auc:.4f}")
    print(f"Test F1-Score: {test_f1:.4f}")

    # Classification Report
    print("\nClassification Report:")
    # Ensure y_true and y_pred are correctly aligned and sized
    print(classification_report(y_true, y_pred, target_names=['NoStress (0)', 'Stress (1)']))

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['NoStress', 'Stress'], yticklabels=['NoStress', 'Stress'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title(f'{model_name} Confusion Matrix')
    plt.show()

    return {'loss': test_loss, 'accuracy': test_accuracy, 'roc_auc': test_roc_auc, 'f1': test_f1}


# Load best weights before evaluation
if os.path.exists(resnet_checkpoint_path):
    print("Loading best ResNet weights...")
    model_resnet.load_weights(resnet_checkpoint_path)
if os.path.exists(efficientnet_checkpoint_path):
    print("Loading best EfficientNet weights...")
    model_efficientnet.load_weights(efficientnet_checkpoint_path)

# Choose the appropriate test data format
test_data_eval = (X_test, y_test) if 'X_test' in locals() else test_generator

if 'model_resnet' in locals() and 'history_resnet' in locals():
  results_resnet = evaluate_model(model_resnet, "ResNet50", test_data_eval, history_resnet)
else:
  print("ResNet model or history not available for evaluation.")

if 'model_efficientnet' in locals() and 'history_efficientnet' in locals():
  results_efficientnet = evaluate_model(model_efficientnet, "EfficientNetB0", test_data_eval, history_efficientnet)
else:
  print("EfficientNet model or history not available for evaluation.")

# --- 2.4 Hyperparameter Optimization (Conceptual) ---
# Using Optuna or Ray Tune involves defining an objective function that trains
# a model with given hyperparameters and returns a metric (e.g., validation accuracy).
# Example structure with Optuna:
"""
def objective(trial):
    # Suggest hyperparameters
    lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
    dropout_rate = trial.suggest_float('dropout', 0.1, 0.6)
    gru_units = trial.suggest_categorical('gru_units', [64, 128, 256])
    # ... other hyperparameters ...

    # Build model with suggested params
    model = build_resnet_gru_model(INPUT_SHAPE, NUM_CLASSES, gru_units=gru_units, ...)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), ...)

    # Train model (potentially smaller dataset/epochs for speed)
    history = model.fit(...)

    # Return the metric to optimize (e.g., max validation accuracy)
    return history.history['val_accuracy'][-1]

# study = optuna.create_study(direction='maximize')
# study.optimize(objective, n_trials=50) # Number of trials
# print("Best trial:", study.best_trial.params)
"""
print("\nHyperparameter Optimization (Optuna/Ray Tune) conceptualized.")
print("Implementation requires defining an objective function and running the study.")

# --- 2.5 Model Export to ONNX and TensorRT Acceleration ---
# Ensure tf2onnx is installed: pip install tf2onnx
# ONNX Runtime (CPU or GPU): pip install onnxruntime or onnxruntime-gpu
import tf2onnx
import onnxruntime as ort

def export_to_onnx(model, model_name, output_dir):
    """Exports a Keras model to ONNX format."""
    onnx_model_path = str(output_dir / f"{model_name}.onnx")
    print(f"Exporting {model_name} to ONNX: {onnx_model_path}")

    try:
        # Convert the model
        # spec = (tf.TensorSpec(model.input.shape, model.input.dtype, name="input"),) # Fails with generator sometimes
        # Use concrete function
        # Need to define input signature explicitly
        input_signature = [tf.TensorSpec([None, INPUT_SHAPE[0], INPUT_SHAPE[1], INPUT_SHAPE[2]], tf.float32, name='input_image')]

        model_proto, external_tensor_storage = tf2onnx.convert.from_keras(
            model, input_signature=input_signature, opset=13, output_path=onnx_model_path
        )
        print(f"Model successfully converted to ONNX format at {onnx_model_path}")

        # Optional: Verify the ONNX model
        print("Verifying ONNX model...")
        ort_session = ort.InferenceSession(onnx_model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) # Use GPU if available
        print(f"ONNX model loaded successfully with providers: {ort_session.get_providers()}")
        input_name = ort_session.get_inputs()[0].name
        output_name = ort_session.get_outputs()[0].name
        print(f"Input name: {input_name}, Output name: {output_name}")

        # Test with a dummy input
        if 'X_test' in locals() and X_test.shape[0] > 0:
           dummy_input = X_test[0:1] # Batch size of 1
        else:
           dummy_input = np.random.rand(1, INPUT_SHAPE[0], INPUT_SHAPE[1], INPUT_SHAPE[2]).astype(np.float32)

        ort_inputs = {input_name: dummy_input}
        ort_outs = ort_session.run([output_name], ort_inputs)
        print(f"ONNX model prediction (test): {ort_outs[0]}")

    except Exception as e:
        print(f"Error during ONNX conversion or verification for {model_name}: {e}")
        onnx_model_path = None # Indicate failure

    return onnx_model_path

# Select the best performing model based on evaluation metrics
# For demo, let's assume EfficientNet performed better or is preferred for mobile
# In a real scenario, compare results_resnet and results_efficientnet
chosen_model = model_efficientnet # Example choice
chosen_model_name = "efficientnet_stress_model"
# Or choose based on metrics:
# best_model_name = "ResNet50" if results_resnet['accuracy'] > results_efficientnet['accuracy'] else "EfficientNetB0"

if 'chosen_model' in locals():
    onnx_path = export_to_onnx(chosen_model, chosen_model_name, ONNX_MODEL_DIR)
else:
    print("No model was successfully trained or selected for ONNX export.")
    onnx_path = None

# TensorRT Acceleration (Conceptual - Requires NVIDIA GPU, CUDA, cuDNN, TensorRT installed)
# Typically done during deployment on the target NVIDIA hardware.
# Process:
# 1. Build a TensorRT engine from the ONNX model using `trtexec` command-line tool or TensorRT Python API.
#    `trtexec --onnx=model.onnx --saveEngine=model.trt --fp16` (Example command)
# 2. Use NVIDIA Triton Inference Server or custom C++/Python code with the TensorRT runtime to load the .trt engine for inference.
print("\nTensorRT Acceleration:")
print("1. Export the ONNX model (done above).")
print("2. Use `trtexec` or TensorRT API on target NVIDIA hardware to build a .trt engine.")
print("   Example: trtexec --onnx={onnx_path} --saveEngine={ONNX_MODEL_DIR / (chosen_model_name + '.trt')} --fp16")
print("3. Deploy the .trt engine using Triton Inference Server or custom TensorRT runtime application.")

# Final Trained Models (Keras format)
final_resnet_path = str(MODEL_DIR / "resnet_stress_final.keras")
final_efficientnet_path = str(MODEL_DIR / "efficientnet_stress_final.keras")

if 'model_resnet' in locals(): model_resnet.save(final_resnet_path)
if 'model_efficientnet' in locals(): model_efficientnet.save(final_efficientnet_path)
print(f"Final Keras models saved to {MODEL_DIR}")

In [None]:
# @title << 3. Model Development – Shift Optimization >>

# --- 3.1 Shift Optimization Engine using Google OR-Tools (CP-SAT) ---

# Define Inputs (Example Data)
num_doctors = 5
num_days = 7
num_shifts = 3 # e.g., Morning (0), Afternoon (1), Night (2)

# Doctor availability (doctor, day): 1 if available, 0 otherwise
# Could be more complex (e.g., preferred days off)
availability = np.random.randint(0, 2, size=(num_doctors, num_days))
# For demo, assume everyone is available every day initially
availability = np.ones((num_doctors, num_days), dtype=int)

# Doctor skills (doctor, skill): 1 if possesses skill
# Example skills: 0=General, 1=Surgery, 2=Cardiology
num_skills = 3
skills = np.random.randint(0, 2, size=(num_doctors, num_skills))
# Ensure at least one doctor per skill for feasibility
skills[0, 1] = 1 # Doctor 0 has Surgery skill
skills[1, 2] = 1 # Doctor 1 has Cardiology skill
skills[:, 0] = 1 # All doctors have General skill

# Shift requirements (day, shift, skill): Min number of doctors needed
# Making this up for the example
shift_skill_requirements = np.zeros((num_days, num_shifts, num_skills), dtype=int)
shift_skill_requirements[:, :, 0] = 1 # Need 1 general doctor per shift
shift_skill_requirements[:, 0, 1] = 1 # Need 1 surgeon in the morning shift
shift_skill_requirements[:, 1, 2] = 1 # Need 1 cardiologist in the afternoon

# Workload Logs (Conceptual - used to estimate future stress or load)
# Example: workload[doctor, day] = hours worked or complexity score
workload_logs = np.random.uniform(6, 10, size=(num_doctors, num_days))

# Stress Levels (Input from the Facial Stress Detection Model)
# Example: stress_levels[doctor, day] = predicted average stress score (0-1)
# This would be updated dynamically based on real-time predictions
current_stress_levels = np.random.uniform(0.1, 0.8, size=num_doctors) # Current stress for each doctor

# Constraints (Legal, Policy, Preferences)
max_shifts_per_week = 5
min_shifts_per_week = 2 # Ensure fairness
max_consecutive_shifts = 2
min_rest_between_shifts = 2 # E.g., must have 2 shifts off (16 hours if shifts are 8h) after a shift
max_stress_threshold = 0.7 # Doctors exceeding this should ideally get lighter load or rest

# --- CP-SAT Model ---
model = cp_model.CpModel()

# Decision Variables: shifts[(d, day, s)] = 1 if doctor d works shift s on day, 0 otherwise
shifts = {}
for d in range(num_doctors):
    for day in range(num_days):
        for s in range(num_shifts):
            shifts[(d, day, s)] = model.NewBoolVar(f'shift_d{d}_day{day}_s{s}')

# --- Constraints ---

# 1. Availability: Doctor cannot work if unavailable (if availability matrix is used)
# for d in range(num_doctors):
#     for day in range(num_days):
#         if availability[d, day] == 0:
#             for s in range(num_shifts):
#                 model.Add(shifts[(d, day, s)] == 0)

# 2. Shift Skill Requirements: Ensure enough skilled doctors are assigned
for day in range(num_days):
    for s in range(num_shifts):
        for skill in range(num_skills):
            required_count = shift_skill_requirements[day, s, skill]
            if required_count > 0:
                model.Add(sum(shifts[(d, day, s)] * skills[d, skill] for d in range(num_doctors)) >= required_count)

# 3. One Shift Per Doctor Per Day (at most): A doctor can only work one shift on a given day
for d in range(num_doctors):
    for day in range(num_days):
        model.Add(sum(shifts[(d, day, s)] for s in range(num_shifts)) <= 1)

# 4. Weekly Workload Limits: Min/Max number of shifts per doctor per week
for d in range(num_doctors):
    total_shifts_worked = sum(shifts[(d, day, s)] for day in range(num_days) for s in range(num_shifts))
    model.Add(total_shifts_worked >= min_shifts_per_week)
    model.Add(total_shifts_worked <= max_shifts_per_week)

# 5. Max Consecutive Shifts: (Slightly more complex constraint)
# for d in range(num_doctors):
#     for day in range(num_days - max_consecutive_shifts):
#         for s in range(num_shifts):
#             # This needs careful indexing if shifts wrap around days/nights
#             # Simplified: Check within a day first
#             # TODO: Implement robust consecutive shift constraint (might need helper variables)
            pass # Placeholder for brevity


# 6. Minimum Rest Between Shifts: (Also complex, depends on shift timing)
# Similar to consecutive shifts, requires careful indexing across days.
# Example: If doctor works day 'd' shift 's', cannot work day 'd' shift 's+1', 's+2' (if < min_rest)
# and potentially day 'd+1' shift 0 etc.
# TODO: Implement robust min rest constraint

# 7. Stress Constraint (Example): Doctors with high stress should not work night shifts (shift 2)
for d in range(num_doctors):
    if current_stress_levels[d] > max_stress_threshold:
        print(f"Applying stress constraint for Doctor {d} (Stress: {current_stress_levels[d]:.2f})")
        for day in range(num_days):
            model.Add(shifts[(d, day, 2)] == 0) # Cannot work night shift (index 2)

# --- Objective Function ---
# Example Objective: Minimize total number of shifts assigned (implies efficiency)
# OR: Maximize fairness (e.g., minimize variance in shifts worked)
# OR: Minimize assignments for high-stress doctors (weighted)

# Objective: Minimize sum of (stress_level * shift_assignment)
# This encourages assigning fewer shifts to stressed doctors. Use quadratic term? or linear approx
objective_terms = []
for d in range(num_doctors):
    for day in range(num_days):
        for s in range(num_shifts):
            # Weight shift assignment by doctor's current stress level
            # Scale stress to make it a meaningful cost
            stress_cost = int(current_stress_levels[d] * 100) # Scale 0-1 stress to 0-100 cost
            objective_terms.append(shifts[(d, day, s)] * stress_cost)

model.Minimize(sum(objective_terms))

# --- Solve ---
solver = cp_model.CpSolver()
solver.parameters.max_time_in_seconds = 60.0 # Set time limit
solver.parameters.log_search_progress = True
status = solver.Solve(model)

# --- Process Results ---
if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE:
    print(f'\nSolution found (Status: {solver.StatusName(status)})')
    print(f'Objective value (Stress-weighted cost): {solver.ObjectiveValue()}')

    schedule = np.zeros((num_doctors, num_days, num_shifts), dtype=int)
    total_shifts = np.zeros(num_doctors, dtype=int)

    print("\nGenerated Schedule:")
    header = "Doctor | " + " | ".join([f"Day {day}" for day in range(num_days)])
    print(header)
    print("-" * len(header))

    for d in range(num_doctors):
        row = f"  {d: <4} | "
        for day in range(num_days):
            shift_str = " --- "
            for s in range(num_shifts):
                if solver.Value(shifts[(d, day, s)]) == 1:
                    schedule[d, day, s] = 1
                    total_shifts[d] += 1
                    shift_str = f"  S{s} " # Indicate Shift 0, 1, or 2
                    break # Assumes max 1 shift per day constraint holds
            row += shift_str + " | "
        row += f" (Total: {total_shifts[d]}, Stress: {current_stress_levels[d]:.2f})"
        print(row)

    print("\nVerifying Skill Requirements (Example Day 0):")
    day_check = 0
    for s in range(num_shifts):
        for skill in range(num_skills):
            required = shift_skill_requirements[day_check, s, skill]
            if required > 0:
                assigned_count = sum(solver.Value(shifts[(d, day_check, s)]) * skills[d, skill] for d in range(num_doctors))
                print(f"Day {day_check}, Shift {s}, Skill {skill}: Required={required}, Assigned={assigned_count} {'(OK)' if assigned_count >= required else '(PROBLEM!)'}")

elif status == cp_model.INFEASIBLE:
    print('\nSolver returned INFEASIBLE. Constraints cannot be satisfied.')
    print("Check constraints, requirements, and availability.")
    # TODO: Implement fallback logic or constraint relaxation here
elif status == cp_model.MODEL_INVALID:
    print('\nSolver returned MODEL_INVALID. Check model definition.')
else:
    print(f'\nSolver returned status: {solver.StatusName(status)}')


# --- 3.2 Reinforcement Learning Agent (Conceptual) ---
# RL for adaptive scheduling is complex.
# - State: Current schedule, predicted stress levels, workload backlog, doctor availability/preferences.
# - Action: Swap shifts, assign open shifts, modify schedule based on rules/predictions.
# - Reward: Negative of the CP-SAT objective (e.g., -total_stress_cost), penalties for constraint violations, bonuses for smooth operations.
# - Algorithm: Actor-Critic (e.g., PPO, A2C) or DQN if action space is discrete and manageable.
# - Training: Requires a simulator environment that models hospital operations and stress accumulation. This is non-trivial to build.
print("\nReinforcement Learning for Shift Optimization (Conceptual):")
print("- State: Schedule, Stress Levels, Workload, Availability")
print("- Action: Swap/Assign Shifts")
print("- Reward: -(Stress Cost), Constraint Penalties")
print("- Requires a simulation environment for training.")

# --- 3.3 Fallback Rule-Based Logic ---
def apply_fallback_rules(schedule, stress_levels, threshold):
    """Simple rule-based adjustments for safety."""
    new_schedule = schedule.copy()
    print("\nApplying Fallback Rules...")
    for d in range(num_doctors):
        if stress_levels[d] > threshold:
            # Rule: If highly stressed, try to remove demanding shifts (e.g., night shifts)
            # This example just prints a warning, but could modify 'new_schedule'
            if np.sum(new_schedule[d, :, 2]) > 0: # Check if assigned any night shifts
                 print(f"Fallback Rule Warning: Doctor {d} has high stress ({stress_levels[d]:.2f}) and is assigned night shifts.")
                 # In reality: Attempt to find a swap or remove the shift, potentially triggering re-optimization
    print("Fallback rule check complete.")
    return new_schedule

if 'schedule' in locals():
    final_schedule = apply_fallback_rules(schedule, current_stress_levels, STRESS_THRESHOLD)
else:
    print("No initial schedule generated, skipping fallback rules.")


# --- 3.4 Synthetic Training Data (NSPLib etc.) ---
# For training RL or evaluating different optimization strategies, benchmark datasets are useful.
# - NSPLib (Nurse Scheduling Problem Library): Standard benchmark for nurse rostering.
# - Kaggle datasets related to employee scheduling or rostering problems.
# These often provide realistic constraints and instance sizes.
print("\nSynthetic Data Generation:")
print("- Use benchmarks like NSPLib for realistic scheduling problems.")
print("- Generate data by varying num_doctors, num_days, constraints, skill mixes.")

In [None]:
# @title << 4. System Architecture & Microservices >>

# --- 4.1 System Flowchart (Mermaid) ---
# Using Mermaid syntax within a markdown cell for visualization

"""
%%mermaid
graph LR
    subgraph Edge Device (e.g., Camera at Workstation)
        A[Camera Input] --> B(Face Detection / Preprocessing);
        B --> C{Feature Encoding};
        C --> D[Send Features/Cropped Face];
    end

    subgraph Cloud Backend
        D --> E(API Gateway / Load Balancer);
        E --> F[FastAPI Inference Service];
        F -- Request Features --> C;
        F -- Image/Features --> G((Stress Detection Model ONNX/TRT));
        G -- Stress Score --> F;
        F -- Store Stress Log --> H{Time-Series Aggregator};
        H -- Aggregated Stress --> I[SQL Server Database];
        H -- Trigger? --> J(FastAPI Shift Optimizer Service);
        J -- Request Data (Stress, Workload) --> I;
        J -- Request Constraints --> K(Configuration/Rules DB);
        J --> L((OR-Tools/RL Engine));
        L -- Optimized Schedule --> J;
        J -- Store Schedule --> I;
    end

    subgraph Frontend / Users
        M(Flutter Mobile/Web App) --> E;
        M -- View Dashboards --> I;
        M -- Receive Alerts (WebSocket/Polling) --> F;
        M -- Request Manual Optimization --> J;
        N(Admin Interface) --> E;
        N -- Manage Settings --> K;
        N -- View Reports --> I;
    end

    %% Styling (Optional)
    classDef edge fill:#f9f,stroke:#333,stroke-width:2px;
    classDef cloud fill:#ccf,stroke:#333,stroke-width:2px;
    classDef user fill:#cfc,stroke:#333,stroke-width:2px;
    class A,B,C,D edge;
    class E,F,G,H,I,J,K,L cloud;
    class M,N user;
"""

# --- 4.2 Database Schema (SQLAlchemy - Conceptual) ---

DB_URL = "sqlite:///hospital_schedule.db" # Example using SQLite for demo; replace with SQL Server connection string
# SQL Server example: "mssql+pyodbc://user:password@server/database?driver=ODBC+Driver+17+for+SQL+Server"
engine = create_engine(DB_URL)
metadata = MetaData()

# Doctors Table
doctors_table = Table('doctors', metadata,
    Column('doctor_id', Integer, primary_key=True),
    Column('name', String(100), nullable=False),
    Column('email', String(100), unique=True), # For notifications/login
    Column('specialty', String(100)),
    Column('current_avg_stress', Float, default=0.0), # Updated periodically
    Column('created_at', DateTime, default=datetime.datetime.utcnow)
)

# Skills Table (if needed for complex skill management)
# skills_table = Table('skills', metadata, ...)
# doctor_skills_table = Table('doctor_skills', metadata, ...) # Many-to-many mapping

# Stress Logs Table
stress_logs_table = Table('stress_logs', metadata,
    Column('log_id', Integer, primary_key=True),
    Column('doctor_id', Integer, ForeignKey('doctors.doctor_id'), nullable=False),
    Column('timestamp', DateTime, default=datetime.datetime.utcnow, index=True),
    Column('raw_stress_score', Float, nullable=False), # Output from the model
    Column('smoothed_stress_score', Float), # Optional: after temporal smoothing
    Column('image_ref', String(255)), # Optional reference to anonymized image/features for audit (use with caution!)
    Column('source_device', String(50)) # e.g., 'Workstation-1A'
)

# Schedules Table
schedules_table = Table('schedules', metadata,
    Column('schedule_id', Integer, primary_key=True),
    Column('doctor_id', Integer, ForeignKey('doctors.doctor_id'), nullable=False),
    Column('schedule_date', DateTime, nullable=False, index=True),
    Column('shift_type', Integer, nullable=False), # 0=Morning, 1=Afternoon, 2=Night
    Column('assigned_at', DateTime, default=datetime.datetime.utcnow),
    Column('schedule_version', Integer, default=1) # To track optimization runs
)

# Workload Logs Table (Example)
workload_logs_table = Table('workload_logs', metadata,
     Column('log_id', Integer, primary_key=True),
     Column('doctor_id', Integer, ForeignKey('doctors.doctor_id'), nullable=False),
     Column('log_date', DateTime, nullable=False),
     Column('hours_worked', Float),
     Column('tasks_completed', Integer),
     # Other relevant workload metrics
)


# Create tables in the database
print("\nCreating database schema (if it doesn't exist)...")
metadata.create_all(engine)
print("Database schema defined.")

# --- 4.3 FastAPI Microservice (Inference & Basic Scheduling Trigger) ---

app = fastapi.FastAPI(title="Doctor Stress & Schedule API")

# --- Pydantic Models for API Requests/Responses ---
class ImageInput(BaseModel):
    image_b64: str # Base64 encoded image string
    doctor_id: int
    source_device: str | None = None

class StressPredictionResponse(BaseModel):
    doctor_id: int
    stress_probability: float
    is_stressed: bool
    timestamp: datetime.datetime
    message: str

class OptimizeScheduleRequest(BaseModel):
    target_date: datetime.date
    force_run: bool = False # Option to force re-optimization

class OptimizeScheduleResponse(BaseModel):
    status: str
    message: str
    schedule_version: int | None = None # Version ID of the generated schedule

# --- Load ONNX Model for Inference ---
onnx_model = None
onnx_session = None
onnx_input_name = None
onnx_output_name = None

if onnx_path and os.path.exists(onnx_path):
    try:
        onnx_session = ort.InferenceSession(onnx_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
        onnx_input_name = onnx_session.get_inputs()[0].name
        onnx_output_name = onnx_session.get_outputs()[0].name
        print(f"ONNX stress detection model loaded successfully for API. Input: {onnx_input_name}, Output: {onnx_output_name}")
    except Exception as e:
        print(f"Error loading ONNX model for API: {e}")
        onnx_session = None
else:
    print("ONNX model file not found or not generated. Inference endpoint will be disabled.")

# --- Helper Functions ---
def decode_image(b64_string):
    image_bytes = b64decode(b64_string)
    image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
    return image

def preprocess_for_onnx(image, target_size):
    # Perform face detection if needed (or assume input is already cropped)
    face_img, _ = detect_face(image) # Use BGR image
    if face_img is None or face_img.size == 0:
        print("No face detected in API input.")
        # Option 1: Return error
        # raise ValueError("No face detected")
        # Option 2: Try processing the whole image (might be inaccurate)
        face_img = image
        # Option 3: Return a default "low confidence" prediction
        # return None

    # Resize and normalize (same as training)
    processed_face = preprocess_image(face_img, target_size) # Returns float32 [0,1]

    # Add batch dimension and ensure correct type
    input_tensor = np.expand_dims(processed_face, axis=0).astype(np.float32)
    return input_tensor


# --- API Endpoints ---
@app.get("/")
async def read_root():
    return {"message": "Welcome to the Doctor Stress Detection and Shift Optimization API"}

@app.post("/predict_stress", response_model=StressPredictionResponse)
async def predict_stress(input_data: ImageInput):
    """Receives Base64 encoded image, performs inference, returns stress probability."""
    if onnx_session is None or onnx_input_name is None:
        raise fastapi.HTTPException(status_code=503, detail="Stress detection model not loaded.")

    try:
        start_time = time.time()
        image_bgr = decode_image(input_data.image_b64)

        if image_bgr is None or image_bgr.size == 0:
             raise fastapi.HTTPException(status_code=400, detail="Invalid image data.")

        # Preprocess image for the ONNX model
        input_tensor = preprocess_for_onnx(image_bgr, IMG_SIZE) # IMG_SIZE defined earlier
        if input_tensor is None:
             # Handle case where preprocessing failed (e.g., no face detected)
             return StressPredictionResponse(
                 doctor_id=input_data.doctor_id,
                 stress_probability=-1.0, # Indicate failure/low confidence
                 is_stressed=False,
                 timestamp=datetime.datetime.utcnow(),
                 message="Prediction failed (e.g., face not detected)."
             )


        # Run inference
        ort_inputs = {onnx_input_name: input_tensor}
        ort_outs = onnx_session.run([onnx_output_name], ort_inputs)
        stress_prob = float(ort_outs[0][0][0]) # Output is likely [[prob]]

        is_stressed = stress_prob > STRESS_THRESHOLD
        end_time = time.time()
        processing_time = end_time - start_time

        # --- Store stress log in DB (Asynchronous task recommended for production) ---
        try:
            conn = engine.connect()
            ins = stress_logs_table.insert().values(
                doctor_id=input_data.doctor_id,
                timestamp=datetime.datetime.utcnow(),
                raw_stress_score=stress_prob,
                source_device=input_data.source_device
            )
            conn.execute(ins)
            conn.commit() # Use commit method on Connection object
            conn.close() # Close connection
            db_log_msg = "Stress log saved."
        except Exception as db_err:
            print(f"Database Error logging stress: {db_err}")
            db_log_msg = "Failed to save stress log."
            # Consider retry logic or logging to a file as fallback

        return StressPredictionResponse(
            doctor_id=input_data.doctor_id,
            stress_probability=stress_prob,
            is_stressed=is_stressed,
            timestamp=datetime.datetime.utcnow(),
            message=f"Prediction successful ({processing_time:.3f}s). {db_log_msg}"
        )

    except ValueError as ve: # Catch specific errors like no face detected
        raise fastapi.HTTPException(status_code=400, detail=str(ve))
    except Exception as e:
        print(f"Error during prediction: {e}")
        raise fastapi.HTTPException(status_code=500, detail=f"Internal server error during prediction: {e}")


@app.post("/optimize_schedule", response_model=OptimizeScheduleResponse)
async def trigger_schedule_optimization(request: OptimizeScheduleRequest):
    """Triggers the shift optimization process (placeholder)."""
    print(f"Received request to optimize schedule for date: {request.target_date}")
    # --- In a real system: ---
    # 1. Check if optimization is already running or recently completed.
    # 2. Fetch required data (current stress levels, workload, availability) from the database.
    # 3. Fetch constraints from config or DB.
    # 4. Run the OR-Tools solver (potentially as a background task using Celery, RQ, or FastAPI's BackgroundTasks).
    # 5. Store the new schedule in the database with a new version number.
    # 6. Handle infeasible solutions (alert admin, use fallback).

    # --- Placeholder Implementation ---
    # Simulate running the optimization logic defined earlier
    # You would fetch real data here instead of using the example globals
    print("Simulating optimization run...")
    # Re-run the OR-Tools part (or call a dedicated function)
    # This is synchronous here, make async in production
    try:
        # --- Fetch current average stress from DB (Example) ---
        # This part needs refinement - should likely average recent stress logs
        avg_stress_data = {}
        conn = engine.connect()
        select_stmt = db.select(doctors_table.c.doctor_id, doctors_table.c.current_avg_stress)
        results = conn.execute(select_stmt)
        for row in results:
             # Use tuple access if using older SQLAlchemy or check row mapping
             try:
                 avg_stress_data[row[0]] = row[1] # Assuming column order
             except AttributeError: # If row is RowProxy or similar
                 avg_stress_data[row.doctor_id] = row.current_avg_stress

        conn.close()
        # Update the `current_stress_levels` used by the solver
        simulated_current_stress = np.array([avg_stress_data.get(d, 0.3) for d in range(num_doctors)]) # Default stress if not found
        print("Fetched/Simulated Stress Levels for Optimization:", simulated_current_stress)

        # --- Run OR-Tools Solver (using fetched data) ---
        # This would ideally be a function call:
        # success, objective_value, final_schedule, schedule_version = run_or_tools_solver(simulated_current_stress, ...)
        # For demo, we just print success message
        success = True # Assume it worked for the demo
        simulated_version = random.randint(100, 999)

        if success:
            # --- Store the generated schedule in the DB ---
            # This needs the actual 'final_schedule' array from solver
            # Loop through final_schedule and insert into 'schedules_table'
            # ... (DB insertion logic) ...
            print("Placeholder: Schedule would be saved to DB here.")

            return OptimizeScheduleResponse(
                status="Success",
                message=f"Shift optimization completed successfully for {request.target_date}. Schedule Version: {simulated_version}",
                schedule_version=simulated_version
            )
        else:
             return OptimizeScheduleResponse(
                status="Failed",
                message=f"Shift optimization failed or was infeasible for {request.target_date}.",
                schedule_version=None
            )

    except Exception as e:
        print(f"Error during schedule optimization trigger: {e}")
        raise fastapi.HTTPException(status_code=500, detail=f"Internal server error during optimization: {e}")


# --- Run FastAPI app using Uvicorn and ngrok (for Colab) ---
# Function to run FastAPI in a separate thread and expose via ngrok
def run_fastapi():
    nest_asyncio.apply() # Allow running uvicorn in Colab's event loop
    # Set ngrok authtoken (replace with your token if needed, get from ngrok dashboard)
    # ngrok.set_auth_token("YOUR_NGROK_AUTHTOKEN") # Optional, usually needed for more features/longer sessions
    port = 8000
    # Kill existing ngrok tunnels if any
    ngrok.kill()
    # Open ngrok tunnel
    public_url = ngrok.connect(port)
    print(f"FastAPI running on: {public_url}")
    # Start uvicorn server
    uvicorn.run(app, host="0.0.0.0", port=port, log_level="info")

# Start FastAPI in a background thread so Colab notebook execution can continue
print("\nStarting FastAPI server in background thread...")
fastapi_thread = threading.Thread(target=run_fastapi, daemon=True)
fastapi_thread.start()
# Give it a few seconds to start up
time.sleep(5)
print("FastAPI setup complete. Public URL should be printed above.")
# Note: The server will run until the Colab runtime is disconnected or the thread is stopped.

In [None]:
# @title << 5. Frontend Development with Flutter (Conceptual Outline) >>

# Flutter is a UI toolkit for building natively compiled applications for mobile, web, and desktop from a single codebase.
# We cannot run Flutter code directly in Colab, but we can outline the key components and interactions.

"""
--- Flutter Frontend Overview ---

1.  **Project Setup:**
    * Create a new Flutter project: `flutter create hospital_stress_dashboard`
    * Add dependencies to `pubspec.yaml`:
        * `http` or `dio`: For making REST API calls to the FastAPI backend.
        * `provider` or `flutter_bloc`: For state management.
        * `syncfusion_flutter_charts` or `fl_chart`: For displaying dashboards and graphs.
        * `web_socket_channel`: For real-time alerts (if using WebSockets).
        * `intl`: For date/time formatting.
        * `shared_preferences`: For storing basic user settings or tokens.

2.  **Core Widgets/Screens:**
    * `LoginScreen`: Handles user authentication (e.g., using email/password, OAuth against the backend).
    * `DashboardScreen`: Main screen showing:
        * Overall stress level overview (average, trends).
        * List of doctors with current stress scores (color-coded).
        * Upcoming schedule view (e.g., weekly calendar).
        * Alerts section.
    * `DoctorDetailScreen`: Shows detailed stress history (graph), recent logs, current schedule for a specific doctor.
    * `ScheduleScreen`: Displays the full schedule, possibly with filtering options (by date, doctor, shift). Allows admins to trigger manual optimization.
    * `SettingsScreen`: App settings, notification preferences.

3.  **Backend Integration (Services/Repositories):**
    * `ApiService` class:
        * `login(email, password)` -> Returns user info/token.
        * `getStressLevels()` -> Fetches aggregated stress data. Calls `GET /doctors` or a dedicated stats endpoint.
        * `getDoctorStressLog(doctorId)` -> Fetches `GET /stress_logs?doctor_id=...`.
        * `getCurrentSchedule()` -> Fetches `GET /schedules?date=...`.
        * `triggerOptimization(date)` -> Calls `POST /optimize_schedule`.
        * (If Edge processing is simulated/tested): `predictStress(imageBase64, doctorId)` -> Calls `POST /predict_stress`.
    * Use `http` or `dio` package to make calls to the FastAPI backend URL (the ngrok URL during development/testing in Colab). Handle responses, errors, and authentication headers (e.g., Bearer token).

4.  **State Management (`Provider` Example):**
    * `DoctorProvider`: Manages the list of doctors and their current stress states.
    * `ScheduleProvider`: Manages the current schedule data.
    * `AuthProvider`: Manages user authentication state.
    * Widgets listen to providers using `Consumer` or `context.watch` to rebuild when data changes.

5.  **Real-time Alerts:**
    * **Option 1 (WebSockets):**
        * Backend (FastAPI) needs WebSocket endpoint (`@app.websocket("/ws")`).
        * When high stress is detected or schedule changes, backend pushes message via WebSocket.
        * Flutter app connects using `web_socket_channel` and listens for messages, updating UI or showing notifications.
    * **Option 2 (Polling):**
        * Flutter app periodically calls API endpoints (e.g., `/stress_levels`, `/notifications`) every X seconds/minutes. Simpler but less efficient.

6.  **UI Components:**
    * Use `ListView.builder` to display lists of doctors/schedule entries.
    * Use chart libraries (`syncfusion_flutter_charts`, `fl_chart`) to render stress trend graphs and schedule visualizations.
    * Use `Card`, `ListTile`, `AppBar`, `BottomNavigationBar` for layout.
    * Implement visual cues for stress levels (e.g., green/yellow/red indicators).

7.  **Example Snippet (Conceptual API Call with `http`):**

    ```dart
    // // conceptual_api_service.dart (Flutter/Dart code)
    // import 'package:http/http.dart' as http;
    // import 'dart:convert';

    // class ApiService {
    //   final String _baseUrl = "YOUR_FASTAPI_BASE_URL"; // e.g., ngrok URL

    //   Future<Map<String, dynamic>> getDoctorStress(int doctorId) async {
    //     final response = await http.get(
    //       Uri.parse('$_baseUrl/stress_logs?doctor_id=$doctorId&latest=true'), // Example endpoint
    //       headers: {'Authorization': 'Bearer YOUR_AUTH_TOKEN'}, // If auth is needed
    //     );

    //     if (response.statusCode == 200) {
    //       return jsonDecode(response.body); // Expecting JSON response
    //     } else {
    //       throw Exception('Failed to load stress data: ${response.statusCode}');
    //     }
    //   }

    //    Future<Map<String, dynamic>> triggerOptimization(DateTime date) async {
    //      final response = await http.post(
    //        Uri.parse('$_baseUrl/optimize_schedule'),
    //        headers: {
    //          'Content-Type': 'application/json; charset=UTF-8',
    //          'Authorization': 'Bearer YOUR_AUTH_TOKEN',
    //        },
    //        body: jsonEncode(<String, dynamic>{
    //          'target_date': date.toIso8601String().substring(0, 10), // Format as YYYY-MM-DD
    //          'force_run': false,
    //        }),
    //      );

    //      if (response.statusCode == 200) {
    //        return jsonDecode(response.body);
    //      } else {
    //       throw Exception('Failed to trigger optimization: ${response.statusCode} ${response.body}');
    //      }
    //    }
    // }
    ```

--- End Conceptual Outline ---
"""
print("Flutter Frontend conceptual outline generated.")
print("Actual implementation requires Flutter SDK and IDE (like VS Code or Android Studio).")

In [None]:
# @title << 6. Gemini 2.0 Integration >>

# Ensure the google-generativeai library is installed and API key is configured.
# Done in Setup cell.

# --- Initialize Gemini Model ---
# Use a model suitable for text generation/analysis (e.g., 'gemini-1.5-flash' or 'gemini-pro')
# Gemini 2.0 isn't an official model name; using available generative models.
try:
    gemini_model = genai.GenerativeModel('gemini-1.5-flash') # Or 'gemini-pro'
    print("Gemini Generative Model initialized.")
except Exception as e:
    print(f"Error initializing Gemini Model: {e}")
    print("Ensure API key is valid and configured.")
    gemini_model = None


# --- Function to Get Insights from Stress Data using Gemini ---
def get_stress_insights_with_gemini(doctor_id, current_stress_score, recent_stress_logs=None):
    """
    Generates natural language insights about a doctor's stress using Gemini.

    Args:
        doctor_id (int): The ID of the doctor.
        current_stress_score (float): The latest predicted stress score (0-1).
        recent_stress_logs (list[dict], optional): List of recent logs [{'timestamp': ..., 'score': ...}].

    Returns:
        str: Natural language insights or an error message.
    """
    if not gemini_model:
        return "Gemini model not available."

    # --- Construct the Prompt ---
    prompt = f"""
    Analyze the stress level of Doctor ID {doctor_id}.

    Current Situation:
    - The doctor's latest predicted stress score is {current_stress_score:.2f} (where 0 is no stress, 1 is high stress).
    - A score above {STRESS_THRESHOLD} is considered potentially high stress requiring attention.

    """

    if recent_stress_logs:
        prompt += "\nRecent Stress Trend (last few readings):\n"
        for log in recent_stress_logs[-5:]: # Show last 5 logs max
             timestamp = log.get('timestamp', 'N/A')
             score = log.get('raw_stress_score', 'N/A')
             if isinstance(timestamp, datetime.datetime):
                 timestamp_str = timestamp.strftime('%Y-%m-%d %H:%M')
             else:
                 timestamp_str = str(timestamp)
             prompt += f"- Timestamp: {timestamp_str}, Score: {score:.2f}\n"
    else:
        prompt += "\nNo recent stress trend data available.\n"

    prompt += f"""
    Based on this information:
    1. Briefly assess the current stress level (e.g., Low, Moderate, High, Critical).
    2. If the stress is high (above {STRESS_THRESHOLD}) or shows a concerning upward trend, suggest potential contributing factors (e.g., workload, consecutive shifts - you can infer possibilities).
    3. Recommend one or two brief, actionable insights or considerations for the scheduling system or supervisor (e.g., 'Consider assigning shorter shifts', 'Monitor closely', 'Ensure adequate rest period before next shift', 'Seems stable').

    Keep the response concise and professional for a hospital operations context. Do not give medical advice.
    """

    # --- Call Gemini API ---
    try:
        print(f"\n--- Calling Gemini for Doctor {doctor_id} (Stress: {current_stress_score:.2f}) ---")
        # print("Prompt:", prompt) # Uncomment to debug prompt

        response = gemini_model.generate_content(prompt)

        # --- Process Response ---
        # Check for safety ratings and blocked prompts if necessary
        # print(response.prompt_feedback)

        if response.candidates and hasattr(response.candidates[0], 'content') and response.candidates[0].content.parts:
            insight = response.text # Access text directly if available
            print(f"Gemini Insight for Doctor {doctor_id}:\n{insight}")
            return insight
        else:
             # Handle cases where response is blocked or empty
             print(f"Gemini Warning: Received no content or response was blocked for Doctor {doctor_id}.")
             # Check prompt_feedback for details if available
             feedback = getattr(response, 'prompt_feedback', None)
             if feedback:
                 print(f"Prompt Feedback: {feedback}")
             return f"Gemini could not generate insights for Doctor {doctor_id}. The request might have been blocked or returned no content."


    except Exception as e:
        print(f"Error calling Gemini API: {e}")
        return f"Error communicating with Gemini: {e}"

# --- Example Usage (using data from previous steps) ---

# Fetch some recent logs for a doctor (Example Doctor 0)
doctor_to_analyze = 0
fetched_logs = []
try:
    conn = engine.connect()
    select_logs = db.select(stress_logs_table.c.timestamp, stress_logs_table.c.raw_stress_score)\
                    .where(stress_logs_table.c.doctor_id == doctor_to_analyze)\
                    .order_by(stress_logs_table.c.timestamp.desc())\
                    .limit(5)
    results = conn.execute(select_logs)
    # Use .mappings().all() for easy dictionary conversion if using newer SQLAlchemy versions
    # fetched_logs = results.mappings().all() # Preferred method
    # Manual conversion for broader compatibility:
    fetched_logs = [{'timestamp': row[0], 'raw_stress_score': row[1]} for row in results]

    conn.close()
    print(f"\nFetched {len(fetched_logs)} recent logs for Doctor {doctor_to_analyze}.")
except Exception as e:
    print(f"Error fetching recent logs for Gemini example: {e}")


# Get the current stress score used in optimization (or predict again if needed)
if 'simulated_current_stress' in locals():
    current_score_example = simulated_current_stress[doctor_to_analyze]
elif 'current_stress_levels' in locals():
     current_score_example = current_stress_levels[doctor_to_analyze]
else:
     current_score_example = 0.5 # Default example score


# Call the Gemini function
if gemini_model:
  gemini_insight = get_stress_insights_with_gemini(
      doctor_id=doctor_to_analyze,
      current_stress_score=current_score_example,
      recent_stress_logs=fetched_logs
  )
else:
    print("Skipping Gemini insight generation as model is not initialized.")


# --- Conversational Assistant (Conceptual) ---
# Gemini can power a chatbot for hospital staff:
# - Staff could ask: "Show Dr. Smith's stress trend." -> App queries DB, formats data, sends to Gemini for summarization.
# - Staff could ask: "What are the constraints for scheduling night shifts?" -> Gemini accesses configured rules/knowledge base.
# - Staff could ask: "Suggest a replacement for Dr. Jones' shift tomorrow afternoon." -> Gemini interacts with the optimization engine/DB to find suitable candidates.

def ask_gemini_assistant(query, context_data=None):
    """Simulates asking a question to a Gemini-powered assistant."""
    if not gemini_model:
        return "Gemini assistant not available."

    prompt = f"You are a helpful assistant for hospital staff using a stress detection and scheduling system.\n"
    prompt += f"User Query: '{query}'\n\n"
    if context_data:
         prompt += f"Relevant Context Data:\n{json.dumps(context_data, indent=2, default=str)}\n\n" # Provide context if needed

    prompt += "Provide a helpful and concise answer based on the query and context."

    try:
        print(f"\n--- Calling Gemini Assistant ---")
        response = gemini_model.generate_content(prompt)
        if response.candidates and hasattr(response.candidates[0], 'content') and response.candidates[0].content.parts:
            answer = response.text
            print(f"Gemini Assistant Response:\n{answer}")
            return answer
        else:
            print("Gemini Assistant Warning: No content or blocked response.")
            return "Assistant could not process the query."
    except Exception as e:
        print(f"Error calling Gemini Assistant API: {e}")
        return f"Error communicating with Gemini: {e}"

# Example Assistant Query
assistant_query = "Summarize the current schedule status for Day 1."
# In a real app, fetch schedule data for Day 1 as context
schedule_context = {"Day1_Schedule": "Dr. 0 (S0), Dr. 1 (S1), Dr. 2 (S2), Dr. 3 (S0), Dr. 4 (S1)"} # Example
if gemini_model:
  assistant_response = ask_gemini_assistant(assistant_query, schedule_context)
else:
  print("Skipping Gemini assistant query as model is not initialized.")

In [None]:
# @title << 7. Privacy, Security & Deployment >>

# --- 7.1 Privacy & Consent ---
print("\n--- Privacy Considerations ---")
print("1.  **Explicit Consent:** Absolutely crucial. Before monitoring any facial expressions:")
print("    - Obtain clear, informed, written consent from each doctor.")
print("    - Explain exactly WHAT data is collected (images, features, stress scores), HOW it's used (scheduling, anonymized reporting), WHO can access it, and HOW LONG it's stored.")
print("    - Ensure consent is specific, granular (if possible), and easily revocable.")
print("2.  **Anonymization/De-identification:**")
print("    - **Edge Processing:** Ideally, perform face detection and feature extraction on the edge device. Only send numerical features or anonymized/cropped face images (if strictly necessary) to the backend.")
print("    - **Data Minimization:** Collect only the minimum data needed. Do you need full video, or just snapshots? Do you need high-res images, or can lower-res suffice for feature extraction?")
print("    - **Feature Focus:** Store landmarks, stress scores, and metadata. AVOID storing raw facial images long-term unless essential for regulatory/audit reasons (with strict access controls).")
print("    - **Aggregation:** Report on team/department stress levels in aggregate to avoid singling out individuals in general reports.")
print("3.  **Data Security (see below):** Protect the data collected.")
print("4.  **Purpose Limitation:** Use the data ONLY for the consented purposes (stress detection for scheduling optimization and well-being support). Do not repurpose for performance evaluation or disciplinary actions unless explicitly consented to and ethically reviewed.")
print("5.  **Transparency:** Allow doctors to view their own stress data and understand how it influences their schedule.")
print("6.  **Compliance:** Adhere to relevant regulations (e.g., GDPR in Europe, HIPAA in the US). Consult legal/privacy experts.")
print("7.  **Ethical Review:** Consider review by an institutional ethics committee due to the sensitive nature of biometric data and potential impact on staff.")

# --- 7.2 Security Measures ---
print("\n--- Security Measures ---")
print("1.  **Authentication & Authorization:**")
print("    - Secure API endpoints (FastAPI): Use OAuth2 (e.g., with JWT Bearer tokens) or API Keys for server-to-server communication.")
print("    - Implement role-based access control (RBAC). Admins can trigger optimization, view all data; doctors can view own data; system components have specific permissions.")
print("    - Secure Frontend Login: Use robust password hashing (e.g., bcrypt), consider Multi-Factor Authentication (MFA).")
print("2.  **Data Encryption:**")
print("    - **In Transit:** Use HTTPS (TLS/SSL) for all API communication (ngrok provides this for Colab tunnels; use load balancers/reverse proxies like Nginx/Traefik in production).")
print("    - **At Rest:** Encrypt sensitive data in the database (SQL Server TDE, application-level encryption for specific fields). Encrypt backups.")
print("3.  **Input Validation:** Sanitize and validate all inputs to API endpoints (Pydantic helps in FastAPI) to prevent injection attacks.")
print("4.  **Secrets Management:** DO NOT hardcode API keys, database credentials, or secret keys. Use environment variables, Docker secrets, or dedicated secrets managers (GCP Secret Manager, AWS Secrets Manager, HashiCorp Vault).")
print("5.  **Secure Dependencies:** Regularly scan dependencies for vulnerabilities (e.g., `pip-audit`, `safety`, GitHub Dependabot/Snyk).")
print("6.  **Rate Limiting:** Protect API endpoints from abuse by implementing rate limiting.")
print("7.  **Logging & Monitoring:** Log security events (login attempts, access violations, errors). Monitor system health and performance.")
print("8.  **Container Security:** Scan Docker images for vulnerabilities. Run containers with least privilege.")

# --- 7.3 Deployment Strategy ---
print("\n--- Deployment Strategy ---")
print("1.  **Containerization (Docker):**")
print("    - Create `Dockerfile` for the FastAPI backend (including ONNX runtime, OR-Tools, etc.).")
# Example Dockerfile structure (conceptual):
"""
# Dockerfile Example (Conceptual)
# Use an appropriate base image (e.g., Python slim with build tools)
# FROM python:3.10-slim

# WORKDIR /app

# Install system dependencies (like build-essential for some Python packages, ODBC drivers for SQL Server)
# RUN apt-get update && apt-get install -y --no-install-recommends \
#    build-essential \
#    unixodbc-dev # Example for pyodbc
    # Add commands to install Microsoft ODBC Driver if needed
#    && rm -rf /var/lib/apt/lists/*

# Copy requirements and install Python packages
# COPY requirements.txt .
# RUN pip install --no-cache-dir -r requirements.txt

# Copy application code, models, etc.
# COPY . .
# Make sure ONNX models are copied into the image or loaded from a volume

# Expose the port FastAPI runs on
# EXPOSE 8000

# Command to run the application
# CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] # Assuming your FastAPI app object is in main.py
"""
print("    - Create `Dockerfile` for any other services (e.g., time-series aggregator if separate).")
print("    - Use Docker Compose for local development/testing setup.")

print("2.  **Cloud Platform (GCP or AWS):**")
print("    - **Compute:**")
print("        - **GCP:** Google Kubernetes Engine (GKE) for orchestration, Cloud Run for serverless containers (good for stateless APIs like inference if cold starts are acceptable), Vertex AI Endpoints for managed ML model serving (supports ONNX, TF, PyTorch, custom containers, includes scaling, monitoring).")
print("        - **AWS:** Elastic Kubernetes Service (EKS), Elastic Container Service (ECS), Fargate (serverless containers), SageMaker Endpoints (managed ML serving).")
print("    - **Database:**")
print("        - **GCP:** Cloud SQL (managed PostgreSQL, MySQL, SQL Server), AlloyDB.")
print("        - **AWS:** RDS (managed instances), Aurora.")
print("    - **API Gateway:**")
print("        - **GCP:** API Gateway or Cloud Load Balancer + IAP (Identity-Aware Proxy).")
print("        - **AWS:** API Gateway.")
print("    - **Monitoring & Logging:**")
print("        - **GCP:** Cloud Monitoring, Cloud Logging.")
print("        - **AWS:** CloudWatch.")
print("    - **Secrets Management:** GCP Secret Manager, AWS Secrets Manager.")

print("3.  **CI/CD Pipeline:**")
print("    - Use tools like GitHub Actions, GitLab CI, Jenkins, Google Cloud Build, AWS CodePipeline.")
print("    - Automate testing (unit, integration, API tests), security scanning, building Docker images, and deploying to staging/production environments.")

print("4.  **Model Deployment & MLOps:**")
print("    - **Model Registry:** Use MLflow, Vertex AI Model Registry, or SageMaker Model Registry to version and track models.")
print("    - **Inference Optimization:** Use ONNX Runtime, TensorRT (on NVIDIA GPUs), TensorFlow Lite (for edge/mobile if needed).")
print("    - **Monitoring:** Track model performance (accuracy, F1, AUC), prediction drift, data drift, inference latency, resource utilization.")
print("    - **Retraining Strategy:** Define triggers for retraining (e.g., performance degradation, new data available) and automate the retraining pipeline.")

print("5.  **Edge Deployment (Optional):**")
print("    - If face detection/feature extraction runs on-site: Use smaller models (MobileNet, EfficientNet-Lite), TensorFlow Lite, ONNX Runtime for Edge.")
print("    - Manage edge devices using IoT platforms (GCP IoT Core - *being deprecated, consider alternatives*, AWS IoT Core).")


# --- 7.4 Dockerfile Example Snippet ---
dockerfile_content = """
# Use an official Python runtime as a parent image
FROM python:3.10-slim

# Set environment variables (optional but good practice)
ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1

# Set the working directory in the container
WORKDIR /app

# Install system dependencies if needed (e.g., for opencv, onnxruntime, pyodbc)
# Example: RUN apt-get update && apt-get install -y --no-install-recommends libgl1-mesa-glx libglib2.0-0 && rm -rf /var/lib/apt/lists/*
# Example for SQL Server ODBC: RUN apt-get update && apt-get install -y curl apt-transport-https gnupg2 unixodbc-dev ... && #<commands to install msodbcsql17>

# Install pip requirements
# Copy only requirements first to leverage Docker cache
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy the rest of the application code
COPY . .

# Ensure ONNX model is copied or accessible (e.g., via volume mount)
# COPY ./onnx_models /app/onnx_models

# Expose the port the app runs on
EXPOSE 8000

# Define the command to run the application
# Assumes your FastAPI app instance is named 'app' in a file named 'main.py'
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
"""
print("\n--- Example Dockerfile Structure ---")
print("```dockerfile")
print(dockerfile_content)
print("```")
print("Note: Adapt system dependencies and entrypoint based on the final 'main.py' structure.")


In [None]:
# @title << 8. Conclusion & Next Steps >>

print("\n--- Project Summary ---")
print("This Colab notebook provides a comprehensive blueprint and implementation for an AI-driven platform to:")
print("1.  **Detect Doctor Stress:** Using facial expression analysis with Deep Learning models (ResNet, EfficientNet) trained on relevant datasets (like FER2013 mapped to stress) and exported to ONNX for efficient inference.")
print("2.  **Optimize Shifts:** Employing Google OR-Tools (CP-SAT solver) to generate fair and constraint-compliant schedules that consider predicted stress levels, skills, availability, and rules.")
print("3.  **Integrate Systems:** Outlining a microservices architecture using FastAPI for the backend, defining a SQL Server database schema, conceptualizing a Flutter frontend, and integrating the Gemini API for enhanced insights and assistant capabilities.")
print("4.  **Address MLOps & Productionization:** Covering essential aspects like data preprocessing, augmentation, model evaluation, hyperparameter tuning concepts, deployment strategies (Docker, Cloud), security, and privacy.")

print("\n--- Key Components Implemented/Demonstrated ---")
print("- Kaggle API for dataset download.")
print("- Data preprocessing pipeline (face detection placeholder, normalization, resizing).")
print("- Training and evaluation of two CNN models (ResNet, EfficientNet) for stress classification (binary).")
print("- Model export to ONNX format.")
print("- Shift optimization using OR-Tools CP-SAT with various constraints (skills, workload, stress).")
print("- Conceptual database schema using SQLAlchemy.")
print("- FastAPI backend skeleton with prediction and optimization endpoints.")
print("- Loading and using the ONNX model within FastAPI.")
print("- Integration with the Gemini API for generating stress insights and conversational support.")
print("- Discussion of privacy, security, deployment, and MLOps best practices.")
print("- Conceptual outline for a Flutter frontend.")

print("\n--- Limitations & Future Work ---")
print("- **Stress Dataset:** Relied heavily on FER2013 mapped to stress. A dedicated, validated facial *stress* dataset would significantly improve model accuracy and reliability.")
print("- **rPPG:** Not implemented; incorporating physiological signals like heart rate variability via rPPG could enhance stress detection.")
print("- **Temporal Analysis:** Models trained on single images. Using LSTMs/GRUs/Transformers on sequences of features over time would capture temporal dynamics of stress.")
print("- **Hyperparameter Optimization:** Only conceptualized; full Optuna/Ray Tune integration needed for optimal model performance.")
print("- **TensorRT:** Only conceptualized; requires specific hardware and environment setup.")
print("- **RL for Scheduling:** Only conceptualized; requires building a complex simulation environment.")
print("- **Robust Constraints:** Consecutive shift and minimum rest constraints in OR-Tools need more robust implementation.")
print("- **Error Handling & Resilience:** Production code needs more comprehensive error handling, retries, and monitoring.")
print("- **Real-time Processing:** Handling real-time video streams efficiently requires optimization (e.g., frame skipping, batching).")
print("- **Full Frontend/Backend:** Only backend API skeleton and Flutter concepts provided. Full implementation required.")
print("- **Ethical & Bias Audit:** Thoroughly audit models and system for potential biases (e.g., performance differences across demographics) and ethical implications.")
print("- **User Feedback Loop:** Incorporate feedback from doctors and administrators to iteratively improve the system.")

print("\n--- Final Thoughts ---")
print("Building such a system requires careful consideration of technical, ethical, and practical aspects. This notebook serves as a strong starting point, demonstrating the feasibility and integration of various AI and software engineering components. Continuous iteration, validation with real-world data and users, and a strong focus on privacy and ethics are paramount for success in a sensitive domain like healthcare.")