# **SETUP**

In [1]:
import torch
import os
import sys

import pickle
import itertools
from tqdm import tqdm

from contextlib import redirect_stdout, redirect_stderr

current_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(current_dir, "../src/")))
from engine import run_vit_experiment
from config import MODEL_CONFIG_MAP, OPTIMIZER_CONFIG_MAP
from utils import _DevNull

print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

  from .autonotebook import tqdm as notebook_tqdm


CUDA available: True
Number of GPUs: 1
GPU 0: NVIDIA A100-PCIE-40GB MIG 7g.40gb


In [2]:
DEVICE = "cuda:0"
ROOT = "/home/hoai-linh.dao/Works/BraTS/results/extended-dataset-Figshare"
CHECKPOINT_FILE = os.path.join(ROOT, "checkpoint.pkl") # for monitorring progress
SUMMARY_FILE = os.path.join(ROOT, "summary.txt")
DTS_DIR = "/home/hoai-linh.dao/Works/BraTS/dts/Figshare"
os.makedirs(ROOT, exist_ok=True)

if os.path.exists(CHECKPOINT_FILE):
    with open(CHECKPOINT_FILE, "rb") as f:
        all_results = pickle.load(f)
else:
    all_results = {}
    
model_patch_sizes = ["B16"]
img_nets = ["1K", "21K"]
case_nums = ["4"]
optimizer_names = ["adam"]

# **BENCH TEST**

In [3]:
combinations = list(itertools.product(model_patch_sizes, img_nets, case_nums, optimizer_names))

remaining_combinations = [
    (patch_size, img_net, case_num, optimizer)
    for patch_size, img_net, case_num, optimizer in combinations
    if f"{patch_size}_{img_net}_case{case_num}_{optimizer}" not in all_results
]

pbar = tqdm(
    remaining_combinations,
    total=len(remaining_combinations),
    desc="Running experiments"
)

for patch_size, img_net, case_num, optimizer in pbar:
    config_key = f"{patch_size}_{img_net}_case{case_num}_{optimizer}"
    experiment_name = f"{patch_size}_{img_net}_case{case_num}"
    log_path = os.path.join(
        ROOT, "logs", optimizer, f"{experiment_name}.txt"
    )

    pbar.set_description(f"{config_key}")
    pbar.set_postfix(log=log_path)

    try:
        with redirect_stdout(_DevNull()), redirect_stderr(_DevNull()):
            result = run_vit_experiment(
                model_patch_size     = patch_size,
                img_net              = img_net,
                case_num             = case_num,
                optimizer_name       = optimizer,
                seed                 = 42,
                dataset_dir          = DTS_DIR,
                num_classes          = 3,
                device               = DEVICE,
                experiment_root      = ROOT,
                model_config_map     = MODEL_CONFIG_MAP,
                optimizer_config_map = OPTIMIZER_CONFIG_MAP
            )

        all_results[config_key] = result
        with open(CHECKPOINT_FILE, "wb") as f:
            pickle.dump(all_results, f)

    except Exception as e:
        tqdm.write(f"[FAIL] {config_key}: {e}")

with open(SUMMARY_FILE, "w", encoding="utf-8") as f:
    for config, result in all_results.items():
        weight_path = result.get('weights_path', 'No information for weight')
        f.write(f"{config}: Weights saved at {weight_path}\n")


B16_21K_case4_adam: 100%|██████████| 2/2 [01:01<00:00, 30.55s/it, log=/home/hoai-linh.dao/Works/BraTS/results/extended-dataset-Figshare/logs/adam/B16_21K_case4.txt]
