In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

!pip install --quiet torch torchvision webdataset tqdm pillow

Mounted at /content/drive
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m93.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m86.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m51.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━

In [None]:
from pathlib import Path
import yaml
import sys
import time
import importlib
import logging
import torch
from typing import Any, Dict
from tqdm import tqdm
import inspect

config_path = Path('/content/drive/MyDrive/Colab Notebooks/MLA_PROJECT/wsi-ssrl-rcc_project/config/training.yaml')

with config_path.open('r') as f:
    cfg = yaml.safe_load(f)

colab_root = Path(cfg['env_paths']['colab'])
local_root = Path(cfg['env_paths']['local'])
PROJECT_ROOT = colab_root if colab_root.exists() else local_root
if not PROJECT_ROOT.exists():
    raise FileNotFoundError(f"Project root not find: {PROJECT_ROOT}")

sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / 'src'))


from importlib.util import spec_from_file_location, module_from_spec

utils_dir = PROJECT_ROOT / 'src' / 'utils'
src_file = utils_dir / 'training_utils.py'

spec = spec_from_file_location('utils.training_utils', str(src_file))
training_utils = module_from_spec(spec)
spec.loader.exec_module(training_utils)

sys.modules['utils.training_utils'] = training_utils

from utils.training_utils import TRAINER_REGISTRY
from trainers.extract_features import extract_features
from trainers.train_classifier import train_classifier
from utils.training_utils import get_latest_checkpoint, load_checkpoint
print(f"🔥 PROJECT_ROOT: {PROJECT_ROOT}")


🔥 PROJECT_ROOT: /content/drive/MyDrive/Colab Notebooks/MLA_PROJECT/wsi-ssrl-rcc_project


In [None]:
for split in ['train','val','test']:
    rel = cfg['data'].get(split)
    if rel:
        cfg['data'][split] = str(PROJECT_ROOT / rel)

print("📂 Dataset paths:")
for split in ['train','val','test']:
    print(f"  • {split}: {cfg['data'][split]}")

sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / 'src'))

📂 Dataset paths:
  • train: /content/drive/MyDrive/Colab Notebooks/MLA_PROJECT/wsi-ssrl-rcc_project/data/processed/webdataset_2500/train/patches-0000.tar
  • val: /content/drive/MyDrive/Colab Notebooks/MLA_PROJECT/wsi-ssrl-rcc_project/data/processed/webdataset_2500/val/patches-0000.tar
  • test: /content/drive/MyDrive/Colab Notebooks/MLA_PROJECT/wsi-ssrl-rcc_project/data/processed/webdataset_2500/test/patches-0000.tar


In [None]:
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("LAUNCHER")
logger.info("✅ Logger initialized at INFO level")


In [None]:
trainer_modules = [
    "trainers.simclr",
    "trainers.moco_v2",
    "trainers.rotation",
    "trainers.jigsaw",
    "trainers.supervised",
    "trainers.transfer",
]
for module_name in trainer_modules:
    if module_name in sys.modules:
        importlib.reload(sys.modules[module_name])
    else:
        importlib.import_module(module_name)


In [None]:
import inspect
import time

