In [None]:
from fastai.vision.all import *
import pandas as pd
import torchvision

In [None]:
input_data = "*****"
mycsv = pd.read_csv(input_data)
mycsv["fn"] = ["../jpegs_same_contrast_cropped/" + x.split("/")[-1].split("-")[-1] for x in mycsv.image]
mycsv = mycsv.loc[mycsv["choice"].isin(["Codman Hakim", "Codman Certas Plus", "Sophysa Sophy SM8", "proGAV 2.0"])]

In [None]:
with open("splits.pkl", "rb") as f:
    splits = pickle.load(f)

In [None]:
def get_model_name(s: int): return f'resnet34_pretrained_4_ventile_on_patient_split_squish_pret_{s}'

In [None]:
def create_dls(split, bs=32):
    dblock = DataBlock(
                        blocks=(ImageBlock, CategoryBlock(
                            vocab=["Codman Hakim", "Codman Certas Plus", "Sophysa Sophy SM8", "proGAV 2.0"],
                            sort=False
                        )), 
                        get_x=ColReader("fn"),
                        get_y=ColReader("choice"),
                        splitter=IndexSplitter(split[1]),
                        item_tfms=[Resize(512)],
                        batch_tfms=aug_transforms()
    )
    dsets = dblock.datasets(mycsv)
    dls = dblock.dataloaders(mycsv, bs=bs)
    return dls


def create_model():
    resnet34 = torchvision.models.resnet34(weights=ResNet34_Weights.DEFAULT)
    body = create_body(resnet34, cut=-2)
    head = create_head(512, 4)
    return nn.Sequential(body, head)

In [None]:
import torch
import torch.nn.functional as F
from collections import defaultdict

def calculate_prediction_metrics(softmax_probs):
    """
    Calculate various prediction confidence metrics.
    
    Args:
        softmax_probs (torch.Tensor): Tensor of shape [n, 4] containing softmax probabilities
        
    Returns:
        tuple: (entropy, max_prob, prob_gap) tensors of shape [n]
    """
    # Calculate entropy
    eps = 1e-15
    softmax_probs_clipped = torch.clamp(softmax_probs, min=eps, max=1.0)
    entropy = -torch.sum(softmax_probs_clipped * torch.log2(softmax_probs_clipped), dim=1)
    
    # Get top 2 probabilities for each prediction
    top2_probs, _ = torch.topk(softmax_probs, k=2, dim=1)
    
    # Maximum probability
    max_prob = top2_probs[:, 0]
    
    # Gap between top 2 probabilities
    prob_gap = top2_probs[:, 0] - top2_probs[:, 1]
    
    return entropy, max_prob, prob_gap

def analyze_metrics_by_target(softmax_probs, targets):
    """
    Calculate and analyze prediction metrics grouped by target class.
    
    Args:
        softmax_probs (torch.Tensor): Tensor of shape [n, 4] containing softmax probabilities
        targets (torch.Tensor): Tensor of shape [n] containing target classes
        
    Returns:
        dict: Dictionary containing metric statistics for each class
    """
    entropy, max_prob, prob_gap = calculate_prediction_metrics(softmax_probs)
    unique_targets = torch.unique(targets)
    
    metrics_by_class = {}
    for target in unique_targets:
        mask = (targets == target)
        class_entropy = entropy[mask]
        class_max_prob = max_prob[mask]
        class_prob_gap = prob_gap[mask]
        
        metrics_by_class[target.item()] = {
            'entropy': {
                'mean': class_entropy.mean().item(),
                'std': class_entropy.std().item(),
                'min': class_entropy.min().item(),
                'max': class_entropy.max().item(),
                'values': class_entropy
            },
            'max_prob': {
                'mean': class_max_prob.mean().item(),
                'std': class_max_prob.std().item(),
                'min': class_max_prob.min().item(),
                'max': class_max_prob.max().item(),
                'values': class_max_prob
            },
            'prob_gap': {
                'mean': class_prob_gap.mean().item(),
                'std': class_prob_gap.std().item(),
                'min': class_prob_gap.min().item(),
                'max': class_prob_gap.max().item(),
                'values': class_prob_gap
            },
            'count': mask.sum().item()
        }
    
    return metrics_by_class

def aggregate_metrics_across_splits(splits_data):
    """
    Aggregate prediction metrics across multiple splits.
    
    Args:
        splits_data: List of tuples, each containing (softmax_probs, targets) for a split
        
    Returns:
        dict: Aggregated statistics across all splits
    """
    all_metrics_by_class = defaultdict(lambda: defaultdict(list))
    aggregated_stats = {}
    
    # Process each split
    for split_idx, split in enumerate(splits_data):
        dls = create_dls(split)
        model = create_model()

        learn = Learner(dls, model).load(get_model_name(split_idx))
        
        preds, targets = learn.get_preds() 
        split_stats = analyze_metrics_by_target(preds, targets)
        
        # Collect metrics by class across splits
        for class_idx, stats in split_stats.items():
            for metric_name in ['entropy', 'max_prob', 'prob_gap']:
                all_metrics_by_class[class_idx][metric_name].append({
                    'mean': stats[metric_name]['mean'],
                    'std': stats[metric_name]['std'],
                    'values': stats[metric_name]['values'],
                    'count': stats['count'],
                    'split_idx': split_idx
                })
    
    # Calculate aggregate statistics for each class
    for class_idx, metrics in all_metrics_by_class.items():
        aggregated_stats[class_idx] = {}
        
        for metric_name, split_stats_list in metrics.items():
            # Concatenate all values for this metric across splits
            all_values = torch.cat([stats['values'] for stats in split_stats_list])
            
            # Calculate split-level statistics
            split_means = torch.tensor([stats['mean'] for stats in split_stats_list])
            
            aggregated_stats[class_idx][metric_name] = {
                'overall_mean': all_values.mean().item(),
                'overall_std': all_values.std().item(),
                'overall_min': all_values.min().item(),
                'overall_max': all_values.max().item(),
                'mean_std_error': (split_means.std() / torch.sqrt(torch.tensor(len(split_means)))).item(),
                'per_split_means': [stats['mean'] for stats in split_stats_list],
                'per_split_stds': [stats['std'] for stats in split_stats_list]
            }
        
        # Add sample count information
        aggregated_stats[class_idx]['total_samples'] = sum(
            metrics['entropy'][i]['count'] for i in range(len(splits_data))
        )
        aggregated_stats[class_idx]['per_split_counts'] = [
            metrics['entropy'][i]['count'] for i in range(len(splits_data))
        ]
    
    return aggregated_stats

