Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmlearn/datasets/processors/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class IJEPAMaskGenerator:
allow_overlap: bool = False
enc_mask_scale: tuple[float, float] = (0.85, 1.0)
pred_mask_scale: tuple[float, float] = (0.15, 0.2)
aspect_ratio: tuple[float, float] = (0.75, 1.0)
aspect_ratio: tuple[float, float] = (0.75, 1.5)
nenc: int = 1
npred: int = 4

Expand Down
26 changes: 18 additions & 8 deletions mmlearn/modules/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class ExponentialMovingAverage:
The final decay value for EMA.
ema_anneal_end_step : int
The number of steps to anneal the decay from ``ema_decay`` to ``ema_end_decay``.
device_id : Optional[Union[int, torch.device]], optional, default=None
The device to move the model to.
skip_keys : Optional[Union[list[str], Set[str]]], optional, default=None
The keys to skip in the EMA update. These parameters will be copied directly
from the model to the EMA model.
Expand All @@ -41,14 +39,9 @@ def __init__(
ema_decay: float,
ema_end_decay: float,
ema_anneal_end_step: int,
device_id: Optional[Union[int, torch.device]] = None,
skip_keys: Optional[Union[list[str], Set[str]]] = None,
):
) -> None:
self.model = self.deepcopy_model(model)
self.model.requires_grad_(False)

if device_id is not None:
self.model.to(device_id)