def launch_training(cfg: dict) -> None:
    models_cfg = cfg.get('models', {})
    run_model = cfg.get('run_model', 'all').lower()
    if run_model == 'all':
        tasks = list(models_cfg.items())
    else:
        if run_model not in models_cfg:
            raise KeyError(f"Model '{run_model}' not found in cfg['models']")
        tasks = [(run_model, models_cfg[run_model])]

    for name, m_cfg in tasks:
        if name not in TRAINER_REGISTRY:
            raise KeyError(f"No trainer registered for '{name}'")

        trainer = TRAINER_REGISTRY[name](m_cfg, cfg['data'])
        device = getattr(trainer, 'device', 'n/a')
        epochs = int(m_cfg['training'].get('epochs', 0))
        batch_size = int(m_cfg['training'].get('batch_size', 0))

        print(f"Device: {device} 🚀  Starting training for model '{name}'")
        print(f"→ Model config: {m_cfg}")
        print(f"Epochs: {epochs} | Batch size: {batch_size}\n")

        has_validation = hasattr(trainer, 'validate_epoch')
        experiment_id =  "prova"
        experiment_dir = PROJECT_ROOT / f"data/processed2/dataset_9f30917e/experiments/{experiment_id}/{name}"
        ckpt_path_list = sorted(experiment_dir.glob(f"{trainer.__class__.__name__}_best_epoch*.pt"))
        ckpt_path = ckpt_path_list[-1] if ckpt_path_list else None


        skip_training = False
        if ckpt_path and ckpt_path.exists():
            print(f"⏭️  Checkpoint found for '{name}' → skipping training and loading encoder/projector/model.")
            if hasattr(trainer, "encoder") and hasattr(trainer, "projector"):
                model = torch.nn.Sequential(trainer.encoder, trainer.projector)
                load_checkpoint(ckpt_path, model=model)
                trainer.encoder = model[0].to(trainer.device)
                trainer.projector = model[1].to(trainer.device)
            elif hasattr(trainer, "model"):
                load_checkpoint(ckpt_path, model=trainer.model)
                trainer.model = trainer.model.to(trainer.device)
            elif hasattr(trainer, "encoder") and hasattr(trainer, "head"):
                model = torch.nn.Sequential(trainer.encoder, trainer.head)
                load_checkpoint(ckpt_path, model=model)
                trainer.encoder = model[0].to(trainer.device)
                trainer.head = model[1].to(trainer.device)
            else:
                raise AttributeError(f"❌ Trainer '{name}' has no encoder/projector or model to load into.")
            skip_training = True


        if not skip_training:
            for epoch in range(1, epochs + 1):
                epoch_start = time.time()
                total_batches = getattr(trainer, 'batches_train', None)
                print(f"TOTAL BATCHES {total_batches}")
                running_loss, running_correct, total_samples = 0.0, 0, 0

                print(f"--- Epoch {epoch}/{epochs} ---")
                for i, batch in enumerate(trainer.train_loader, start=1):
                    sig = inspect.signature(trainer.train_step)
                    result = trainer.train_step(batch) if len(sig.parameters) == 1 else trainer.train_step(*batch)
                    if len(result) == 4:
                        _, loss, correct, bs = result
                    else:
                        loss, bs = result
                        correct = 0

                    running_loss += loss * bs
                    running_correct += correct
                    total_samples += bs
                    avg_loss = running_loss / total_samples
                    avg_acc = (running_correct / total_samples) if has_validation else 0.0
                    elapsed = time.time() - epoch_start
                    pct = (i / total_batches) * 100 if total_batches else 0.0
                    eta = (elapsed / i) * (total_batches - i) if total_batches else 0.0

                    msg = f"  Batch {i}/{total_batches} ({pct:.1f}%) | Loss: {avg_loss:.4f}"
                    if has_validation:
                        msg += f" | Acc: {avg_acc:.3f}"
                    msg += f" | Elapsed: {elapsed:.1f}s | ETA: {eta:.1f}s"
                    print(msg)

                if has_validation:
                    val_loss, val_acc = trainer.validate_epoch()
                    print(f"Val -> Loss: {val_loss:.4f} | Acc: {val_acc:.3f}")
                    trainer.post_epoch(epoch, val_acc)
                else:
                    epoch_loss = running_loss / total_samples
                    trainer.post_epoch(epoch, epoch_loss)

                print(f"Epoch {epoch} completed in {time.time() - epoch_start:.1f}s\n")

        best = trainer.summary()
        if isinstance(best, tuple) and len(best) == 2:
            be, bm = best
            print(f"✅ Training for '{name}' completed. Best @ epoch {be} -> {bm:.3f}")

        # Only for self-supervised
        if not has_validation:
            print(f"🔍 Extracting features from model '{name}'")
            feature_path = experiment_dir / f"{name}_features.pt"
            classifier_path = experiment_dir / f"{name}_classifier.joblib"
            feature_path.parent.mkdir(parents=True, exist_ok=True)

            if hasattr(trainer, "extract_features_to"):
                trainer.extract_features_to(str(feature_path))
                print(f"✅ Features saved to {feature_path}")
            else:
                print(f"⚠️ Trainer '{name}' does not implement extract_features_to(), skipping.")

            print(f"🧠 Training classifier on features '{name}'")
            train_classifier(
                features_path=str(feature_path),
                output_model=str(classifier_path)
            )


In [None]:
launch_training(cfg)



Device: cpu 🚀  Starting training for model 'jigsaw'
→ Model config: {'backbone': 'resnet18', 'grid_size': 3, 'training': {'epochs': 50, 'batch_size': 64, 'optimizer': 'adam', 'learning_rate': '1e-4', 'weight_decay': '1e-5'}}
Epochs: 50 | Batch size: 64

⏭️  Checkpoint found for 'jigsaw' → skipping training and loading encoder/projector/model.




✅ Training for 'jigsaw' completed. Best @ epoch 0 -> inf
🔍 Extracting features from model 'jigsaw'


Extracting features: 24it [02:33,  6.41s/it]


✅ Jigsaw features saved to /content/drive/MyDrive/Colab Notebooks/MLA_PROJECT/wsi-ssrl-rcc_project/data/processed2/dataset_9f30917e/experiments/prova/jigsaw/jigsaw_features.pt
✅ Features saved to /content/drive/MyDrive/Colab Notebooks/MLA_PROJECT/wsi-ssrl-rcc_project/data/processed2/dataset_9f30917e/experiments/prova/jigsaw/jigsaw_features.pt
🧠 Training classifier on features 'jigsaw'
✅ Loaded 1475 keys and (1475, 512) features
📊 Class distribution:
Counter({np.str_('not_tumor'): 299, np.str_('ONCO'): 297, np.str_('CHROMO'): 293, np.str_('ccRCC'): 293, np.str_('pRCC'): 293})
✅ Filtered dataset: 1475 samples
              precision    recall  f1-score   support

      CHROMO       0.58      0.12      0.20        58
        ONCO       0.26      0.58      0.36        59
       ccRCC       0.28      0.12      0.17        59
   not_tumor       0.28      0.13      0.18        60
        pRCC       0.33      0.56      0.42        59

    accuracy                           0.30       295
   ma