In [None]:
model train

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.model_selection import train_test_split
import timm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm


metadata_csv = "/data/raw/HAM10000/metadata.csv"
image_dir = "/data/raw/HAM10000/metadata/img"
batch_size = 16
num_epochs = 50
learning_rate = 1e-3
patience = 15
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if os.path.isdir(metadata_csv):
    guess = os.path.join(metadata_csv, "HAM10000_metadata.csv")
    if os.path.exists(guess):
        metadata_csv = guess


label_map = {
    'akiec': 0,
    'bcc':   1,
    'bkl':   2,
    'df':    3,
    'mel':   4,
    'nv':    5,
    'vasc':  6
}
idx_to_label = {v: k for k, v in label_map.items()}


df = pd.read_csv(metadata_csv)
df = df[df['dx'].isin(label_map.keys())].copy()
df['label'] = df['dx'].map(label_map)
df['path']  = df['image_id'].apply(lambda x: os.path.join(image_dir, f"{x}.jpg"))

df = df[df['path'].apply(os.path.exists)].reset_index(drop=True)

train_df, val_df = train_test_split(
    df, test_size=0.2, stratify=df['label'], random_state=42
)

class SkinDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['path']).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = int(row['label'])
        return img, label


train_transforms = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(12),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((300, 300)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


train_dataset = SkinDataset(train_df, transform=train_transforms)
val_dataset   = SkinDataset(val_df,   transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


# CLASS WEIGHTS
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_df['label']),
    y=train_df['label']
)
class_weights = torch.tensor(class_weights, dtype=torch.float, device=device)

# MODEL
model = timm.create_model('efficientnet_b3', pretrained=True)
model.classifier = nn.Linear(model.classifier.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.05)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=5, verbose=True)

# Mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))


# TRAINING LOOP

best_val_acc = 0.0
epochs_no_improve = 0
train_acc_history, val_acc_history = [], []
train_loss_history, val_loss_history = [], []

for epoch in range(num_epochs):
    print(f"\n— Epoch {epoch+1}/{num_epochs} —")

    #TRAIN
    model.train()
    running_corrects, running_loss, n_train = 0, 0.0, 0
    for inputs, labels in tqdm(train_loader, desc="Training", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
        scaler.step(optimizer)
        scaler.update()

        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels).item()
        running_loss += loss.item() * inputs.size(0)
        n_train += inputs.size(0)

    train_acc = running_corrects / max(1, n_train)
    train_loss = running_loss / max(1, n_train)
    train_acc_history.append(train_acc)
    train_loss_history.append(train_loss)

    #VALIDATE
    model.eval()
    val_corrects, val_loss_sum, n_val = 0, 0.0, 0
    all_labels, all_preds, all_probs = [], [], []
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validating", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            val_corrects += torch.sum(preds == labels).item()
            val_loss_sum += loss.item() * inputs.size(0)
            n_val += inputs.size(0)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.detach().cpu().numpy())

    val_acc = val_corrects / max(1, n_val)
    val_loss = val_loss_sum / max(1, n_val)
    val_acc_history.append(val_acc)
    val_loss_history.append(val_loss)

    scheduler.step(val_acc)  # ReduceLROnPlateau monitors val_acc

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    print(f"Epoch {epoch+1} completed")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "val_acc": best_val_acc,
            "label_map": label_map
        }, "best_model.pth")
        print("model updated")

        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    # Early stopping
    if epochs_no_improve >= patience:
        print("Early stopping")
        break


# PLOTS
plt.figure(figsize=(10,4))
plt.plot(train_acc_history, label="Train Acc")
plt.plot(val_acc_history, label="Val Acc")
plt.title("Accuracy"); plt.xlabel("Epoch"); plt.ylabel("Acc"); plt.legend(); plt.show()

plt.figure(figsize=(10,4))
plt.plot(train_loss_history, label="Train Loss")
plt.plot(val_loss_history, label="Val Loss")
plt.title("Loss"); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.show()


