In [1]:
import pandas as pd
import os

### 📦 Cluster similar classes into common one

In [2]:
diagnosis_map = {
    # Benign (0)
    "nevus": 0, "solar lentigo": 0, "dermatofibroma": 0, "vascular lesion": 0, "nev": 0, "sek": 0,
    "seborrheic keratosis": 0,
    "blue nevus": 0, "congenital nevus": 0, "dermal nevus": 0, "seborrheic keratosis": 0,
    "nv": 0, "nevus": 0, "bkl": 0, "benign keratosis": 0, "df": 0, "dermatofibroma": 0,
    "vasc": 0, "vascular": 0,
    "Intradermal Nevus": 0, "common nevus":0,
    # Intermediate Benign (1)
    "atypical melanocytic proliferation": 1, "actinic keratosis": 1, "lichenoid keratosis": 1,
    "ack": 1, "akiec": 1, "atypical nevus":1,
    # Intermediate Melanoma (2)
    "melanoma (in situ)": 2, "melanoma (<0.76 mm)": 2, "lentigo maligna": 2,
    "atypical spitz tumor": 2,
    # Melanoma (3)
    "melanoma": 3, "melanoma metastasis": 3, "melanoma (>0.76 mm)": 3,
    "mel": 3, "Nodular Melanoma": 3, "melanoma (0.76 to 1.5 mm)": 3,
    "melanoma (more than 1.5 mm)": 3,
}

### 📦 From .TXT to .CSV (For PH2 Dataset)
This cell does the necessary string manupulation to adapt the .CSV format

In [3]:
import re
import csv

# Legend mappings remain the same
clinical_diag_map = {
    "0": "Common Nevus",
    "1": "Atypical Nevus",
    "2": "Melanoma"
}

asymmetry_map = {
    "0": "Fully Symmetric",
    "1": "Symetric in 1 axe",
    "2": "Fully Asymmetric"
}

feature_map = {
    "A": "Absent",
    "AT": "Atypical",
    "P": "Present",
    "T": "Typical"
}

colors_map = {
    "1": "White",
    "2": "Red",
    "3": "Light-Brown",
    "4": "Dark-Brown",
    "5": "Blue-Gray",
    "6": "Black"
}

# Read the file
with open("/kaggle/input/ph2dataset/PH2Dataset/PH2_dataset.txt", "r") as f:
    lines = f.readlines()

# Get the header line
header_line = lines[0].strip()
header_line = re.sub(r'^\|\||\|\|$', '', header_line)  # Remove leading/trailing ||

# Split by both single and double pipes
header_parts = re.split(r'\|\||\|', header_line)
header = [part.strip() for part in header_parts if part.strip()]

# Process data rows
processed_rows = []
for line in lines[1:]:
    if not line.strip() or line.startswith("||---"):
        continue

    # Clean the line
    clean_line = line.strip()
    clean_line = re.sub(r'^\|\||\|\|$', '', clean_line)  # Remove leading/trailing ||

    # Split by both single and double pipes
    parts = re.split(r'\|\||\|', clean_line)
    row_data = [part.strip() for part in parts]

    # Create a dictionary for this row with all columns
    row_dict = {}

    # Add data for each column, using empty string for missing values
    for i, field_name in enumerate(header):
        if i >= len(row_data):
            value = ""
        elif i == 2 and row_data[i]:  # Clinical Diagnosis
            value = clinical_diag_map.get(row_data[i], row_data[i])
        elif i == 3 and row_data[i]:  # Asymmetry
            value = asymmetry_map.get(row_data[i], row_data[i])
        elif i >= 4 and i <= 8 and row_data[i]:  # Features
            value = feature_map.get(row_data[i], row_data[i])
        elif i == 9 and row_data[i]:  # Colors
            value = " ".join(colors_map.get(v, v) for v in row_data[i].split())
        else:
            value = row_data[i]

        row_dict[field_name] = value

    processed_rows.append(row_dict)

# In convert-to-csv.py, modify the final section:
# Write to CSV with only first 200 rows
with open("PH2_dataset.csv", "w", newline="", encoding="utf-8") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=header)
    writer.writeheader()
    for row in processed_rows[:200]:  # Only write first 200 rows
        writer.writerow(row)

