## Base Lightning model

In [2]:
#| default_exp base_lightning
#| export
import lightning.pytorch as pl
import torch
import torch.nn.functional as F
from typing import Optional, Dict, Union

In [3]:
#| export
class BaseLightningModule(pl.LightningModule):
    
    def __init__(
        self,
        learning_rate: 1e-3
    ):

        super().__init__()

        self.save_hyperparameters()
        self.learning_rate = learning_rate

    def forward(
        self,
        x
    ):

        raise NotImplementedError("Implement the forward function")

    def clip_gradients(
        self,
        optimizer: Optional[torch.optim.Optimizer] = None,
        gradient_clip_val: Optional[float] = None,
        gradient_clip_algorithm: Optional[str] = None,
    ) -> None:

        """
            Clip / normalize gradients. Call this **after** .backward() and **before** .step().
            
            Supports:
            - norm-based clipping    (default: global norm)
            - value-based clipping   (element-wise hard clip)
            - set grad_clip_val=0 or None → no clipping
            
            Works safely even with mixed precision / gradient accumulation.
        """

        clip_val = self.grad_clip_val
        clip_algo = self.grad_clip_algorithm

        if clip_val is None or clip_val <= 0:
            return

        if optimizer is None:
            parameters = self.parameters()
        else:
            parameters = [p for group in optimizer.param_groups for p in group["params"]
                          if p.grad is not None]
            
        if not parameters:
            return

        if clip_algo == "value":
            torch.nn.utils.clip_grad_value_(parameters, clip_val)
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                parameters,
                max_norm=clip_val,
                norm_type=2.0, # L2
            )

            self.log("grad_norm", grad_norm, on_step=True, on_epoch=False, sync_dist=True)

    def log_dict_helper(
        self,
        metrics: Dict[str, Union[torch.Tensor, float, int]],
        prefix: str = "",
    ) -> None:

        """
        Logs a dictionary of metrics with smart control over step/epoch logging.
        
        - Automatically adds prefix (e.g. "train/", "val/")
        - Respects self.log_on = "step" | "epoch" | "both"
        - Handles tensors → scalars automatically
        - commit=None → Lightning decides (usually True on epoch end)
        """

        if not metrics:
            return

        on_step = self.log_on in ("step", "both")
        on_epoch = self.log_on in ("epoch", "both")

        logged = {}

        for k, v in metrics.items():
            if isinstance(v, torch.Tensor):

                v = v.detach()
                if v.numel() == 1:
                    v = v.item()

                else:
                    v = v.mean().item()

            key = f"{prefix}{k}" if prefix else k
            self.log(
                key,
                v,
                on_step=on_step,
                on_epoch=on_epoch,
                prog_bar=("loss" in k.lower() or "acc" in k.lower()),
                sync_dist=True,
            )

            logged[key] = v