ckpt = torch.load("best_model.pth", map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

all_labels, all_preds, all_probs = [], [], []
with torch.no_grad():
    for inputs, labels in tqdm(val_loader, desc="Final Eval", leave=False):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)
        _, preds = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())


# METRICS & VISUALS
print("\n— Classification Report (Validation, best checkpoint) —")
target_names = [idx_to_label[i] for i in range(num_classes)]
print(classification_report(all_labels, all_preds, target_names=target_names, digits=4))

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=target_names, yticklabels=target_names)
plt.xlabel("Predicted"); plt.ylabel("True"); plt.title("Confusion Matrix (Val)"); plt.show()

# ROC Curves (One-vs-Rest)
all_probs = np.array(all_probs)
y_true_bin = np.zeros((len(all_labels), num_classes))
for i, label in enumerate(all_labels):
    y_true_bin[i, label] = 1

plt.figure(figsize=(8,6))
for i, class_name in enumerate(target_names):
    fpr, tpr, _ = roc_curve(y_true_bin[:, i], all_probs[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr,
 label=f"{class_name} (AUC = {roc_auc:.2f})")
plt.plot([0,1], [0,1], '--', color='gray')
plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
plt.title("ROC Curves (Val)"); plt.legend(); plt.show()

In [None]:
Test mel sev acc

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import classification_report, confusion_matrix
import timm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm
import cv2


metadata_csv = "/data/raw/HAM10000/metadata.csv"
image_dir = "/data/raw/HAM10000/metadata/img"
model_path = "/data/models/HAM10000"
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

severity_map = {'mild': 0, 'moderate': 1, 'severe': 2}
idx_to_severity = {v: k for k, v in severity_map.items()}
num_classes = 3

disease_names = {
    'mel': 'Melanoma', 
    'bcc': 'Basal Cell Carcinoma', 
    'akiec': 'Actinic Keratoses',
    'bkl': 'Benign Keratosis', 
    'df': 'Dermatofibroma', 
    'nv': 'Melanocytic Nevus', 
    'vasc': 'Vascular Lesion'
}

if os.path.isdir(metadata_csv):
    guess = os.path.join(metadata_csv, "HAM10000_metadata.csv")
    if os.path.exists(guess):
        metadata_csv = guess


def analyze_image_features(image_path):
    try:
        img = cv2.imread(image_path)
        if img is None:
            return {'area': 0, 'irregularity': 0, 'darkness': 0, 'color_variance': 0}
        
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        _, thresh = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        contours, _ = cv2.findContours(255 - thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if contours:
            largest_contour = max(contours, key=cv2.contourArea)
            area = cv2.contourArea(largest_contour)
            perimeter = cv2.arcLength(largest_contour, True)
            if perimeter > 0:
                irregularity = (perimeter ** 2) / (4 * np.pi * area) if area > 0 else 0
            else:
                irregularity = 0
        else:
            area = 0
            irregularity = 0
        
        darkness = np.mean(255 - img_gray) / 255.0
        color_variance = np.std(img_rgb.reshape(-1, 3), axis=0).mean()
        
        return {
            'area': area, 
            'irregularity': irregularity, 
            'darkness': darkness, 
            'color_variance': color_variance
        }
    
    except Exception:
        return {'area': 0, 'irregularity': 0, 'darkness': 0, 'color_variance': 0}


def assign_severity_based_on_features(features):
    area_score = min(features['area'] / 10000, 1.0)
    irregularity_score = min(features['irregularity'] / 5.0, 1.0)
    darkness_score = features['darkness']
    color_variance_score = min(features['color_variance'] / 50.0, 1.0)
    
    severity_score = (area_score * 0.3 + irregularity_score * 0.3 + 
                     darkness_score * 0.2 + color_variance_score * 0.2)
    
    if severity_score < 0.33:
        return 'mild'
    elif severity_score < 0.66:
        return 'moderate'
    else:
        return 'severe'


class SeverityTestDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['path']).convert("RGB")
        if self.transform:
            img = self.transform(img)
        severity_label = int(row['severity_label'])
        return img, severity_label


def load_severity_model(model_path):
    print(f"Loading severity model from {model_path}...")
    
    checkpoint = torch.load(model_path, map_location=device)
    target_disease = checkpoint.get('target_disease', 'unknown')
    
    model = timm.create_model('efficientnet_b2', pretrained=False)
    model.classifier = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(model.classifier.in_features, 128),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(128, num_classes)
    )
    
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)
    
    model = model.to(device)
    model.eval()
    
    print("Severity model loaded successfully")
    print(f"Target disease: {disease_names.get(target_disease, target_disease)}")
    if "val_acc" in checkpoint:
        print(f"Training validation accuracy: {checkpoint['val_acc']:.4f}")
    
    return model, target_disease