print(f"Conversion complete. CSV saved as PH2_dataset.csv with {len(processed_rows)} rows")

Conversion complete. CSV saved as PH2_dataset.csv with 220 rows


### 🛠️ Define Paths & load metadata
Sets file paths and load metadata.

In [4]:
base_dir = '/kaggle/input'
FINAL_METADATA_PATH = "/kaggle/working/unified_metadata.csv"
AUGMENTED_DATA_PATH = "/kaggle/working/augmented_data/"
AUGMENTED_METADATA_PATH = "/kaggle/working/augmented_metadata.csv"
TRAIN_METADATA_PATH = "/kaggle/working/train_df.csv"
VAL_METADATA_PATH = "/kaggle/working/validation_df.csv"
TEST_METADATA_PATH = "/kaggle/working/test_df.csv"

# Example usage for each metadata file:
pad_df = pd.read_csv('/kaggle/input/skin-cancer/metadata.csv')
darm_df = pd.read_csv('/kaggle/input/derm7pt/release_v0/meta/meta.csv')
ham_df = pd.read_csv('/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_metadata.csv')
ph2_df = pd.read_csv('/kaggle/working/PH2_dataset.csv')

ham_df.rename(columns={'image_id':'image_id', 'dx':'diagnosis'}, inplace=True)
pad_df.rename(columns={'img_id':'image_id','diagnostic':'diagnosis'}, inplace=True)
darm_df.rename(columns={'case_id':'image_id'}, inplace=True)
ph2_df.rename(columns={'Name':'image_id', 'Clinical Diagnosis': 'diagnosis'}, inplace=True)

ham_df['diagnosis_numeric'] = ham_df['diagnosis'].str.lower().map(diagnosis_map)
pad_df['diagnosis_numeric'] = pad_df['diagnosis'].str.lower().map(diagnosis_map)
darm_df['diagnosis_numeric'] = darm_df['diagnosis'].str.lower().map(diagnosis_map)
ph2_df['diagnosis_numeric'] = ph2_df['diagnosis'].str.lower().map(diagnosis_map)

ham_df['dataset_source'] = 'HAM10000'
pad_df['dataset_source'] = 'PAD-UFES-20'
darm_df['dataset_source'] = 'DERM7PT'
ph2_df['dataset_source'] = 'PH2'

### 📦 Metadata modification & merging
This cell filters out targeted columns & rows from all datasets and merging the all metadata files after preparing image path

In [None]:
import pandas as pd
import numpy as np
import os

# Merge explicitly
# unified_df = pd.concat([df_isic, df_ham, df_pad, df_derm7pt, df_ph2], ignore_index=True)
unified_df = pd.concat([ham_df, pad_df, darm_df, ph2_df], ignore_index=True)
unified_df.dropna(subset=['diagnosis_numeric'], inplace=True)
print(unified_df["dataset_source"].unique())

# Explicit image path generation
def generate_image_path(row):
    source, image_id = row['dataset_source'], row['image_id']
    # BASE_DIR = '/kaggle/input'

    # if source == 'ISIC':
    #     return f"{BASE_DIR}/all-isic-data-20240629/images/{image_id}.jpg"

    if source == 'HAM10000':
        for part in ['HAM10000_images_part_1', 'HAM10000_images_part_2']:
            path = f"{base_dir}/skin-cancer-mnist-ham10000/{part}/{image_id}.jpg"
            if os.path.exists(path):
                return path

    elif source == 'PAD-UFES-20':
        for part in [1,2,3]:
            path = f"{base_dir}/skin-cancer/imgs_part_{part}/imgs_part_{part}/{image_id}"
            if os.path.exists(path):
                return path

    elif source == 'DERM7PT':
        if pd.notnull(row['derm']):
            return f"{base_dir}/derm7pt/release_v0/images/{row['derm']}"
        elif pd.notnull(row['clinic']):
            return f"{base_dir}/derm7pt/release_v0/images/{row['clinic']}"
        else:
            return None

    elif source == 'PH2':
        return f"{base_dir}/ph2dataset/PH2Dataset/PH2_Dataset_images/{image_id}/{image_id}_Dermoscopic_Image/{image_id}.bmp"

