In [3]:
%pip install seaborn
import seaborn as sns
import matplotlib.pyplot as plt
import polars as pl

def plot_metric_bars(scores, col_names, title, ylim=None, colors=None):
    df = pl.DataFrame({
        'Metric': [name for name in col_names for _ in range(len(scores))],
        'Score': [val for score_list in scores for val in score_list],
        'Epoch': [f'Epoch {i+1}' for i in range(len(scores)) for _ in col_names]
    })
    
    plt.figure(figsize=(8, 6))
    ax = sns.barplot(data=df.to_pandas(), x='Epoch', y='Score', hue='Metric', palette=colors)
    plt.title(title)
    if ylim:
        plt.ylim(ylim)
    return plt

def plot_loss_curves(losses, col_names, title, ylim=None, colors=None):
    df = pl.DataFrame({
        'Loss': [val for loss_list in losses for val in loss_list],
        'Type': [name for name in col_names for _ in range(len(losses))],
        'Epoch': [f'Epoch {i+1}' for i in range(len(losses)) for _ in col_names]
    })
    
    plt.figure(figsize=(8, 6))
    ax = sns.lineplot(data=df.to_pandas(), x='Epoch', y='Loss', hue='Type', marker='o', palette=colors)
    plt.title(title)
    if ylim:
        plt.ylim(ylim)
    return plt

def create_performance_subplot(scores_f1, scores_acc, losses, col_names, 
                             main_title="Model Performance", ylims=None, colors=None):
    fig = plt.figure(figsize=(15, 10))
    fig.suptitle(main_title)
    
    plt.subplot(2, 2, 1)
    df_f1 = pl.DataFrame({
        'Metric': [name for name in col_names for _ in range(len(scores_f1))],
        'Score': [val for score_list in scores_f1 for val in score_list],
        'Epoch': [f'Epoch {i+1}' for i in range(len(scores_f1)) for _ in col_names]
    })
    sns.barplot(data=df_f1.to_pandas(), x='Epoch', y='Score', hue='Metric', palette=colors)
    plt.title('F1 Score Comparison')
    if ylims and 'f1' in ylims:
        plt.ylim(ylims['f1'])
    
    plt.subplot(2, 2, 2)
    df_acc = pl.DataFrame({
        'Metric': [name for name in col_names for _ in range(len(scores_acc))],
        'Score': [val for score_list in scores_acc for val in score_list],
        'Epoch': [f'Epoch {i+1}' for i in range(len(scores_acc)) for _ in col_names]
    })
    sns.barplot(data=df_acc.to_pandas(), x='Epoch', y='Score', hue='Metric', palette=colors)
    plt.title('Accuracy Comparison')
    if ylims and 'accuracy' in ylims:
        plt.ylim(ylims['accuracy'])
    
    plt.subplot(2, 2, 3)
    df_loss = pl.DataFrame({
        'Loss': [val for loss_list in losses for val in loss_list],
        'Type': [name for name in col_names for _ in range(len(losses))],
        'Epoch': [f'Epoch {i+1}' for i in range(len(losses)) for _ in col_names]
    })
    sns.lineplot(data=df_loss.to_pandas(), x='Epoch', y='Loss', hue='Type', marker='o', palette=colors)
    plt.title('Loss Curves')
    if ylims and 'loss' in ylims:
        plt.ylim(ylims['loss'])
    
    plt.subplot(2, 2, 4).set_visible(False)
    plt.tight_layout()
    return plt






Collecting seaborn
  Using cached seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)
Using cached seaborn-0.13.2-py3-none-any.whl (294 kB)
Installing collected packages: seaborn
Successfully installed seaborn-0.13.2
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [4]:
# Create test data for VIT and DenseNet models
vit_data = {
    'train': {'f1': [0.6], 'accuracy': [0.6], 'loss': [0.8]},
    'val': {'f1': [0.52], 'accuracy': [0.52], 'loss': [0.9]}, 
    'test': {'f1': [0.54], 'accuracy': [0.54], 'loss': [0.85]}
}

densenet_data = {
    'train': {'f1': [0.86], 'accuracy': [0.86], 'loss': [0.3]},
    'val': {'f1': [0.53], 'accuracy': [0.53], 'loss': [0.8]},
    'test': {'f1': [0.51], 'accuracy': [0.51], 'loss': [0.82]}
}

# Combine data for plotting
scores_f1 = [
    [vit_data['train']['f1'][0], vit_data['val']['f1'][0], vit_data['test']['f1'][0]],
    [densenet_data['train']['f1'][0], densenet_data['val']['f1'][0], densenet_data['test']['f1'][0]]
]

scores_acc = [
    [vit_data['train']['accuracy'][0], vit_data['val']['accuracy'][0], vit_data['test']['accuracy'][0]],
    [densenet_data['train']['accuracy'][0], densenet_data['val']['accuracy'][0], densenet_data['test']['accuracy'][0]]
]

losses = [
    [vit_data['train']['loss'][0], vit_data['val']['loss'][0], vit_data['test']['loss'][0]],
    [densenet_data['train']['loss'][0], densenet_data['val']['loss'][0], densenet_data['test']['loss'][0]]
]

col_names = ['ViT (512x512)', 'DenseNet']
colors = ['blue', 'green','']
ylims = {'f1': (0, 1.0), 'accuracy': (0, 1.0), 'loss': (0, 1.0)}

# Create and display plot
plt.style.use('seaborn')
create_performance_subplot(scores_f1, scores_acc, losses, col_names, 
                         main_title="Model Architecture Comparison",
                         ylims=ylims, colors=colors)
plt.show()


OSError: 'seaborn' is not a valid package style, path of style file, URL of style file, or library style name (library styles are listed in `style.available`)