In [None]:
import gc, torch, pickle, torchvision, wandb

print("CUDA available?", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Torch version:",  torch.__version__)
for i in range(torch.cuda.device_count()):
    print(f"Device {i}: {torch.cuda.get_device_name(i)}")


In [None]:
wandb.finish()      # tidy up before the next run
torch.cuda.empty_cache()     # <— NEW
gc.collect()                 # <— NEW

# Run (parallel slurm sessions)

Ef21-IGT
p = 0.71
q = 0.57

In [None]:
%%bash
###############################################################################
# 0.  GLOBAL SETTINGS — edit these two lines only if you need to
###############################################################################
CONDA_ENV="ef21-hess"                      # ← your conda env

METHODS=("EF21_HM_NORM" "EF21_IGT_NORM" "EF21_RHM_NORM" "EF21_MVR_NORM" "EF21" "ECONTROL" "EF21_SGDM" "EF21_SGDM_NORM")   # launch order

#METHODS=("EF21" "ECONTROL" "EF21_SGDM" "EF21_SGDM_NORM" )   # launch order

#METHODS=("EF21_SGDM_NORM")

#METHODS=("EF21_IGT_NORM")
#METHODS=("EF21_RHM_NORM")
#METHODS=("EF21_MVR_NORM")

###############################################################################

tstamp=$(date +%Y%m%d_%H%M%S)

for EF in "${METHODS[@]}"; do
  SESSION="train_${EF}_${tstamp}"           # unique tmux session
  SCRIPT="run_training_${EF}_${tstamp}.py"  # per-method Python file

  #############################################################################
  # 1.  WRITE A ONE-OFF PYTHON SCRIPT (identical for all methods except EF)
  #############################################################################
  cat > "$SCRIPT" <<'PY'
import os, gc, torch, wandb
import numpy as np
from torch.nn import CrossEntropyLoss
from quant import top_k_wrap
from utils import create_exp, myrepr
from train import tune_step_size
from models import resnet18
from prep_data import create_loaders
# ---------------------------------------------------------------------------

# ------------------------- USER CONFIG -------------------------------------
n_workers   = 10
bs          = 64
h           = 0.1
project_name= "EF21_SOM"

gen_dict = {"model_architecture": resnet18, "model_architecture_str": "resnet18", "dataset": "cifar10",
            "epochs": 5, "ef_methods": [os.environ["EF_METHOD"]]}

presets = {
    "EF21":          {"lrs":[1.0], "etas":[None], "p_exps":[None], "q_exps":[None], "cuda_device":"cuda:0"},
    "ECONTROL":      {"lrs":[1.0], "etas":[0.1],  "p_exps":[None], "q_exps":[None], "cuda_device":"cuda:1"},
    "EF21_SGDM":     {"lrs":[0.1], "etas":[0.1], "p_exps":[None], "q_exps":[None], "cuda_device":"cuda:2"},
    
    "EF21_SGDM_NORM":    {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.5], "cuda_device":"cuda:3"},
    "EF21_HM_NORM": {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.67], "cuda_device":"cuda:0"},
    "EF21_RHM_NORM":          {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.67], "cuda_device":"cuda:1"},
    "EF21_MVR_NORM":       {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.67], "cuda_device":"cuda:2"},
    "EF21_IGT_NORM":          {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.57], "cuda_device":"cuda:3"}   
}


# ---------------------------------------------------------------------------

model_architecture = gen_dict["model_architecture"]
dataset            = gen_dict["dataset"]
epochs             = gen_dict["epochs"]
model_architecture_str = gen_dict["model_architecture_str"] 

