In [25]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import to_rgba
from collections import defaultdict
import pathlib
import tifffile
from sklearn.metrics import jaccard_score
from tqdm import tqdm


def analyze_contrasting_performance(df, overall_scores, output_dir="plots"):
    # Create output directory if it doesn't exist
    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # First, incorporate silver truth into the DataFrame
    enhanced_df = create_silver_truth_df(overall_scores, df)
    
    # Then proceed with the analysis
    cases = find_contrasting_cases(enhanced_df)
    
    if not cases:
        print("No contrasting cases found!")
        return
    
    print("\nTop Contrasting Performance Cases:")
    print("-" * 80)
    
    for i, case in enumerate(cases[:3], 1):
        case_name = case['image'].split('/')[-1].replace('.', '_')
        print(f"\nCase {i}:")
        print(f"Image: {case['image'].split('/')[-1]}")
        print(f"Competitors: {case['competitor1']} vs {case['competitor2']}")
        print(f"Contrast Magnitude: {case['contrast_magnitude']:.3f}")
        
        # Create all five plot variants
        plot_variants = [
            {
                'name': 'all_competitors',
                'params': {'show_all': True, 'include_silver_truth': False},
                'title': f"Case {i}: All Competitor Models"
            },
            {
                'name': 'contrasting_pair',
                'params': {'show_all': False, 'include_silver_truth': False},
                'title': f"Case {i}: Contrasting Competitors Highlighted"
            },
            {
                'name': 'all_with_silver',
                'params': {'show_all': True, 'include_silver_truth': True},
                'title': f"Case {i}: All Competitors with Silver Truth"
            },
            {
                'name': 'silver_highlighted',
                'params': {'show_all': False, 'include_silver_truth': True, 'highlight_silver': True},
                'title': f"Case {i}: Silver Truth Highlighted"
            },
            {
                'name': 'best_highlighted',
                'params': {'show_all': True, 'highlight_best': True},
                'title': f"Case {i}: Best Performers Highlighted"
            }
        ]
        
        for variant in plot_variants:
            plt.figure(figsize=(15, 8))  # Larger figure size to accommodate legend
            visualize_case(case, enhanced_df, **variant['params'])
            plt.title(variant['title'])
            
            # Save plot with descriptive filename
            filename = f"case_{i}_{case_name}_{variant['name']}.png"
            filepath = output_dir / filename
            plt.savefig(filepath, bbox_inches='tight', dpi=300)
            plt.close()
            
            print(f"Saved plot: {filename}")
            
    return cases, enhanced_df