In [None]:
aggregated_results = aggregate_metrics_across_splits(splits)
    
# Print results
print("\nAggregated analysis across splits:")
for class_idx, stats in aggregated_results.items():
    print(f"\nClass {class_idx} (Total samples: {stats['total_samples']}):")
    
    for metric in ['entropy', 'max_prob', 'prob_gap']:
        metric_stats = stats[metric]
        print(f"\n  {metric.upper()}:")
        print(f"    Mean: {metric_stats['overall_mean']:.3f} ± {metric_stats['mean_std_error']:.3f}")
        print(f"    Overall std: {metric_stats['overall_std']:.3f}")
        print(f"    Range: [{metric_stats['overall_min']:.3f}, {metric_stats['overall_max']:.3f}]")
        print(f"    Per-split means: {[f'{x:.3f}' for x in metric_stats['per_split_means']]}")

In [None]:
from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score

results = []
reports = []

for i, split in enumerate(splits):
    dls = create_dls(split)
    model = create_model()

    learn = Learner(dls, model).load(get_model_name(i))
    
    preds, targets = learn.get_preds()
    predicted_labels = preds.argmax(dim=1)
    print("split ", i+1, "\n")
    reports.append(classification_report(targets, predicted_labels, output_dict=True))
    results.append({
        'precision': precision_score(targets, predicted_labels, average='macro'),
        'recall': recall_score(targets, predicted_labels, average='macro'),
        'f1_score': f1_score(targets, predicted_labels, average='macro'),
        'accuracy': accuracy_score(targets, predicted_labels),
    })
print(results)


In [None]:
results

In [None]:
np.array(list(map(lambda x: x['f1_score'], results))).mean().round(2), np.array(list(map(lambda x: x['accuracy'], results))).std().round(2)
for metric in ['precision', 'recall', 'f1_score', 'accuracy']:
    print(np.array(list(map(lambda x: x[metric], results))).mean().round(2), np.array(list(map(lambda x: x[metric], results))).std().round(2))

In [None]:
def print_result(metric):
    print(metric, ":", np.array([r[metric] for r in results]).mean().round(2), " +- ", np.array([r[metric] for r in results]).std().round(2))

for metric in ['precision', 'recall', 'f1_score', 'accuracy']:
    print_result(metric)

In [None]:
rs = {}

for r in reports:
    for o in ["0", "1", "2", "3"]:
        if o not in rs:
            rs[o] = {}
        for m in ["precision", "recall", "f1-score"]:
            if m not in rs[o]:
                rs[o][m] = []
            rs[o][m].append(r[o][m])

for cl in rs:
    print("Class: ", cl)
    for metric in rs[cl]:
        print(metric, np.mean(rs[cl][metric]).round(2), " +- ", np.std(rs[cl][metric]).round(2))

In [None]:
cfms = []

for split in range(0,5):
    dls = create_dls(splits[split])
    model = create_model()

    learn = Learner(dls, model).load(get_model_name(split))
    interp = ClassificationInterpretation.from_learner(learn)
    cfms.append(interp.confusion_matrix())

overall_cfm = np.array(cfms).mean(axis=0)
normalized_cfm = overall_cfm.astype('float') / overall_cfm.sum(axis=1)[:, np.newaxis]

vocab = interp.vocab

fig = plt.figure()
plt.imshow(normalized_cfm, interpolation='nearest', cmap="Blues")
plt.title("Confusion matrix")
tick_marks = np.arange(len(vocab))
plt.xticks(tick_marks, vocab, rotation=90)
plt.yticks(tick_marks, vocab, rotation=0)

thresh = normalized_cfm.max() / 2.
for i, j in itertools.product(range(normalized_cfm.shape[0]), range(normalized_cfm.shape[1])):
    coeff = f'{normalized_cfm[i, j]:.{2}f}'
    plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white"
                if normalized_cfm[i, j] > thresh else "black")

ax = fig.gca()
ax.set_ylim(len(vocab)-.5,-.5)

plt.tight_layout()
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.grid(False)

#save it as pdf

plt.savefig("confusion_matrix_ventile.jpeg", dpi=300, bbox_inches='tight')

In [None]:
plt.savefig('confusion_matrix.pdf', format='pdf')