In [40]:
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

!pip install --quiet torch torchvision webdataset tqdm pillow

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [41]:
from pathlib import Path
import yaml
import sys
import time
import importlib
import logging
from typing import Any, Dict
from tqdm import tqdm

config_path = Path('/content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/config/training.yaml')

with config_path.open('r') as f:
    cfg = yaml.safe_load(f)

colab_root = Path(cfg['env_paths']['colab'])
local_root = Path(cfg['env_paths']['local'])
PROJECT_ROOT = colab_root if colab_root.exists() else local_root
if not PROJECT_ROOT.exists():
    raise FileNotFoundError(f"Project root non trovato: {PROJECT_ROOT}")

sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / 'src'))


from importlib.util import spec_from_file_location, module_from_spec

utils_dir = PROJECT_ROOT / 'src' / 'utils'
src_file = utils_dir / 'training_utils.py'

spec = spec_from_file_location('utils.training_utils', str(src_file))
training_utils = module_from_spec(spec)
spec.loader.exec_module(training_utils)

sys.modules['utils.training_utils'] = training_utils

from utils.training_utils import TRAINER_REGISTRY

print(f"🔥 PROJECT_ROOT: {PROJECT_ROOT}")


🔥 PROJECT_ROOT: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project


In [42]:
for split in ['train','val','test']:
    rel = cfg['data'].get(split)
    if rel:
        cfg['data'][split] = str(PROJECT_ROOT / rel)

print("📂 Dataset paths:")
for split in ['train','val','test']:
    print(f"  • {split}: {cfg['data'][split]}")

sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / 'src'))

📂 Dataset paths:
  • train: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/webdataset_2500/train/patches-0000.tar
  • val: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/webdataset_2500/val/patches-0000.tar
  • test: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/webdataset_2500/test/patches-0000.tar


In [43]:
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("LAUNCHER")
logger.info("✅ Logger inizializzato a livello INFO")

In [44]:
trainer_modules = [
    "trainers.simclr",
    "trainers.moco_v2",
    "trainers.rotation",
    "trainers.jigsaw",
    "trainers.supervised",
    "trainers.transfer",
]
for module_name in trainer_modules:
    if module_name in sys.modules:
        importlib.reload(sys.modules[module_name])
    else:
        importlib.import_module(module_name)


In [45]:

def launch_training(cfg: dict) -> None:
    """
    Lancia il training dei modelli specificati.
    Stampa: device, model config, epoche, batch size;
    per batch: loss, acc (se disponibile), % completamento, ETA;
    validazione e checkpoint.
    """
    # Config
    models_cfg = cfg.get('models', {})
    run_model = cfg.get('run_model', 'all').lower()
    if run_model == 'all':
        tasks = list(models_cfg.items())
    else:
        if run_model not in models_cfg:
            raise KeyError(f"Modello '{run_model}' non presente in cfg['models']")
        tasks = [(run_model, models_cfg[run_model])]

    # Logger base (usato internamente)
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(levelname)-8s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logger = logging.getLogger("LAUNCHER")
    logger.info(f"Modelli da eseguire: {[name for name, _ in tasks]}")

    for name, m_cfg in tasks:
        if name not in TRAINER_REGISTRY:
            raise KeyError(f"Nessun trainer per '{name}'")
        TrainerCls = TRAINER_REGISTRY[name]
        trainer = TrainerCls(m_cfg, cfg['data'])

        # Info iniziali
        device = getattr(trainer, 'device', 'n/a')
        epochs = int(m_cfg['training'].get('epochs', 0))
        batch_size = int(m_cfg['training'].get('batch_size', 0))
        print(f"Device: {device} 🚀  Inizio training per modello '{name}'")
        print(f"→ Model config: {m_cfg}")
        print(f"Epoche: {epochs} | Batch size: {batch_size}")

        # Loop epoche
        for epoch in range(1, epochs + 1):
            epoch_start = time.time()
            total_batches = getattr(trainer, 'batches_train', None)
            print(f"\n--- Epoch {epoch}/{epochs} ---")
            bar = tqdm(
                trainer.train_loader,
                total=total_batches,
                unit='batch',
                desc=f"Train E{epoch}",
                leave=False,
                dynamic_ncols=True,
            )
            running_loss = 0.0
            running_correct = 0
            total_samples = 0
            for batch in bar:
                # gestisce sia supervised che simclr train_step
                result = trainer.train_step(batch)
                if len(result) == 4:
                    _, loss, correct, bs = result
                else:
                    loss, bs = result
                    correct = 0
                running_loss += loss * bs
                running_correct += correct
                total_samples += bs
                avg_loss = running_loss / total_samples
                avg_acc = (running_correct / total_samples) if total_samples else 0
                pct = (bar.n / total_batches) * 100 if total_batches else 0
                elapsed = time.time() - epoch_start
                eta = (elapsed / bar.n) * (total_batches - bar.n) if total_batches and bar.n else 0
                bar.set_postfix(
                    loss=f"{avg_loss:.4f}",
                    acc=f"{avg_acc:.3f}",
                    pct=f"{pct:.1f}%",
                    eta=f"{eta:.1f}s"
                )
            bar.close()

            # Validazione (solo supervised)
            if hasattr(trainer, 'validate_epoch'):
                val_loss, val_acc = trainer.validate_epoch()
                print(f"Val -> Loss: {val_loss:.4f} | Acc: {val_acc:.3f}")
                trainer.post_epoch(epoch, val_acc)
            else:
                epoch_loss = running_loss / (total_samples or 1)
                trainer.post_epoch(epoch, epoch_loss)

            epoch_duration = time.time() - epoch_start
            print(f"Epoch completed in {epoch_duration:.1f}s")

        # Riepilogo
        best = trainer.summary()
        if isinstance(best, tuple) and len(best) == 2:
            be, bm = best
            print(f"\n✅ Training per '{name}' completato. Best @ epoch {be} -> {bm:.3f}")


In [46]:
launch_training(cfg)

Device: cpu 🚀  Inizio training per modello 'simclr'
→ Model config: {'backbone': 'resnet18', 'proj_dim': 128, 'augmentation': {'enabled': True, 'horizontal_flip': True, 'rotation': [0, 90, 180, 270], 'color_jitter': {'brightness': 0.5, 'contrast': 0.5, 'saturation': 0.5, 'hue': 0.2}}, 'training': {'epochs': 2, 'batch_size': 16, 'optimizer': 'adam', 'learning_rate': '1e-3', 'weight_decay': '1e-6', 'temperature': 0.5}}
Epoche: 2 | Batch size: 16

--- Epoch 1/2 ---




KeyboardInterrupt: 