def visualize_case(case, df, max_labels=10, show_all=False, include_silver_truth=True, highlight_silver=False, highlight_best=False):
    """
    Enhanced visualization function with label limitation and multiple display options.
    """
    comp1 = case['competitor1']
    comp2 = case['competitor2']
    image = case['image']
    image_data = df[df['Gt_source_file'] == image]
    image_data = image_data[image_data['J_value'] > 0.001]
    
    all_competitors = sorted(comp for comp in image_data['competitor_name'].unique() 
                           if comp != 'Silver Truth')
    
    # Select most interesting labels
    label_interests = []
    all_labels = sorted(image_data['Label'].unique())
    
    for label in all_labels:
        label_data = image_data[image_data['Label'] == label]
        scores_dict = {row['competitor_name']: row['J_value'] 
                      for _, row in label_data.iterrows()}
        
        # Calculate variance of scores for this label (excluding silver truth)
        scores = [v for k, v in scores_dict.items() if k != 'Silver Truth']
        variance = np.var(scores) if scores else 0
        
        # Calculate difference between main competitors if both present
        main_diff = 0
        if comp1 in scores_dict and comp2 in scores_dict:
            main_diff = abs(scores_dict[comp1] - scores_dict[comp2])
        
        # Calculate maximum score difference between any competitors
        max_diff = 0
        if scores:
            max_diff = max(scores) - min(scores)
        
        # Combine factors for interestingness
        interestingness = variance + main_diff + max_diff
        label_interests.append((label, interestingness))
    
    # Sort by interestingness and take top max_labels
    label_interests.sort(key=lambda x: x[1], reverse=True)
    selected_labels = [label for label, _ in label_interests[:max_labels]]
    selected_labels.sort()  # Sort labels numerically for consistent display
    
    # Create label mapping
    label_to_seq = {label: idx + 1 for idx, label in enumerate(selected_labels)}
    seq_to_label = {idx + 1: label for idx, label in enumerate(selected_labels)}
    
    # For highlighting best performers
    if highlight_best:
        best_scores = {}
        for label in selected_labels:
            label_data = image_data[image_data['Label'] == label]
            if not label_data.empty:
                max_score = label_data['J_value'].max()
                best_scores[label] = set(label_data[label_data['J_value'] == max_score]['competitor_name'])
    
    # Plot competitors
    if show_all:
        colors = plt.cm.tab20(np.linspace(0, 1, len(all_competitors)))
        for idx, competitor in enumerate(all_competitors):
            comp_data = image_data[image_data['competitor_name'] == competitor]
            scores = []
            x_values = []
            
            for label in selected_labels:  # Use selected_labels instead of all_labels
                label_score = comp_data[comp_data['Label'] == label]['J_value']
                if not label_score.empty and label_score.iloc[0] > 0.001:
                    scores.append(label_score.iloc[0])
                    x_values.append(label_to_seq[label])
            
            if scores:
                if highlight_best:
                    has_best = any(competitor in best_scores.get(label, set()) for label in selected_labels)
                    alpha = 1.0 if has_best else 0.2
                    linewidth = 2.5 if has_best else 1
                    label = f"{competitor} {'(Best)' if has_best else ''}"
                else:
                    alpha = 1.0
                    linewidth = 2
                    label = competitor
                
                plt.plot(x_values, scores, '-o', color=colors[idx], 
                        label=label, linewidth=linewidth, markersize=6,
                        alpha=alpha)
    else:
        # Plot non-highlighted competitors
        for competitor in all_competitors:
            if competitor not in [comp1, comp2]:
                comp_data = image_data[image_data['competitor_name'] == competitor]
                scores = []
                x_values = []
                
                for label in selected_labels:  # Use selected_labels instead of all_labels
                    label_score = comp_data[comp_data['Label'] == label]['J_value']
                    if not label_score.empty and label_score.iloc[0] > 0.001:
                        scores.append(label_score.iloc[0])
                        x_values.append(label_to_seq[label])
                
                if scores:
                    plt.plot(x_values, scores, '-o', color='gray', alpha=0.2, 
                            linewidth=1, markersize=4, label=competitor)
        
        # Plot highlighted competitors
        if not highlight_silver:
            for competitor, color in [(comp1, '#FF6B6B'), (comp2, '#4ECDC4')]:
                comp_data = image_data[image_data['competitor_name'] == competitor]
                scores = []
                x_values = []
                
                for label in selected_labels:  # Use selected_labels instead of all_labels
                    label_score = comp_data[comp_data['Label'] == label]['J_value']
                    if not label_score.empty and label_score.iloc[0] > 0.001:
                        scores.append(label_score.iloc[0])
                        x_values.append(label_to_seq[label])
                
                if scores:
                    plt.plot(x_values, scores, '-o', color=color, 
                            label=f"{competitor} (Highlighted)", linewidth=2.5, markersize=8)
    
    # Plot silver truth if included
    if include_silver_truth:
        st_data = image_data[image_data['competitor_name'] == 'Silver Truth']
        st_scores = []
        st_x_values = []
        
        for label in selected_labels:  # Use selected_labels instead of all_labels
            st_score = st_data[st_data['Label'] == label]['J_value']
            if not st_score.empty and st_score.iloc[0] > 0.001:
                st_scores.append(st_score.iloc[0])
                st_x_values.append(label_to_seq[label])
        
        if st_scores:
            st_style = {'color': 'gold', 'linewidth': 2.5, 'markersize': 12, 
                       'alpha': 1.0 if highlight_silver else 0.8}
            label = 'Silver Truth (Highlighted)' if highlight_silver else 'Silver Truth'
            plt.plot(st_x_values, st_scores, '-*', label=label, **st_style)
    
    # Formatting
    plt.xlabel('Sequential Label Number (Original Label)')
    plt.ylabel('Jaccard Score')
    plt.grid(True, alpha=0.2)
    
    # Place legend outside the plot
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    # Convert J_values to float and handle the y-axis limits
    #valid_scores = pd.to_numeric(image_data['J_value'], errors='coerce')
    valid_scores = image_data['J_value'][image_data['J_value'] > 0.001]
    print(valid_scores)
    if not valid_scores.empty:
        plt.ylim(max(0, valid_scores.min() - 0.1), min(1, valid_scores.max() + 0.1))
    
    x_ticks = list(label_to_seq.values())
    plt.xticks(x_ticks, [f'{seq} ({seq_to_label[seq]})' for seq in x_ticks], rotation=45)
    plt.xlim(min(x_ticks) - 0.5, max(x_ticks) + 0.5)
    
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.tight_layout()
# Usage:
# cases, enhanced_df = analyze_contrasting_performance(df, overall_scores, output_dir="path/to/save/plots")