self.skip_keys: Union[list[str], set[str]] = skip_keys or set()
self.num_updates = 0
Expand All @@ -57,6 +50,8 @@ def __init__(
self.ema_end_decay = ema_end_decay
self.ema_anneal_end_step = ema_anneal_end_step

self._model_configured = False

@staticmethod
def deepcopy_model(model: torch.nn.Module) -> torch.nn.Module:
"""Deep copy the model.
Expand Down Expand Up @@ -93,8 +88,23 @@ def get_annealed_rate(
pct_remaining = 1 - curr_step / total_steps
return end - r * pct_remaining

def configure_model(self, device_id: Union[int, torch.device]) -> None:
"""Configure the model for EMA."""
if self._model_configured:
return

self.model.requires_grad_(False)
self.model.to(device_id)

self._model_configured = True

def step(self, new_model: torch.nn.Module) -> None:
"""Perform single EMA update step."""
if not self._model_configured:
raise RuntimeError(
"Model is not configured for EMA. Call `configure_model` first."
)

self._update_weights(new_model)
self._update_ema_decay()

Expand Down
39 changes: 16 additions & 23 deletions mmlearn/tasks/ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,36 +90,29 @@ def __init__(
self.modality = Modalities.get_modality(modality)
self.mask_generator = IJEPAMaskGenerator()

self.current_step = 0
self.total_steps = None

self.encoder = encoder
self.predictor = predictor

self.predictor.num_patches = encoder.patch_embed.num_patches
self.predictor.embed_dim = encoder.embed_dim
self.predictor.num_heads = encoder.num_heads

self.ema = ExponentialMovingAverage(
self.encoder,
ema_decay,
ema_decay_end,
ema_anneal_end_step,
device_id=self.device,
self.target_encoder = ExponentialMovingAverage(
self.encoder, ema_decay, ema_decay_end, ema_anneal_end_step
)

def configure_model(self) -> None:
"""Configure the model."""
self.ema.model.to(device=self.device, dtype=self.dtype)
self.target_encoder.configure_model(self.device)

def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Perform exponential moving average update of target encoder.

This is done right after the optimizer step, which comes just before `zero_grad`
to account for gradient accumulation.
This is done right after the ``optimizer.step()`, which comes just before
``optimizer.zero_grad()`` to account for gradient accumulation.
"""
if self.ema is not None:
self.ema.step(self.encoder)
if self.target_encoder is not None:
self.target_encoder.step(self.encoder)

def training_step(self, batch: dict[str, Any], batch_idx: int) -> torch.Tensor:
"""Perform a single training step.
Expand Down Expand Up @@ -200,10 +193,10 @@ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint : dict[str, Any]
The state dictionary to save the EMA state to.
"""
if self.ema is not None:
if self.target_encoder is not None:
checkpoint["ema_params"] = {
"decay": self.ema.decay,
"num_updates": self.ema.num_updates,
"decay": self.target_encoder.decay,
"num_updates": self.target_encoder.num_updates,
}

def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
Expand All @@ -214,12 +207,12 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint : dict[str, Any]
The state dictionary to restore the EMA state from.
"""
if "ema_params" in checkpoint and self.ema is not None:
if "ema_params" in checkpoint and self.target_encoder is not None:
ema_params = checkpoint.pop("ema_params")
self.ema.decay = ema_params["decay"]
self.ema.num_updates = ema_params["num_updates"]
self.target_encoder.decay = ema_params["decay"]
self.target_encoder.num_updates = ema_params["num_updates"]

self.ema.restore(self.encoder)
self.target_encoder.restore(self.encoder)

def _shared_step(
self, batch: dict[str, Any], batch_idx: int, step_type: str
Expand All @@ -237,7 +230,7 @@ def _shared_step(

# Forward pass through target encoder to get h
with torch.no_grad():
h = self.ema.model(batch)[0]
h = self.target_encoder.model(batch)[0]
h = F.layer_norm(h, h.size()[-1:])
h_masked = apply_masks(h, predictor_masks)
h_masked = repeat_interleave_batch(
Expand All @@ -252,7 +245,7 @@ def _shared_step(
z_pred = self.predictor(z, encoder_masks, predictor_masks)

if step_type == "train":
self.log("train/ema_decay", self.ema.decay, prog_bar=True)
self.log("train/ema_decay", self.target_encoder.decay, prog_bar=True)

if self.loss_fn is not None and (
step_type == "train"
Expand Down
6 changes: 3 additions & 3 deletions projects/ijepa/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def ijepa_transforms(
crop_size: int = 224,
crop_scale: tuple = (0.3, 1.0),
color_jitter: float = 0.0,
color_jitter_strength: float = 0.0,
horizontal_flip: bool = False,
color_distortion: bool = False,
gaussian_blur: bool = False,
Expand All @@ -31,7 +31,7 @@ def ijepa_transforms(
Size of the image crop.
crop_scale : tuple, default=(0.3, 1.0)
Range for the random resized crop scaling.
color_jitter : float, default=0.0
color_jitter_strength : float, default=0.0
Strength of color jitter.
horizontal_flip : bool, default=False
Whether to apply random horizontal flip.
Expand Down Expand Up @@ -89,7 +89,7 @@ def __call__(self, img):
if horizontal_flip:
transforms_list.append(transforms.RandomHorizontalFlip())
if color_distortion:
transforms_list.append(get_color_distortion(s=color_jitter))
transforms_list.append(get_color_distortion(s=color_jitter_strength))
if gaussian_blur:
transforms_list.append(GaussianBlur(p=0.5))
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@ defaults:
- /datasets/transforms@datasets.train.transform: ijepa_transforms
- /datasets@datasets.val: ImageNet
- /datasets/transforms@datasets.val.transform: ijepa_transforms
- /modules/encoders@task.encoder: vit_base
- /modules/encoders@task.encoder: vit_small
- /modules/encoders@task.predictor: vit_predictor
- /modules/optimizers@task.optimizer: AdamW
- /modules/lr_schedulers@task.lr_scheduler.scheduler: CosineAnnealingLR
- /modules/lr_schedulers@task.lr_scheduler.scheduler: linear_warmup_cosine_annealing_lr
- /trainer/callbacks@trainer.callbacks.lr_monitor: LearningRateMonitor
- /trainer/callbacks@trainer.callbacks.model_checkpoint: ModelCheckpoint
- /trainer/callbacks@trainer.callbacks.early_stopping: EarlyStopping
- /trainer/callbacks@trainer.callbacks.model_summary: ModelSummary
- /trainer/logger@trainer.logger.wandb: WandbLogger
- override /task: IJEPA
Expand All @@ -20,6 +19,16 @@ defaults:
seed: 0

datasets:
train:
transform:
color_jitter_strength: 0.4
horizontal_flip: true
color_distortion: true
gaussian_blur: false
crop_scale:
- 0.3
- 1.0
crop_size: 224
val:
split: val
transform:
Expand All @@ -28,45 +37,50 @@ datasets:
dataloader:
train:
batch_size: 256
num_workers: 10
num_workers: 8
pin_memory: true
drop_last: true
val:
batch_size: 256
num_workers: 10
num_workers: 8
pin_memory: false

task:
ema_decay: 0.996
ema_decay_end: 1.0
ema_anneal_end_step: ${task.lr_scheduler.scheduler.max_steps}
predictor:
kwargs:
embed_dim: 384
predictor_embed_dim: 384
depth: 6
num_heads: 6
optimizer:
betas:
- 0.9
- 0.999
lr: 1.0e-3
weight_decay: 0.05
eps: 1.0e-8
lr_scheduler:
scheduler:
T_max: ${trainer.max_epochs}
warmup_steps: 12_510
max_steps: 125_100
start_factor: 0.2
eta_min: 1.0e-6
extras:
interval: epoch
interval: step

trainer:
max_epochs: 300
precision: 16-mixed
max_epochs: 100
precision: bf16-mixed
deterministic: False
benchmark: True
sync_batchnorm: False # Set to True if using DDP with batchnorm
log_every_n_steps: 100
accumulate_grad_batches: 4
log_every_n_steps: 10
accumulate_grad_batches: 1
check_val_every_n_epoch: 1
callbacks:
model_checkpoint:
monitor: val/loss
save_top_k: 1
save_last: True
every_n_epochs: 1
dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on Vector SLURM environment
early_stopping:
monitor: val/loss
patience: 5
mode: min
every_n_epochs: 10
dirpath: /checkpoint/${oc.env:USER}/${oc.env:SLURM_JOB_ID} # only works on VI's SLURM environment
model_summary:
max_depth: 2

Expand Down
2 changes: 1 addition & 1 deletion tests/modules/test_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_ema() -> None:
ema_end_decay=0.9999,
ema_anneal_end_step=300000,
)
ema.model = ema.model.cpu() # for testing purposes
ema.configure_model(device_id=torch.device("cpu"))

# test output between model and ema model
model_input = torch.rand(1, 3, 224, 224)
Expand Down
Loading