In [1]:
# Cell 1: Setup, paths, seeds
import os, random, pickle
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from sklearn.metrics import classification_report, confusion_matrix, f1_score, recall_score, roc_curve, auc

# ---------------- Paths (your provided paths) ----------------
csv_path = r" ## Path to Csv ##"
images_dir = r" ##path to Images directory##"
results_dir = r"##path to the folder you want to save your results##"

# make subfolders
os.makedirs(results_dir, exist_ok=True)
for sf in ["checkpoints", "pickles", "reports", "plots"]:
    os.makedirs(os.path.join(results_dir, sf), exist_ok=True)

# reproducibility + device
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
print("Results folder:", results_dir)


KeyboardInterrupt: 

In [None]:
# Cell 2: Load CSV & inspect
df = pd.read_csv(csv_path)
print("Total rows in CSV:", len(df))
print("Categories and counts:")
print(df['Category'].value_counts())

# Ensure 'Image Index' column exists (NIH style). If different, change below.
if "Image Index" not in df.columns and "path" in df.columns:
    print("Using 'path' column for image filenames.")


In [None]:
# Cell 3: Dataset class (CLAHE applied) and train/val/test split
from sklearn.model_selection import train_test_split

# Dataset with CLAHE on grayscale -> convert to RGB
class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, images_root, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.images_root = images_root
        self.transform = transform
        self.label_col = "Category"

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]["Image Index"] if "Image Index" in self.df.columns else self.df.iloc[idx]["path"]
        img_path = os.path.join(self.images_root, img_name)
        # read grayscale, apply CLAHE
        img_gray = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img_gray is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        img_clahe = clahe.apply(img_gray)
        img_rgb = cv2.cvtColor(img_clahe, cv2.COLOR_GRAY2RGB)
        img_pil = Image.fromarray(img_rgb)

        label = self.df.iloc[idx][self.label_col]
        return (self.transform(img_pil) if self.transform else img_pil, label)

# stratified split: train 80%, val 10%, test 10%
train_df, temp_df = train_test_split(df, test_size=0.2, stratify=df['Category'], random_state=seed)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['Category'], random_state=seed)
print("Split sizes -> train:", len(train_df), "val:", len(val_df), "test:", len(test_df))

