In [17]:
import importlib.util
import os
import sys

# Current directory
current_dir = os.getcwd()

# Function to load a module from file
def load_module_from_file(module_name, file_path):
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module  # This is key - add to sys.modules
    spec.loader.exec_module(module)
    return module

# Temporarily patch sys.modules to help with imports
sys.modules['solo'] = type('solo', (), {})
sys.modules['solo.methods'] = type('methods', (), {})
sys.modules['solo.losses'] = type('losses', (), {})
sys.modules['solo.utils'] = type('utils', (), {})

# Load mocov3 first
mocov3_path = os.path.join(current_dir, "mocov3.py")
mocov3_module = load_module_from_file("solo.methods.mocov3", mocov3_path)
# Add to solo.methods
sys.modules['solo.methods'].mocov3 = mocov3_module
MoCoV3 = mocov3_module.MoCoV3

# Load curriculum_mocov3
curriculum_path = os.path.join(current_dir, "curriculum_mocov3.py")
curriculum_module = load_module_from_file("solo.methods.curriculum_mocov3", curriculum_path)
# Add to solo.methods
sys.modules['solo.methods'].curriculum_mocov3 = curriculum_module
CurriculumMoCoV3 = curriculum_module.CurriculumMoCoV3

# Load selective_curriculum_mocov3
selective_path = os.path.join(current_dir, "selective_curriculum_mocov3.py")
selective_module = load_module_from_file("solo.methods.selective_curriculum_mocov3", selective_path)
SelectiveJEPACurriculumMoCoV3 = selective_module.SelectiveJEPACurriculumMoCoV3

print("Successfully loaded all modules!")
print(f"MoCoV3: {MoCoV3}")
print(f"CurriculumMoCoV3: {CurriculumMoCoV3}")
print(f"SelectiveJEPACurriculumMoCoV3: {SelectiveJEPACurriculumMoCoV3}")

Successfully loaded all modules!
MoCoV3: <class 'solo.methods.mocov3.MoCoV3'>
CurriculumMoCoV3: <class 'solo.methods.curriculum_mocov3.CurriculumMoCoV3'>
SelectiveJEPACurriculumMoCoV3: <class 'solo.methods.selective_curriculum_mocov3.SelectiveJEPACurriculumMoCoV3'>


In [None]:
from __future__ import annotations

import omegaconf
import torch
import torch.nn as nn
import sys
import os


# ----------------------------------------------------------------------------
# Minimal Trainer stub (Lightning‑like attributes the modules expect)
# ----------------------------------------------------------------------------
class _DummyTrainer:
    def __init__(self, epoch: int = 0, rank: int = 0):
        self.current_epoch = epoch
        self.global_rank = rank


# ----------------------------------------------------------------------------
# Build a config that satisfies Solo‑Learn **and nothing more**
# ----------------------------------------------------------------------------

def _build_cfg() -> omegaconf.DictConfig:
    """Return an OmegaConf config with every required field present."""

    cfg_dict = {
        "name": "moco-smoke-test",
        # mandatory blocks ----------------------------------------------------
        "method": {"name": "mocov3"},
        "backbone": {"name": "resnet18"},
        "data": {
            "dataset": "dummy", 
            "num_classes": 0,
            "train_path": None,
            "val_path": None,
        },
        "momentum": {"base_tau": 0.996},
        "no_validation": True,
        # optimizer now includes batch_size (this was missing) --------------
        "optimizer": {
            "name": "sgd", 
            "lr": 0.05, 
            "weight_decay": 0.0, 
            "momentum": 0.9,
            "batch_size": 32,  # This was the missing field
        },
        "scheduler": {"name": None},
        # MoCo‑specific -------------------------------------------------------
        "method_kwargs": {
            "proj_output_dim": 32,
            "proj_hidden_dim": 64,
            "pred_hidden_dim": 64,
            "temperature": 0.2,
            # curriculum parameters
            "curriculum_type": "mae",
            "curriculum_strategy": "exponential",
            "curriculum_warmup_epochs": 5,
            "curriculum_weight": 1.0,
            "reconstruction_masking_ratio": 0.75,
            "curriculum_reverse": False,
            # selective curriculum parameters
            "num_candidates": 8,
            "selection_epochs": 100,
        },
        # misc ----------------------------------------------------------------
        "max_epochs": 1,
    }

    return omegaconf.OmegaConf.create(cfg_dict)


# ----------------------------------------------------------------------------
# Attach *just enough* plumbing so .training_step works outside PL Trainer
# ----------------------------------------------------------------------------

def _attach_testing_stubs(model, *, epoch: int = 0):
    model.trainer = _DummyTrainer(epoch)
    model.optimizers = lambda: [torch.optim.SGD(model.parameters(), lr=0.01)]
    model.lr_schedulers = lambda: None
    model.manual_backward = lambda loss: loss.backward()
    # Add necessary logging methods
    model.log_dict = lambda *args, **kwargs: None
    model.log = lambda *args, **kwargs: None
    # Add device property
    model.device = torch.device('cpu')
    return model

# ----------------------------------------------------------------------------
# Instantiate models
# ----------------------------------------------------------------------------
config = _build_cfg()

moco        = _attach_testing_stubs(MoCoV3(config))
curriculum  = _attach_testing_stubs(CurriculumMoCoV3(config))
sel_curr    = _attach_testing_stubs(SelectiveJEPACurriculumMoCoV3(config))

# ----------------------------------------------------------------------------
# Synthetic data  (two 224×224 views) + candidate set for selective‑JEPA
# ----------------------------------------------------------------------------
B, C, H, W = 4, 3, 224, 224
x1, x2 = torch.randn(B, C, H, W), torch.randn(B, C, H, W)
indices = torch.arange(B)

batch_std = (indices, [x1, x2])

K = 6
cands = torch.randn(B, K, C, H, W)
batch_sel = (indices, x1, cands, None)

# ----------------------------------------------------------------------------
# Monkey‑patch CurriculumMoCoV3 to capture weights produced inside
# ----------------------------------------------------------------------------

def _tap_weights(self, errors, epoch):
    w = CurriculumMoCoV3._compute_sample_weights(self, errors, epoch)
    self._last_weights = w.detach()
    return w

curriculum._compute_sample_weights = _tap_weights.__get__(curriculum, CurriculumMoCoV3)

# ----------------------------------------------------------------------------
# Helper to run a single training step
# ----------------------------------------------------------------------------

def _run_step(model, batch):
    print(f"\n=== {model.__class__.__name__} ===")
    loss = model.training_step(batch, 0)
    print("loss:", float(loss))
    if hasattr(model, "_last_weights"):
        print("weights:", [round(float(v), 3) for v in model._last_weights])


_run_step(moco, batch_std)
_run_step(curriculum, batch_std)
_run_step(sel_curr, batch_sel)

# ----------------------------------------------------------------------------
# Manually influence weights example
# ----------------------------------------------------------------------------
print("\nManual weight demo → override errors = [0.05, 0.3, 0.9, 0.2]")
fake_err = torch.tensor([0.05, 0.3, 0.9, 0.2])
print("returned weights:", [round(float(v), 3) for v in curriculum._compute_sample_weights(fake_err, 0)])



ConfigAttributeError: Missing key batch_size
    full_key: optimizer.batch_size
    object_type=dict