<a href="https://colab.research.google.com/github/Ravikrishnan05/PrediscanMedtech_project/blob/main/Unsloth_ptmodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# To run this, press "Runtime" and press "Run all" on a free Tesla T4 Google Colab instance!

#    Join Discord if you need help + ⭐ Star us on Github ⭐
# To install Unsloth on your own computer, follow the installation instructions on our Github page here.

# You will learn how to do data prep, how to train, how to run the model, & how to save it

# News
# Unsloth now supports Text-to-Speech (TTS) models. Read our guide here.

# Read our Qwen3 Guide and check out our new Dynamic 2.0 quants which outperforms other quantization methods!

# Visit our docs for all our model uploads and notebooks.

# To run this, press "Runtime" and press "Run all" on a free Tesla T4 Google Colab instance!
# %%capture # Use %%capture to hide pip outputs if desired
import os
if "COLAB_" not in "".join(os.environ.keys()):
    print("Installing Unsloth for local environment...")
    !pip install "unsloth[colab-new]@git+https://github.com/unslothai/unsloth.git"
else:
    print("Installing Unsloth for Colab environment...")
    !pip install --no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
    !pip install --no-deps "unsloth[colab-new]@git+https://github.com/unslothai/unsloth.git"

In [None]:
# -----------------------------------------------------------------------------
# Cell 0.2: Additional Library Installations
# -----------------------------------------------------------------------------
print("\nInstalling additional libraries for data processing and DICOM handling...")
!pip install -q pydicom pandas opencv-python Pillow scikit-learn matplotlib seaborn "huggingface_hub>=0.23.0" "hf_transfer>=0.1.6" "datasets>=2.16.0" sentencepiece protobuf

# Install unsloth_zoo
print("\nInstalling unsloth_zoo...")
!pip install unsloth_zoo

In [None]:
# Unsloth FastModel supports loading nearly any model now! This includes Vision and Text models!

# -----------------------------------------------------------------------------
# Cell 0.3: Unsloth Model Loading
# -----------------------------------------------------------------------------
from unsloth import FastLanguageModel # Changed from FastModel to FastLanguageModel as per recent Unsloth examples for language models
import torch

In [None]:
# IMPORTANT: MODEL SELECTION FOR YOUR TASK
# The model "unsloth/gemma-3-4b-it" is a TEXT-BASED instruct model.
# Your original code used MedGemma, a VISION-LANGUAGE model, and processed images.
# If your task involves processing images to predict LDL, you MUST select a vision-language model.
# Examples:
#   - Search for Unsloth-quantized vision models: https://huggingface.co/unsloth
#   - Try loading a standard HF vision model (e.g., "google/medgemma-4b-pt", "llava-hf/llava-1.5-7b-hf", "microsoft/phi-3-vision-128k-instruct")
#     FastLanguageModel might support them. If so, set `finetune_vision_layers = True` in the PEFT setup.
# For this example, we'll use the text model from the Unsloth template.
# You will need to adapt your data processing (especially image handling in the Dataset)
# if you use a text model for a vision task, or change the model_name.
"""
from unsloth import FastLanguageModel
import torch

# --- Model Selection ---
# We are focusing on MedGemma for vision-based LDL prediction.
selected_model_name = "google/medgemma-4b-pt"

print(f"Attempting to load model: {selected_model_name}")
# When loading a multimodal model like MedGemma, FastLanguageModel handles it.
# The 'tokenizer' returned will be a multimodal processor (e.g., GemmaProcessor)
# which contains both the image_processor and the text_tokenizer.
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=selected_model_name,
    max_seq_length=2048,  # Max sequence length for the language model part (less critical for pure vision regression)
    dtype=None,           # Autodetect
    load_in_4bit=True,    # Enable 4-bit quantization for memory efficiency
    # token = "hf_...",   # Use if the model is gated
)
print(f"Model {selected_model_name} loaded successfully.")
print(f"Tokenizer type: {type(tokenizer)}")

# --- Verify Image Processor and Get Vision Feature Dimension ---
# For MedGemma, the tokenizer is a GemmaProcessor which should have an 'image_processor'
if hasattr(tokenizer, 'image_processor') and tokenizer.image_processor is not None:
    print("Image processor found in tokenizer.")
    # The vision tower configuration is part of the main model's config for MedGemma
    if hasattr(model.config, 'vision_config'):
        vision_config = model.config.vision_config
        # The vision feature dimension is typically 'hidden_size' of the vision_config
        # For SigLIP (MedGemma's vision tower), it's usually referred to as hidden_size.
        vision_feature_dim = vision_config.hidden_size
        print(f"Detected vision feature dimension from model.config.vision_config: {vision_feature_dim}")
    else:
        print("ERROR: model.config.vision_config not found. Cannot determine vision_feature_dim automatically.")
        # Fallback: Try to inspect the vision_tower directly if it exists on the base model
        # This path might be needed if Unsloth wraps the model differently.
        base_model_ref = model.model if hasattr(model, 'model') else model
        if hasattr(base_model_ref, 'vision_tower') and hasattr(base_model_ref.vision_tower, 'config'):
            vision_feature_dim = base_model_ref.vision_tower.config.hidden_size
            print(f"Detected vision feature dimension from base_model.vision_tower.config: {vision_feature_dim}")
        else:
            vision_feature_dim = None
            print("ERROR: Could not access vision_tower.config. Manually inspect 'model' object and set vision_feature_dim.")
            print("Model structure:", model) # Helps in debugging
else:
    print("ERROR: No image_processor found in the tokenizer. This is unexpected for MedGemma.")
    vision_feature_dim = None

if vision_feature_dim is None:
    print("CRITICAL ERROR: vision_feature_dim could not be determined. Regression head cannot be initialized correctly.")
    # You might need to manually set it based on MedGemma's architecture if auto-detection fails.
    # For medgemma-4b-pt, the vision feature dimension (SigLIP-L/16) is 1024.
    vision_feature_dim = 1152 # Example: Manually set if necessary
    print(f"Attempting to use manually set vision_feature_dim: {vision_feature_dim}")

"""
# Note: For vision models, the 'tokenizer' might be a composite object
# or you might access an image processor via `model.processor` or `tokenizer.image_processor`.
# This depends on how Unsloth handles vision models.

# Cell 0.3: Unsloth Model Loading (REVISED FOR FLOAT16)

from unsloth import FastLanguageModel
import torch

selected_model_name = "google/medgemma-4b-pt"

print(f"Attempting to load model: {selected_model_name} with dtype=torch.float16")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=selected_model_name,
    max_seq_length=2048,
    dtype=torch.float16,  # <--- CRITICAL CHANGE HERE
    load_in_4bit=True,
)
print(f"Model {selected_model_name} loaded successfully.")
print(f"Tokenizer type: {type(tokenizer)}")
print(f"Base model dtype after loading: {model.dtype}") # Should now be torch.float16

# --- Verify Image Processor and Get Vision Feature Dimension ---
# (This part remains the same as your last working version of Cell 0.3)
if hasattr(tokenizer, 'image_processor') and tokenizer.image_processor is not None:
    print("Image processor found in tokenizer.")
    if hasattr(model.config, 'vision_config'):
        vision_config = model.config.vision_config
        vision_feature_dim = vision_config.hidden_size
        print(f"Detected vision feature dimension from model.config.vision_config: {vision_feature_dim}")
    else:
        print("ERROR: model.config.vision_config not found. Cannot determine vision_feature_dim automatically.")
        base_model_ref = model.model if hasattr(model, 'model') else model
        if hasattr(base_model_ref, 'vision_tower') and hasattr(base_model_ref.vision_tower, 'config'):
            vision_feature_dim = base_model_ref.vision_tower.config.hidden_size
            print(f"Detected vision feature dimension from base_model.vision_tower.config: {vision_feature_dim}")
        else:
            vision_feature_dim = None
            print("ERROR: Could not access vision_tower.config. Manually inspect 'model' object and set vision_feature_dim.")
            # For medgemma-4b-pt, vision_feature_dim is 1024.
            if vision_feature_dim is None:
                print("Attempting to manually set vision_feature_dim to 1024 for MedGemma.")
                vision_feature_dim = 1024