unified_df['image_path'] = unified_df.apply(generate_image_path, axis=1)
unified_df.dropna(subset=['image_path', 'diagnosis_numeric'], inplace=True)
print(unified_df["dataset_source"].unique())
columns_to_keep = ['image_id', 'diagnosis', 'diagnosis_numeric', 'dataset_source', 'image_path']
unified_df = unified_df[columns_to_keep]
unified_df.to_csv(FINAL_METADATA_PATH, index=False)

['HAM10000' 'PAD-UFES-20' 'DERM7PT' 'PH2']


### 📦 Augmentation Definition
This cell does the augmentation & saves the augmented images

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
from sklearn.utils.class_weight import compute_class_weight

def generate_augmented_df(original_df, target_count, transform, save_dir):
    """
    Generates augmented images and returns new DataFrame with paths & labels.
    Saves images to disk in save_dir.
    """
    os.makedirs(save_dir, exist_ok=True)
    augmented_records = []
    existing_count = len(original_df)
    needed = target_count - existing_count

    print(f"Original: {existing_count}, Target: {target_count}, Augmenting: {needed}")

    augment_idx = 0
    while len(augmented_records) < needed:
        for idx, row in original_df.iterrows():
            if len(augmented_records) >= needed:
                break

            img_path = row['image_path']
            label = row['diagnosis_numeric']

            image = cv2.imread(img_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            augmented = transform(image=image)['image']

            new_filename = f"aug_{label}_{augment_idx}.jpg"
            save_path = os.path.join(save_dir, new_filename)
            aug_img_np = augmented.permute(1, 2, 0).cpu().numpy()
            aug_img_np = np.clip(aug_img_np * 255.0, 0, 255).astype(np.uint8)
            cv2.imwrite(save_path, cv2.cvtColor(aug_img_np, cv2.COLOR_RGB2BGR))

            augmented_records.append({'image_path': save_path, 'diagnosis_numeric': label, 'original_image_path': img_path})
            augment_idx += 1

    new_df = pd.concat([original_df, pd.DataFrame(augmented_records)], ignore_index=True)
    return new_df

def balance_custom_classes(df, transform, save_root):
    """
    Custom-balanced class augmentation strategy:
    Benign -> 7000 (downsample)
    Melanoma -> 5000 (augment)
    Intermediate Benign -> 3000 (augment)
    Intermediate Melanoma -> 1000 (augment)
    """
    class_targets = {
        0: 7000,  # Benign
        3: 5000,  # Melanoma
        1: 3000,  # Intermediate Benign
        2: 1000   # Intermediate Melanoma
    }

    final_df_list = []

    for cls, target_count in class_targets.items():
        class_df = df[df['diagnosis_numeric'] == cls]
        existing_count = len(class_df)

        print(f"\nClass {cls}: Existing samples = {existing_count}")

        if existing_count > target_count:
            class_df = class_df.sample(target_count, random_state=42).reset_index(drop=True)
            print(f"Downsampled to {target_count}")
            final_df_list.append(class_df)

        elif existing_count < target_count:
            save_dir = os.path.join(save_root, f"aug_class_{cls}")
            class_aug_df = generate_augmented_df(class_df, target_count, transform, save_dir)
            final_df_list.append(class_aug_df)

        else:
            final_df_list.append(class_df)

    final_balanced_df = pd.concat(final_df_list, ignore_index=True)
    return final_balanced_df

### ⚙️ Augmentation Code Execution
Executes a general part of the augmentation pipeline.

In [None]:
# Augmentation to apply to the intermediate classes
augment_pipeline = A.Compose([
    A.RandomResizedCrop((224, 224), scale=(0.8, 1.0)),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=30),
    A.RandomBrightnessContrast(p=0.2),
    A.ColorJitter(p=0.3),
    A.Normalize(),
    ToTensorV2()
])

df = pd.read_csv(FINAL_METADATA_PATH)

augmented_metadata_df = balance_custom_classes(df, transform=augment_pipeline, save_root=AUGMENTED_DATA_PATH)
augmented_metadata_df.to_csv(AUGMENTED_METADATA_PATH, index=False)

### 📦 Prompt creation function definition

