In [1]:
!pip install torch torchvision
!pip install torchxrayvision transformers gradio grad-cam huggingface_hub accelerate bitsandbytes

Collecting torchxrayvision
  Downloading torchxrayvision-1.4.0-py3-none-any.whl.metadata (18 kB)
Collecting grad-cam
  Downloading grad-cam-1.5.5.tar.gz (7.8 MB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m7.8/7.8 MB[0m [31m79.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting ttach (from grad-cam)
  Downloading ttach-0.0.3-py3-none-any.whl.metadata (5.2 kB)
Downloading torchxrayvision-1.4.0-py3-none-any.whl (29.0 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m29.0/29.0 MB[0m [31m44.6 MB/s[0m eta [36m0:00:00[0

In [None]:
# --- STEP 2: RUN THE APPLICATION (FINAL RESIZE FIX) ---
# (Run this cell *after* restarting the session)

# --- Imports & HF Login ---
import torch
import gradio as gr
import torchxrayvision as xrv
import numpy as np
from PIL import Image
from google.colab import userdata
import huggingface_hub
import os
import re
import math # <-- Import for ViT reshape

# Import model-specific classes
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoImageProcessor,
    AutoModelForImageClassification
)

# --- Using 'pytorch_grad_cam' as you requested ---
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# Suppress warnings
import warnings
warnings.filterwarnings("ignore")

print("‚úÖ Step 1/7: Libraries imported")

# Authenticate with Hugging Face
try:
    # --- Using sample token as you requested ---
    hf_token = 'x'
    huggingface_hub.login(token=hf_token)
    print("‚úÖ Step 2/7: Hugging Face Login Successful")
except Exception as e:
    print(f"üõë Step 2/7: HF Login Failed. Error: {e}")

# --- WRAPPER CLASS FOR HUGGINGFACE MODELS ---
class HuggingFaceWrapper(torch.nn.Module):
    def __init__(self, model):
        super(HuggingFaceWrapper, self).__init__()
        self.model = model

    def forward(self, input_tensor):
        outputs = self.model(input_tensor)
        return outputs.logits

# --- RESHAPE FUNCTION FOR ViT ---
def reshape_transform_vit(tensor):
    result = tensor[:, 1:, :]
    batch_size, num_patches, hidden_dim = result.shape
    side_length = int(math.sqrt(num_patches)) # 14
    result = result.reshape(batch_size, side_length, side_length, hidden_dim)
    result = result.permute(0, 3, 1, 2)
    return result


# --- STEP 3: LOAD ALL MODELS (ONCE) ---
print("‚è≥ Step 3/7: Loading models... This will take a few minutes.")

# 3a: Load Chest X-Ray Model (Stable: float32)
def load_chest_model():
    model = xrv.models.DenseNet(weights="densenet121-res224-all").to("cuda")
    model.eval()
    return model

# 3b: Load Retina Model (Stable: float32)
def load_retina_model():
    model_id = "rafalosa/diabetic-retinopathy-224-procnorm-vit"
    processor = AutoImageProcessor.from_pretrained(model_id)
    model = AutoModelForImageClassification.from_pretrained(
        model_id
    ).to("cuda")
    model.eval()
    return model, processor

# 3c: Load Med-Gemma LLM (using your pre-quantized model)
def load_med_gemma():
    model_id = "lekhana123456/med-gemma-2b-4bit"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype="auto",
        device_map="auto",
        trust_remote_code=True
    )
    return model, tokenizer

# Load all models into memory
chest_model = load_chest_model()
retina_model, retina_processor = load_retina_model()
llm_model, llm_tokenizer = load_med_gemma() # <-- Med-Gemma is still loaded
print("‚úÖ Med-Gemma 2B model is loaded but will be bypassed.")


# --- Pre-initialize Grad-CAM objects ---
print("‚è≥ Initializing Grad-CAM...")

# 1. Chest X-Ray (no wrapper needed)
chest_target_layer = chest_model.features.denseblock4.denselayer16
chest_cam = GradCAM(model=chest_model, target_layers=[chest_target_layer])

# 2. Retina Scan (Use the wrapper class AND the reshape function)
retina_target_layer = retina_model.vit.encoder.layer[-1].output
wrapped_retina_model = HuggingFaceWrapper(retina_model)
retina_cam = GradCAM(model=wrapped_retina_model,
                     target_layers=[retina_target_layer],
                     reshape_transform=reshape_transform_vit)

print("‚úÖ Grad-CAM initialized.")

print("‚úÖ Step 3/7: All models loaded successfully!")

# --- STEP 4: HELPER FUNCTION - GRAD-CAM ---
# --- *** THIS IS THE UPDATED FUNCTION TO FIX THE RESIZE ERROR *** ---
def get_grad_cam_overlay(cam_object, input_tensor, base_image_pil, target_class_index=None):
    if target_class_index is None:
        model_output = cam_object.model(input_tensor)
        target_class_index = model_output.argmax().item()

    targets = [ClassifierOutputTarget(target_class_index)]

    # grayscale_cam will be (e.g.) (224, 224)
    grayscale_cam = cam_object(input_tensor=input_tensor, targets=targets)[0, :]

    # --- *** CRITICAL FIX HERE *** ---
    # Get the heatmap's dimensions
    heatmap_height, heatmap_width = grayscale_cam.shape

    # Resize the base_image_pil (which could be 480x640) to match the heatmap (224x224)
    resized_base_image_pil = base_image_pil.resize((heatmap_width, heatmap_height))

    # Convert the *resized* PIL image to a numpy array
    img_np = np.array(resized_base_image_pil.convert("RGB")) / 255.0
    # --- *** END FIX *** ---

    # Now img_np and grayscale_cam have matching H/W dimensions
    cam_image = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)

    return Image.fromarray(cam_image), target_class_index

# --- STEP 5: HELPER FUNCTION - LLM REPORT GENERATION ---
# --- (This function is bypassed but kept for future use) ---
def generate_llm_report(findings_prompt):
    full_prompt = f"""Synthesize these AI findings into a brief preliminary report. Include the confidence percentages.

FINDINGS:
{findings_prompt}

PRELIMINARY REPORT:
"""

    inputs = llm_tokenizer(full_prompt, return_tensors="pt").to("cuda")
    prompt_token_length = inputs.input_ids.shape[1]

    outputs = llm_model.generate(
        **inputs,
        max_new_tokens=250,
        temperature=0.2,
        do_sample=True,
        pad_token_id=llm_tokenizer.eos_token_id
    )

    all_output_tokens = outputs[0]
    new_tokens = all_output_tokens[prompt_token_length:]
    report = llm_tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

    return report

print("‚úÖ Step 4/7 & 5/7: Helper functions defined")

# --- STEP 6: THE MAIN ANALYSIS FUNCTION (Ties everything together) ---
def analyze_image(image_pil, image_type):
    if image_pil is None:
        return None, "Please upload an image."

    print(f"--- Processing {image_type} ---")
    image_pil = image_pil.convert("RGB")

    if image_type == "Chest X-Ray":
        # 1. Preprocess for Model
        img_greyscale = image_pil.convert("L")
        img_numpy_g = np.array(img_greyscale)
        img_numpy_g_3d = img_numpy_g[np.newaxis, ...]
        img_cropped_g = xrv.datasets.XRayCenterCrop()(img_numpy_g_3d)
        img_normalized = xrv.datasets.normalize(img_cropped_g, 255)
        img_tensor = torch.from_numpy(img_normalized).unsqueeze(0).to("cuda")

        # 2. Preprocess for Visualization (create a cropped RGB image)
        img_rgb_np = np.array(image_pil)
        img_rgb_np_cwh = img_rgb_np.transpose(2, 0, 1)
        img_rgb_cropped_cwh = xrv.datasets.XRayCenterCrop()(img_rgb_np_cwh)
        img_rgb_cropped_hwc = img_rgb_cropped_cwh.transpose(1, 2, 0)
        cropped_pil_for_viz = Image.fromarray(img_rgb_cropped_hwc)

        # 3. Get Confidence Scores
        with torch.no_grad():
            outputs = chest_model(img_tensor)
            probs = torch.sigmoid(outputs)

        probs = probs.float()
        top_probs, top_indices = torch.topk(probs[0], 3)
        findings = {}
        for i, prob in zip(top_indices, top_probs):
            if prob > 0.1:
                disease = chest_model.targets[i]
                findings[disease] = f"{prob.item()*100:.2f}%"

        # 4. Get Grad-CAM
        # The new get_grad_cam_overlay function will handle resizing cropped_pil_for_viz
        cam_image, top_class_idx = get_grad_cam_overlay(chest_cam, img_tensor, cropped_pil_for_viz)
        cam_finding = chest_model.targets[top_class_idx]

        # 5. Format prompt for LLM
        findings_str = (
            f"* Image Type: Chest X-Ray\n"
            f"* Top Model Findings (Multi-Label):\n" +
            "\n".join([f"  - {disease}: {conf}" for disease, conf in findings.items()]) +
            f"\n* Explainability (Grad-CAM): Heatmap is focused on the area for the top finding: '{cam_finding}'."
        )

    elif image_type == "Retina Scan":
        # 1. Preprocess
        inputs = retina_processor(images=image_pil, return_tensors="pt").to("cuda")
        img_tensor = inputs["pixel_values"] # Shape is (1, 3, 224, 224)

        # 2. Get Confidence Scores
        with torch.no_grad():
            outputs = retina_model(img_tensor) # This returns the object
            probs = torch.softmax(outputs.logits, dim=1) # We get the logits here

        probs = probs.float()
        top_prob, top_idx = torch.max(probs, 1)
        top_class_idx = top_idx.item()
        top_class_prob = top_prob.item()

        labels = retina_model.config.id2label
        findings = {labels[i]: f"{probs[0, i].item()*100:.2f}%" for i in range(len(labels))}

        # 3. Get Grad-CAM
        # The new get_grad_cam_overlay function will handle resizing image_pil
        cam_image, _ = get_grad_cam_overlay(
            retina_cam, img_tensor, image_pil, target_class_index=top_class_idx
        )
        cam_finding = labels[top_class_idx]

        # 4. Format prompt for LLM
        findings_str = (
            f"* Image Type: Retina Scan (Diabetic Retinopathy Classification)\n"
            f"* Primary Diagnosis: {cam_finding} (Confidence: {top_class_prob*100:.2f}%)\n"
            f"* All Class Confidences:\n" +
            "\n".join([f"  - {label}: {conf}" for label, conf in findings.items()]) +
            f"\n* Explainability (Grad-CAM): Heatmap is focused on the area for the primary diagnosis."
        )

    # 5. Send to Med-Gemma and get the final report
    print("Findings (will be used as report):\n", findings_str)

    # --- We are commenting out the LLM call ---
    # report = generate_llm_report(findings_str)
    # print("LLM Report:\n", report)

    # --- Instead, the "report" is just the simple prompt string ---
    report = findings_str

    return cam_image, report

print("‚úÖ Step 6/7: Main analysis function defined")

# --- STEP 7: LAUNCH THE GRADIO UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # üè• Medical Image Analysis & Report Generation
        **Upload a Chest X-Ray or Retina Scan to generate a preliminary report.**

        **Disclaimer:** This is a technology demo and is **NOT** a medical device.
        The output is generated by AI models and has not been verified by a medical professional.
        Do not use for self-diagnosis.
        """
    )
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload Image")
            model_selector = gr.Radio(
                ["Chest X-Ray", "Retina Scan"],
                label="Select ImageType",
                value="Chest X-Ray"
            )
            submit_btn = gr.Button("Generate Report", variant="primary")
        with gr.Column(scale=2):
            heatmap_output = gr.Image(label="Explainability Heatmap (Grad-CAM)")
            report_output = gr.Textbox(label="Generated Preliminary Report", lines=10)

    submit_btn.click(
        analyze_image,
        inputs=[image_input, model_selector],
        outputs=[heatmap_output, report_output]
    )

# Launch the app! The link will appear in the notebook output.
print("‚úÖ Step 7/7: Launching Gradio Interface...")
demo.launch(debug=True, share=True)

‚úÖ Step 1/7: Libraries imported
‚úÖ Step 2/7: Hugging Face Login Successful
‚è≥ Step 3/7: Loading models... This will take a few minutes.
‚úÖ Med-Gemma 2B model is loaded but will be bypassed.
‚è≥ Initializing Grad-CAM...
‚úÖ Grad-CAM initialized.
‚úÖ Step 3/7: All models loaded successfully!
‚úÖ Step 4/7 & 5/7: Helper functions defined
‚úÖ Step 6/7: Main analysis function defined
‚úÖ Step 7/7: Launching Gradio Interface...
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://96aa1c898b00de0cad.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fastapi/applications.py", line 1134, in __call__
    await super().__call__(scope, receive, send)
  File "/usr/local/lib/python3.12/dist-packages/starlette/applications.py", line 113, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py",

--- Processing Retina Scan ---
Findings (will be used as report):
 * Image Type: Retina Scan (Diabetic Retinopathy Classification)
* Primary Diagnosis: moderate (Confidence: 51.42%)
* All Class Confidences:
  - mild: 9.84%
  - moderate: 51.42%
  - no dr: 22.44%
  - proliferative: 6.86%
  - severe: 9.45%
* Explainability (Grad-CAM): Heatmap is focused on the area for the primary diagnosis.


ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fastapi/applications.py", line 1134, in __call__
    await super().__call__(scope, receive, send)
  File "/usr/local/lib/python3.12/dist-packages/starlette/applications.py", line 113, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/usr/local/lib/python3.12/dist-packages/starlette/middleware/errors.py",

--- Processing Retina Scan ---
Findings (will be used as report):
 * Image Type: Retina Scan (Diabetic Retinopathy Classification)
* Primary Diagnosis: moderate (Confidence: 38.84%)
* All Class Confidences:
  - mild: 10.66%
  - moderate: 38.84%
  - no dr: 16.85%
  - proliferative: 15.12%
  - severe: 18.53%
* Explainability (Grad-CAM): Heatmap is focused on the area for the primary diagnosis.