for ef_method in gen_dict["ef_methods"]:
    exp_dict = gen_dict|presets[ef_method] # merge the dictionaries
    cuda_device      = exp_dict["cuda_device"]
    print(f"Running experiment with {ef_method}")
    for lr in exp_dict["lrs"]:
        for eta in exp_dict["etas"]:
            for p_exp in exp_dict["p_exps"]:
                exp_dict["lr_schedule"] = "poly" if p_exp is not None else None           
                for q_exp in exp_dict["q_exps"]:
                    exp_dict["eta_schedule"] = "poly" if q_exp is not None else None
                    exp_name = (f"{model_architecture.__name__}_{dataset}_{ef_method}"
                                f"_topk-{myrepr(h)}"f"_lr-{myrepr(lr)}"f"_eta-{myrepr(eta)}"f"_p-{myrepr(p_exp)}"f"_q-{myrepr(q_exp)}")
                    
                    exp = create_exp(name=exp_name, dataset=dataset, net=model_architecture, n_workers=n_workers, epochs=epochs, seed=42,
                                    batch_size=bs, lrs=[lr], etas=[eta], lr_schedule=exp_dict["lr_schedule"], eta_schedule=exp_dict["eta_schedule"],
                                    compression={'wrapper': False, 'compression': top_k_wrap(h=h)}, error_feedback=ef_method,
                                    criterion=CrossEntropyLoss(), master_compression=None, momentum=0, weight_decay=0, p=p_exp, q=q_exp)

                    wandb.init(project=project_name, name=f"{exp_name}_{cuda_device}", config={**exp, "lr": lr, "eta": eta, "cuda_device": cuda_device}, 
                            tags=[exp['error_feedback'], model_architecture_str, dataset])

                    best_lr, best_acc_lr = tune_step_size(exp, suffix=exp_name, schedule=exp_dict["lr_schedule"], device=cuda_device)
                    wandb.finish()
                    torch.cuda.empty_cache()
                    gc.collect()  # tidy up
                    
                    
if __name__ == "__main__":
    pass
PY
  #############################################################################
  # 2.  LAUNCH THE SCRIPT IN ITS OWN DETACHED TMUX SESSION
  #############################################################################
  tmux new-session -d -s "$SESSION" \
       "source $(conda info --base)/etc/profile.d/conda.sh && \
        conda activate $CONDA_ENV && \
        EF_METHOD=$EF python $SCRIPT 2>&1 | tee ${SESSION}.log"

  echo "Started tmux session: $SESSION   (method: $EF)"
done

echo -e "\n tmux attach -t $SESSION"


# Run (consecutive slurm sessions) 

In [None]:
%%writefile run_all_experiments.sh
###############################################################################
# 0.  GLOBAL SETTINGS — edit these two lines only if you need to
###############################################################################
CONDA_ENV="ef21-hess"                      # ← your conda env

METHODS=("EF21_HM_NORM" "EF21_RHM_NORM" "EF21_MVR_NORM" "EF21_IGT_NORM" "EF21" "ECONTROL" "EF21_SGDM" "EF21_SGDM_NORM")   # launch order

#METHODS=("EF21" "ECONTROL" "EF21_SGDM" "EF21_SGDM_NORM" )   # launch order

#METHODS=("EF21_SGDM_NORM")

#METHODS=("EF21_IGT_NORM")
#METHODS=("EF21_RHM_NORM")
#METHODS=("EF21_MVR_NORM")

###############################################################################

tstamp=$(date +%Y%m%d_%H%M%S)

for EF in "${METHODS[@]}"; do
  SESSION="train_${EF}_${tstamp}"           # unique tmux session
  SCRIPT="run_training_${EF}_${tstamp}.py"  # per-method Python file

  #############################################################################
  # 1.  WRITE A ONE-OFF PYTHON SCRIPT (identical for all methods except EF)
  #############################################################################
  cat > "$SCRIPT" <<'PY'
import os, gc, torch, wandb
import numpy as np
from torch.nn import CrossEntropyLoss
from quant import top_k_wrap
from utils import create_exp, myrepr
from train import tune_step_size
from models import resnet18
from prep_data import create_loaders
# ---------------------------------------------------------------------------

# ------------------------- USER CONFIG -------------------------------------
n_workers   = 10
bs          = 64
h           = 0.1
project_name= "EF21_SOM"

gen_dict = {"model_architecture": resnet18, "model_architecture_str": "resnet18", "dataset": "cifar10",
            "epochs": 90, "ef_methods": [os.environ["EF_METHOD"]]}

