Example code for comparing the performances of the datasets. To be implemented later.

In [None]:
# Evaluate the model
def evaluate_individual_augmentations(model, batch_transform, device='cpu'):
    df = pd.read_csv(IND_AUGMENTED_CSV)
    unique_methods = df['augmentation_type'].unique()

    method_metrics = {}

    for method in unique_methods:
        if method == 'none':
            continue
        
        print(f"\nEvaluating augmentation method: {method}")
        df_method = df[df['augmentation_type'] == method]
        
        temp_csv = f'temp_{method}.csv'
        df_method.to_csv(temp_csv, index=False)
        
        temp_dataset = IRDataset(temp_csv)
        _, test_loader = create_loaders(temp_dataset)
        
        metrics = evaluate(model, test_loader, batch_transform, device)
        method_metrics[method] = metrics['macro']

    return method_metrics

def evaluate_dataset(csv_path, model, batch_transform, device='cpu'):
    dataset = IRDataset(csv_path)
    _, test_loader = create_loaders(dataset)
    metrics = evaluate(model, test_loader, batch_transform, device)
    return metrics['macro']

original_metrics = evaluate_dataset(CSV_PATH, model, batch_transform)
augmented_metrics = evaluate_dataset(AUGMENTED_CSV, model, batch_transform)
individual_metrics = evaluate_individual_augmentations(model, batch_transform)


# Compare performances
def plot_metrics_comparison(metrics_dict, title="Augmentation Method Comparison"):
    labels = list(next(iter(metrics_dict.values())).keys())
    methods = list(metrics_dict.keys())

    x = np.arange(len(labels))
    width = 0.15

    plt.figure(figsize=(12, 7))

    for i, method in enumerate(methods):
        values = [metrics_dict[method][label] for label in labels]
        plt.bar(x + i * width, values, width=width, label=method)

    plt.xticks(x + width * (len(methods) - 1) / 2, labels, rotation=45)
    plt.ylabel("Score")
    plt.title(title)
    plt.legend()
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig("augmentation_methods_comparison.png")
    plt.show()



all_metrics = {
    'original': original_metrics,
    'augmented_full': augmented_metrics,
    **individual_metrics


plot_metrics_comparison(all_metrics, title="Augmentation Methods Comparison")