# transforms
train_transforms = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomResizedCrop(224, scale=(0.9,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
val_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# Datasets & loaders
batch_size = 32
num_workers = 0 if torch.cuda.is_available() else 0

train_dataset = ChestXrayDataset(train_df, images_dir, transform=train_transforms)
val_dataset   = ChestXrayDataset(val_df, images_dir, transform=val_transforms)
test_dataset  = ChestXrayDataset(test_df, images_dir, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)


In [None]:
# Cell 4: class name -> index mapping (consistent across notebook)
classes = sorted(df['Category'].unique())  # e.g. ['Effusion','Infiltration','No Finding']
class_to_idx = {c:i for i,c in enumerate(classes)}
idx_to_class = {i:c for c,i in class_to_idx.items()}
print("Classes:", classes)
print("Mapping:", class_to_idx)

# Replace labels in datasets with numeric indices using a wrapper dataset
class MapLabelDataset(Dataset):
    def __init__(self, ds, mapping):
        self.ds = ds
        self.mapping = mapping
    def __len__(self):
        return len(self.ds)
    def __getitem__(self, idx):
        img, label = self.ds[idx]
        return img, self.mapping[label]

train_loader = DataLoader(MapLabelDataset(train_dataset, class_to_idx), batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader   = DataLoader(MapLabelDataset(val_dataset, class_to_idx), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader  = DataLoader(MapLabelDataset(test_dataset, class_to_idx), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)


In [None]:
import torch
import torch.nn as nn
from torchvision import models
from transformers import ViTModel

class CNNViTHybrid(nn.Module):
    def __init__(self, num_classes=3):
        super(CNNViTHybrid, self).__init__()
        
        # --- DenseNet121 Backbone ---
        self.cnn = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.cnn.classifier = nn.Identity()  # remove classifier to get features
        
        # --- ViT Backbone ---
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        
        # --- Fusion Layer ---
        cnn_feat_dim = 1024
        vit_feat_dim = self.vit.config.hidden_size  # 768
        fusion_dim = cnn_feat_dim + vit_feat_dim
        
        self.fc = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # DenseNet feature extraction
        cnn_features = self.cnn.features(x)
        cnn_features = nn.functional.adaptive_avg_pool2d(cnn_features, (1,1))
        cnn_features = torch.flatten(cnn_features, 1)
        
        # ViT expects (batch, channels, height, width)
        vit_out = self.vit(x)
        vit_features = vit_out.pooler_output
        
        # Concatenate CNN and ViT features
        combined = torch.cat((cnn_features, vit_features), dim=1)
        out = self.fc(combined)
        return out


In [None]:
model = CNNViTHybrid(num_classes=3).to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))

ckpt_dir = os.path.join(results_dir, "checkpoints")
best_ckpt = os.path.join(ckpt_dir, "hybrid_best.pth")# Cell 5: Hybrid model (DenseNet121 encoder -> transformer encoder -> head)
class CNNViTHybrid(nn.Module):
    def __init__(self, num_classes=3, cnn_out_channels=1024, embed_dim=384, depth=6, num_heads=6, mlp_ratio=4.0, drop_rate=0.1):
        super().__init__()
        densenet = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
        self.cnn = densenet.features                 # output expected [B,1024,7,7]
        self.cnn_norm = nn.BatchNorm2d(cnn_out_channels)
        self.proj = nn.Linear(cnn_out_channels, embed_dim)

        self.grid_size = 7
        num_patches = self.grid_size * self.grid_size
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1,1+num_patches,embed_dim))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(embed_dim*mlp_ratio),
            dropout=drop_rate,
            batch_first=True,
            norm_first=True,
            activation="gelu"
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.cnn(x)
        x = self.cnn_norm(x)
        B,C,H,W = x.shape
        x = x.flatten(2).transpose(1,2)   # [B,49,C]
        x = self.proj(x)                  # [B,49,embed_dim]
        cls = self.cls_token.expand(B,-1,-1)
        x = torch.cat([cls, x], dim=1)    # [B,1+49,embed_dim]
        x = x + self.pos_embed
        x = self.encoder(x)
        x = self.norm(x[:,0])
        logits = self.head(x)
        return logits

# instantiate
model = CNNViTHybrid(num_classes=len(classes)).to(device)
print("Model instantiated, output classes:", len(classes))

latest_ckpt = os.path.join(ckpt_dir, "hybrid_latest.pth")
pickle_path = os.path.join(results_dir, "pickles", "hybrid_3class.pkl")


In [None]:
# Cell 6: Loss, optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
# We'll use ReduceLROnPlateau with mode='max' because we monitor val metric (higher is better)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))

best_metric = 0.0   # best of (val_f1, val_recall)
patience = 6
no_improve = 0
num_epochs = 30

best_state_path = os.path.join(results_dir, "checkpoints", "stage1_best_state.pth")
best_pickle_path = os.path.join(results_dir, "pickles", "hybrid_stage1.pkl")


In [None]:
# Cell 7: Training loop (monitor max(val_f1, val_recall_macro))
from collections import defaultdict

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    total = 0
    correct = 0
    loop = tqdm(train_loader, desc=f"[Stage1] Epoch {epoch+1}/{num_epochs}", leave=False)
    for imgs, labels in loop:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(1)
        correct += (preds==labels).sum().item()
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc = correct / total

    # Validation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = outputs.argmax(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    val_f1 = f1_score(all_labels, all_preds, average='macro')
    val_recall = recall_score(all_labels, all_preds, average='macro')
    val_metric = max(val_f1, val_recall)

    print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, val_f1={val_f1:.4f}, val_recall={val_recall:.4f}, monitor_metric={val_metric:.4f}")

    # scheduler step (pass the monitored metric)
    scheduler.step(val_metric)

    # early stopping & save best
    if val_metric > best_metric:
        best_metric = val_metric
        no_improve = 0
        # save state_dict and full pickle
        torch.save(model.state_dict(), best_state_path)
        torch.save(model, best_pickle_path)
        # save classification report for this epoch
        report = classification_report(all_labels, all_preds, target_names=classes)
        with open(os.path.join(results_dir, "reports", f"stage1_report_epoch{epoch+1}.txt"), "w") as f:
            f.write(report)
        print("üèÜ New best model saved (state + pickle).")
    else:
        no_improve += 1
        print(f"‚ö†Ô∏è No improvement count: {no_improve}/{patience}")
        if no_improve >= patience:
            print("‚èπ Early stopping triggered (Stage1).")
            break