else:
    print("ERROR: No image_processor found in the tokenizer. This is unexpected for MedGemma.")
    vision_feature_dim = 1024 # Fallback
    print(f"Attempting to manually set vision_feature_dim to {vision_feature_dim} due to missing image_processor.")


if vision_feature_dim is None:
    print("CRITICAL ERROR: vision_feature_dim could not be determined. Regression head cannot be initialized correctly.")

In [None]:
# After loading the model with Unsloth:
# The actual path might be model.model.vision_tower if PEFT wraps it further
base_medgemma_model = model.model if hasattr(model, 'model') else model # Access base model if PEFT wrapped

if hasattr(base_medgemma_model, 'vision_tower') and hasattr(base_medgemma_model.vision_tower, 'config'):
    vision_config = base_medgemma_model.vision_tower.config
    vision_feature_dim = vision_config.hidden_size
    print(f"Detected vision feature dimension: {vision_feature_dim}")
    # Now define your regression head separately or as part of a wrapper
    # regression_head = torch.nn.Linear(vision_feature_dim, 1)
else:
    print("ERROR: Could not access model.vision_tower.config to get vision_feature_dim.")
    print("Please inspect the 'model' object structure from Unsloth carefully.")
    # You might need to print(model) and explore its attributes
    vision_feature_dim = None # Fallback

In [None]:
print("\n--- Applying PEFT (LoRA) ---")
# `model` is the Unsloth-loaded MedGemma model from the previous cell.
# We use get_peft_model for LoRA.
RANDOM_SEED=42
model = FastLanguageModel.get_peft_model(
    model,
    r=16,  # LoRA rank (higher can mean more expressiveness but more params)
    lora_alpha=32,  # LoRA alpha (scaling factor, often 2*r)
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing="unsloth", # Recommended by Unsloth
    random_state=RANDOM_SEED,
    target_modules=None, # Let Unsloth automatically find layers for LoRA.
                         # It should target both vision and language linear layers by default.
    finetune_vision_layers=True, # CRITICAL: Ensure vision tower layers are targeted for LoRA
    finetune_language_layers=False # OPTIONAL: For pure vision regression, we might not need to tune language layers.
                                  # Set to False if language model outputs are not used by the regression head.
                                  # If True (default), language LoRA adapters will also be trained.
)
print("PEFT (LoRA) adapters added to the MedGemma model.")
print("Trainable parameters after LoRA:")
model.print_trainable_parameters()

In [None]:
"""
import torch.nn as nn

class MedGemmaVisionRegressor(nn.Module):
    def __init__(self, peft_medgemma_model, vision_feature_dim_input: int):
        super().__init__()
        self.medgemma_model = peft_medgemma_model # This is the PEFT-adapted model from Unsloth

        # The regression head takes the pooled vision features and outputs 1 LDL value
        self.regression_head = nn.Sequential(
            nn.Linear(vision_feature_dim_input, vision_feature_dim_input // 2), # Intermediate layer
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(vision_feature_dim_input // 2, 1) # Output layer
        )

        # Note: Freezing of base MedGemma layers is handled by Unsloth's PEFT.
        # LoRA adapters are trainable. The regression_head is also trainable.

    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        # The peft_medgemma_model is already a PeftModel.
        # We need to pass pixel_values to it.
        # The MedGemma model's forward pass can take pixel_values directly.
        # It will internally use its vision_tower.
        # For regression from vision, we typically want the pooled image features.

        # Option 1: If the PEFT model directly gives vision features or allows access
        # The `Gemma3ForMultiModalGeneration` (base for MedGemma) has `vision_tower`
        # and can output `image_embeds` or similar.
        # When using PEFT, the base model is often accessed via `self.medgemma_model.model`

        base_model = self.medgemma_model.model # Access the original model underlying PEFT

        # Get vision embeddings from the vision_tower
        # The vision_tower (SigLIP) in MedGemma outputs pooled features.
        vision_outputs = base_model.vision_tower(pixel_values=pixel_values, return_dict=True)

        # `pooler_output` from SigLipVisionModelOutput is [batch_size, vision_feature_dim]
        image_features = vision_outputs.pooler_output

        if image_features is None:
            # Fallback if pooler_output is not directly available (should be for SigLIP)
            # This might happen if the model structure is different than expected.
            # For ViT-like models, the first token's embedding ([CLS] token) is often used.
            if hasattr(vision_outputs, 'last_hidden_state'):
                image_features = vision_outputs.last_hidden_state[:, 0, :] # CLS token embedding
            else:
                raise ValueError("Could not extract pooled image features (pooler_output or CLS token) from vision_tower output.")

        # Pass vision features through the regression head
        ldl_prediction = self.regression_head(image_features)
        return ldl_prediction

# --- Instantiate the Regressor Model ---
if vision_feature_dim is not None:
    # `model` here is the PEFT-adapted MedGemma model from Cell 0.4
    regressor_model = MedGemmaVisionRegressor(model, vision_feature_dim)
    print(f"MedGemmaVisionRegressor created with regression head input dim {vision_feature_dim}.")

    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    regressor_model.to(device)
    print(f"Regressor model moved to {device}.")

    print("\nTrainable parameters of the Regressor Model (includes LoRA + head):")
    total_params = 0
    trainable_params = 0
    for name, param in regressor_model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
            # print(f"Trainable: {name}, Shape: {param.shape}") # Uncomment to see all trainable params
    print(f"Total parameters in RegressorModel: {total_params:,}")
    print(f"Trainable parameters in RegressorModel: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")

else:
    regressor_model = None
    print("CRITICAL ERROR: Cannot create MedGemmaVisionRegressor because vision_feature_dim is None.")

"""
# Cell 0.5: Define and Instantiate Custom Model Wrapper (MedGemmaVisionRegressor) - REVISED FOR FLOAT16

import torch.nn as nn

