In [None]:
import os
import random
import json

import numpy as np
import pandas as pd
from PIL import Image

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity

from imblearn.combine import SMOTETomek
from imblearn.over_sampling import SMOTE

from torchvision import transforms
import imagehash
import joblib
from tqdm import tqdm

# =========================================================
# CONFIGURATION ‚Äì CHANGE THESE PATHS
# =========================================================
META_CSV      = r"meta\meta.csv"
TRAIN_IDX_CSV = r"meta\train_indexes.csv"
VALID_IDX_CSV = r"meta\valid_indexes.csv"
TEST_IDX_CSV  = r"meta\test_indexes.csv"

# Root folder where 'derm' image paths are stored
# e.g., if derm column is "NEL/Nel026.jpg", and they live in "images/NEL/Nel026.jpg"
IMAGE_ROOT_DIR = r"images"

SAVE_DIR = r"augmented"
AUG_DIR  = os.path.join(SAVE_DIR, "augmented_images")

os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(AUG_DIR, exist_ok=True)

# Random seed
seed = 42
random.seed(seed)
np.random.seed(seed)

# =========================================================
# LOAD META + SPLITS
# =========================================================
meta = pd.read_csv(META_CSV)

train_idx = pd.read_csv(TRAIN_IDX_CSV)["indexes"].values
val_idx   = pd.read_csv(VALID_IDX_CSV)["indexes"].values
test_idx  = pd.read_csv(TEST_IDX_CSV)["indexes"].values

print(f"Total cases in meta: {len(meta)}")
print(f"Train indices: {len(train_idx)}")
print(f"Valid indices: {len(val_idx)}")
print(f"Test indices:  {len(test_idx)}")

# =========================================================
# IMAGE PATHS & LABELS
# =========================================================
# Use dermoscopy images
meta["ImagePath"] = meta["derm"].apply(lambda x: os.path.join(IMAGE_ROOT_DIR, x))

# =========================================================
# LABEL GROUPING (as per Derm7pt paper)
# =========================================================
# Group diagnoses into 5 clinically meaningful categories (EXACTLY 5 classes)
# BCC | MEL | MISC | NEV | SEK
diagnosis_grouping = {
    # MEL - Melanoma
    "melanoma": "MEL",
    "melanoma in situ": "MEL",
    "melanoma (in situ)": "MEL",
    "melanoma (less than 0.76 mm)": "MEL",
    "melanoma (0.76 to 1.5 mm)": "MEL",
    "melanoma (more than 1.5 mm)": "MEL",
    "melanoma metastasis": "MEL",
    "melanosis": "MEL",  # melanosis ‚Üí MEL
    
    # NEV - Nevus (all nevus types)
    "blue nevus": "NEV",
    "Clark nevus": "NEV",
    "clark nevus": "NEV",  # lowercase variant
    "combined nevus": "NEV",
    "congenital nevus": "NEV",
    "dermal nevus": "NEV",
    "recurrent nevus": "NEV",
    "Reed nevus": "NEV",
    "Spitz nevus": "NEV",
    "reed or spitz nevus": "NEV",  # combined Reed/Spitz ‚Üí NEV
    "nevus": "NEV",
    
    # BCC - Basal Cell Carcinoma
    "basal cell carcinoma": "BCC",
    
    # SEK - Seborrheic Keratosis (SK in paper)
    "seborrheic keratosis": "SEK",
    "solar lentigo": "SEK",
    "lentigo": "SEK",
    "lichenoid keratosis": "SEK",
    
    # MISC - Miscellaneous (dermatofibroma, vascular lesions, etc.)
    "dermatofibroma": "MISC",
    "vascular lesion": "MISC",
    "miscellaneous": "MISC",
}

# Apply grouping
meta["diagnosis_grouped"] = meta["diagnosis"].map(diagnosis_grouping)

# Check for unmapped diagnoses (should be none now)
unmapped = meta[meta["diagnosis_grouped"].isna()]["diagnosis"].unique()
if len(unmapped) > 0:
    print(f"‚ö†Ô∏è  ERROR: {len(unmapped)} unmapped diagnoses found:")
    for d in unmapped:
        print(f"   - '{d}'")
    raise ValueError(f"Unmapped diagnoses found! Please add them to diagnosis_grouping dict.")

