In [38]:
from sklearn.metrics import classification_report, confusion_matrix,roc_curve,auc

In [3]:
import torch
import open_clip
import cv2
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from glob import glob
from torchvision import transforms
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report, confusion_matrix,roc_curve
from lightgbm import LGBMClassifier
import warnings
import albumentations as A
import os

warnings.filterwarnings('ignore')

albumentations_aug = A.Compose([
    A.OneOf([
        A.MotionBlur(p=0.3),
        A.GaussianBlur(blur_limit=5, p=0.5),
        A.MedianBlur(blur_limit=5, p=0.5),
    ], p=0.6),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.Downscale(scale_min=0.7, scale_max=0.95, p=0.3),
    A.RandomResizedCrop(size=(672, 672), scale=(0.8, 1.0), ratio=(0.75, 1.33), p=0.4),
    A.GaussNoise(var_limit=(5.0, 20.0), p=0.3),
    A.ImageCompression(quality_lower=30, quality_upper=80, compression_type='jpeg', p=0.2),
])

# === 1. Load CLIP Model ===
def load_clip(device='cuda' if torch.cuda.is_available() else 'cpu'):
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
    model.to(device).eval()
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    return model, preprocess, tokenizer, device

# === 2. Pupil Cropping ===
def crop_to_pupil(image_path, output_size=(512, 512)):
    image = cv2.imread(image_path)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    gray_blur = cv2.medianBlur(gray, 5)

    circles = cv2.HoughCircles(gray_blur, cv2.HOUGH_GRADIENT, dp=1.5, minDist=30,
                                param1=50, param2=30, minRadius=20, maxRadius=150)

    if circles is not None:
        circles = np.uint16(np.around(circles))
        x, y, r = circles[0][0]
        pad = int(r * 1.5)
        x1, y1 = max(0, x - pad), max(0, y - pad)
        x2, y2 = min(image.shape[1], x + pad), min(image.shape[0], y + pad)
        cropped = image[y1:y2, x1:x2]
    else:
        print(f"⚠️ Pupil not detected in {image_path}, using full image.")
        cropped = image

    resized = cv2.resize(cropped, output_size)
    return resized

# === 3. Save Cropped Images ===
def preprocess_folder(input_folder, output_folder, size=(512, 512)):
    os.makedirs(output_folder, exist_ok=True)
    for fname in tqdm(os.listdir(input_folder)):
        if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
            path = os.path.join(input_folder, fname)
            cropped = crop_to_pupil(path, output_size=size)
            save_path = os.path.join(output_folder, fname)
            cv2.imwrite(save_path, cropped)

# === 4. Augmentation (for training only) ===
augmentation_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.GaussianBlur(3)], p=0.2)
])

def augment_image(img_pil, n_augmentations=2):
    augmented = [augmentation_transforms(img_pil) for _ in range(n_augmentations)]
    grayscale = transforms.Grayscale()(img_pil)
    augmented.append(grayscale.convert("RGB"))

    np_img = np.array(img_pil)
    gray = cv2.cvtColor(np_img, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, 100, 200)
    edge_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
    edge_pil = Image.fromarray(edge_rgb)
    augmented.append(edge_pil)
    return augmented

# === 5. Build DataFrame for Training ===
def build_dataframe(folder, label, preprocess, model, device='cuda', augment=True, n_aug=3):
    df = pd.DataFrame(columns=range(512))
    idx = 0
    for image_path in tqdm(glob(f"{folder}/*.png")):
        image = Image.open(image_path).convert("RGB")
        images = [image]
        if augment:
            np_img = np.array(image)
            for _ in range(n_aug):
                aug_img = albumentations_aug(image=np_img)['image']
                aug_img_pil = Image.fromarray(aug_img)
                images.append(aug_img_pil)
            images += augment_image(image, n_augmentations=0)

        for img in images:
            with torch.no_grad(), torch.cuda.amp.autocast():
                tensor = preprocess(img).unsqueeze(0).to(device)
                feat = model.encode_image(tensor)
                feat = feat / feat.norm(dim=-1, keepdim=True)
                df.loc[idx, list(range(512))] = feat.cpu().numpy()[0]
                df.loc[idx, 'category'] = label
                idx += 1
    return df

# === 6. Train Classifier ===
def train_classifier(X, y):
    model = LGBMClassifier(random_state=42)
    model.fit(X, y)
    return model

# === 7. Evaluate Model ===
def evaluate_model(model, X_val, y_val):
    y_pred = model.predict(X_val)
    print(confusion_matrix(y_val, y_pred))
    print(classification_report(y_val, y_pred))