presets = {
    "EF21":          {"lrs":[1.0], "etas":[None], "p_exps":[None], "q_exps":[None], "cuda_device":"cuda:1"},
    "ECONTROL":      {"lrs":[1.0], "etas":[0.1],  "p_exps":[None], "q_exps":[None], "cuda_device":"cuda:1"},
    "EF21_SGDM":     {"lrs":[0.1], "etas":[0.1], "p_exps":[None], "q_exps":[None], "cuda_device":"cuda:1"},
    
    "EF21_SGDM_NORM":    {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.5], "cuda_device":"cuda:1"},
    "EF21_HM_NORM": {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.67], "cuda_device":"cuda:1"},
    "EF21_RHM_NORM":          {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.67], "cuda_device":"cuda:1"},
    "EF21_MVR_NORM":       {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.67], "cuda_device":"cuda:1"},
    "EF21_IGT_NORM":          {"lrs":[0.1], "etas":[1.0], "p_exps":[None], "q_exps":[0.57], "cuda_device":"cuda:1"}   
}


# ---------------------------------------------------------------------------

model_architecture = gen_dict["model_architecture"]
dataset            = gen_dict["dataset"]
epochs             = gen_dict["epochs"]
model_architecture_str = gen_dict["model_architecture_str"] 

for ef_method in gen_dict["ef_methods"]:
    exp_dict = gen_dict|presets[ef_method] # merge the dictionaries
    cuda_device      = exp_dict["cuda_device"]
    print(f"Running experiment with {ef_method}")
    for lr in exp_dict["lrs"]:
        for eta in exp_dict["etas"]:
            for p_exp in exp_dict["p_exps"]:
                exp_dict["lr_schedule"] = "poly" if p_exp is not None else None           
                for q_exp in exp_dict["q_exps"]:
                    exp_dict["eta_schedule"] = "poly" if q_exp is not None else None
                    exp_name = (f"{model_architecture.__name__}_{dataset}_{ef_method}"
                                f"_topk-{myrepr(h)}"f"_lr-{myrepr(lr)}"f"_eta-{myrepr(eta)}"f"_p-{myrepr(p_exp)}"f"_q-{myrepr(q_exp)}")
                    
                    exp = create_exp(name=exp_name, dataset=dataset, net=model_architecture, n_workers=n_workers, epochs=epochs, seed=42,
                                    batch_size=bs, lrs=[lr], etas=[eta], lr_schedule=exp_dict["lr_schedule"], eta_schedule=exp_dict["eta_schedule"],
                                    compression={'wrapper': False, 'compression': top_k_wrap(h=h)}, error_feedback=ef_method,
                                    criterion=CrossEntropyLoss(), master_compression=None, momentum=0, weight_decay=0, p=p_exp, q=q_exp)

                    wandb.init(project=project_name, name=f"{exp_name}_{cuda_device}", config={**exp, "lr": lr, "eta": eta, "cuda_device": cuda_device}, 
                            tags=[exp['error_feedback'], model_architecture_str, dataset])

                    best_lr, best_acc_lr = tune_step_size(exp, suffix=exp_name, schedule=exp_dict["lr_schedule"], device=cuda_device)
                    wandb.finish()
                    torch.cuda.empty_cache()
                    gc.collect()  # tidy up
                    
                    
if __name__ == "__main__":
    pass
PY
  #############################################################################
  # 2.  LAUNCH THE SCRIPT IN ITS OWN TMUX SESSION AND WAIT UNTIL IT FINISHES
  #############################################################################
  tmux new-session -d -s "$SESSION" \
      "source \"$(conda info --base)/etc/profile.d/conda.sh\" && \
        conda activate $CONDA_ENV && \
        EF_METHOD=$EF python $SCRIPT 2>&1 | tee ${SESSION}.log"

  echo "Started tmux session: $SESSION   (method: $EF)  — waiting for it to finish..."
  # Block until the session is gone (i.e., Python finished and tmux exited)
  while tmux has-session -t "$SESSION" 2>/dev/null; do
    sleep 300
  done
echo "Finished: $SESSION"
done
