In [161]:
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

!pip install --quiet torch torchvision webdataset tqdm pillow

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [162]:
from pathlib import Path
import yaml
import sys
import time
import importlib
import logging
from typing import Any, Dict
from tqdm import tqdm
import inspect

config_path = Path('/content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/config/training.yaml')

with config_path.open('r') as f:
    cfg = yaml.safe_load(f)

colab_root = Path(cfg['env_paths']['colab'])
local_root = Path(cfg['env_paths']['local'])
PROJECT_ROOT = colab_root if colab_root.exists() else local_root
if not PROJECT_ROOT.exists():
    raise FileNotFoundError(f"Project root non trovato: {PROJECT_ROOT}")

sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / 'src'))


from importlib.util import spec_from_file_location, module_from_spec

utils_dir = PROJECT_ROOT / 'src' / 'utils'
src_file = utils_dir / 'training_utils.py'

spec = spec_from_file_location('utils.training_utils', str(src_file))
training_utils = module_from_spec(spec)
spec.loader.exec_module(training_utils)

sys.modules['utils.training_utils'] = training_utils

from utils.training_utils import TRAINER_REGISTRY

print(f"🔥 PROJECT_ROOT: {PROJECT_ROOT}")


🔥 PROJECT_ROOT: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project


In [163]:
for split in ['train','val','test']:
    rel = cfg['data'].get(split)
    if rel:
        cfg['data'][split] = str(PROJECT_ROOT / rel)

print("📂 Dataset paths:")
for split in ['train','val','test']:
    print(f"  • {split}: {cfg['data'][split]}")

sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / 'src'))

📂 Dataset paths:
  • train: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/webdataset_2500/train/patches-0000.tar
  • val: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/webdataset_2500/val/patches-0000.tar
  • test: /content/drive/MyDrive/ColabNotebooks/wsi-ssrl-rcc_project/data/processed/webdataset_2500/test/patches-0000.tar


In [164]:
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("LAUNCHER")
logger.info("✅ Logger inizializzato a livello INFO")

In [165]:
trainer_modules = [
    "trainers.simclr",
    "trainers.moco_v2",
    "trainers.rotation",
    "trainers.jigsaw",
    "trainers.supervised",
    "trainers.transfer",
]
for module_name in trainer_modules:
    if module_name in sys.modules:
        importlib.reload(sys.modules[module_name])
    else:
        importlib.import_module(module_name)


In [166]:
def launch_training(cfg: dict) -> None:
    """
    Lancia il training dei modelli specificati.
    Per ogni batch: stampa Loss, Acc, % completamento, Elapsed, ETA su riga separata.
    Supporta trainer con signature batch-wise o (imgs, labels).
    """
    models_cfg = cfg.get('models', {})
    run_model = cfg.get('run_model', 'all').lower()
    if run_model == 'all':
        tasks = list(models_cfg.items())
    else:
        if run_model not in models_cfg:
            raise KeyError(f"Modello '{run_model}' non presente in cfg['models']")
        tasks = [(run_model, models_cfg[run_model])]

    for name, m_cfg in tasks:
        if name not in TRAINER_REGISTRY:
            raise KeyError(f"Nessun trainer registrato per '{name}'")
        trainer = TRAINER_REGISTRY[name](m_cfg, cfg['data'])
        device = getattr(trainer, 'device', 'n/a')
        epochs = int(m_cfg['training'].get('epochs', 0))
        batch_size = int(m_cfg['training'].get('batch_size', 0))

        print(f"Device: {device} 🚀  Inizio training per modello '{name}'")
        print(f"→ Model config: {m_cfg}")
        print(f"Epoche: {epochs} | Batch size: {batch_size}\n")

        for epoch in range(1, epochs + 1):
            epoch_start = time.time()
            total_batches = getattr(trainer, 'batches_train', None)
            running_loss = 0.0
            running_correct = 0
            total_samples = 0

            print(f"--- Epoch {epoch}/{epochs} ---")
            for i, batch in enumerate(trainer.train_loader, start=1):
                # dispatch conforme alla signature di train_step
                sig = inspect.signature(trainer.train_step)
                if len(sig.parameters) == 1:
                    # signature: (self, batch)
                    result = trainer.train_step(batch)
                else:
                    # signature: (self, imgs, labels)
                    imgs, labels = batch
                    result = trainer.train_step(imgs, labels)

                # unpack risultati
                if len(result) == 4:
                    _, loss, correct, bs = result
                else:
                    loss, bs = result
                    correct = 0

                # aggiornamento metriche
                running_loss += loss * bs
                running_correct += correct
                total_samples += bs
                avg_loss = running_loss / total_samples
                avg_acc = (running_correct / total_samples) if total_samples else 0.0
                elapsed = time.time() - epoch_start
                pct = (i / total_batches) * 100 if total_batches else 0.0
                eta = (elapsed / i) * (total_batches - i) if total_batches and i else 0.0

                print(
                    f"  Batch {i}/{total_batches} ({pct:.1f}%) | "
                    f"Loss: {avg_loss:.4f} | Acc: {avg_acc:.3f} | "
                    f"Elapsed: {elapsed:.1f}s | ETA: {eta:.1f}s"
                )

            # validazione o post-epoca
            if hasattr(trainer, 'validate_epoch'):
                val_loss, val_acc = trainer.validate_epoch()
                print(f"Val -> Loss: {val_loss:.4f} | Acc: {val_acc:.3f}")
                trainer.post_epoch(epoch, val_acc)
            else:
                epoch_loss = running_loss / total_samples
                trainer.post_epoch(epoch, epoch_loss)

            duration = time.time() - epoch_start
            print(f"Epoch {epoch} completed in {duration:.1f}s\n")

        # riepilogo finale
        best = trainer.summary()
        if isinstance(best, tuple) and len(best) == 2:
            be, bm = best
            print(f"✅ Training per '{name}' completato. Best @ epoch {be} -> {bm:.3f}")