# Encode grouped labels
label_encoder = LabelEncoder()
meta["label"] = label_encoder.fit_transform(meta["diagnosis_grouped"])

label_mapping = {cls: int(i) for i, cls in enumerate(label_encoder.classes_)}

# Verify we have exactly 5 classes
num_classes = len(label_encoder.classes_)
if num_classes != 5:
    raise ValueError(f"Expected 5 classes, but got {num_classes}: {list(label_encoder.classes_)}")

print("\n‚úÖ Grouped Label mapping (5 classes):")
for k, v in label_mapping.items():
    print(f"  {k} -> {v}")
    
print("\nDiagnosis distribution after grouping:")
class_dist = meta["diagnosis_grouped"].value_counts().sort_index()
for cls, count in class_dist.items():
    print(f"  {cls}: {count} samples")
print(f"\nTotal: {len(meta)} samples across {num_classes} classes")

# =========================================================
# METADATA PREPROCESSING (as per paper)
# =========================================================
# Paper uses ONLY: sex, location, elevation (all categorical, one-hot encoded)
# NO numeric fields, NO 7-point checklist features for metadata
categorical_cols = ["sex", "location", "elevation"]

# Clean categoricals (fill missing with 'unknown')
for col in categorical_cols:
    meta[col] = meta[col].fillna("unknown")

print("\nMetadata feature distributions:")
for col in categorical_cols:
    print(f"\n{col}:")
    print(meta[col].value_counts())

# One-hot encode categoricals (no drop_first to keep all categories)
X_cat = pd.get_dummies(meta[categorical_cols], drop_first=False)

print(f"\nOne-hot encoded metadata shape: {X_cat.shape}")
print(f"Metadata features: {list(X_cat.columns)}")

# Final metadata feature matrix (only categorical features)
X_meta = X_cat.reset_index(drop=True)

y = meta["label"].astype(int)
img_paths_all = meta["ImagePath"]

print(f"\nMetadata feature shape: {X_meta.shape}")

# Check class distribution
print("\nClass distribution:")
class_counts = y.value_counts().sort_index()
for cls, count in class_counts.items():
    cls_name = label_encoder.classes_[cls]
    print(f"  {cls_name} (label {cls}): {count} samples")

# Find minimum class size
min_class_size = class_counts.min()
print(f"\nMinimum class size: {min_class_size}")

# Set k_neighbors based on smallest class (must be < min_class_size)
k_neighbors = max(1, min(5, min_class_size - 1))
print(f"Using k_neighbors={k_neighbors} for SMOTETomek")

print("\nApplying SMOTETomek on full metadata (before split)...")

# =========================================================
# SMOTETomek (on full dataset, like your ISIC code)
# =========================================================
from imblearn.over_sampling import SMOTE

# Convert to numpy array (SMOTE needs numeric data, not boolean DataFrame)
X_meta_array = X_meta.values.astype(np.float32)

smote = SMOTETomek(
    smote=SMOTE(k_neighbors=k_neighbors, random_state=seed),
    random_state=seed
)
X_meta_res, y_res = smote.fit_resample(X_meta_array, y)

orig_len = len(X_meta)
# Convert back to DataFrame for consistency
X_meta_res = pd.DataFrame(X_meta_res, columns=X_meta.columns)
y_res = pd.Series(y_res)

X_synth = X_meta_res.iloc[orig_len:]
y_synth = y_res.iloc[orig_len:]

print(f"‚úÖ Generated {len(X_synth)} synthetic metadata samples via SMOTETomek")

# =========================================================
# OFFLINE AUGMENTATION FOR SYNTHETIC METADATA
# =========================================================
offline_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomApply([
        transforms.ColorJitter(0.3, 0.3, 0.3, 0.1),
        transforms.RandomRotation(45),
        transforms.RandomHorizontalFlip(),
    ], p=1.0),
])

