<a href="https://colab.research.google.com/github/adrirens/MML_course/blob/main/code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
from google.colab import drive
from os.path import join

ROOT = '/content/drive'     # default for the drive
PROJ = 'MyDrive/mml'       # path to your project on Drive

drive.mount(ROOT)           # we mount the drive at /content/drive

PROJECT_PATH = join(ROOT, PROJ)
!mkdir "{PROJECT_PATH}"    # in case we haven't created it already
%cd "{PROJECT_PATH}"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
mkdir: cannot create directory ‘/content/drive/MyDrive/mml’: File exists
/content/drive/MyDrive/mml


In [19]:
GIT_USERNAME = "adrirens"  # This is a shared repository. If you want to synchronize with your own fork, well, you need to fork the github repository and replace the given name by your username.
GIT_TOKEN = "XXX"  # This token is used only if you work with your own fork. You have to generate it and put it here. Make sure to keep it confidential: it is a very sensitive information!
GIT_REPOSITORY = "MML_course"

In [20]:
# GIT_PATH = "https://"+GIT_TOKEN+"@github.com/"+GIT_USERNAME+"/"+GIT_REPOSITORY+".git"
#GIT_PATH = "https://github.com/"+GIT_USERNAME+"/"+GIT_REPOSITORY+".git"
#!git clone "{GIT_PATH}"
%cd "{GIT_REPOSITORY}"

/content/drive/MyDrive/mml/MML_course


In [21]:
!ls

base_model  function_fitting	       kan_comparison_table.gsheet  model
code.ipynb  image_classification       kan_results_summary.csv	    README.md
dataset     image_classification_conv  logs			    solving_pde
figures     kan_comparison_table.csv   MMl_course


In [22]:
!pip install torch torchvision fvcore iopath yacs tqdm numpy matplotlib



In [23]:
models = [
    "I.6.20a",
    "I.6.20",
    "I.6.20b",
    "I.8.4",
    "I.9.18",
    "I.12.2",
    "I.12.4",
    "I.12.5",
    "I.10.7",
    "I.12.11",
    "I.13.4",
    "I.13.12",
    "I.14.3",
    "I.14.4",
    "I.15.3x",
    "I.15.10",
    "I.18.4"
]

grid_sizes = [
    "5",
    "20"
]


model_types = ["KAN", "HyperKAN"]



!python image_classification_conv/train_meta.py \
  --model MetaKAN8_M \
  --n_metanets 1 \
  --optim_set double \
  --lr_h 1e-4 \
  --lr_e 1e-3 \
  --embedding_dim 1 \
  --hidden_dim 32 \
  --dataset CIFAR10 \
  --batch-size 128 \
  --epochs 50
!python image_classification_conv/train.py \
  --model KAN8 \
  --grid_size 5 \
  --spline_order 3 \
  --dataset CIFAR10 \
  --batch-size 128 \
  --epochs 50 \
  --lr 1e-3

In [26]:
import subprocess
import re
import pandas as pd
import time

def run_and_extract(cmd):
    """Exécute une commande shell et retourne les métriques extraites"""
    result = subprocess.run(cmd, capture_output=True, text=True)
    out = result.stdout

    # Extraction des infos
    num_params = re.search(r"Number of parameters:\s*([\d,]+)", out)
    test_losses = re.findall(r"test set: Average loss:\s*([\d\.Ee-]+)", out)

    return {
        "num_params": int(num_params.group(1).replace(",", "")) if num_params else None,
        "final_test_loss": float(test_losses[-1]) if test_losses else None,
    }

raw_results = []
for grid in grid_sizes:
  for model_name in models:
    print(f"\n🔹 Dataset: {model_name} | Grid: {grid}")
    for model_type in model_types:
        print(f"   → Training {model_type}...")

        if model_type == "KAN":
            cmd = [
                "python", "function_fitting/train.py",
                "--model", "KAN",
                "--optimizer", "lbfgs",
                "--lr", "1",
                "--dataset", model_name,
                "--layers_width", "5", "5", "5",
                "--loss", "mse",
                "--kan_bspline_grid", str(grid)
            ]
        else:  # HyperKAN
            cmd = [
                "python", "function_fitting/train_hyper.py",
                "--model", "HyperKAN",
                "--optimizer", "lbfgs",
                "--lr", "1",
                "--dataset", model_name,
                "--layers_width", "5", "5", "5",
                "--loss", "mse",
                "--embedding_dim", "1",
                "--hidden_dim", "16",
                "--kan_bspline_grid", str(grid)
            ]

        res = run_and_extract(cmd)
        res.update({
            "dataset": model_name,
            "type": model_type,
            "grid": grid,
        })
        raw_results.append(res)
        time.sleep(0.5)

# 🧾 Conversion en DataFrame
df = pd.DataFrame(raw_results)

# 🪄 Transformation au format du papier
# (colonnes = G=5 KAN, G=5 HyperKAN, G=20 KAN, G=20 HyperKAN)
pivot = df.pivot_table(
    index="dataset",
    columns=["grid", "type"],
    values=["final_test_loss", "num_params"]
)

# Nettoyage du MultiIndex pour affichage plus clair
pivot.columns = [
    f"G={g} {t} {m}" for m, g, t in pivot.columns.to_flat_index()
]

# ✅ Tableau final
pivot = pivot.reset_index()
display(pivot)

# 💾 Sauvegarde CSV
pivot.to_csv("kan_comparison_table.csv", index=False)
print("\n✅ Tableau sauvegardé dans kan_comparison_table.csv")


🔹 Dataset: I.6.20a | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.6.20 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.6.20b | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.8.4 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.9.18 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.12.2 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.12.4 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.12.5 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.10.7 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.12.11 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.13.4 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.13.12 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Dataset: I.14.3 | Grid: 5
   → Training KAN...
   → Training HyperKAN...

🔹 Datase

Unnamed: 0,dataset,G=20 HyperKAN final_test_loss,G=20 KAN final_test_loss,G=5 HyperKAN final_test_loss,G=5 KAN final_test_loss,G=20 HyperKAN num_params,G=20 KAN num_params,G=5 HyperKAN num_params,G=5 KAN num_params
0,I.10.7,0.051128,0.005265,0.005615,0.005101,1272.0,3656.0,867.0,1556.0
1,I.12.11,0.304376,0.164421,0.253001,0.094853,1396.0,4176.0,961.0,1776.0
2,I.12.2,0.001809,0.003858,0.001811,0.002086,1334.0,3916.0,914.0,1666.0
3,I.12.4,0.011045,0.001281,0.002996,0.000635,1272.0,3656.0,867.0,1556.0
4,I.12.5,0.048057,0.00333,0.002972,0.002007,1210.0,3396.0,820.0,1446.0
5,I.13.12,0.013044,0.006375,0.004512,0.003734,1396.0,4176.0,961.0,1776.0
6,I.13.4,0.014097,0.004309,0.015846,0.004541,1334.0,3916.0,914.0,1666.0
7,I.14.3,0.143989,0.002843,0.012828,0.00266,1272.0,3656.0,867.0,1556.0
8,I.14.4,0.002106,0.000554,0.000818,0.001148,1210.0,3396.0,820.0,1446.0
9,I.15.10,0.195244,0.008609,0.010609,0.005801,1272.0,3656.0,867.0,1556.0



✅ Tableau sauvegardé dans kan_comparison_table.csv