In [169]:
launch_training(cfg)

Device: cuda 🚀  Inizio training per modello 'simclr'
→ Model config: {'backbone': 'resnet18', 'proj_dim': 128, 'augmentation': {'enabled': True, 'horizontal_flip': True, 'rotation': [0, 90, 180, 270], 'color_jitter': {'brightness': 0.5, 'contrast': 0.5, 'saturation': 0.5, 'hue': 0.2}}, 'training': {'epochs': 2, 'batch_size': 32, 'optimizer': 'adam', 'learning_rate': '1e-3', 'weight_decay': '1e-6', 'temperature': 0.5}}
Epoche: 2 | Batch size: 32

--- Epoch 1/2 ---
  Batch 1/47 (2.1%) | Loss: 4.1203 | Acc: 0.000 | Elapsed: 1.1s | ETA: 49.7s
  Batch 2/47 (4.3%) | Loss: 4.1508 | Acc: 0.000 | Elapsed: 1.6s | ETA: 36.8s
  Batch 3/47 (6.4%) | Loss: 4.1429 | Acc: 0.000 | Elapsed: 2.2s | ETA: 32.5s
  Batch 4/47 (8.5%) | Loss: 4.1107 | Acc: 0.000 | Elapsed: 2.8s | ETA: 30.5s
  Batch 5/47 (10.6%) | Loss: 4.1159 | Acc: 0.000 | Elapsed: 3.5s | ETA: 29.6s
  Batch 6/47 (12.8%) | Loss: 4.1053 | Acc: 0.000 | Elapsed: 4.4s | ETA: 30.0s
  Batch 7/47 (14.9%) | Loss: 4.0889 | Acc: 0.000 | Elapsed: 5.3s | E

Device: cuda 🚀  Inizio training per modello 'supervised'
→ Model config: {'backbone': 'resnet50', 'pretrained': False, 'training': {'epochs': 2, 'batch_size': 32, 'optimizer': 'adam', 'learning_rate': '1e-4', 'weight_decay': '1e-5'}}
Epoche: 2 | Batch size: 32

