<a href="https://colab.research.google.com/github/adrirens/MML_course/blob/main/code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
from google.colab import drive
from os.path import join

ROOT = '/content/drive'     # default for the drive
PROJ = 'MyDrive/mml'       # path to your project on Drive

drive.mount(ROOT)           # we mount the drive at /content/drive

PROJECT_PATH = join(ROOT, PROJ)
!mkdir "{PROJECT_PATH}"    # in case we haven't created it already
%cd "{PROJECT_PATH}"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
mkdir: cannot create directory ‚Äò/content/drive/MyDrive/mml‚Äô: File exists
/content/drive/MyDrive/mml


In [10]:
GIT_USERNAME = "adrirens"  # This is a shared repository. If you want to synchronize with your own fork, well, you need to fork the github repository and replace the given name by your username.
GIT_TOKEN = "XXX"  # This token is used only if you work with your own fork. You have to generate it and put it here. Make sure to keep it confidential: it is a very sensitive information!
GIT_REPOSITORY = "MML_course"

In [11]:
# GIT_PATH = "https://"+GIT_TOKEN+"@github.com/"+GIT_USERNAME+"/"+GIT_REPOSITORY+".git"
#GIT_PATH = "https://github.com/"+GIT_USERNAME+"/"+GIT_REPOSITORY+".git"
#!git clone "{GIT_PATH}"
%cd "{GIT_REPOSITORY}"

/content/drive/MyDrive/mml/MML_course


In [12]:
!ls

base_model	  image_classification		 logs
code.ipynb	  image_classification_conv	 MMl_course
dataset		  kan_comparison_formatted.xlsx  model
figures		  kan_comparison_table.csv	 README.md
function_fitting  kan_comparison_table.gsheet	 solving_pde


In [13]:
!pip install torch torchvision fvcore iopath yacs tqdm numpy matplotlib



In [16]:
models = {
    "I.6.20A": ["2","2","1"],
    "I.6.20": ["2","2","1"],
    "I.6.20B": ["5","5","1"],
    "I.8.4": ["5","5","1"],
    "I.9.18": ["5","5","1"],
    "I.12.2": ["3","3","1"],
    "I.12.4": ["3","3","1"],
    "I.12.5": ["3","3","1"],
    "I.10.7": ["2","2","1"],
    "I.12.11": ["2","2","1"],
    "I.13.4": ["2","2","1"],
    "I.13.12": ["4","4","1"],
    "I.14.3": ["2","2","1"],
    "I.14.4": ["2","2","1"],
    "I.15.3X": ["2","2","1"],
    "I.15.10": ["3","3","1"],
    "I.18.4": ["4","3","1"]
}

grid_sizes = [
    "5"
]


model_types = ["KAN", "HyperKAN"]



!python image_classification_conv/train_meta.py \
  --model MetaKAN8_M \
  --n_metanets 1 \
  --optim_set double \
  --lr_h 1e-4 \
  --lr_e 1e-3 \
  --embedding_dim 1 \
  --hidden_dim 32 \
  --dataset CIFAR10 \
  --batch-size 128 \
  --epochs 50
!python image_classification_conv/train.py \
  --model KAN8 \
  --grid_size 5 \
  --spline_order 3 \
  --dataset CIFAR10 \
  --batch-size 128 \
  --epochs 50 \
  --lr 1e-3

In [None]:
import subprocess
import re
import pandas as pd
import time

def run_and_extract(cmd):
    """Ex√©cute une commande shell et retourne les m√©triques extraites"""
    result = subprocess.run(cmd, capture_output=True, text=True)
    out = result.stdout

    # Extraction des infos
    num_params = re.search(r"Number of parameters:\s*([\d,]+)", out)
    test_losses = re.findall(r"test set: Average loss:\s*([\d\.Ee-]+)", out)

    return {
        "num_params": int(num_params.group(1).replace(",", "")) if num_params else None,
        "final_test_loss": float(test_losses[-1]) if test_losses else None,
    }