In [None]:
# Cell 8: Stage1 Test evaluation & save results
# load best
state = torch.load(best_state_path, map_location=device)
model.load_state_dict(state)
model.to(device); model.eval()

all_preds, all_labels, all_probs = [], [], []
with torch.no_grad():
    for imgs, labels in tqdm(test_loader, desc="Stage1 Testing"):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        probs = torch.softmax(outputs, dim=1)
        preds = outputs.argmax(1)
        all_preds.extend(preds.cpu().numpy()); all_labels.extend(labels.cpu().numpy()); all_probs.extend(probs.cpu().numpy())

# classification report
report = classification_report(all_labels, all_preds, target_names=classes)
print(report)
with open(os.path.join(results_dir, "reports", "stage1_test_report.txt"), "w") as f:
    f.write(report)

# confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6,5)); sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
plt.title("Stage1 Confusion Matrix"); plt.xlabel("Predicted"); plt.ylabel("True")
plt.savefig(os.path.join(results_dir, "plots", "stage1_confusion_matrix.png")); plt.show(); plt.close()

# ROC curves (one-vs-rest)
plt.figure(figsize=(7,6))
all_labels_bin = torch.nn.functional.one_hot(torch.tensor(all_labels), num_classes=len(classes)).numpy()
for i, cls in enumerate(classes):
    fpr, tpr, _ = roc_curve(all_labels_bin[:, i], [p[i] for p in all_probs])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"{cls} (AUC={roc_auc:.2f})", lw=2)
plt.plot([0,1],[0,1],"--", color="gray"); plt.legend(loc="lower right")
plt.title("Stage1 ROC Curves"); plt.savefig(os.path.join(results_dir, "plots", "stage1_roc.png")); plt.show(); plt.close()


In [None]:
# Cell 9: Binary dataset (Infiltration vs No Finding)
binary_classes = ["Infiltration", "No Finding"]
train_bin_df = train_df[train_df['Category'].isin(binary_classes)].reset_index(drop=True)
val_bin_df   = val_df[val_df['Category'].isin(binary_classes)].reset_index(drop=True)
test_bin_df  = test_df[test_df['Category'].isin(binary_classes)].reset_index(drop=True)

binary_map = {binary_classes[0]: 0, binary_classes[1]: 1}

class BinaryDataset(Dataset):
    def __init__(self, dataframe, images_root, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.images_root = images_root
        self.transform = transform
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]["Image Index"] if "Image Index" in self.df.columns else self.df.iloc[idx]["path"]
        img_path = os.path.join(self.images_root, img_name)
        img_gray = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        img_clahe = clahe.apply(img_gray)
        img_rgb = cv2.cvtColor(img_clahe, cv2.COLOR_GRAY2RGB)
        img_pil = Image.fromarray(img_rgb)
        label = self.df.iloc[idx]["Category"]
        return (self.transform(img_pil) if self.transform else img_pil, binary_map[label])