class MedGemmaVisionRegressor(nn.Module):
    def __init__(self, peft_medgemma_model, vision_feature_dim_input: int):
        super().__init__()
        self.medgemma_model = peft_medgemma_model
        self.target_dtype = self.medgemma_model.dtype # Should be torch.float16 now
        print(f"[Regressor Init] Base PEFT model target dtype: {self.target_dtype}")

        self.regression_head = nn.Sequential(
            nn.Linear(vision_feature_dim_input, vision_feature_dim_input // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(vision_feature_dim_input // 2, 1)
        )

        if self.target_dtype == torch.float16: # Explicitly check for float16
            print(f"[Regressor Init] Casting regression_head to {self.target_dtype}.")
            self.regression_head = self.regression_head.to(dtype=self.target_dtype)
        elif self.target_dtype is not None: # If it's something else, print a warning but still cast
            print(f"[Regressor Init] WARNING: Base model dtype is {self.target_dtype}, not float16. Casting regression_head to {self.target_dtype} anyway.")
            self.regression_head = self.regression_head.to(dtype=self.target_dtype)


    def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
        # Input pixel_values should be cast to self.target_dtype (float16) in the training loop

        base_model = self.medgemma_model.model
        vision_outputs = base_model.vision_tower(pixel_values=pixel_values, return_dict=True)
        image_features = vision_outputs.pooler_output

        if image_features is None:
            if hasattr(vision_outputs, 'last_hidden_state'):
                image_features = vision_outputs.last_hidden_state[:, 0, :]
            else:
                raise ValueError("Could not extract pooled image features from vision_tower output.")

        # Ensure image_features are in float16 before feeding to regression_head
        if image_features.dtype != self.target_dtype:
            image_features = image_features.to(self.target_dtype)

        ldl_prediction = self.regression_head(image_features)
        return ldl_prediction

# --- Instantiate the Regressor Model ---
if vision_feature_dim is not None and 'model' in locals() and model is not None:
    regressor_model = MedGemmaVisionRegressor(model, vision_feature_dim) # `model` is from Cell 0.3
    print(f"MedGemmaVisionRegressor created with regression head input dim {vision_feature_dim}.")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    regressor_model.to(device)
    print(f"Regressor model moved to {device}.")

    # Parameter printing (same as before)
    print("\nTrainable parameters of the Regressor Model (includes LoRA + head):")
    total_params = 0; trainable_params = 0
    for name, param in regressor_model.named_parameters():
        total_params += param.numel()
        if param.requires_grad: trainable_params += param.numel()
    print(f"Total parameters in RegressorModel: {total_params:,}")
    print(f"Trainable parameters in RegressorModel: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)")
else:
    regressor_model = None
    print("CRITICAL ERROR: Cannot create MedGemmaVisionRegressor. Check 'vision_feature_dim' and 'model'.")

In [None]:
# -----------------------------------------------------------------------------
# Cell 1: PyTorch/HuggingFace Imports and Setup (Adapted from user's Cell 1)
# -----------------------------------------------------------------------------
print("\nImporting libraries...")
# Python Standard Libraries
import shutil # os, zipfile already imported or not needed here
import zipfile

# Third-party Libraries
import pandas as pd
import numpy as np
import pydicom
import cv2 # OpenCV
from PIL import Image

# PyTorch
# import torch # Already imported
from torch.utils.data import Dataset, DataLoader

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Hugging Face (tokenizer is already loaded by Unsloth)
# from transformers import AutoProcessor # Replaced by Unsloth's tokenizer

# Plotting (optional, but often useful)
import matplotlib.pyplot as plt
import seaborn as sns

# Colab specific
from google.colab import drive

print("--- Library Version Checks ---")
print(f"Pandas version: {pd.__version__}")
print(f"NumPy version: {np.__version__}")
import sklearn
print(f"Scikit-learn version: {sklearn.__version__}")
# print(f"TensorFlow Version: {tf.__version__}") # TensorFlow not used in this Unsloth/PyTorch setup
if torch.cuda.is_available():
    print(f"PyTorch version: {torch.__version__}")
    print(f"PyTorch CUDA version: {torch.version.cuda}")
    print(f"GPU available for PyTorch: {torch.cuda.get_device_name(0)}")
else:
    print("GPU not available for PyTorch, using CPU.")

# For reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

print("\nCell 1: Imports and basic setup complete.")

In [None]:
# --------------------------------------------------
# Cell 2: Configuration and Unzip Data (From user's Cell 2)
# --------------------------------------------------
drive.mount('/content/drive', force_remount=True)

# --- Configuration ---
DRIVE_CSV_PATH = "/content/drive/MyDrive/cp.csv"
DRIVE_ZIP_PATH = "/content/drive/MyDrive/1000-20250517T062750Z-1-001.zip" # Your image ZIP on Drive

LOCAL_EXTRACT_PATH = "/content/medgemma_extracted_images"
LOCAL_IMAGES_ROOT = os.path.join(LOCAL_EXTRACT_PATH, "1000") # Adjusted to match your structure
LOCAL_CSV_PATH = "/content/medgemma_cp.csv"

# --- Unzip Data (if not already done or if re-running) ---
if os.path.exists(DRIVE_CSV_PATH):
    shutil.copy(DRIVE_CSV_PATH, LOCAL_CSV_PATH)
    print(f"CSV copied to {LOCAL_CSV_PATH}")
else:
    print(f"ERROR: CSV file not found at {DRIVE_CSV_PATH}")

if os.path.exists(LOCAL_EXTRACT_PATH):
    print(f"Removing existing extraction directory: {LOCAL_EXTRACT_PATH}")
    shutil.rmtree(LOCAL_EXTRACT_PATH)
os.makedirs(LOCAL_EXTRACT_PATH, exist_ok=True)
print(f"Created local extraction directory: {LOCAL_EXTRACT_PATH}")

if os.path.exists(DRIVE_ZIP_PATH):
    print(f"Unzipping {DRIVE_ZIP_PATH} to {LOCAL_EXTRACT_PATH}...")
    with zipfile.ZipFile(DRIVE_ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(LOCAL_EXTRACT_PATH)
    print("Unzipping complete.")
    if os.path.exists(LOCAL_IMAGES_ROOT):
        print(f"Image root folder found at: {LOCAL_IMAGES_ROOT}")
    else:
        print(f"ERROR: Expected image root folder '{LOCAL_IMAGES_ROOT}' not found after unzipping. Check ZIP structure.")
        print(f"Contents of {LOCAL_EXTRACT_PATH}: {os.listdir(LOCAL_EXTRACT_PATH)}")

else:
    print(f"ERROR: ZIP file not found at {DRIVE_ZIP_PATH}")

print("\nCell 2: Data unzipping complete.")


In [None]:
# --------------------------------------------------
# Cell 3: Load and Filter Clinical Data to create image_df (From user's Cell 3)
# --------------------------------------------------
image_df = pd.DataFrame()

if not os.path.exists(LOCAL_CSV_PATH):
    print(f"FATAL ERROR: Clinical CSV file not found at the expected local path: {LOCAL_CSV_PATH}")
else:
    df_raw_from_cell3 = pd.read_csv(LOCAL_CSV_PATH)
    print(f"Initial number of rows in clinical data (Cell 3): {len(df_raw_from_cell3)}")

    person_id_col_name_c3 = 'person_id'
    ldl_col_name_c3 = "LDL Cholesterol Calculation (mg/dL)" # Ensure this matches your CSV header

    if not (person_id_col_name_c3 in df_raw_from_cell3.columns and ldl_col_name_c3 in df_raw_from_cell3.columns):
        print(f"ERROR: Required columns ('{person_id_col_name_c3}' or '{ldl_col_name_c3}') not found in CSV.")
        print(f"Available columns: {df_raw_from_cell3.columns.tolist()}")
    else:
        df_selected_c3 = df_raw_from_cell3[[person_id_col_name_c3, ldl_col_name_c3]].copy()
        df_selected_c3.rename(columns={ldl_col_name_c3: 'LDL_temp'}, inplace=True)
        df_selected_c3['LDL_temp'] = pd.to_numeric(df_selected_c3['LDL_temp'], errors='coerce')
        df_selected_c3.dropna(subset=['LDL_temp'], inplace=True)
        df_selected_c3 = df_selected_c3[df_selected_c3['LDL_temp'] > 0].copy()
        df_selected_c3[person_id_col_name_c3] = df_selected_c3[person_id_col_name_c3].astype(str)
        print(f"Cleaned clinical data (positive LDLs only): {len(df_selected_c3)} records.")

        ldl_lookup_c3 = df_selected_c3.set_index(person_id_col_name_c3)['LDL_temp'].to_dict()

        if not (os.path.exists(LOCAL_IMAGES_ROOT) and os.path.isdir(LOCAL_IMAGES_ROOT)):
            print(f"FATAL ERROR: Images root path '{LOCAL_IMAGES_ROOT}' does not exist or is not a directory.")
        else:
            available_folders_c3 = set(os.listdir(LOCAL_IMAGES_ROOT))
            valid_ids_clinical_c3 = set(ldl_lookup_c3.keys())
            common_person_ids_c3 = sorted(list(valid_ids_clinical_c3 & available_folders_c3))
            print(f"Found {len(common_person_ids_c3)} common person_ids for mapping.")

            image_records_list = []
            for pid_c3 in common_person_ids_c3:
                folder_path_c3 = os.path.join(LOCAL_IMAGES_ROOT, pid_c3)
                ldl_val_c3 = ldl_lookup_c3[pid_c3]
                if os.path.isdir(folder_path_c3):
                    for filename_c3 in os.listdir(folder_path_c3):
                        if filename_c3.lower().endswith(".dcm"):
                            image_path_c3 = os.path.join(folder_path_c3, filename_c3)
                            image_records_list.append({
                                "person_id": pid_c3,
                                "image_path": image_path_c3,
                                "LDL": ldl_val_c3
                            })
            image_df = pd.DataFrame(image_records_list)
            if not image_df.empty:
                print(f"Final image_df created with {len(image_df)} image-LDL pairs.")
                from IPython.display import display # For better display in Colab
                display(image_df.head())
                print(f"LDL stats in final image_df: min={image_df['LDL'].min()}, max={image_df['LDL'].max()}, mean={image_df['LDL'].mean()}")
            else:
                print("WARNING: image_df is empty after mapping. Check paths, IDs, and DICOM file existence.")
print("\nCell 3: image_df preparation complete.")

In [None]:
# -----------------------------------------------------------------------------
# Cell 4: Verify image_df (Adapted from user's Cell 4)
# -----------------------------------------------------------------------------
if 'image_df' in locals() and isinstance(image_df, pd.DataFrame) and not image_df.empty:
    print(f"\nContinuing with 'image_df' which has {len(image_df)} records.")
    print("Columns in image_df:", image_df.columns.tolist())
    from IPython.display import display # Ensure display is imported
    print("Sample of image_df:")
    display(image_df.head())

    required_cols = ['person_id', 'image_path', 'LDL']
    if not all(col in image_df.columns for col in required_cols):
        print(f"ERROR: 'image_df' is missing one or more required columns: {required_cols}. Please re-run Cell 3.")
    elif image_df['LDL'].min() <= 0:
        print(f"ERROR: 'image_df' still contains non-positive LDL values. LDL min: {image_df['LDL'].min()}. Please re-run filtering in Cell 3.")
    else:
        print("'image_df' seems okay to proceed.")
else:
    print("ERROR: 'image_df' not found or is empty. Please ensure Cell 3 (data preparation) has been run successfully.")
    # To prevent later errors, create an empty df if it's missing, though this indicates a problem.
    if 'image_df' not in locals() or not isinstance(image_df, pd.DataFrame):
        image_df = pd.DataFrame(columns=['person_id', 'image_path', 'LDL'])


print(f"\nUsing Unsloth loaded model: {selected_model_name}") # From Cell 0.3
print("\nCell 4: image_df verification and Model ID check complete.")

In [None]:
# -----------------------------------------------------------------------------
# Cell 5: Unsloth Tokenizer/Processor Info (Adapted from user's Cell 5)
# -----------------------------------------------------------------------------
# The `medgemma_processor` is now replaced by the `tokenizer` from Unsloth.
# For vision models, this tokenizer might wrap an image processor,
# or `model.processor` might be set by Unsloth.

# This cell's original purpose was to find TARGET_SIZE_MEDGEMMA.
# For MedGemma, the image processor (part of the 'tokenizer' object) handles resizing.
# We can inspect the image_processor's configuration.

print("\n--- Inspecting MedGemma Image Processor ---")
TARGET_SIZE_FOR_IMAGES = None # Will be determined by the image_processor

if 'tokenizer' in locals() and hasattr(tokenizer, 'image_processor') and tokenizer.image_processor is not None:
    medgemma_image_processor = tokenizer.image_processor
    print(f"MedGemma Image Processor Type: {type(medgemma_image_processor)}")

    # The image processor config usually has 'size' information.
    # For SigLIPImageProcessor (used by MedGemma), it's often under `size` directly.
    # The 'size' attribute can be an int (for shortest_edge) or a dict {'height': H, 'width': W}.
    if hasattr(medgemma_image_processor, 'size'):
        size_info = medgemma_image_processor.size
        print(f"  Image processor 'size' attribute: {size_info}")
        if isinstance(size_info, int): # e.g., size=224 means shortest edge is 224
            # MedGemma models often use square inputs, e.g., 224x224 for SigLIP-B, 384x384 for SigLIP-L
            # The MedGemma paper mentions images are resized to 896×896 for their experiments.
            # However, the underlying SigLIP processor might have its own default.
            # Let's check if 'crop_size' is also available, which is often the final input size.
            if hasattr(medgemma_image_processor, 'crop_size') and medgemma_image_processor.crop_size is not None:
                crop_info = medgemma_image_processor.crop_size
                if isinstance(crop_info, int):
                    TARGET_SIZE_FOR_IMAGES = (crop_info, crop_info)
                elif isinstance(crop_info, dict) and 'height' in crop_info and 'width' in crop_info:
                    TARGET_SIZE_FOR_IMAGES = (crop_info['height'], crop_info['width'])
                print(f"  Using 'crop_size' for TARGET_SIZE_FOR_IMAGES: {TARGET_SIZE_FOR_IMAGES}")

            if TARGET_SIZE_FOR_IMAGES is None: # If crop_size wasn't definitive
                 # If size is int, assume square image based on that size for processing.
                 # The processor itself will handle the exact resizing logic.
                 # We use this for our basic transforms if the processor fails.
                 TARGET_SIZE_FOR_IMAGES = (size_info, size_info)
                 print(f"  Using 'size' attribute for TARGET_SIZE_FOR_IMAGES (assuming square): {TARGET_SIZE_FOR_IMAGES}")

        elif isinstance(size_info, dict) and 'height' in size_info and 'width' in size_info:
            TARGET_SIZE_FOR_IMAGES = (size_info['height'], size_info['width'])
            print(f"  Using 'size' dict for TARGET_SIZE_FOR_IMAGES: {TARGET_SIZE_FOR_IMAGES}")
        else:
            print("  Could not determine target size from image_processor.size. Check processor config.")
    else:
        print("  Image processor does not have a direct 'size' attribute. Check its config details.")

    # Fallback if still not found, to MedGemma paper's mentioned size
    if TARGET_SIZE_FOR_IMAGES is None:
        TARGET_SIZE_FOR_IMAGES = (896, 896) # Default from MedGemma paper if not found in processor
        print(f"  Falling back to default TARGET_SIZE_FOR_IMAGES: {TARGET_SIZE_FOR_IMAGES} (from MedGemma paper)")
else:
    print("ERROR: MedGemma image_processor not found in tokenizer. Cannot determine target image size.")
    TARGET_SIZE_FOR_IMAGES = (896, 896) # Fallback
    print(f"  Using fallback TARGET_SIZE_FOR_IMAGES: {TARGET_SIZE_FOR_IMAGES}")

print(f"Final TARGET_SIZE_FOR_IMAGES to be used by Dataset (if processor fails or for reference): {TARGET_SIZE_FOR_IMAGES}")
print("\nCell 5: MedGemma image processor check complete.")

In [None]:
# -----------------------------------------------------------------------------
# Cell 6: Data Splitting (Patient-Level) and LDL Normalization (From user's Cell 6)
# -----------------------------------------------------------------------------
train_df, val_df, test_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
ldl_scaler = None # Will store the fitted StandardScaler

if 'image_df' in locals() and not image_df.empty:
    print(f"\nStarting data splitting for {len(image_df)} image-LDL pairs...")
    if 'person_id' not in image_df.columns:
        print("ERROR: 'person_id' column missing in image_df. Cannot perform patient-level split. Please check image_df preparation in Cell 3.")
    else:
        unique_person_ids = image_df['person_id'].unique()
        print(f"Total unique patients for splitting: {len(unique_person_ids)}")

        if len(unique_person_ids) < 3: # Need at least 3 patients for train/val/test
            print("Warning: Not enough unique patients for a robust 3-way (train/validation/test) split.")
            # Simplified split logic for few patients
            if len(unique_person_ids) == 2:
                train_pids, val_pids = train_test_split(unique_person_ids, test_size=0.5, random_state=RANDOM_SEED)
                test_pids = np.array([]) # Empty array for consistency
            elif len(unique_person_ids) == 1:
                train_pids = unique_person_ids
                val_pids, test_pids = np.array([]), np.array([])
            else: # 0 patients
                train_pids, val_pids, test_pids = np.array([]), np.array([]), np.array([])
        else:
            # Standard 70% train, 15% validation, 15% test split of person_ids
            train_pids, temp_pids = train_test_split(
                unique_person_ids, test_size=0.30, random_state=RANDOM_SEED # 70% train, 30% temp
            )
            if len(temp_pids) > 1 : # Ensure there's at least 2 for val/test split
                 val_pids, test_pids = train_test_split(
                    temp_pids, test_size=0.50, random_state=RANDOM_SEED # Split temp 50/50 for val/test (15% each of total)
                )
            elif len(temp_pids) == 1: # Only one patient left for temp
                val_pids = temp_pids # Assign to validation
                test_pids = np.array([])
            else: # No patients left for temp
                val_pids, test_pids = np.array([]), np.array([])


        train_df = image_df[image_df['person_id'].isin(train_pids)].copy()
        val_df = image_df[image_df['person_id'].isin(val_pids)].copy()
        test_df = image_df[image_df['person_id'].isin(test_pids)].copy()

        print(f"Train set: {len(train_df)} samples from {len(train_pids)} patients.")
        print(f"Validation set: {len(val_df)} samples from {len(val_pids)} patients.")
        print(f"Test set: {len(test_df)} samples from {len(test_pids)} patients.")

        # Sanity check for patient overlap
        if len(train_pids)>0 and len(val_pids)>0: assert len(set(train_pids) & set(val_pids)) == 0, "Patient overlap train/val!"
        if len(train_pids)>0 and len(test_pids)>0: assert len(set(train_pids) & set(test_pids)) == 0, "Patient overlap train/test!"
        if len(val_pids)>0 and len(test_pids)>0: assert len(set(val_pids) & set(test_pids)) == 0, "Patient overlap val/test!"
        print("Patient-level splits verified (no overlap if sets are non-empty).")

        # --- LDL Value Normalization ---
        if not train_df.empty and 'LDL' in train_df.columns:
            print("\nNormalizing LDL values using StandardScaler...")
            ldl_scaler = StandardScaler()
            # Fit the scaler ONLY on the training data's LDL values
            train_df['LDL_scaled'] = ldl_scaler.fit_transform(train_df[['LDL']])

            # Transform validation and test data using the FITTED scaler
            if not val_df.empty:
                val_df['LDL_scaled'] = ldl_scaler.transform(val_df[['LDL']])
            else: # Add LDL_scaled column even if empty, for consistency
                val_df['LDL_scaled'] = pd.Series(dtype='float64')

            if not test_df.empty:
                test_df['LDL_scaled'] = ldl_scaler.transform(test_df[['LDL']])
            else:
                test_df['LDL_scaled'] = pd.Series(dtype='float64')

            print("LDL normalization complete.")
            print("Scaled LDL stats in train_df (should be mean~0, std~1):")
            from IPython.display import display # Ensure display is imported
            display(train_df['LDL_scaled'].describe())

            # Optional: Save the scaler
            # import joblib
            # scaler_filename = 'ldl_scaler_medgemma.joblib'
            # joblib.dump(ldl_scaler, scaler_filename)
            # print(f"LDL scaler saved to {scaler_filename}")
        else:
            print("Train DataFrame is empty or 'LDL' column missing. Skipping LDL normalization.")
else:
    print("image_df is empty (from Cell 3). Skipping data splitting and LDL normalization.")

print("\nCell 6: Data splitting and LDL normalization attempt complete.")

In [None]:
# -----------------------------------------------------------------------------
# Cell 5.1 (from user, now Cell 6.1): Check Unsloth tokenizer/model.processor
# -----------------------------------------------------------------------------
print("\n--- Sanity Check for Unsloth Components (Cell 6.1) ---")
if 'tokenizer' in locals() and tokenizer is not None:
    print(f"Unsloth tokenizer IS LOADED. Type: {type(tokenizer)}")
    if hasattr(tokenizer, 'image_processor') and tokenizer.image_processor is not None:
        print(f"  It has a tokenizer.image_processor of type: {type(tokenizer.image_processor)}")
    else:
        print("  It does NOT have a direct `tokenizer.image_processor` attribute (or it's None).")

    if hasattr(model, 'processor') and model.processor is not None:
        print(f"Unsloth model.processor IS LOADED. Type: {type(model.processor)}")
        if hasattr(model.processor, 'image_processor') and model.processor.image_processor is not None:
             print(f"  model.processor has an image_processor component of type: {type(model.processor.image_processor)}")
    else:
        print("  The model does NOT have a `model.processor` attribute (or it's None).")

    if not (hasattr(tokenizer, 'image_processor') and tokenizer.image_processor is not None) and \
       not (hasattr(model, 'processor') and model.processor is not None and hasattr(model.processor, 'image_processor')):
        print(f"  WARNING: No obvious image processor found. The model '{selected_model_name}' may be text-only.")
        print("  If your task requires image input, ensure you've selected a vision-language model and that Unsloth loads its image processor correctly.")
else:
    print("Unsloth tokenizer IS NOT LOADED or is None.")

In [None]:
# -----------------------------------------------------------------------------
# Cell 7: Custom PyTorch Dataset for DICOM Images and LDL (New Cell)
# -----------------------------------------------------------------------------
# Helper for printing messages only once during dataset iteration or training
printed_messages_dataset = set()
def print_once_dataset(message):
    global printed_messages_dataset
    if message not in printed_messages_dataset:
        print(message)
        printed_messages_dataset.add(message)

import torchvision.transforms as T # Import T for transforms

class MedGemmaVisionDataset(Dataset):
    def __init__(self, dataframe, medgemma_tokenizer_processor, target_img_size_ref=(896, 896)):
        """
        Args:
            dataframe (pd.DataFrame): DataFrame with 'image_path' and 'LDL_scaled' columns.
            medgemma_tokenizer_processor: The multimodal processor from Unsloth (contains image_processor).
            target_img_size_ref (tuple): Reference target image size, primarily for fallback.
                                         The image_processor itself determines the actual processing.
        """
        self.dataframe = dataframe
        self.processor = medgemma_tokenizer_processor # This is the GemmaProcessor (or similar)
        self.target_size_ref = target_img_size_ref # For fallback basic transforms

        if not hasattr(self.processor, 'image_processor') or self.processor.image_processor is None:
            raise ValueError("The provided processor must have a valid 'image_processor' attribute for MedGemma.")

        # Basic image transforms (fallback if image_processor fails for an image)
        self.basic_transforms = T.Compose([
            T.Resize(self.target_size_ref),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Imagenet stats
        ])


    def __len__(self):
        return len(self.dataframe)

    def load_and_preprocess_dicom(self, dicom_path):
        try:
            dicom_file = pydicom.dcmread(dicom_path)
            pixel_array = dicom_file.pixel_array

            # Normalize pixel data to 0-255 and ensure 3 channels (RGB)
            # This is a common pre-step before PIL conversion for many image processors
            if pixel_array.dtype != np.uint8:
                pixel_array = pixel_array.astype(np.float32)
                min_val, max_val = np.min(pixel_array), np.max(pixel_array)
                if max_val > min_val:
                    pixel_array = (pixel_array - min_val) / (max_val - min_val) * 255.0
                else: # Handle case where all pixels are the same
                    pixel_array = np.zeros_like(pixel_array)
                pixel_array = pixel_array.astype(np.uint8)

            if pixel_array.ndim == 2: # Grayscale
                pil_image = Image.fromarray(pixel_array).convert('RGB')
            elif pixel_array.ndim == 3 and pixel_array.shape[-1] == 1: # Grayscale with channel dim
                pil_image = Image.fromarray(pixel_array.squeeze(-1)).convert('RGB')
            elif pixel_array.ndim == 3 and pixel_array.shape[-1] == 3: # RGB
                pil_image = Image.fromarray(pixel_array)
            elif pixel_array.ndim == 3 and pixel_array.shape[-1] == 4: # RGBA
                pil_image = Image.fromarray(pixel_array).convert('RGB')
            else:
                print_once_dataset(f"Warning: Unsupported DICOM pixel array shape {pixel_array.shape} for {dicom_path}. Trying to convert.")
                # Attempt to make it a 2D grayscale image if possible
                if pixel_array.ndim > 2 : pixel_array = pixel_array[...,0] # take first channel or slice
                if pixel_array.ndim > 2 : pixel_array = pixel_array[0] # take first frame
                pil_image = Image.fromarray(pixel_array.astype(np.uint8)).convert('RGB')


            # Use MedGemma's image_processor
            # It expects a PIL Image or list of PIL Images.
            # It handles resizing, normalization, and tensor conversion according to MedGemma's needs.
            processed_output = self.processor.image_processor(images=pil_image, return_tensors="pt")
            pixel_values = processed_output.pixel_values.squeeze(0) # Remove batch dim
            return pixel_values

        except Exception as e:
            print_once_dataset(f"Error processing DICOM {dicom_path} with image_processor: {e}. Applying basic fallback.")
            # Fallback: create a dummy black image if processing fails
            try:
                # Try to load with PIL directly for basic transform
                pil_image_fallback = Image.open(dicom_path).convert("RGB") # This might fail for some DICOMs
                return self.basic_transforms(pil_image_fallback)
            except Exception as e_fallback:
                print_once_dataset(f"Fallback PIL loading also failed for {dicom_path}: {e_fallback}. Returning zero tensor.")
                return torch.zeros((3, self.target_size_ref[0], self.target_size_ref[1]))


    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image_path = row['image_path']
        ldl_scaled = row['LDL_scaled'] # Target variable

        pixel_values = self.load_and_preprocess_dicom(image_path)
        target_ldl_scaled = torch.tensor(ldl_scaled, dtype=torch.float32)

        return {
            "pixel_values": pixel_values,
            "labels": target_ldl_scaled.unsqueeze(0) # Ensure target is (1,) for MSELoss
        }

# --- Create Datasets ---
# `tokenizer` from Cell 0.3 is MedGemma's processor
# `TARGET_SIZE_FOR_IMAGES` from Cell 5 is a reference
if 'train_df' in locals() and not train_df.empty and 'tokenizer' in locals() and tokenizer is not None:
    train_dataset = MedGemmaVisionDataset(train_df, tokenizer, TARGET_SIZE_FOR_IMAGES)
    print(f"Train dataset created with {len(train_dataset)} samples.")
else:
    train_dataset = None
    print("Could not create train_dataset. Check train_df and tokenizer.")

if 'val_df' in locals() and not val_df.empty and 'tokenizer' in locals() and tokenizer is not None:
    val_dataset = MedGemmaVisionDataset(val_df, tokenizer, TARGET_SIZE_FOR_IMAGES)
    print(f"Validation dataset created with {len(val_dataset)} samples.")
else:
    val_dataset = None
    print("Could not create val_dataset. Check val_df and tokenizer.")

# Example: Fetch one item to test
if train_dataset:
    print("\nSample from train_dataset:")
    try:
        sample = train_dataset[0]
        for key, val in sample.items():
            print(f"  {key}: shape {val.shape}, dtype {val.dtype}")
    except Exception as e:
        print(f"Error fetching sample from train_dataset: {e}")
        print("This might indicate issues with DICOM loading or processing in your dataset.")

print("\nCell 7: MedGemmaVisionDataset class defined and datasets instantiated.")

In [None]:
# Standard PyTorch collate_fn should work if items are already tensors.
def vision_collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    return {"pixel_values": pixel_values, "labels": labels}

BATCH_SIZE = 8 # Adjust based on GPU memory (e.g., 4, 8, 16)

if train_dataset:
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=vision_collate_fn, # Use custom collate
        num_workers=2, # Use multiple workers for faster data loading if not on Windows/debugging
        pin_memory=True if torch.cuda.is_available() else False
    )
    print(f"\nTrain DataLoader created. Batches per epoch: {len(train_loader)}")
else:
    train_loader = None

if val_dataset:
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False, # No need to shuffle validation data
        collate_fn=vision_collate_fn,
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    print(f"Validation DataLoader created. Batches per epoch: {len(val_loader)}")
else:
    val_loader = None

# Test one batch from train_loader
if train_loader:
    print("\nSample batch from train_loader:")
    try:
        batch_sample = next(iter(train_loader))
        for key, val in batch_sample.items():
            print(f"  {key}: shape {val.shape}, dtype {val.dtype}")
    except Exception as e:
        print(f"Error fetching batch from train_loader: {e}")

print("\nCell 8: DataLoaders created.")

In [None]:
# Cell 9: Training Setup (Optimizer, Loss, Learning Rate) - MODIFIED

"""
import torch.optim as optim

LEARNING_RATE = 5e-5 # Common starting point for LoRA fine-tuning. May need adjustment.
EPOCHS = 10 # Start with a moderate number, e.g., 5-20.
WEIGHT_DECAY = 0.01

if regressor_model is not None: # Ensure the model was created in Cell 0.5
    optimizer = optim.AdamW(regressor_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    criterion = nn.MSELoss()

    # --- STRATEGY 1: Attempt to cast the entire regressor_model to bfloat16 ---
    # This assumes `regressor_model` is already on the correct `device` (e.g., 'cuda')
    # And `model_dtype` (e.g. torch.bfloat16) should be what Unsloth set for the base model.

    # Get the dtype from the Unsloth-loaded base model component within regressor_model
    # This is the most reliable source for the target dtype.
    if hasattr(regressor_model, 'medgemma_model') and hasattr(regressor_model.medgemma_model, 'dtype'):
        target_model_dtype = regressor_model.medgemma_model.dtype
        print(f"\nTarget dtype for model components (from base Unsloth model): {target_model_dtype}")

        if target_model_dtype == torch.bfloat16:
            print(f"Attempting to cast entire regressor_model and its submodules to {target_model_dtype} (Strategy 1)...")
            try:
                # This will attempt to cast all parameters and buffers.
                regressor_model = regressor_model.to(dtype=target_model_dtype)
                print("Casting of entire regressor_model to bfloat16 attempted.")

                # Optional: Verification - Check dtypes of some parameters
                # print("Verifying some parameter dtypes after full model cast:")
                # for name, param in regressor_model.named_parameters():
                #     if "lora" in name.lower() or "regression_head" in name.lower() or "bias" in name.lower(): # Check some key ones
                #         if param.numel() > 0: # Only print if param is not empty
                #             print(f"  Param: {name[:60]}..., Dtype: {param.dtype}, Device: {param.device}")
                #         break # Just check a few to avoid too much output
            except Exception as e_cast_full:
                print(f"ERROR during full regressor_model.to(dtype={target_model_dtype}): {e_cast_full}")
                print("Full model cast failed. Proceeding without it, relying on input tensor casting in training loop.")
        else:
            print(f"Base model dtype is {target_model_dtype}, not bfloat16. Skipping full model bfloat16 cast strategy.")
    else:
        print("\nCould not reliably determine target_model_dtype from regressor_model.medgemma_model.dtype.")
        print("Skipping full model cast strategy. Will rely on input tensor casting in training loop.")
    # --- END OF STRATEGY 1 ---

    print(f"\nOptimizer: AdamW, LR: {LEARNING_RATE}, Weight Decay: {WEIGHT_DECAY}")
    print(f"Loss Function: MSELoss")
    print(f"Training for {EPOCHS} epochs.")
else:
    print("CRITICAL ERROR: regressor_model is None (was not created in Cell 0.5). Cannot set up optimizer and loss.")
    optimizer = None
    criterion = None
"""

# Cell 9: Training Setup (Optimizer, Loss, Learning Rate) - REVISED FOR FLOAT16

import torch.optim as optim

LEARNING_RATE = 5e-5
EPOCHS = 10
WEIGHT_DECAY = 0.01

if regressor_model is not None:
    optimizer = optim.AdamW(regressor_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    criterion = nn.MSELoss()

    if hasattr(regressor_model, 'medgemma_model') and hasattr(regressor_model.medgemma_model, 'dtype'):
        target_dtype_for_components = regressor_model.medgemma_model.dtype # Should be float16
        print(f"\nTarget dtype for all model components (from base Unsloth model): {target_dtype_for_components}")

        if target_dtype_for_components == torch.float16: # Check for float16
            print(f"Attempting to ensure all components of regressor_model are on {target_dtype_for_components}...")
            try:
                regressor_model = regressor_model.to(dtype=target_dtype_for_components)
                print(f"Casting of entire regressor_model to {target_dtype_for_components} completed.")
                # Verification
                # print("Verifying select parameter dtypes after full model cast:")
                # for name, param in regressor_model.named_parameters():
                #     if param.requires_grad and ("lora" in name.lower() or "regression_head" in name.lower()):
                #         if param.numel() > 0:
                #              print(f"  Trainable Param: {name[:70]}..., Dtype: {param.dtype}, Device: {param.device}")
            except Exception as e_cast_all:
                print(f"ERROR during full regressor_model.to(dtype={target_dtype_for_components}): {e_cast_all}")
        else:
            print(f"Base model dtype is {target_dtype_for_components}, not float16. Current strategy might need adjustment.")
    else:
        print("\nCould not reliably determine target_dtype for component casting.")

    print(f"\nOptimizer: AdamW, LR: {LEARNING_RATE}, Weight Decay: {WEIGHT_DECAY}")
    print(f"Loss Function: MSELoss")
    print(f"Training for {EPOCHS} epochs.")
else:
    print("CRITICAL ERROR: regressor_model is None. Cannot set up optimizer and loss.")
    optimizer = None; criterion = None

In [None]:
# Cell 10: Training and Evaluation Loop - MODIFIED
"""
train_losses = []
val_losses = []
best_val_loss = float('inf')
patience_counter = 0
PATIENCE_EPOCHS = 3 # For early stopping if validation loss doesn't improve

# Helper for printing messages only once during training loop
printed_messages_train_loop = set() # Use a different name to avoid conflict if re-running cells
def print_once_train_loop(message):
    global printed_messages_train_loop
    if message not in printed_messages_train_loop:
        print(message)
        printed_messages_train_loop.add(message)

if regressor_model is not None and train_loader is not None and val_loader is not None and optimizer is not None and criterion is not None:
    print(f"\nStarting training on device: {device}...") # device was set in Cell 0.5

    # Determine the model's expected input dtype (should be bfloat16 if Unsloth set it)
    # This comes from the base Unsloth model component
    if hasattr(regressor_model, 'medgemma_model') and hasattr(regressor_model.medgemma_model, 'dtype'):
        model_input_dtype = regressor_model.medgemma_model.dtype
    else:
        # Fallback if attribute not found, assume bfloat16 based on previous errors
        print_once_train_loop("Warning: Could not directly get model_input_dtype from regressor_model.medgemma_model.dtype. Assuming torch.bfloat16.")
        model_input_dtype = torch.bfloat16

    print(f"Model's expected input dtype for pixel_values: {model_input_dtype}")

    for epoch in range(EPOCHS):
        regressor_model.train()
        running_train_loss = 0.0
        processed_batches_train = 0

        for i, batch in enumerate(train_loader):
            # Data from DataLoader is typically float32
            pixel_values_f32 = batch['pixel_values'].to(device)
            labels_f32 = batch['labels'].to(device) # Labels for MSELoss are typically Float32

            # Explicitly cast pixel_values to the model's expected input dtype
            pixel_values_casted = pixel_values_f32.to(model_input_dtype)

            optimizer.zero_grad()

            try:
                # Forward pass with casted input
                predictions = regressor_model(pixel_values_casted)

                # Predictions will likely be in model_input_dtype (e.g., bfloat16).
                # MSELoss can often handle mixed precision (e.g., bfloat16 pred, float32 label).
                # If criterion errors on dtype, cast predictions: loss = criterion(predictions.to(torch.float32), labels_f32)
                loss = criterion(predictions, labels_f32)

                loss.backward()
                optimizer.step()

                running_train_loss += loss.item()
                processed_batches_train += 1

                if (i + 1) % 20 == 0 or (i + 1) == len(train_loader):
                    print(f"Epoch [{epoch+1}/{EPOCHS}], Batch [{i+1}/{len(train_loader)}], Train Loss: {loss.item():.4f}")

            except Exception as e:
                print_once_train_loop(f"ERROR during training forward/backward pass at batch {i}: {e}")
                if "expected scalar type" in str(e).lower():
                    print_once_train_loop(f"  Input pixel_values_casted dtype: {pixel_values_casted.dtype}")
                    # If predictions object exists before error:
                    if 'predictions' in locals() and isinstance(predictions, torch.Tensor):
                         print_once_train_loop(f"  Predictions (if formed) dtype: {predictions.dtype}")
                # To get more details on where exactly the error occurs inside the model:
                # import traceback
                # print_once_train_loop(traceback.format_exc())
                continue # Skip this batch and try the next one

        epoch_train_loss = running_train_loss / processed_batches_train if processed_batches_train > 0 else 0.0
        train_losses.append(epoch_train_loss)
        print(f"Epoch [{epoch+1}/{EPOCHS}] - Average Training Loss: {epoch_train_loss:.4f}")

        # Validation phase
        regressor_model.eval()
        running_val_loss = 0.0
        processed_batches_val = 0
        with torch.no_grad():
            for batch_val in val_loader:
                pixel_values_f32_val = batch_val['pixel_values'].to(device)
                labels_f32_val = batch_val['labels'].to(device)

                pixel_values_casted_val = pixel_values_f32_val.to(model_input_dtype)

                try:
                    predictions_val = regressor_model(pixel_values_casted_val)
                    loss_val = criterion(predictions_val, labels_f32_val)
                    running_val_loss += loss_val.item()
                    processed_batches_val +=1
                except Exception as e_val:
                    print_once_train_loop(f"ERROR during validation forward pass: {e_val}")
                    continue

        epoch_val_loss = running_val_loss / processed_batches_val if processed_batches_val > 0 else 0.0
        val_losses.append(epoch_val_loss)
        print(f"Epoch [{epoch+1}/{EPOCHS}] - Average Validation Loss: {epoch_val_loss:.4f}")

        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            patience_counter = 0
            save_dir = "./best_model_checkpoint"
            if not os.path.exists(save_dir): os.makedirs(save_dir)

            # Save LoRA adapters from the PEFT-adapted model component
            if hasattr(regressor_model, 'medgemma_model') and hasattr(regressor_model.medgemma_model, 'save_pretrained'):
                regressor_model.medgemma_model.save_pretrained(os.path.join(save_dir, "lora_adapters"))
                print(f"Saved LoRA adapters at epoch {epoch+1}.")
            else:
                print_once_train_loop("Could not save LoRA adapters: regressor_model.medgemma_model.save_pretrained not found.")

            # Save the state of the regression head
            if hasattr(regressor_model, 'regression_head'):
                torch.save(regressor_model.regression_head.state_dict(), os.path.join(save_dir, "regression_head.pth"))
                print(f"Saved regression head state at epoch {epoch+1}.")
            else:
                print_once_train_loop("Could not save regression head: regressor_model.regression_head not found.")
            print(f"Validation loss improved to {best_val_loss:.4f}. Saved best model components.")
        else:
            patience_counter += 1

        print("-" * 30)
        if patience_counter >= PATIENCE_EPOCHS:
            print(f"Early stopping triggered after {PATIENCE_EPOCHS} epochs without improvement on validation loss.")
            break
    print("Training complete.")
else:
    print("Cannot start training. One or more critical components (model, dataloaders, optimizer, criterion) are missing.")

# Plotting training and validation loss
if train_losses and val_losses:
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss Over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss (MSE)')
    plt.legend()
    plt.grid(True)
    plt.show()
"""
import torch.amp

train_losses = []
val_losses = []
best_val_loss = float('inf')
patience_counter = 0
PATIENCE_EPOCHS = 3

printed_messages_train_loop = set()
def print_once_train_loop(message):
    global printed_messages_train_loop
    if message not in printed_messages_train_loop:
        print(message)
        printed_messages_train_loop.add(message)

if regressor_model is not None and train_loader is not None and val_loader is not None and optimizer is not None and criterion is not None:
    print(f"\nStarting training on device: {device}...")

    # Model activation dtype should now be float16
    if hasattr(regressor_model, 'medgemma_model') and hasattr(regressor_model.medgemma_model, 'dtype'):
        model_activation_dtype = regressor_model.medgemma_model.dtype # Should be torch.float16
    else:
        print_once_train_loop("Warning: Could not get model_activation_dtype. Assuming torch.float16.")
        model_activation_dtype = torch.float16

    print(f"Model's activation dtype (for input pixel_values): {model_activation_dtype}")

    # Autocast will now use float16 on CUDA for T4
    autocast_dtype = torch.float16 # Explicitly use float16 for T4
    print(f"Using torch.amp.autocast with dtype: {autocast_dtype} on device type: {device.type}")

    for epoch in range(EPOCHS):
        regressor_model.train()
        running_train_loss = 0.0
        processed_batches_train = 0

        for i, batch in enumerate(train_loader):
            pixel_values_f32 = batch['pixel_values'].to(device)
            labels_f32 = batch['labels'].to(device)

            # Cast input to float16
            pixel_values_casted_for_input = pixel_values_f32.to(model_activation_dtype) # model_activation_dtype is float16

            optimizer.zero_grad()

            with torch.amp.autocast(device_type=device.type, dtype=autocast_dtype, enabled=True):
                try:
                    predictions = regressor_model(pixel_values_casted_for_input)
                    loss = criterion(predictions.to(torch.float32), labels_f32) # Keep predictions to float32 for robust loss
                except Exception as e:
                    print_once_train_loop(f"ERROR during training forward pass (inside autocast) at batch {i}: {e}")
                    # import traceback # Uncomment for full traceback
                    # print_once_train_loop(traceback.format_exc())
                    continue

            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()
            processed_batches_train += 1

            if (i + 1) % 20 == 0 or (i + 1) == len(train_loader):
                print(f"Epoch [{epoch+1}/{EPOCHS}], Batch [{i+1}/{len(train_loader)}], Train Loss: {loss.item():.4f}")

        epoch_train_loss = running_train_loss / processed_batches_train if processed_batches_train > 0 else 0.0
        train_losses.append(epoch_train_loss)
        print(f"Epoch [{epoch+1}/{EPOCHS}] - Average Training Loss: {epoch_train_loss:.4f}")

        regressor_model.eval()
        running_val_loss = 0.0
        processed_batches_val = 0
        with torch.no_grad():
            for batch_val in val_loader:
                pixel_values_f32_val = batch_val['pixel_values'].to(device)
                labels_f32_val = batch_val['labels'].to(device)
                pixel_values_casted_for_input_val = pixel_values_f32_val.to(model_activation_dtype)

                with torch.amp.autocast(device_type=device.type, dtype=autocast_dtype, enabled=True):
                    try:
                        predictions_val = regressor_model(pixel_values_casted_for_input_val)
                        loss_val = criterion(predictions_val.to(torch.float32), labels_f32_val)
                    except Exception as e_val:
                        print_once_train_loop(f"ERROR during validation forward pass (inside autocast): {e_val}")
                        continue
                running_val_loss += loss_val.item()
                processed_batches_val +=1

        epoch_val_loss = running_val_loss / processed_batches_val if processed_batches_val > 0 else 0.0
        val_losses.append(epoch_val_loss)
        print(f"Epoch [{epoch+1}/{EPOCHS}] - Average Validation Loss: {epoch_val_loss:.4f}")

        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss; patience_counter = 0
            save_dir = "./best_model_checkpoint"
            if not os.path.exists(save_dir): os.makedirs(save_dir)
            if hasattr(regressor_model, 'medgemma_model') and hasattr(regressor_model.medgemma_model, 'save_pretrained'):
                regressor_model.medgemma_model.save_pretrained(os.path.join(save_dir, "lora_adapters"))
            if hasattr(regressor_model, 'regression_head'):
                torch.save(regressor_model.regression_head.state_dict(), os.path.join(save_dir, "regression_head.pth"))
            print(f"Validation loss improved to {best_val_loss:.4f}. Saved best model components at epoch {epoch+1}.")
        else:
            patience_counter += 1
        print("-" * 30)
        if patience_counter >= PATIENCE_EPOCHS:
            print(f"Early stopping triggered.")
            break
    print("Training complete.")
else:
    print("Cannot start training. Critical components missing.")

if train_losses and val_losses: # Plotting
    plt.figure(figsize=(10,5)); plt.plot(train_losses, label='Training Loss'); plt.plot(val_losses, label='Validation Loss')
    plt.title('Loss Over Epochs'); plt.xlabel('Epochs'); plt.ylabel('Loss (MSE)'); plt.legend(); plt.grid(True); plt.show()

In [None]:
# The best model was saved during training. Here's how you might save the *final* model
# if you didn't use early stopping or want the model from the last epoch.

final_model_save_path = "./final_model_checkpoint"
if regressor_model is not None and os.path.exists("./best_model_checkpoint"): # Check if best model was saved
    print(f"\nBest model was saved during training to ./best_model_checkpoint")
    print("To use the best model, load from './best_model_checkpoint/lora_adapters' and './best_model_checkpoint/regression_head.pth'")
elif regressor_model is not None: # Save final model if no best model path exists (e.g. early stopping not triggered or not implemented fully)
    if not os.path.exists(final_model_save_path): os.makedirs(final_model_save_path)
    print(f"\nSaving final model to {final_model_save_path}...")
    # Save LoRA adapters of the base MedGemma model
    regressor_model.medgemma_model.save_pretrained(os.path.join(final_model_save_path, "lora_adapters"))
    # Save the state of the regression head
    torch.save(regressor_model.regression_head.state_dict(), os.path.join(final_model_save_path, "regression_head.pth"))
    print(f"Final LoRA adapters saved to {os.path.join(final_model_save_path, 'lora_adapters')}")
    print(f"Final regression head state saved to {os.path.join(final_model_save_path, 'regression_head.pth')}")
else:
    print("\nNo model to save or best model already indicated.")


# --- How to load the saved (best or final) model for inference ---
# This demonstrates loading the components back.

# 1. Define the path to your saved components (e.g., best model)
saved_lora_path = "./best_model_checkpoint/lora_adapters" # Or final_model_save_path + "/lora_adapters"
saved_head_path = "./best_model_checkpoint/regression_head.pth" # Or final_model_save_path + "/regression_head.pth"

if os.path.exists(saved_lora_path) and os.path.exists(saved_head_path) and vision_feature_dim is not None:
    print(f"\n--- Example: Loading saved model components from {saved_lora_path} and {saved_head_path} ---")
    # A. Load the base MedGemma model (without PEFT initially, or it will try to load adapters from original HF name)
    #    It's often cleaner to load the base and then apply PEFT adapters.
    #    However, Unsloth's `from_pretrained` on a PEFT saved path should work.

    print(f"Loading base MedGemma model ({selected_model_name}) and then applying saved LoRA adapters from {saved_lora_path}...")

    loaded_base_model, loaded_tokenizer = FastLanguageModel.from_pretrained(
        model_name=selected_model_name, # Start with the original base model name
        max_seq_length=2048,
        dtype=None,
        load_in_4bit=True,
        # token = "hf_..."
    )

    # Now, apply the saved LoRA adapters
    # Important: The `PeftModel.from_pretrained` expects the *base model* and the path to adapters.
    from peft import PeftModel
    loaded_peft_medgemma_model = PeftModel.from_pretrained(loaded_base_model, saved_lora_path)
    print("PEFT MedGemma model with saved LoRA adapters loaded.")

    # B. Instantiate your RegressorModel wrapper with the loaded PEFT MedGemma
    loaded_regressor_model = MedGemmaVisionRegressor(loaded_peft_medgemma_model, vision_feature_dim)

    # C. Load the state_dict for the regression head
    loaded_regressor_model.regression_head.load_state_dict(torch.load(saved_head_path, map_location=device))
    print("Regression head state loaded.")

    loaded_regressor_model.to(device)
    loaded_regressor_model.eval() # Set to evaluation mode
    print("Complete RegressorModel loaded and ready for inference.")

    # Example inference (requires a sample from val_loader or test_loader)
    if val_loader:
        try:
            sample_batch_inference = next(iter(val_loader))
            pixel_values_inf = sample_batch_inference['pixel_values'].to(device)
            labels_inf = sample_batch_inference['labels'].to(device)
            with torch.no_grad():
                predictions_inf = loaded_regressor_model(pixel_values_inf)
            print(f"\nSample inference output shape: {predictions_inf.shape}")
            # You would then unscale predictions using ldl_scaler.inverse_transform()
            if ldl_scaler:
                 predicted_ldl_original_scale = ldl_scaler.inverse_transform(predictions_inf.cpu().numpy())
                 actual_ldl_original_scale = ldl_scaler.inverse_transform(labels_inf.cpu().numpy())
                 print(f"Sample predictions (original scale): {predicted_ldl_original_scale[:5].flatten()}")
                 print(f"Sample actuals (original scale):    {actual_ldl_original_scale[:5].flatten()}")

        except Exception as e:
            print(f"Error during sample inference with loaded model: {e}")
else:
    print("\nSkipping demonstration of loading model as saved paths or vision_feature_dim not found.")


print("\nCell 11: Model saving and loading example complete.")
print("\n--- End of Script ---")