--- Epoch 1/2 ---
  Batch 1/47 (2.1%) | Loss: 2.1067 | Acc: 0.156 | Elapsed: 1.1s | ETA: 51.1s
  Batch 2/47 (4.3%) | Loss: 1.9196 | Acc: 0.172 | Elapsed: 1.4s | ETA: 31.9s
  Batch 3/47 (6.4%) | Loss: 1.8083 | Acc: 0.208 | Elapsed: 1.7s | ETA: 25.3s
  Batch 4/47 (8.5%) | Loss: 1.7630 | Acc: 0.258 | Elapsed: 2.1s | ETA: 22.2s
  Batch 5/47 (10.6%) | Loss: 1.7752 | Acc: 0.269 | Elapsed: 2.4s | ETA: 20.0s
  Batch 6/47 (12.8%) | Loss: 1.7162 | Acc: 0.271 | Elapsed: 2.7s | ETA: 18.4s
  Batch 7/47 (14.9%) | Loss: 1.7117 | Acc: 0.308 | Elapsed: 3.0s | ETA: 17.2s
  Batch 8/47 (17.0%) | Loss: 1.7112 | Acc: 0.289 | Elapsed: 3.3s | ETA: 16.2s
  Batch 9/47 (19.1%) | Loss: 1.6856 | Acc: 0.288 | Elapsed: 3.7s | ETA: 15.4s
  Batch 10/47 (21.3%) | Loss: 1.6566 | Acc: 0.309 | Elapsed: 4.0s | ETA: 14.7s
  Batch 11/47 (23.4%) | Loss: 1.6345 | Acc: 0.312 | Elapsed: 4.3s | ETA: 14.1s
  Batch 12/47 (25.5%) | Loss: 1.6266 | Acc: 0.302 | Elapsed: 4.6s | ETA: 13.5s
  Batch 13/47 (27.7%) | Loss: 1.6210 | Acc: 0.298 | Elapsed: 4.9s | ETA: 12.9s
  Batch 14/47 (29.8%) | Loss: 1.6011 | Acc: 0.308 | Elapsed: 5.3s | ETA: 12.4s
  Batch 15/47 (31.9%) | Loss: 1.5948 | Acc: 0.306 | Elapsed: 5.6s | ETA: 11.9s
  Batch 16/47 (34.0%) | Loss: 1.5712 | Acc: 0.320 | Elapsed: 5.9s | ETA: 11.5s
  Batch 17/47 (36.2%) | Loss: 1.5564 | Acc: 0.333 | Elapsed: 6.2s | ETA: 11.0s
  Batch 18/47 (38.3%) | Loss: 1.5426 | Acc: 0.340 | Elapsed: 6.6s | ETA: 10.6s
  Batch 19/47 (40.4%) | Loss: 1.5288 | Acc: 0.354 | Elapsed: 6.9s | ETA: 10.1s
  Batch 20/47 (42.6%) | Loss: 1.5220 | Acc: 0.353 | Elapsed: 7.2s | ETA: 9.7s
  Batch 21/47 (44.7%) | Loss: 1.5081 | Acc: 0.362 | Elapsed: 7.5s | ETA: 9.3s
  Batch 22/47 (46.8%) | Loss: 1.4947 | Acc: 0.369 | Elapsed: 7.9s | ETA: 8.9s
  Batch 23/47 (48.9%) | Loss: 1.4929 | Acc: 0.372 | Elapsed: 8.2s | ETA: 8.5s
  Batch 24/47 (51.1%) | Loss: 1.4813 | Acc: 0.374 | Elapsed: 8.5s | ETA: 8.2s
  Batch 25/47 (53.2%) | Loss: 1.4728 | Acc: 0.374 | Elapsed: 8.8s | ETA: 7.8s
  Batch 26/47 (55.3%) | Loss: 1.4662 | Acc: 0.376 | Elapsed: 9.2s | ETA: 7.4s
  Batch 27/47 (57.4%) | Loss: 1.4561 | Acc: 0.382 | Elapsed: 9.5s | ETA: 7.0s
  Batch 28/47 (59.6%) | Loss: 1.4485 | Acc: 0.383 | Elapsed: 9.8s | ETA: 6.7s
  Batch 29/47 (61.7%) | Loss: 1.4380 | Acc: 0.384 | Elapsed: 10.2s | ETA: 6.3s
  Batch 30/47 (63.8%) | Loss: 1.4254 | Acc: 0.388 | Elapsed: 10.5s | ETA: 5.9s
  Batch 31/47 (66.0%) | Loss: 1.4164 | Acc: 0.390 | Elapsed: 10.8s | ETA: 5.6s
  Batch 32/47 (68.1%) | Loss: 1.4059 | Acc: 0.395 | Elapsed: 11.1s | ETA: 5.2s
  Batch 33/47 (70.2%) | Loss: 1.4035 | Acc: 0.393 | Elapsed: 11.5s | ETA: 4.9s
  Batch 34/47 (72.3%) | Loss: 1.3943 | Acc: 0.397 | Elapsed: 11.8s | ETA: 4.5s
  Batch 35/47 (74.5%) | Loss: 1.3864 | Acc: 0.400 | Elapsed: 12.1s | ETA: 4.2s
  Batch 36/47 (76.6%) | Loss: 1.3818 | Acc: 0.399 | Elapsed: 12.4s | ETA: 3.8s
  Batch 37/47 (78.7%) | Loss: 1.3862 | Acc: 0.397 | Elapsed: 12.8s | ETA: 3.5s
  Batch 38/47 (80.9%) | Loss: 1.3838 | Acc: 0.397 | Elapsed: 13.1s | ETA: 3.1s
  Batch 39/47 (83.0%) | Loss: 1.3793 | Acc: 0.397 | Elapsed: 13.4s | ETA: 2.8s
  Batch 40/47 (85.1%) | Loss: 1.3699 | Acc: 0.401 | Elapsed: 13.8s | ETA: 2.4s
  Batch 41/47 (87.2%) | Loss: 1.3655 | Acc: 0.403 | Elapsed: 14.1s | ETA: 2.1s
  Batch 42/47 (89.4%) | Loss: 1.3586 | Acc: 0.403 | Elapsed: 14.4s | ETA: 1.7s
  Batch 43/47 (91.5%) | Loss: 1.3565 | Acc: 0.402 | Elapsed: 14.8s | ETA: 1.4s
  Batch 44/47 (93.6%) | Loss: 1.3494 | Acc: 0.404 | Elapsed: 15.1s | ETA: 1.0s
  Batch 45/47 (95.7%) | Loss: 1.3445 | Acc: 0.406 | Elapsed: 15.4s | ETA: 0.7s
  Batch 46/47 (97.9%) | Loss: 1.3405 | Acc: 0.410 | Elapsed: 15.8s | ETA: 0.3s
  Batch 47/47 (100.0%) | Loss: 1.3405 | Acc: 0.410 | Elapsed: 15.9s | ETA: 0.0s