train_bin_loader = DataLoader(BinaryDataset(train_bin_df, images_dir, transform=train_transforms),
                              batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_bin_loader   = DataLoader(BinaryDataset(val_bin_df, images_dir, transform=val_transforms),
                              batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_bin_loader  = DataLoader(BinaryDataset(test_bin_df, images_dir, transform=val_transforms),
                              batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

print("Binary set sizes -> train:", len(train_bin_df), "val:", len(val_bin_df), "test:", len(test_bin_df))


In [None]:
# Cell 10: Binary model (reuse features from stage1)
binary_model = CNNViTHybrid(num_classes=2).to(device)

# load stage1 state dict and copy shared weights except head
stage1_state = torch.load(best_state_path, map_location=device) if os.path.exists(best_state_path) else None
if stage1_state is not None:
    # stage1_state is state_dict saved earlier
    filtered = {k: v for k, v in stage1_state.items() if not k.startswith("head")}
    missing, unexpected = binary_model.load_state_dict(filtered, strict=False)
    print("Loaded shared weights from Stage1 into binary model. Missing/unexpected keys:", missing, unexpected)
else:
    print("No Stage1 weights found, training binary model from scratch.")

criterion_b = nn.CrossEntropyLoss()
optimizer_b = optim.AdamW(binary_model.parameters(), lr=5e-5, weight_decay=1e-4)
scheduler_b = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_b, mode='max', factor=0.5, patience=3, verbose=True)

best_bin_metric = 0.0
patience_bin = 6
no_imp_bin = 0
best_bin_state = os.path.join(results_dir, "checkpoints", "stage2_best_state.pth")
best_bin_pickle = os.path.join(results_dir, "pickles", "hybrid_stage2_binary.pkl")

num_epochs_bin = 20
from sklearn.metrics import f1_score, recall_score

for epoch in range(num_epochs_bin):
    binary_model.train()
    for imgs, labels in tqdm(train_bin_loader, desc=f"[Stage2] Epoch {epoch+1}/{num_epochs_bin}", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer_b.zero_grad()
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            outs = binary_model(imgs)
            loss = criterion_b(outs, labels)
        loss.backward(); optimizer_b.step()

    # validation
    binary_model.eval()
    all_p, all_l = [], []
    with torch.no_grad():
        for imgs, labels in val_bin_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = binary_model(imgs).argmax(1)
            all_p.extend(preds.cpu().numpy()); all_l.extend(labels.cpu().numpy())
    val_f1_b = f1_score(all_l, all_p, average='macro')
    val_recall_b = recall_score(all_l, all_p, average='macro')
    val_metric_b = max(val_f1_b, val_recall_b)
    print(f"Stage2 Epoch {epoch+1}: val_f1={val_f1_b:.4f}, val_recall={val_recall_b:.4f}, monitor_metric={val_metric_b:.4f}")
    scheduler_b.step(val_metric_b)

    if val_metric_b > best_bin_metric:
        best_bin_metric = val_metric_b
        no_imp_bin = 0
        torch.save(binary_model.state_dict(), best_bin_state)
        torch.save(binary_model, best_bin_pickle)
        # save report
        report = classification_report(all_l, all_p, target_names=binary_classes)
        with open(os.path.join(results_dir, "reports", f"stage2_report_epoch{epoch+1}.txt"), "w") as f:
            f.write(report)
        print("üèÜ New best binary model saved.")
    else:
        no_imp_bin += 1
        print(f"‚ö†Ô∏è No improvement count (binary): {no_imp_bin}/{patience_bin}")
        if no_imp_bin >= patience_bin:
            print("‚èπ Early stopping triggered (Stage2).")
            break


In [None]:
# Cell 11: Binary test evaluation
binary_model.load_state_dict(torch.load(best_bin_state, map_location=device))
binary_model.to(device); binary_model.eval()

all_p, all_l, all_pr = [], [], []
with torch.no_grad():
    for imgs, labels in tqdm(test_bin_loader, desc="Binary Testing"):
        imgs, labels = imgs.to(device), labels.to(device)
        outs = binary_model(imgs)
        probs = torch.softmax(outs, dim=1)
        preds = outs.argmax(1)
        all_p.extend(preds.cpu().numpy()); all_l.extend(labels.cpu().numpy()); all_pr.extend(probs.cpu().numpy())

report_bin = classification_report(all_l, all_p, target_names=binary_classes)
print(report_bin)
with open(os.path.join(results_dir, "reports", "stage2_test_report.txt"), "w") as f:
    f.write(report_bin)

# confusion
cm = confusion_matrix(all_l, all_p)
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=binary_classes, yticklabels=binary_classes)
plt.title("Stage2 Confusion Matrix"); plt.xlabel("Predicted"); plt.ylabel("True")
plt.savefig(os.path.join(results_dir, "plots", "stage2_confusion_matrix.png")); plt.show(); plt.close()

# ROC
plt.figure(figsize=(6,5))
all_l_bin_oh = torch.nn.functional.one_hot(torch.tensor(all_l), num_classes=2).numpy()
for i, cls in enumerate(binary_classes):
    fpr, tpr, _ = roc_curve(all_l_bin_oh[:, i], [p[i] for p in all_pr])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f"{cls} (AUC={roc_auc:.2f})", lw=2)
plt.plot([0,1],[0,1],"--", color="gray"); plt.legend(loc="lower right")
plt.title("Stage2 ROC"); plt.savefig(os.path.join(results_dir, "plots", "stage2_roc.png")); plt.show(); plt.close()


In [None]:
# Cell 12: Hierarchical inference function that loads pickles and runs stage1 then stage2
def hierarchical_predict(img_tensor_batch, stage1_model, stage2_model, threshold=None):
    """
    img_tensor_batch: batch of preprocessed tensors on device
    stage1_model: PyTorch model (3-class)
    stage2_model: PyTorch model (binary)
    Returns predicted class indices (relative to 'classes' list)
    """
    stage1_model.eval(); stage2_model.eval()
    with torch.no_grad():
        out1 = stage1_model(img_tensor_batch)
        p1 = torch.softmax(out1, dim=1)
        pred1 = p1.argmax(1).cpu().numpy()

        final_preds = []
        for i in range(len(pred1)):
            cls_idx = pred1[i]
            cls_name = classes[cls_idx]
            if cls_name == "Effusion":
                final_preds.append(class_idx := cls_idx)
            else:
                # pass image through stage2
                img = img_tensor_batch[i].unsqueeze(0)
                out2 = stage2_model(img)
                pred2 = out2.argmax(1).item()
                # map binary index to overall class idx
                mapped = class_to_idx[binary_classes[pred2]]
                final_preds.append(mapped)
        return final_preds

# Example: run hierarchical on the full test_loader and print final classification report
# load models from pickles (safer for exact architecture)
stage1 = torch.load(os.path.join(results_dir, "pickles", "hybrid_stage1.pkl"), map_location=device)
stage2 = torch.load(os.path.join(results_dir, "pickles", "hybrid_stage2_binary.pkl"), map_location=device)
stage1.to(device); stage2.to(device)

y_true, y_pred = [], []
for imgs, labels in tqdm(test_loader, desc="Hierarchical Testing"):
    imgs = imgs.to(device)
    preds = hierarchical_predict(imgs, stage1, stage2)
    y_pred.extend(preds)
    y_true.extend(labels.numpy())

print("Hierarchical final report:")
print(classification_report(y_true, y_pred, target_names=classes))


In [None]:
# Define where you want to save the CSV files (adjust path as needed)
split_dir = r"C:\Users\harsh\OneDrive\Documents\main project\balanced ds\splits" # Example: create a 'splits' subfolder
os.makedirs(split_dir, exist_ok=True) # Create the directory if it doesn't exist

# Construct full file paths
train_csv_path = os.path.join(split_dir, "train_split.csv")
val_csv_path = os.path.join(split_dir, "val_split.csv")
test_csv_path = os.path.join(split_dir, "test_split.csv")

# Save each DataFrame to a CSV file
# index=False prevents pandas from writing the DataFrame index as a column
train_df.to_csv(train_csv_path, index=False)
val_df.to_csv(val_csv_path, index=False)
test_df.to_csv(test_csv_path, index=False)

print(f"Train split saved to: {train_csv_path}")
print(f"Validation split saved to: {val_csv_path}")
print(f"Test split saved to: {test_csv_path}")