In [None]:
# | default_exp utils/ema

# Imports

In [None]:
# | export

from torch import nn
from torch.nn.modules.module import _addindent

In [None]:
from copy import deepcopy

from vision_architectures.schedulers.sigmoid import SigmoidScheduler

# Implement the network

In [None]:
# | export


class EMA(nn.Module):
    def __init__(self, model: nn.Module, decay: float | object):
        super().__init__()

        self.model = model
        self.decay = decay

    def get_decay(self) -> float:
        if isinstance(self.decay, float):
            return self.decay
        return self.decay.get()

    def update_decay(self):
        if not isinstance(self.decay, float):
            self.decay.step()

    def forward(self, weights: nn.Module | dict) -> nn.Module:
        # Ensure weights is a state_dict
        if isinstance(weights, nn.Module):
            weights = weights.state_dict()

        # Sanity check
        if not set(weights.keys()).issubset(set(self.model.state_dict().keys())):
            raise ValueError("Weights do not match the EMA model's weights")

        # Get decay value
        decay = self.get_decay()

        # Perform EMA
        for name, param in self.model.named_parameters():
            if name not in weights:
                continue
            param.data = decay * param.data + (1 - decay) * weights[name].data

        # Update decay if it's a scheduler
        self.update_decay()

        return self.model

    def __repr__(self):
        return (
            "EMA(\n" f"  decay={_addindent(repr(self.decay), 2)}\n" f"  model={_addindent(repr(self.model), 2)}\n" ")"
        )

In [None]:
net1 = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10))
net2 = deepcopy(net1)
[param.data.fill_(1.0) for param in net1.parameters()]
[param.data.zero_() for param in net2.parameters()]

decay = 0.9

test = EMA(net1, decay)
display(test)
print(net1.state_dict()["0.weight"][0][0])

for _ in range(10):
    test(net2)
    print(net1.state_dict()["0.weight"][0][0])


[1;35mEMA[0m[1m([0m
  [33mdecay[0m=[1;36m0[0m[1;36m.9[0m
  [33mmodel[0m=[1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m10[0m, [33mout_features[0m=[1;36m20[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m20[0m, [33mout_features[0m=[1;36m10[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
[1m)[0m

tensor(1.)
tensor(0.9000)
tensor(0.8100)
tensor(0.7290)
tensor(0.6561)
tensor(0.5905)
tensor(0.5314)
tensor(0.4783)
tensor(0.4305)
tensor(0.3874)
tensor(0.3487)


In [None]:
net1 = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10))
net2 = deepcopy(net1)
[param.data.fill_(1.0) for param in net1.parameters()]
[param.data.zero_() for param in net2.parameters()]

decay = SigmoidScheduler(min_y=0.9)
decay.set_num_steps(10)

test = EMA(net1, decay)
display(test)
print(net1.state_dict()["0.weight"][0][0])

for _ in range(10):
    test(net2)
    print(net1.state_dict()["0.weight"][0][0], test.get_decay())


[1;35mEMA[0m[1m([0m
  [33mdecay[0m=[1;35mSigmoidScheduler[0m[1m([0m
    [33mmin_y[0m=[1;36m0[0m[1;36m.9[0m
    [33mmax_y[0m=[1;36m1[0m[1;36m.0[0m
    [33mmin_x[0m=[1;36m-7[0m
    [33mmax_x[0m=[1;36m7[0m
    [33mnum_steps[0m=[1;36m10[0m
  [1m)[0m
  [33mmodel[0m=[1;35mSequential[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m10[0m, [33mout_features[0m=[1;36m20[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0m[1;36m2[0m[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m20[0m, [33mout_features[0m=[1;36m10[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
[1m)[0m

tensor(1.)
tensor(0.9001) 0.9003684239899437
tensor(0.8104) 0.9014774031693273
tensor(0.7306) 0.9057324175898869
tensor(0.6617) 0.9197816111441418
tensor(0.6086) 0.9500000000000001
tensor(0.5782) 0.9802183888558582
tensor(0.5668) 0.9942675824101131
tensor(0.5635) 0.9985225968306727
tensor(0.5627) 0.9996315760100564
tensor(0.5625) 0.9999088948805599


# nbdev

In [None]:
!nbdev_export