# Map label -> available real images from ORIGINAL data
label_to_imgs = {lbl: img_paths_all[y == lbl].tolist() for lbl in sorted(y.unique())}

aug_paths   = []
aug_meta    = []
aug_labels  = []
aug_hashes  = []   # pHash of augmented images to avoid near-duplicates within aug set

print(f"\nGenerating {len(X_synth)} synthetic image‚Äìmetadata pairs using offline augmentation...")
for i, (m_row, lbl) in enumerate(zip(X_synth.values, y_synth)):
    candidates = label_to_imgs[lbl]
    img_path = random.choice(candidates)

    try:
        img = Image.open(img_path).convert("RGB")
    except Exception:
        continue

    aug_img = offline_aug(img)

    # Avoid near-duplicate augmented images
    h = imagehash.phash(aug_img)
    if any(h - hh <= 5 for hh in aug_hashes):
        continue

    out_path = os.path.join(AUG_DIR, f"aug_{lbl}_{i}.jpg")
    aug_img.save(out_path, "JPEG", quality=95)

    aug_hashes.append(h)
    aug_paths.append(out_path)
    aug_meta.append(m_row)
    aug_labels.append(lbl)

aug_meta_df = pd.DataFrame(aug_meta, columns=X_meta.columns)
aug_labels_df = pd.DataFrame({"ImagePath": aug_paths, "Label": aug_labels})

print(f"‚úÖ Saved {len(aug_labels_df)} augmented images ‚Üí {AUG_DIR}")

# Save raw augmented data (optional, before cleaning)
aug_meta_df.to_csv(os.path.join(SAVE_DIR, "augmented_metadata_raw.csv"), index=False)
aug_labels_df.to_csv(os.path.join(SAVE_DIR, "augmented_labels_raw.csv"), index=False)

# =========================================================
# BUILD OFFICIAL TRAIN / VAL / TEST SPLITS (NO RESPLITTING)
# =========================================================
X_train_meta = X_meta.iloc[train_idx].reset_index(drop=True)
X_val_meta   = X_meta.iloc[val_idx].reset_index(drop=True)
X_test_meta  = X_meta.iloc[test_idx].reset_index(drop=True)

X_train_img_paths = img_paths_all.iloc[train_idx].reset_index(drop=True)
X_val_img_paths   = img_paths_all.iloc[val_idx].reset_index(drop=True)
X_test_img_paths  = img_paths_all.iloc[test_idx].reset_index(drop=True)

y_train = y.iloc[train_idx].reset_index(drop=True)
y_val   = y.iloc[val_idx].reset_index(drop=True)
y_test  = y.iloc[test_idx].reset_index(drop=True)

print(f"\n‚úÖ Train: {len(X_train_meta)} samples")
print(f"‚úÖ Val:   {len(X_val_meta)} samples")
print(f"‚úÖ Test:  {len(X_test_meta)} samples")

# =========================================================
# FILTER AUGMENTED DATA AGAINST VAL/TEST ‚Äì IMAGE pHASH (RELAXED)
# =========================================================
PHASH_THRESHOLD = 12   # Relaxed from 5 to allow more augmented variations

def compute_hashes(paths, desc):
    hashes = {}
    for p in tqdm(paths, desc=desc):
        try:
            img = Image.open(p).convert("RGB").resize((224, 224))
            hashes[p] = imagehash.phash(img)
        except Exception:
            continue
    return hashes

print("\nüìä Computing perceptual hashes for val/test/aug images...")
val_hashes  = compute_hashes(X_val_img_paths.tolist(), "Val hashes")
test_hashes = compute_hashes(X_test_img_paths.tolist(), "Test hashes")
aug_hashes_dict = compute_hashes(aug_labels_df["ImagePath"].tolist(), "Aug hashes")

print(f"\nüîç Checking augmented vs validation/test similarity (pHash ‚â§ {PHASH_THRESHOLD})...")
to_drop = set()

for aug_path, h_aug in aug_hashes_dict.items():
    # vs validation
    if len(val_hashes) > 0:
        if any(h_aug - hv <= PHASH_THRESHOLD for hv in val_hashes.values()):
            to_drop.add(aug_path)
    
    # vs test
    if len(test_hashes) > 0:
        if any(h_aug - ht <= PHASH_THRESHOLD for ht in test_hashes.values()):
            to_drop.add(aug_path)