In [None]:
def create_descriptive_prompt(row):
    parts = []

    # Diagnosis label mapping for readability
    diagnosis_mapping = {
        "bkl": "benign keratosis-like lesion",
        "nv": "melanocytic nevus",
        "mel": "melanoma",
        "bcc": "basal cell carcinoma",
        "akiec": "actinic keratosis",
        "vasc": "vascular lesion",
        "df": "dermatofibroma"
    }

    # General patient info
    sex = str(row.get("sex") or row.get("gender", "")).strip().lower()
    age = row.get("age") or row.get("Age")
    location = row.get("localization") or row.get("location")
    raw_diagnosis = row.get("diagnosis") or row.get("dx") or row.get("Histological Diagnosis") or row.get("Clinical Diagnosis")
    diagnosis = diagnosis_mapping.get(str(raw_diagnosis).lower(), str(raw_diagnosis).lower()) if pd.notna(raw_diagnosis) else None

    intro = []
    has_patient_info = False

    if sex and sex != 'nan':
        intro.append(f"Patient is {sex}")
        has_patient_info = True
    if pd.notna(age):
        try:
            age_val = int(float(age))
            if sex and sex != 'nan':
                intro.append(f"aged {age_val}")
            else:
                intro.append(f"Patient age is {age_val}")
            has_patient_info = True
        except:
            pass
    if pd.notna(location):
        intro.append(f"with a lesion on the {location.lower()}")
        has_patient_info = True
    # if pd.notna(diagnosis):
    #     intro.append(f"diagnosed as {diagnosis}")

    if intro:
        if has_patient_info:
            parts.append(" ".join(intro) + ".")
        else:
            diagnosis_phrase = [i for i in intro if "diagnosed as" in i]
            if diagnosis_phrase:
                parts.append(f"This image shows a lesion {diagnosis_phrase[0]}.")

    # Symptoms (PAD)
    for symptom in ["itch", "hurt", "grew", "changed", "bleed"]:
        val = str(row.get(symptom, "")).lower()
        if val in ["1", "true", "yes", "t", "y"]:
            parts.append(f"Patient reported that the lesion {symptom}s.")

    # PH2-specific structured features (grouped)
    ph2_present = []
    ph2_absent = []

    for field in ["Asymmetry", "Pigment Network", "Dots/Globules", "Streaks",
                  "Regression Areas", "Blue-Whitish Veil"]:
        val = str(row.get(field, "")).strip()
        if val:
            if val.lower() == "absent":
                ph2_absent.append(field.lower())
            else:
                ph2_present.append(f"{field.lower()} is {val.lower()}")

    if pd.notna(row.get("Colors")):
        ph2_present.append(f"colors observed include {row['Colors'].lower()}")

    if ph2_present:
        parts.append("The lesion presents the following characteristics: " + ", ".join(ph2_present) + ".")
    if ph2_absent:
        parts.append(f"Other features such as {', '.join(ph2_absent)} are absent.")

    # DARM features
    darm_present = []
    darm_absent = []

    for field in ["pigment_network", "streaks", "pigmentation", "regression_structures",
                  "dots_and_globules", "blue_whitish_veil", "vascular_structures"]:
        if field in row and pd.notna(row[field]):
            val = str(row[field]).strip().lower()
            name = field.replace('_', ' ')
            if val == "absent":
                darm_absent.append(name)
            else:
                darm_present.append(f"{val} {name}")

    if darm_present:
        parts.append(f"Dermoscopic features include {', '.join(darm_present)}.")
    if darm_absent:
        parts.append(f"Other features such as {', '.join(darm_absent)} are absent.")

    return " ".join(parts)

def generate_text_prompts(df):
    df = df.copy()
    df['text_prompt'] = df.apply(create_descriptive_prompt, axis=1)
    return df

### 📦 Image path creation function definition