def prepare_test_data(target_disease):
    print(f"Loading test data for {disease_names.get(target_disease, target_disease)}...")
    
    df = pd.read_csv(metadata_csv)
    df['path'] = df['image_id'].apply(lambda x: os.path.join(image_dir, f"{x}.jpg"))
    df = df[df['path'].apply(os.path.exists)].reset_index(drop=True)
    
    df = df[df['dx'] == target_disease].copy()
    print(f"Found {len(df)} images of {disease_names.get(target_disease, target_disease)}")
    
    if len(df) == 0:
        raise ValueError(f"No images found for disease: {target_disease}")
    
    print("Assigning severity labels...")
    features_list = []
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Analyzing images"):
        features = analyze_image_features(row['path'])
        severity_score = (min(features['area'] / 10000, 1.0) * 0.3 + 
                        min(features['irregularity'] / 5.0, 1.0) * 0.3 + 
                        features['darkness'] * 0.2 + 
                        min(features['color_variance'] / 50.0, 1.0) * 0.2)
        features_list.append((idx, severity_score))
    
    features_list.sort(key=lambda x: x[1])
    n_samples = len(df)
    mild_count = n_samples // 3
    moderate_count = n_samples // 3
    severe_count = n_samples - mild_count - moderate_count
    
    df['severity'] = 'mild'
    df['severity_label'] = 0
    
    mild_indices = [x[0] for x in features_list[:mild_count]]
    moderate_indices = [x[0] for x in features_list[mild_count:mild_count+moderate_count]]
    severe_indices = [x[0] for x in features_list[mild_count+moderate_count:]]
    
    df.loc[mild_indices, 'severity'] = 'mild'
    df.loc[mild_indices, 'severity_label'] = 0
    df.loc[moderate_indices, 'severity'] = 'moderate'
    df.loc[moderate_indices, 'severity_label'] = 1
    df.loc[severe_indices, 'severity'] = 'severe'
    df.loc[severe_indices, 'severity_label'] = 2
    
    severity_counts = df['severity'].value_counts()
    print("Test data severity distribution:")
    for severity, count in severity_counts.items():
        print(f"  {severity.capitalize()}: {count} images ({count/len(df)*100:.1f}%)")
    
    return df


def test_severity_model(model, test_loader):
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    correct = 0
    total = 0
    
    print("Testing severity model...")
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Testing"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    accuracy = correct / total
    print(f"Overall Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    
    return np.array(all_labels), np.array(all_preds), np.array(all_probs), accuracy


def show_detailed_results(y_true, y_pred, y_probs, accuracy, target_disease):
    severity_names = ['Mild', 'Moderate', 'Severe']
    
    print(f"\n{'='*60}")
    print("DETAILED TEST RESULTS")
    print(f"Disease: {disease_names.get(target_disease, target_disease)}")
    print(f"{'='*60}")
    print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    
    report = classification_report(y_true, y_pred, target_names=severity_names, digits=4)
    print("\nClassification Report:")
    print(report)
    
    print("\nPer-Class Accuracy:")
    for i, severity in enumerate(severity_names):
        class_mask = (y_true == i)
        if np.sum(class_mask) > 0:
            class_acc = np.sum((y_pred == i) & class_mask) / np.sum(class_mask)
            class_count = np.sum(class_mask)
            print(f"  {severity}: {class_acc:.4f} ({class_acc*100:.1f}%) - {class_count} samples")
    
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=severity_names, yticklabels=severity_names)
    plt.xlabel("Predicted Severity")
    plt.ylabel("True Severity")
    plt.title(f"Confusion Matrix - {disease_names.get(target_disease, target_disease)} Severity\nAccuracy: {accuracy:.3f}")
    plt.tight_layout()
    plt.show()
    
    print("\nAverage Confidence per Predicted Class:")
    for i, severity in enumerate(severity_names):
        pred_mask = (y_pred == i)
        if np.sum(pred_mask) > 0:
            avg_conf = np.mean([y_probs[j][i] for j in range(len(y_probs)) if pred_mask[j]])
            print(f"  {severity}: {avg_conf:.4f} ({avg_conf*100:.1f}%)")