print(f"üßπ Removing {len(to_drop)} augmented images visually similar to val/test")

mask_keep = ~aug_labels_df["ImagePath"].isin(to_drop)
aug_labels_df = aug_labels_df[mask_keep].reset_index(drop=True)
aug_meta_df   = aug_meta_df[mask_keep].reset_index(drop=True)

print(f"Remaining augmented samples after pHash filtering: {len(aug_labels_df)}")

# =========================================================
# METADATA-BASED FILTERING (RELAXED FOR ONE-HOT ENCODED DATA)
# =========================================================
if len(aug_labels_df) > 0:
    print("\nüîç Checking metadata similarity between augmented and val/test sets‚Ä¶")

    X_aug      = aug_meta_df.values
    X_val_arr  = X_val_meta.values
    X_test_arr = X_test_meta.values

    # Relaxed thresholds for one-hot encoded categorical metadata
    # (one-hot vectors have discrete jumps, not smooth continuous values)
    eucl_thresh = 0.02   # Very small threshold for one-hot data
    cos_thresh  = 0.005  # Very small threshold for one-hot data

    # Euclidean distance
    eucl_val  = euclidean_distances(X_aug, X_val_arr).min(axis=1)
    eucl_test = euclidean_distances(X_aug, X_test_arr).min(axis=1)

    # Cosine distance (1 - similarity)
    cos_val  = 1 - cosine_similarity(X_aug, X_val_arr).max(axis=1)
    cos_test = 1 - cosine_similarity(X_aug, X_test_arr).max(axis=1)

    mask_del = (eucl_val < eucl_thresh) | (cos_val < cos_thresh) | \
               (eucl_test < eucl_thresh) | (cos_test < cos_thresh)

    print(f"\nüßπ Removing {mask_del.sum()} augmented samples with nearly identical metadata to val/test")

    aug_meta_df   = aug_meta_df[~mask_del].reset_index(drop=True)
    aug_labels_df = aug_labels_df[~mask_del].reset_index(drop=True)
else:
    print("\n‚ö†Ô∏è  No augmented samples remaining after pHash filtering. Skipping metadata filtering.")

print(f"‚úÖ Final augmented samples after all filtering: {len(aug_labels_df)}")

# Save cleaned augmented data (optional)
aug_meta_df.to_csv(os.path.join(SAVE_DIR, "augmented_metadata_clean.csv"), index=False)
aug_labels_df.to_csv(os.path.join(SAVE_DIR, "augmented_labels_clean.csv"), index=False)

# =========================================================
# COMBINE CLEAN AUGMENTED DATA WITH TRAIN SET ONLY
# =========================================================
if len(aug_labels_df) > 0:
    print("\nüì¶ Combining cleaned augmented data with TRAIN set only...")

    X_train_meta_final = pd.concat(
        [X_train_meta.reset_index(drop=True),
         aug_meta_df.reset_index(drop=True)],
        ignore_index=True
    )

    X_train_img_paths_final = pd.concat(
        [X_train_img_paths.reset_index(drop=True),
         aug_labels_df["ImagePath"]],
        ignore_index=True
    )

    y_train_final = pd.concat(
        [y_train.reset_index(drop=True),
         aug_labels_df["Label"]],
        ignore_index=True
    )

    print(f"\n‚úÖ Final training samples (real + augmented): {len(X_train_meta_final)}")
    print(f"   - Real training samples: {len(X_train_meta)}")
    print(f"   - Augmented samples:     {len(aug_meta_df)}")
else:
    print("\n‚ö†Ô∏è  No augmented samples survived filtering. Using original training data only.")
    X_train_meta_final = X_train_meta
    X_train_img_paths_final = X_train_img_paths
    y_train_final = y_train
    
    print(f"\n‚úÖ Final training samples: {len(X_train_meta_final)}")