In [None]:
def generate_image_path(row):
    source, image_id = row['dataset_source'], row['image_id']

    if source == 'HAM10000':
        for part in ['HAM10000_images_part_1', 'HAM10000_images_part_2']:
            path = f"{base_dir}/skin-cancer-mnist-ham10000/{part}/{image_id}.jpg"
            if os.path.exists(path):
                return path

    elif source == 'PAD-UFES-20':
        for part in [1,2,3]:
            path = f"{base_dir}/skin-cancer/imgs_part_{part}/imgs_part_{part}/{image_id}"
            if os.path.exists(path):
                return path

    elif source == 'DERM7PT':
        if pd.notnull(row['derm']):
            return f"{base_dir}/derm7pt/release_v0/images/{row['derm']}"
        elif pd.notnull(row['clinic']):
            return f"{base_dir}/derm7pt/release_v0/images/{row['clinic']}"
        else:
            return None

    elif source == 'PH2':
        return f"{base_dir}/ph2dataset/PH2Dataset/PH2_Dataset_images/{image_id}/{image_id}_Dermoscopic_Image/{image_id}.bmp"

### ⚙️ Augmentation, prompt creation & metadata finalization execution

In [None]:
pad_df = generate_text_prompts(pad_df)
darm_df = generate_text_prompts(darm_df)
ham_df = generate_text_prompts(ham_df)
ph2_df = generate_text_prompts(ph2_df)

# Combine all into one final DataFrame
original_text_prompt_df = pd.concat([pad_df, darm_df, ham_df, ph2_df], ignore_index=True)
original_text_prompt_df['image_path'] = original_text_prompt_df.apply(generate_image_path, axis=1)
print("Unified dataset preview:")
print(original_text_prompt_df[['image_id', 'text_prompt', 'image_path']].head().to_string())
original_text_prompt_df = original_text_prompt_df.dropna(subset=['diagnosis_numeric'])

augmented_metadata_df = pd.read_csv(AUGMENTED_METADATA_PATH)

augmented_metadata_df.loc[augmented_metadata_df['original_image_path'].isna() |
                         (augmented_metadata_df['original_image_path'] == ''),
                         'original_image_path'] = augmented_metadata_df['image_path']

merge_df = augmented_metadata_df.merge(
    original_text_prompt_df[['image_path', 'text_prompt']],
    left_on='original_image_path',
    right_on='image_path',
    how='left'
)

augmented_text_prompt_df = merge_df[['diagnosis', 'diagnosis_numeric', 'dataset_source', 'original_image_path', 'text_prompt']]
augmented_text_prompt_df = augmented_text_prompt_df.rename(columns={
    'original_image_path': 'image_path'
})

original_text_prompt_df.to_csv("original_vlm_with_text_prompt_image_path.csv", index=False)
augmented_text_prompt_df.to_csv("augmented_vlm_with_text_prompt_image_path.csv", index=False)

### ⚙️ Define HF token

In [None]:
from huggingface_hub import login
login(token="<HF TOKEN>")

### ⚙️ Dataloader & Hyperparameters definition, dataset splitting, training & test part 

In [None]:
# === Imports ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import densenet169
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                             roc_auc_score, classification_report, confusion_matrix)
from transformers import AutoModel, AutoTokenizer, AutoConfig
from peft import get_peft_model, LoraConfig, TaskType
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import models


# === Dataset ===
class PromptDiagnosisDataset(Dataset):
    def __init__(self, dataframe, tokenizer, transform=None):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.loc[idx]
        image = Image.open(row["image_path"]).convert("RGB")
        image = self.transform(image)

        encoded = self.tokenizer(
            row["text_prompt"],
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        )

        return {
            "image": image,
            "input_ids": encoded["input_ids"].squeeze(0),
            "attention_mask": encoded["attention_mask"].squeeze(0),
            "label": torch.tensor(int(row["diagnosis_numeric"]), dtype=torch.long)
        }


# === Vision Encoder ===
class DenseNet169FeatureExtractor(nn.Module):
    def __init__(self, checkpoint_path=None):
        super().__init__()
        self.backbone = models.densenet169(pretrained=False)
        self.backbone.classifier = nn.Identity()

        if checkpoint_path:
            state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))

            # Remove "base." prefix if present in state_dict
            new_state_dict = {}
            for k, v in state_dict.items():
                new_k = k.replace("base.", "") if k.startswith("base.") else k
                new_state_dict[new_k] = v

            self.backbone.load_state_dict(new_state_dict, strict=False)

    def forward(self, x):
        return self.backbone(x)



