In [None]:
import os
import utils.experiment
import utils.dirtools
import pandas as pd
import seaborn as sb
import matplotlib.pyplot as plt

In [None]:
from config import config_vars

experiment_name = 'impact_of_augmented_dataset_size'

partition = "validation"

total_repetitions = 10

config_vars = utils.dirtools.setup_experiment(config_vars, experiment_name)

config_vars

In [None]:
results = pd.DataFrame(columns=["Samples", "Repeat", "MAP", "Missed", "Merges", "Splits"])
idx = 0

for max_samples in [2, 4, 6, 8, 10, 20, 40, 60, 80, 100]:
    for repetition in range(total_repetitions):
        print("Experiment", idx, "- max_samples:", max_samples, "- repetition:", repetition)
        
        # Modify settings
        config_vars["max_training_images"] = max_samples
        
        # Reconfigure variables and data partitions
        config_vars = utils.dirtools.setup_experiment(config_vars, experiment_name)
        data_partitions = utils.dirtools.read_data_partitions(config_vars)
        
        # Run experiment
        output = utils.experiment.run(config_vars, data_partitions, experiment_name, partition, GPU="2")
        
        # Collect outputs
        record = {
            "Samples": max_samples,
            "Repeat": repetition,
            "MAP": output["MAP"],
            "Missed": output["Missed"].sum(),
            "Merges": output["Merges"],
            "Splits": output["Splits"]
        }
        results.loc[idx] = record
        idx += 1
        
        # Clean up directories
        experiment_dir = config_vars["root_directory"] + "/experiments/" + experiment_name
        if os.path.exists(experiment_dir):
            os.system("rm -Rf " + experiment_dir)
            
        # Save results
        results.to_csv(config_vars["root_directory"] + "/experiments/" + experiment_name + ".csv")

In [None]:
results = pd.read_csv("/data1/image-segmentation/BBBC022/unet/experiments/impact_of_augmented_dataset_size.csv")
mean = results.groupby("Samples").mean().reset_index()
sem = results.groupby("Samples").sem().reset_index()
sem.columns = [c+"_se" for c in sem.columns]
data = pd.concat([mean, sem], axis=1).drop(["Samples_se", "Repeat", "Repeat_se"], axis=1)
data

In [None]:
plt.figure(figsize=(8,8))
plt.errorbar(x=data["Samples"], y=data["Missed"], yerr=data["Missed_se"])
plt.xscale("log")