print(f"‚úÖ Final validation samples: {len(X_val_meta)}")
print(f"‚úÖ Final test samples:       {len(X_test_meta)}")

# Build final DataFrames
train_final = X_train_meta_final.copy()
train_final["ImagePath"] = X_train_img_paths_final
train_final["label"]     = y_train_final

val_final = X_val_meta.copy()
val_final["ImagePath"] = X_val_img_paths
val_final["label"]     = y_val

test_final = X_test_meta.copy()
test_final["ImagePath"] = X_test_img_paths
test_final["label"]     = y_test

# =========================================================
# SAVE FINAL CSVs + PREPROCESSING ARTIFACTS
# =========================================================
train_final_path = os.path.join(SAVE_DIR, "train_metadata_final.csv")
val_final_path   = os.path.join(SAVE_DIR, "val_metadata_final.csv")
test_final_path  = os.path.join(SAVE_DIR, "test_metadata_final.csv")

train_final.to_csv(train_final_path, index=False)
val_final.to_csv(val_final_path, index=False)
test_final.to_csv(test_final_path, index=False)

print(f"\n‚úÖ Saved training metadata to: {train_final_path}")
print(f"‚úÖ Saved validation metadata to: {val_final_path}")
print(f"‚úÖ Saved test metadata to: {test_final_path}")

# Save preprocessing info (no scaler needed - only categorical features)
preprocessing_info = {
    "categorical_cols": list(X_cat.columns),
    "label_mapping": label_mapping,
    "diagnosis_grouping": diagnosis_grouping,
    "original_metadata_cols": categorical_cols
}
info_path = os.path.join(SAVE_DIR, "preprocessing_info.json")
with open(info_path, "w") as f:
    json.dump(preprocessing_info, f, indent=2)

print(f"‚úÖ Saved preprocessing info to: {info_path}")
print("\nüéâ Derm7pt SMOTETomek + offline augmentation pipeline complete!")


In [None]:
import pandas as pd

meta = pd.read_csv("G:\\Downloads\\Dermp7\\release_v0\\meta\\meta.csv")
train_idx = pd.read_csv("G:\\Downloads\\Dermp7\\release_v0\\meta\\train_indexes.csv")
valid_idx = pd.read_csv("G:\\Downloads\\Dermp7\\release_v0\\meta\\valid_indexes.csv")
test_idx  = pd.read_csv("G:\\Downloads\\Dermp7\\release_v0\\meta\\test_indexes.csv")

print("===== META HEAD =====")
print(meta.head())

print("\n===== META INFO =====")
print(meta.info())

print("\n===== TRAIN IDX =====")
print(train_idx.head())

print("\n===== VALID IDX =====")
print(valid_idx.head())

print("\n===== TEST IDX =====")
print(test_idx.head())


# üìã Loading Code for Other Derm7pt Notebooks

**Use the code below in your training notebooks** after running the preprocessing pipeline above.

This code will:
- Load the preprocessed train/val/test CSVs
- Load the saved scaler and preprocessing info
- Provide ready-to-use data for your models

In [None]:
# =========================================================
# DERM7PT DATA LOADER FOR TRAINING NOTEBOOKS
# =========================================================
# Run this code in your training notebooks to load the preprocessed Derm7pt dataset
# Make sure the preprocessing pipeline above has been executed first!

import os
import json
import joblib
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

# =========================================================
# PATHS TO PREPROCESSED DATA
# =========================================================
PREPROCESSED_DIR = r"augmented"

TRAIN_CSV = os.path.join(PREPROCESSED_DIR, "train_metadata_final.csv")
VAL_CSV   = os.path.join(PREPROCESSED_DIR, "val_metadata_final.csv")
TEST_CSV  = os.path.join(PREPROCESSED_DIR, "test_metadata_final.csv")

INFO_PATH = os.path.join(PREPROCESSED_DIR, "preprocessing_info.json")

# =========================================================
# LOAD PREPROCESSED DATA
# =========================================================
print("Loading preprocessed Derm7pt data...")

# Load CSVs
train_df = pd.read_csv(TRAIN_CSV)
val_df   = pd.read_csv(VAL_CSV)
test_df  = pd.read_csv(TEST_CSV)