In [23]:
overall_scores = {}
for gt_image_path in tqdm(gt_source_files):
    synced_gt_image_path = find_gt_image(gt_image_path)
    gt_image = tifffile.imread(synced_gt_image_path)
    st_image_path = find_silvertruth_image(gt_image_path)
    st_image = tifffile.imread(st_image_path)
    #print(gt_image_path)
    labels = np.unique(gt_image)[1:]
    scores = {}
    # This is most likely wrong, but let's just have those values shrank to values from 1 to x
    for order, label in enumerate(labels):
        label_layer = np.zeros_like(gt_image)
        label_layer[gt_image == label] = 1
        mask_layer = np.zeros_like(st_image) 
        mask_layer[st_image == label] = 1
        j = np.round(jaccard_score(label_layer, mask_layer, average='micro'),6)
        scores[order+1] = j
    overall_scores[st_image_path] = scores

100%|███████████████████████████████████████████| 57/57 [00:36<00:00,  1.54it/s]


In [7]:
def calculate_jaccard_scores(gt_image, mask_image):
    labels = np.unique(gt_image)[1:]
    scores = {}
    for label in labels:
        label_layer = np.zeros_like(gt_image)
        label_layer[gt_image == label] = 1
        mask_layer = np.zeros_like(mask_image) 
        mask_layer[mask_image == label] = 1
        j = jaccard_score(label_layer, mask_layer, average='micro')
        scores[label] = j
    return scores
def get_competitor_name(file_path: str):
    """returns the competitor folder name given a path to a file."""
    return file_path.split('/')[2]

# open corresponding silvertruth:
def find_silvertruth_image(gt_image_path: str) -> str:
    input_folder, dataset, inner_split, seg, filename = gt_image_path.split('/')
    silvertruth_image_path = pathlib.Path(input_folder, dataset, f"{inner_split.replace('G', 'S')}_sync", seg, filename)
    return str(silvertruth_image_path)

def find_gt_image(gt_image_path: str) -> str:
    input_folder, dataset, inner_split, seg, filename = gt_image_path.split('/')
    synced_gt_image_path = pathlib.Path(input_folder, dataset, f"{inner_split}_sync", filename)
    return str(synced_gt_image_path)

df = pd.read_csv('preprocessed_dataset.csv')
df['competitor_name'] = df['Mask_file'].apply(get_competitor_name)


gt_source_files = list(df.drop_duplicates(subset=['Gt_mask_file'])['Gt_mask_file'])


In [10]:
def create_silver_truth_df(overall_scores, df):
    """Convert silver truth dict to DataFrame matching the structure of the main df"""
    st_rows = []
    
    for st_image_path, scores in overall_scores.items():
        # Extract corresponding GT source file
        # gt_source_file = str(pathlib.Path(*st_image_path.split('_ST_sync')[0].split('/')[:-1], 
        #                                 st_image_path.split('/')[-1]))
        input_folder, dataset, inner_split, seg, filename = st_image_path.split('/')
        split_number = inner_split.split('_')[0]
        filename = filename.replace('man_seg', 't')
        gt_source_file = str(pathlib.Path(input_folder, dataset, split_number,filename))
        # Extract corresponding GT mask file
        input_folder, dataset, inner_split, seg, filename = st_image_path.split('/')
        split_number = inner_split.split('_')[0]
        gt_mask_file = str(pathlib.Path(input_folder, dataset, f"{split_number}_GT", seg, filename))
        for label, j_value in scores.items():
            st_rows.append({
                'Mask_file': st_image_path,
                'Gt_source_file': gt_source_file,
                'Gt_mask_file': gt_mask_file,
                'Label': label,
                'J_value': j_value,
                'competitor_name': 'Silver Truth'  # Special competitor name for silver truth
            })
    
    st_df = pd.DataFrame(st_rows)
    return pd.concat([df, st_df], ignore_index=True)