Val -> Loss: 4.7800 | Acc: 0.230
Epoch 1 completed in 20.0s

--- Epoch 2/2 ---
  Batch 1/47 (2.1%) | Loss: 1.0099 | Acc: 0.469 | Elapsed: 1.3s | ETA: 61.0s
  Batch 2/47 (4.3%) | Loss: 1.1234 | Acc: 0.531 | Elapsed: 1.7s | ETA: 37.3s
  Batch 3/47 (6.4%) | Loss: 1.1582 | Acc: 0.510 | Elapsed: 2.0s | ETA: 29.2s
  Batch 4/47 (8.5%) | Loss: 1.1780 | Acc: 0.492 | Elapsed: 2.4s | ETA: 25.3s
  Batch 5/47 (10.6%) | Loss: 1.1511 | Acc: 0.506 | Elapsed: 2.7s | ETA: 22.5s
  Batch 6/47 (12.8%) | Loss: 1.1205 | Acc: 0.500 | Elapsed: 3.0s | ETA: 20.6s
  Batch 7/47 (14.9%) | Loss: 1.1666 | Acc: 0.504 | Elapsed: 3.4s | ETA: 19.2s
  Batch 8/47 (17.0%) | Loss: 1.1596 | Acc: 0.512 | Elapsed: 3.7s | ETA: 18.0s
  Batch 9/47 (19.1%) | Loss: 1.1487 | Acc: 0.514 | Elapsed: 4.0s | ETA: 17.0s
  Batch 10/47 (21.3%) | Loss: 1.1577 | Acc: 0.509 | Elapsed: 4.4s | ETA: 16.2s
  Batch 11/47 (23.4%) | Loss: 1.1546 | Acc: 0.509 | Elapsed: 4.7s | ETA: 15.4s
  Batch 12/47 (25.5%) | Loss: 1.1531 | Acc: 0.508 | Elapsed: 5.0s | ETA: 14.7s
  Batch 13/47 (27.7%) | Loss: 1.1649 | Acc: 0.512 | Elapsed: 5.4s | ETA: 14.1s
  Batch 14/47 (29.8%) | Loss: 1.1720 | Acc: 0.502 | Elapsed: 5.7s | ETA: 13.5s
  Batch 15/47 (31.9%) | Loss: 1.1744 | Acc: 0.506 | Elapsed: 6.1s | ETA: 12.9s
  Batch 16/47 (34.0%) | Loss: 1.1742 | Acc: 0.504 | Elapsed: 6.4s | ETA: 12.4s
  Batch 17/47 (36.2%) | Loss: 1.1654 | Acc: 0.506 | Elapsed: 6.8s | ETA: 11.9s
  Batch 18/47 (38.3%) | Loss: 1.1551 | Acc: 0.509 | Elapsed: 7.1s | ETA: 11.4s
  Batch 19/47 (40.4%) | Loss: 1.1545 | Acc: 0.505 | Elapsed: 7.4s | ETA: 11.0s
  Batch 20/47 (42.6%) | Loss: 1.1494 | Acc: 0.506 | Elapsed: 7.8s | ETA: 10.5s
  Batch 21/47 (44.7%) | Loss: 1.1412 | Acc: 0.504 | Elapsed: 8.1s | ETA: 10.0s
  Batch 22/47 (46.8%) | Loss: 1.1388 | Acc: 0.501 | Elapsed: 8.5s | ETA: 9.6s
  Batch 23/47 (48.9%) | Loss: 1.1314 | Acc: 0.503 | Elapsed: 8.8s | ETA: 9.2s
  Batch 24/47 (51.1%) | Loss: 1.1301 | Acc: 0.505 | Elapsed: 9.1s | ETA: 8.8s
  Batch 25/47 (53.2%) | Loss: 1.1287 | Acc: 0.507 | Elapsed: 9.5s | ETA: 8.4s
  Batch 26/47 (55.3%) | Loss: 1.1287 | Acc: 0.510 | Elapsed: 9.8s | ETA: 7.9s
  Batch 27/47 (57.4%) | Loss: 1.1209 | Acc: 0.516 | Elapsed: 10.2s | ETA: 7.5s
  Batch 28/47 (59.6%) | Loss: 1.1169 | Acc: 0.518 | Elapsed: 10.5s | ETA: 7.1s
  Batch 29/47 (61.7%) | Loss: 1.1198 | Acc: 0.517 | Elapsed: 10.9s | ETA: 6.8s
  Batch 30/47 (63.8%) | Loss: 1.1119 | Acc: 0.518 | Elapsed: 11.2s | ETA: 6.4s
  Batch 31/47 (66.0%) | Loss: 1.1129 | Acc: 0.518 | Elapsed: 11.6s | ETA: 6.0s
  Batch 32/47 (68.1%) | Loss: 1.1103 | Acc: 0.518 | Elapsed: 11.9s | ETA: 5.6s
  Batch 33/47 (70.2%) | Loss: 1.1086 | Acc: 0.519 | Elapsed: 12.3s | ETA: 5.2s
  Batch 34/47 (72.3%) | Loss: 1.1028 | Acc: 0.521 | Elapsed: 12.6s | ETA: 4.8s
  Batch 35/47 (74.5%) | Loss: 1.0982 | Acc: 0.523 | Elapsed: 13.0s | ETA: 4.4s
  Batch 36/47 (76.6%) | Loss: 1.0946 | Acc: 0.523 | Elapsed: 13.3s | ETA: 4.1s
  Batch 37/47 (78.7%) | Loss: 1.0901 | Acc: 0.527 | Elapsed: 13.7s | ETA: 3.7s
  Batch 38/47 (80.9%) | Loss: 1.0853 | Acc: 0.529 | Elapsed: 14.0s | ETA: 3.3s
  Batch 39/47 (83.0%) | Loss: 1.0817 | Acc: 0.528 | Elapsed: 14.4s | ETA: 2.9s
  Batch 40/47 (85.1%) | Loss: 1.0740 | Acc: 0.534 | Elapsed: 14.7s | ETA: 2.6s
  Batch 41/47 (87.2%) | Loss: 1.0742 | Acc: 0.536 | Elapsed: 15.1s | ETA: 2.2s
  Batch 42/47 (89.4%) | Loss: 1.0700 | Acc: 0.539 | Elapsed: 15.4s | ETA: 1.8s
  Batch 43/47 (91.5%) | Loss: 1.0674 | Acc: 0.538 | Elapsed: 15.8s | ETA: 1.5s
  Batch 44/47 (93.6%) | Loss: 1.0652 | Acc: 0.539 | Elapsed: 16.1s | ETA: 1.1s
  Batch 45/47 (95.7%) | Loss: 1.0645 | Acc: 0.539 | Elapsed: 16.5s | ETA: 0.7s
  Batch 46/47 (97.9%) | Loss: 1.0668 | Acc: 0.538 | Elapsed: 16.8s | ETA: 0.4s
  Batch 47/47 (100.0%) | Loss: 1.0666 | Acc: 0.538 | Elapsed: 16.9s | ETA: 0.0s
Val -> Loss: 1.5627 | Acc: 0.353
Epoch 2 completed in 19.9s

✅ Training per 'supervised' completato. Best @ epoch 2 -> 0.353