# Load preprocessing info
with open(INFO_PATH, "r") as f:
    preprocessing_info = json.load(f)

categorical_cols = preprocessing_info["categorical_cols"]
label_mapping = preprocessing_info["label_mapping"]

print(f"\n‚úÖ Training samples:   {len(train_df)}")
print(f"‚úÖ Validation samples: {len(val_df)}")
print(f"‚úÖ Test samples:       {len(test_df)}")
print(f"\nLabel mapping: {label_mapping}")

# =========================================================
# EXTRACT FEATURES AND LABELS
# =========================================================
def extract_features(df):
    """Extract image paths, metadata features, and labels from dataframe"""
    img_paths = df["ImagePath"].values
    labels = df["label"].values
    
    # Metadata features (all columns except ImagePath and label)
    metadata_cols = [col for col in df.columns if col not in ["ImagePath", "label"]]
    metadata = df[metadata_cols].values
    
    return img_paths, metadata, labels

X_train_img, X_train_meta, y_train = extract_features(train_df)
X_val_img, X_val_meta, y_val       = extract_features(val_df)
X_test_img, X_test_meta, y_test    = extract_features(test_df)

num_classes = len(label_mapping)
print(f"\nNumber of classes: {num_classes}")

# =========================================================
# PYTORCH DATASET CLASS
# =========================================================
class Derm7ptDataset(Dataset):
    """
    Custom Dataset for Derm7pt with images + metadata
    """
    def __init__(self, img_paths, metadata, labels, transform=None):
        self.img_paths = img_paths
        self.metadata = metadata
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.img_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            # Fallback to black image if loading fails
            print(f"Warning: Failed to load {img_path}, using placeholder")
            image = Image.new("RGB", (224, 224), color="black")
        
        if self.transform:
            image = self.transform(image)
        
        # Get metadata and label
        metadata = self.metadata[idx].astype(np.float32)
        label = int(self.labels[idx])
        
        return image, metadata, label

# =========================================================
# DATA TRANSFORMS
# =========================================================
# Training transforms (with augmentation) - REDUCED for small Derm7pt dataset
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Validation/Test transforms (no augmentation)
val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# =========================================================
# CREATE DATASETS
# =========================================================
train_dataset = Derm7ptDataset(X_train_img, X_train_meta, y_train, transform=train_transform)
val_dataset   = Derm7ptDataset(X_val_img, X_val_meta, y_val, transform=val_test_transform)
test_dataset  = Derm7ptDataset(X_test_img, X_test_meta, y_test, transform=val_test_transform)

print(f"\n‚úÖ Created PyTorch Datasets")
print(f"   - Train: {len(train_dataset)} samples")
print(f"   - Val:   {len(val_dataset)} samples")
print(f"   - Test:  {len(test_dataset)} samples")

# =========================================================
# CREATE DATALOADERS (EXAMPLE - ADJUST BATCH SIZE AS NEEDED)
# =========================================================
BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Set to 0 for Windows, increase for Linux/Mac
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

print(f"\n‚úÖ Created DataLoaders (batch_size={BATCH_SIZE})")
print(f"   - Train batches: {len(train_loader)}")
print(f"   - Val batches:   {len(val_loader)}")
print(f"   - Test batches:  {len(test_loader)}")

# =========================================================
# EXAMPLE: TEST LOADING A BATCH
# =========================================================
print("\nüîç Testing batch loading...")
for images, metadata, labels in train_loader:
    print(f"   - Image batch shape:    {images.shape}")
    print(f"   - Metadata batch shape: {metadata.shape}")
    print(f"   - Labels batch shape:   {labels.shape}")
    break

print("\n‚úÖ Derm7pt data loading complete! Ready for training.")
print("\nüí° Usage in your model:")
print("   for images, metadata, labels in train_loader:")
print("       # images: torch.Tensor of shape (batch_size, 3, 224, 224)")
print("       # metadata: torch.Tensor of shape (batch_size, num_metadata_features)")
print("       # labels: torch.Tensor of shape (batch_size,)")
print("       # Your training code here...")