def run_accuracy_test():
    print("Starting Severity Model Accuracy Test...")
    print("="*50)
    
    model, target_disease = load_severity_model(model_path)
    test_df = prepare_test_data(target_disease)
    
    test_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    test_dataset = SeverityTestDataset(test_df, transform=test_transforms)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                            shuffle=False, num_workers=0, pin_memory=True)
    
    y_true, y_pred, y_probs, accuracy = test_severity_model(model, test_loader)
    show_detailed_results(y_true, y_pred, y_probs, accuracy, target_disease)
    
    results = {
        'accuracy': accuracy,
        'y_true': y_true,
        'y_pred': y_pred,
        'y_probs': y_probs,
        'target_disease': target_disease
    }
    
    np.savez(f'severity_test_results_{target_disease}.npz', **results)
    print(f"\nResults saved to: severity_test_results_{target_disease}.npz")
    
    return accuracy, results


if __name__ == "__main__":
    print("Severity Model Accuracy Tester")
    print("="*40)
    print(f"Using device: {device}")
    print(f"Model path: {model_path}")
    print("="*40)
    
    try:
        accuracy, results = run_accuracy_test()
        print(f"\nTesting completed!")
        print(f"Final accuracy: {accuracy:.4f} ({accuracy*100:.1f}%)")
        
    except FileNotFoundError as e:
        print(f"File not found: {e}")
        print("Please check your model path and data paths")
        
    except Exception as e:
        print(f"Error during testing: {e}")

In [None]:
test model acc

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import classification_report, confusion_matrix
import timm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm


metadata_csv = "/data/raw/HAM10000/metadata.csv"
image_dir = "/data/raw/HAM10000/metadata/img"
model_path = "/data/models/HAM10000"
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if os.path.isdir(metadata_csv):
    guess = os.path.join(metadata_csv, "ISIC2018_Task3_Test_GroundTruth.csv")
    if os.path.exists(guess):
        metadata_csv = guess

label_map = {
    'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3,
    'mel': 4, 'nv': 5, 'vasc': 6
}
idx_to_label = {v: k for k, v in label_map.items()}
num_classes = len(label_map)


class SkinDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['path']).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = int(row['label'])
        return img, label


def load_test_data():
    df = pd.read_csv(metadata_csv)
    df = df[df['dx'].isin(label_map.keys())].copy()
    df['label'] = df['dx'].map(label_map)
    df['path'] = df['image_id'].apply(lambda x: os.path.join(image_dir, f"{x}.jpg"))
    df = df[df['path'].apply(os.path.exists)].reset_index(drop=True)
    
    print(f"Total samples available: {len(df)}")
    print("Class distribution:")
    for label, count in df['dx'].value_counts().items():
        print(f"  {label}: {count}")
    
    return df


def load_trained_model(model_path):
    print(f"Loading model from {model_path}...")
    
    model = timm.create_model('efficientnet_b3', pretrained=False)
    model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    
    checkpoint = torch.load(model_path, map_location=device)
    
    print(f"Available keys in checkpoint: {list(checkpoint.keys())}")
    
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    elif "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
    elif "model" in checkpoint:
        model.load_state_dict(checkpoint["model"])
    else:
        try:
            model.load_state_dict(checkpoint)
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Please check your model file format")
            return None
    
    model = model.to(device)
    model.eval()
    
    print("Model loaded successfully")
    
    for key in ["val_acc", "validation_accuracy", "best_acc", "accuracy"]:
        if key in checkpoint:
            print(f"Training validation accuracy: {checkpoint[key]:.4f}")
            break
    
    return model