raw_results = []
for grid in grid_sizes:
  for model_name in models:
    print(f"\nüîπ Dataset: {model_name} | Grid: {grid}")
    for model_type in model_types:
        print(f"   ‚Üí Training {model_type}...")

        if model_type == "KAN":
            cmd = [
                "python", "function_fitting/train.py",
                "--model", "KAN",
                "--optimizer", "lbfgs",
                "--lr", "1",
                "--dataset", model_name,
                "--layers_width", "5", "5", "5",
                "--loss", "mse",
                "--kan_bspline_grid", str(grid)
            ]
        else:  # HyperKAN
            cmd = [
                "python", "function_fitting/train_hyper.py",
                "--model", "HyperKAN",
                "--optimizer", "lbfgs",
                "--lr", "1",
                "--dataset", model_name,
                "--layers_width", "5", "5", "5",
                "--loss", "mse",
                "--embedding_dim", "1",
                "--hidden_dim", "16",
                "--kan_bspline_grid", str(grid)
            ]

        res = run_and_extract(cmd)
        res.update({
            "dataset": model_name,
            "type": model_type,
            "grid": grid,
        })
        raw_results.append(res)
        time.sleep(0.5)

# Donn√©es simul√©es √† partir du pivot pr√©c√©dent
df = pd.DataFrame(raw_results)

# Pivot hi√©rarchique
pivot = df.pivot_table(
    index="dataset",
    columns=["grid", "type"],
    values=["final_test_loss", "num_params"]
)

# Renommage des niveaux pour ressembler au papier
pivot.columns = pivot.columns.swaplevel(0, 2)  # (type, grid, metric)
pivot = pivot.sort_index(axis=1, level=[1, 0])  # ordonner par grid puis type

# Appliquer le style gras sur les minima entre KAN et HyperKAN
def highlight_min_in_pair(val1, val2):
    if pd.isna(val1) or pd.isna(val2):
        return "", ""
    if val1 < val2:
        return "font-weight: bold", ""
    elif val2 < val1:
        return "", "font-weight: bold"
    return "", ""

styled = pivot.copy()

for grid in pivot.columns.levels[1]:
    for metric in ["final_test_loss", "num_params"]:
        kan_col = ("KAN", grid, metric)
        meta_col = ("HyperKAN", grid, metric)
        if kan_col in pivot.columns and meta_col in pivot.columns:
            for i in pivot.index:
                style_kan, style_meta = highlight_min_in_pair(
                    pivot.loc[i, kan_col],
                    pivot.loc[i, meta_col]
                )
                styled.loc[i, kan_col] = f"**{pivot.loc[i, kan_col]}**" if style_kan else pivot.loc[i, kan_col]
                styled.loc[i, meta_col] = f"**{pivot.loc[i, meta_col]}**" if style_meta else pivot.loc[i, meta_col]

# Sauvegarde propre en Excel avec multi-index clair
styled.columns.names = ["Model", "Grid", "Metric"]
styled.to_excel("kan_comparison_formatted.xlsx", merge_cells=True)

print("‚úÖ Tableau sauvegard√© dans 'kan_comparison_formatted.xlsx' avec format hi√©rarchique et valeurs en gras.")


üîπ Dataset: I.6.20A | Grid: 5
   ‚Üí Training KAN...
   ‚Üí Training HyperKAN...

üîπ Dataset: I.6.20 | Grid: 5
   ‚Üí Training KAN...
   ‚Üí Training HyperKAN...

üîπ Dataset: I.6.20B | Grid: 5
   ‚Üí Training KAN...
   ‚Üí Training HyperKAN...

üîπ Dataset: I.8.4 | Grid: 5
   ‚Üí Training KAN...
   ‚Üí Training HyperKAN...

üîπ Dataset: I.9.18 | Grid: 5
   ‚Üí Training KAN...
   ‚Üí Training HyperKAN...

üîπ Dataset: I.12.2 | Grid: 5
   ‚Üí Training KAN...
   ‚Üí Training HyperKAN...

üîπ Dataset: I.12.4 | Grid: 5
   ‚Üí Training KAN...
   ‚Üí Training HyperKAN...