In [12]:
def find_contrasting_cases(df):
    """Find cases with contrasting performance, ignoring zero values."""
    # Group data by image
    image_data = defaultdict(lambda: defaultdict(dict))
    
    # Filter out zero values
    df_filtered = df[df['J_value'] > 0.001]
    
    # First, organize data by image -> label -> competitor -> score
    for _, row in df_filtered.iterrows():
        image = row['Gt_source_file']
        label = row['Label']
        competitor = row['competitor_name']
        score = row['J_value']
        image_data[image][label][competitor] = score
    
    interesting_cases = []
    
    # Analyze each image
    for image, label_data in image_data.items():
        # Only consider images with multiple labels
        if len(label_data) < 2:
            continue
            
        # Calculate mean score per competitor per label
        competitor_label_scores = defaultdict(dict)
        for label, comp_scores in label_data.items():
            for competitor, score in comp_scores.items():
                competitor_label_scores[competitor][label] = score
        
        # Look for contrasting performance
        for comp1 in competitor_label_scores:
            for comp2 in competitor_label_scores:
                if comp1 >= comp2:
                    continue
                    
                performance_diff = []
                for label in label_data:
                    # Only consider labels where both competitors have scores
                    if label in competitor_label_scores[comp1] and label in competitor_label_scores[comp2]:
                        diff = competitor_label_scores[comp1][label] - competitor_label_scores[comp2][label]
                        performance_diff.append((label, diff))
                
                # Check if there are contrasting performances (positive and negative differences)
                pos_diffs = [d for _, d in performance_diff if d > 0]
                neg_diffs = [d for _, d in performance_diff if d < 0]
                
                if pos_diffs and neg_diffs:  # If we have both positive and negative differences
                    max_contrast = max(pos_diffs) + abs(min(neg_diffs))  # Total contrast magnitude
                    
                    interesting_cases.append({
                        'image': image,
                        'competitor1': comp1,
                        'competitor2': comp2,
                        'contrast_magnitude': max_contrast,
                        'label_differences': performance_diff,
                        'scores': {
                            comp1: competitor_label_scores[comp1],
                            comp2: competitor_label_scores[comp2]
                        }
                    })
    
    # Sort cases by contrast magnitude
    interesting_cases.sort(key=lambda x: x['contrast_magnitude'], reverse=True)
    return interesting_cases


In [16]:
# At the start of your analysis, when loading the data:
df['J_value'] = pd.to_numeric(df['J_value'], errors='coerce')

In [26]:
cases, enhanced_df = analyze_contrasting_performance(df, overall_scores)


Top Contrasting Performance Cases:
--------------------------------------------------------------------------------

Case 1:
Image: t1748.tif
Competitors: CALT-US vs KIT-Sch-GE
Contrast Magnitude: 0.921
153     0.754875
154     0.814249
155     0.914634
156     0.820896
157     0.843511
          ...   
3124    0.860465
3125    0.879630
3126    0.902208
3127    0.830729
3128    0.890380
Name: J_value, Length: 918, dtype: float64
Saved plot: case_1_t1748_tif_all_competitors.png
153     0.754875
154     0.814249
155     0.914634
156     0.820896
157     0.843511
          ...   
3124    0.860465
3125    0.879630
3126    0.902208
3127    0.830729
3128    0.890380
Name: J_value, Length: 918, dtype: float64
Saved plot: case_1_t1748_tif_contrasting_pair.png
153     0.754875
154     0.814249
155     0.914634
156     0.820896
157     0.843511
          ...   
3124    0.860465
3125    0.879630
3126    0.902208
3127    0.830729
3128    0.890380
Name: J_value, Length: 918, dtype: float64
Saved p