def test_model(model, test_loader):
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    correct = 0
    total = 0
    
    print("Testing model...")
    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Testing"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    accuracy = correct / total
    print(f"Overall Test Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    
    return np.array(all_labels), np.array(all_preds), np.array(all_probs), accuracy


def show_results(y_true, y_pred, accuracy):
    target_names = [idx_to_label[i] for i in range(num_classes)]
    
    print(f"\n{'='*50}")
    print("DETAILED TEST RESULTS")
    print(f"{'='*50}")
    print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print("\nPer-class Results:")
    
    report = classification_report(y_true, y_pred, target_names=target_names, 
                                 digits=4, output_dict=True)
    
    for i, class_name in enumerate(target_names):
        if str(i) in report:
            precision = report[str(i)]['precision']
            recall = report[str(i)]['recall']
            f1 = report[str(i)]['f1-score']
            support = report[str(i)]['support']
            print(f"  {class_name:6}: Precision={precision:.3f}, Recall={recall:.3f}, "
                  f"F1={f1:.3f}, Support={support}")
    
    print(f"\nMacro avg: Precision={report['macro avg']['precision']:.3f}, "
          f"Recall={report['macro avg']['recall']:.3f}, "
          f"F1={report['macro avg']['f1-score']:.3f}")
    
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=target_names, yticklabels=target_names)
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title(f"Confusion Matrix\nOverall Accuracy: {accuracy:.3f}")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()


