In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Specify the path to your CSV file
csv_file = 'results/scores_test_multi-mnist.csv'
df = pd.read_csv(csv_file)
df = df.drop(["run_name"], axis=1)

In [None]:
df.head()

In [None]:
# Define the grouping keys
group_keys = ["dataset", "model_type", "c_mid"]

# Define the metrics for which statistics should be computed
metrics = ["acc", "acc_foreground", "iou", "iou_foreground", "loss", "loss_foreground"]

# Group by the specified keys and compute statistics across the "extension" category
grouped_df = df.groupby(group_keys).agg({
    metric: ['mean', 'std'] for metric in metrics
}).reset_index()

In [None]:
grouped_df

In [None]:
pivoted_df = grouped_df.pivot_table(index=group_keys)#, columns='extension', values=metrics, aggfunc='mean')

In [None]:
pivoted_df

In [None]:
pivoted_df = pivoted_df.round(2)

In [None]:
pivoted_df

In [None]:
pivoted_df.to_csv("results/scores_processed_multi-mnist.csv", index=True)
pivoted_df.to_latex("results/scores_processed_multi-mnist", index=True, float_format="%.2f")

# Calculate params

In [None]:
from unet2 import UNet as UNet
from cornn_model2 import Model
from utils import make_model

In [None]:
def calc_params(model):
    # Total number of parameters
    total_params = sum(p.numel() for p in model.parameters())

    # Number of trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return trainable_params, total_params

In [None]:
print(calc_params(UNet(1, 11, c_mid=2)))
print(calc_params(UNet(1, 11, c_mid=3)))
print(calc_params(UNet(1, 11, c_mid=4)))
print(calc_params(UNet(1, 11, c_mid=5)))

In [None]:
net = make_model(device="cuda", model_type="cornn_model2", num_classes=11,
               N=128, dt1=0.1, min_iters=0, max_iters=100, c_in=1, c_mid=16, c_out=16, 
               rnn_kernel=3, img_size=128, kernel_init='op', cell_type="lstm", num_layers=16,
               readout_type="linear_smaller4")
print(calc_params(net))