# === 8. Inference with Flags ===
def infer_on_folder(folder_path, clf, preprocess, model, device='cuda',
                    crop=True, strict_preprocess=True, batch_size=16):
    results = []
    image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path)
                   if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for i in tqdm(range(0, len(image_paths), batch_size), desc="🔍 Batched Inference"):
        batch_paths = image_paths[i:i + batch_size]
        images = []

        for path in batch_paths:
            if crop:
                img_cv = crop_to_pupil(path)
                img = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
            else:
                img = Image.open(path).convert("RGB")

            if strict_preprocess:
                img = preprocess(img)
            else:
                img = transforms.Resize((512, 512))(img)
                img = transforms.ToTensor()(img)

            images.append(img)

        with torch.no_grad(), torch.cuda.amp.autocast():
            batch = torch.stack(images).to(device)
            embeddings = model.encode_image(batch)
            embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
            embeddings = embeddings.cpu().numpy()

        preds = clf.predict(embeddings)
        probs = clf.predict_proba(embeddings)

        for j, path in enumerate(batch_paths):
            results.append({
                'image_path': path,
                'prediction': 'cataract' if preds[j] == 1 else 'normal',
                'prob_cataract': round(probs[j][1], 4),
                'prob_normal': round(probs[j][0], 4),
            })

    return pd.DataFrame(results)


In [None]:
import random
import matplotlib.pyplot as plt

mlflow.set_tracking_uri("file:./mlruns")

# === Helper to log sample images ===
def log_sample_images(original_image_path, save_dir="sample_logs", n_aug=3):
    os.makedirs(save_dir, exist_ok=True)
    image = Image.open(original_image_path).convert("RGB")

    original_path = f"{save_dir}/original.png"
    image.save(original_path)
    mlflow.log_artifact(original_path, artifact_path="images")

    aug_images = augment_image(image, n_augmentations=n_aug)
    for i, img in enumerate(aug_images):
        path = f"{save_dir}/aug_{i}.png"
        img.save(path)
        mlflow.log_artifact(path, artifact_path="images")

# === Helper to log ROC curve ===
def log_roc_curve(y_true, y_probs, file_path="roc_curve.png"):
    fpr, tpr, _ = roc_curve(y_true, y_probs)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(6, 6))
    plt.plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.2f})")
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend(loc="lower right")
    plt.savefig(file_path)
    mlflow.log_artifact(file_path)
    mlflow.log_metric("val_roc_auc", roc_auc)


with mlflow.start_run(run_name="CLIP_LGBM_Cataract_Classifier"):
    mlflow.set_tag("clip_model", "ViT-B-32")
    mlflow.log_param("device", 'cpu')
    mlflow.log_param("n_augmentations", 3)
    mlflow.log_param("dataset_path", "processed_aug_aug/train/")

    # Load model
    model, preprocess, tokenizer, device = load_clip()

    # Step A: Preprocess
    preprocess_folder("processed_images/train/normal", "processed_aug_aug/train/normal")
    preprocess_folder("processed_images/train/cataract", "processed_aug_aug/train/cataract")

    # Log sample images
    normal_samples = random.sample(glob("processed_aug_aug/train/normal/*.png"), 1)
    cataract_samples = random.sample(glob("processed_aug_aug/train/cataract/*.png"), 1)
    for sample in normal_samples + cataract_samples:
        log_sample_images(sample)

    # Step B: Build dataset
    start_train_time = time.time()
    df_normal = build_dataframe("processed_aug_aug/train/normal", 0, preprocess, model, device=device)
    df_cataract = build_dataframe("processed_aug_aug/train/cataract", 1, preprocess, model, device=device)
    df = pd.concat([df_normal, df_cataract]).astype(float).sample(frac=1).reset_index(drop=True)

    X = df.iloc[:, :512]
    y = df['category']
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, stratify=y)

    # Step C: Train model
    clf = train_classifier(X_train, y_train)
    mlflow.sklearn.log_model(clf, "lgbm_model_initial")

    # Evaluate model
    y_pred = clf.predict(X_val)
    y_probs = clf.predict_proba(X_val)[:, 1]
    acc = accuracy_score(y_val, y_pred)
    f1 = f1_score(y_val, y_pred)

    mlflow.log_metric("val_accuracy", acc)
    mlflow.log_metric("val_f1_score", f1)

    report = classification_report(y_val, y_pred, output_dict=True)
    mlflow.log_metric("val_precision_cataract", report["1.0"]["precision"])
    mlflow.log_metric("val_recall_cataract", report["1.0"]["recall"])

    log_roc_curve(y_val, y_probs)

    end_train_time = time.time()
    mlflow.log_metric("train_time_sec", end_train_time - start_train_time)

    # Step D: Grid Search
    param_grid = {
        'n_estimators': [100, 200],
        'learning_rate': [0.01, 0.1],
        'max_depth': [3, 5, 7],
    }
    grid = GridSearchCV(LGBMClassifier(verbose=-1), param_grid, cv=3)
    grid.fit(X_train, y_train)

    best_model = grid.best_estimator_
    mlflow.sklearn.log_model(best_model, "lgbm_model_best")
    mlflow.log_params(grid.best_params_)

    # Step E: Inference (strict)
    start_infer_strict = time.time()
    df_results_cat_strict = infer_on_folder(
        "processed_images/test/cataract/", best_model, preprocess, model,
        device=device, crop=True, strict_preprocess=True, batch_size=32
    )
    end_infer_strict = time.time()
    mlflow.log_metric("inference_time_strict", end_infer_strict - start_infer_strict)
    mlflow.log_metric("cataract_detected_strict", (df_results_cat_strict['prob_cataract'] > 0.5).sum())

    df_results_cat_strict.to_csv("strict_preds.csv", index=False)
    mlflow.log_artifact("strict_preds.csv")

    print("\n🎯 MLflow run complete. All results logged.\n")