# === Q-Former ===
class QFormer(nn.Module):
    def __init__(self, query_dim, vision_dim, num_queries=32):
        super().__init__()
        self.query_tokens = nn.Parameter(torch.randn(num_queries, query_dim))
        self.cross_attention = nn.MultiheadAttention(embed_dim=query_dim, num_heads=4, batch_first=True)
        self.linear_proj = nn.Linear(vision_dim, query_dim)

    def forward(self, vision_feats):
        B = vision_feats.size(0)
        queries = self.query_tokens.unsqueeze(0).expand(B, -1, -1)
        keys = self.linear_proj(vision_feats).unsqueeze(1)
        output, _ = self.cross_attention(queries, keys, keys)
        return output


# === Full Model ===
class PromptBasedVLM(nn.Module):
    def __init__(self, vision_encoder, qformer, text_encoder, hidden_dim, num_classes):
        super().__init__()
        self.vision_encoder = vision_encoder
        self.qformer = qformer
        self.text_encoder = text_encoder
        self.classifier = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, image, input_ids, attention_mask):
        vision_feats = self.vision_encoder(image)
        q_output = self.qformer(vision_feats)
        q_pooled = q_output.mean(dim=1)

        text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_pooled = text_output.last_hidden_state[:, 0, :]

        fused = torch.cat([q_pooled, text_pooled], dim=1)
        logits = self.classifier(fused)
        return logits


# === Evaluation Function ===
def evaluate(model, data_loader, device):
    model.eval()
    y_true, y_pred, y_probs = [], [], []
    with torch.no_grad():
        for batch in data_loader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(image, input_ids, attention_mask)
            probs = F.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            y_probs.extend(probs.cpu().numpy())

    print(classification_report(y_true, y_pred, digits=4))
    print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))
    print("Accuracy:", accuracy_score(y_true, y_pred))
    print("Precision (macro):", precision_score(y_true, y_pred, average='macro'))
    print("Recall (macro):", recall_score(y_true, y_pred, average='macro'))
    print("F1 Score (macro):", f1_score(y_true, y_pred, average='macro'))


# === Training Function ===
def train(model, train_loader, val_loader, optimizer, criterion, scheduler, device, num_epochs=10):
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            image = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            outputs = model(image, input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")
        print("Validation Metrics:")
        evaluate(model, val_loader, device)


# === Setup ===
df = pd.read_csv("/kaggle/working/augmented_vlm_with_text_prompt_image_path.csv")
labels = df['diagnosis_numeric'].values
classes = np.unique(labels)

from sklearn.model_selection import train_test_split
train_df, temp_df = train_test_split(df, test_size=0.3, stratify=labels, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['diagnosis_numeric'], random_state=42)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Tokenizer & LoRA Text Encoder
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
text_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
peft_config = LoraConfig(
    r=8, lora_alpha=32, target_modules=["query", "value"],
    lora_dropout=0.1, bias="none", task_type=TaskType.FEATURE_EXTRACTION
)
text_encoder = get_peft_model(text_model, peft_config)

# Components
vision_checkpoint_path = "/kaggle/input/densenet169_with_ham10k_pad_darm7pr_ph2/pytorch/default/1/denseNet169_160_model_best.pth"
vision_encoder = DenseNet169FeatureExtractor(checkpoint_path=vision_checkpoint_path)
qformer = QFormer(query_dim=384, vision_dim=1664, num_queries=32)
model = PromptBasedVLM(vision_encoder, qformer, text_encoder, hidden_dim=384, num_classes=len(classes))

# DataLoader
train_ds = PromptDiagnosisDataset(train_df, tokenizer)
val_ds = PromptDiagnosisDataset(val_df, tokenizer)
test_ds = PromptDiagnosisDataset(test_df, tokenizer)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=16)
test_loader = DataLoader(test_ds, batch_size=16)

# Loss, Optimizer, Scheduler
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

# Train
train(model, train_loader, val_loader, optimizer, criterion, scheduler, DEVICE, num_epochs=10)

# Test
print("\nFinal Test Metrics:")
evaluate(model, test_loader, DEVICE)

# Save Q-Former
torch.save(qformer.state_dict(), "qformer_weights.pth")