def run_test():
    df = load_test_data()
    
    test_transforms = transforms.Compose([
        transforms.Resize((300, 300)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    test_dataset = SkinDataset(df, transform=test_transforms)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                           shuffle=False, num_workers=2, pin_memory=True)
    
    model = load_trained_model(model_path)
    
    y_true, y_pred, y_probs, accuracy = test_model(model, test_loader)
    
    show_results(y_true, y_pred, accuracy)
    
    return accuracy, y_true, y_pred, y_probs


if __name__ == "__main__":
    print("Starting model test...")
    accuracy, y_true, y_pred, y_probs = run_test()
    print(f"\nTesting complete! Final accuracy: {accuracy:.4f}")

In [None]:
single test img

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import timm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F


model_path = "/data/models/HAM10000"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

label_map = {
    'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3,
    'mel': 4, 'nv': 5, 'vasc': 6
}

idx_to_label = {v: k for k, v in label_map.items()}

disease_names = {
    'akiec': 'Actinic Keratoses',
    'bcc': 'Basal Cell Carcinoma', 
    'bkl': 'Benign Keratosis',
    'df': 'Dermatofibroma',
    'mel': 'Melanoma',
    'nv': 'Melanocytic Nevus',
    'vasc': 'Vascular Lesion'
}

num_classes = len(label_map)


def load_model(model_path):
    print(f"Loading model from {model_path}...")
    
    model = timm.create_model('efficientnet_b3', pretrained=False)
    model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    
    checkpoint = torch.load(model_path, map_location=device)
    
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    elif "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
    elif "model" in checkpoint:
        model.load_state_dict(checkpoint["model"])
    else:
        model.load_state_dict(checkpoint)
    
    model = model.to(device)
    model.eval()
    
    print("Model loaded successfully")
    return model


def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((300, 300)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    try:
        image = Image.open(image_path).convert("RGB")
        original_image = image.copy()
        input_tensor = transform(image).unsqueeze(0)
        
        return input_tensor, original_image
    
    except Exception as e:
        print(f"Error loading image: {e}")
        return None, None


def predict_image(model, image_tensor):
    with torch.no_grad():
        image_tensor = image_tensor.to(device)
        outputs = model(image_tensor)
        probabilities = F.softmax(outputs, dim=1)
        _, predicted_class = torch.max(outputs, 1)
        predicted_class = predicted_class.item()
        confidence = probabilities[0][predicted_class].item()
        all_probs = probabilities[0].cpu().numpy()
        
    return predicted_class, confidence, all_probs


def display_results(original_image, predicted_class, confidence, all_probs, image_path):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    ax1.imshow(original_image)
    ax1.axis('off')
    ax1.set_title(f'Input Image\n{image_path.split("/")[-1]}', fontsize=12)
    
    predicted_label = idx_to_label[predicted_class]
    predicted_disease = disease_names[predicted_label]
    
    class_names = [disease_names[idx_to_label[i]] for i in range(num_classes)]
    colors = ['red' if i == predicted_class else 'lightblue' for i in range(num_classes)]
    
    bars = ax2.barh(range(num_classes), all_probs * 100, color=colors)
    ax2.set_yticks(range(num_classes))
    ax2.set_yticklabels(class_names, fontsize=10)
    ax2.set_xlabel('Confidence (%)', fontsize=12)
    ax2.set_title('Prediction Confidence for All Classes', fontsize=12)
    ax2.set_xlim(0, 100)
    
    for i, (bar, prob) in enumerate(zip(bars, all_probs)):
        width = bar.get_width()
        ax2.text(width + 1, bar.get_y() + bar.get_height()/2, 
                f'{prob*100:.1f}%', ha='left', va='center', fontsize=9)
    
    plt.tight_layout()
    
    print(f"\n{'='*60}")
    print("PREDICTION RESULTS")
    print(f"{'='*60}")
    print(f"Image: {image_path.split('/')[-1]}")
    print(f"Predicted Class: {predicted_disease} ({predicted_label.upper()})")
    print(f"Confidence: {confidence*100:.2f}%")
    print("\nAll Class Probabilities:")
    
    prob_pairs = [(disease_names[idx_to_label[i]], all_probs[i]*100) for i in range(num_classes)]
    prob_pairs.sort(key=lambda x: x[1], reverse=True)
    
    for disease, prob in prob_pairs:
        indicator = "<- PREDICTED" if disease == predicted_disease else ""
        print(f"   {disease:20}: {prob:5.1f}% {indicator}")
    
    plt.show()
    
    return predicted_label, confidence


def test_single_image(image_path, model_path="best_model.pth"):
    print(f"Testing image: {image_path}")
    
    model = load_model(model_path)
    image_tensor, original_image = preprocess_image(image_path)
    
    if image_tensor is None:
        return None, None
    
    predicted_class, confidence, all_probs = predict_image(model, image_tensor)
    predicted_label, confidence = display_results(
        original_image, predicted_class, confidence, all_probs, image_path
    )
    
    return predicted_label, confidence


def main():
    image_path = "your_image.jpg"
    
    try:
        predicted_label, confidence = test_single_image(image_path)
        
        if predicted_label:
            print("\nTesting complete!")
            print(f"Final prediction: {disease_names[predicted_label]} with {confidence*100:.1f}% confidence")
            
    except FileNotFoundError:
        print(f"Image file not found: {image_path}")
        print("Please update the image_path variable with the correct path to your image")
        
    except Exception as e:
        print(f"Error during testing: {e}")


def interactive_test():
    print("Single Image Skin Lesion Classifier")
    print("=" * 40)
    
    while True:
        image_path = input("\nEnter image path (or 'quit' to exit): ").strip()
        
        if image_path.lower() in ['quit', 'exit', 'q']:
            print("Goodbye!")
            break
            
        if not image_path:
            print("Please enter a valid image path")
            continue
            
        try:
            predicted_label, confidence = test_single_image(image_path)
            
            if predicted_label:
                print("\nPrediction Summary:")
                print(f"   Disease: {disease_names[predicted_label]}")
                print(f"   Confidence: {confidence*100:.1f}%")
                
        except Exception as e:
            print(f"Error: {e}")
            
        print("\n" + "-" * 50)


if __name__ == "__main__":
    interactive_test()