In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class CMSLayer(nn.Module):
    """
    One memory level:
      mem_t = alpha * mem_{t-1} + (1 - alpha) * MLP(x_t)
    """
    def __init__(self, dim, alpha):
        super().__init__()
        self.alpha = alpha
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, x, mem):
        # compute new update
        new_val = self.mlp(x)      # [B, D]

        # fast init (first token)
        if mem is None:
            mem = new_val
        else:
            mem = self.alpha * mem + (1 - self.alpha) * new_val

        return mem


In [2]:
class CMS(nn.Module):
    """
    CMS with L memory levels. Each level has its own alpha and MLP.
    Output: average of memory levels (paper default).
    """
    def __init__(self, dim, levels=3, alphas=None):
        super().__init__()
        if alphas is None:
            # default from paper: slowest first
            # e.g. [0.9, 0.7, 0.3]
            alphas = [0.2 + 0.6 * (i/(levels-1)) for i in range(levels)]
            alphas = list(reversed(alphas))

        assert len(alphas) == levels

        self.levels = nn.ModuleList([
            CMSLayer(dim, alpha=alphas[i]) for i in range(levels)
        ])

    def forward(self, x, memories):
        """
        x:          [B, D]
        memories:   list of length L (None or tensors [B, D])
        Return: (output, new_memories)
        """
        new_mems = []
        outs = []

        for i, layer in enumerate(self.levels):
            mem_i = None if memories is None else memories[i]
            new_mem = layer(x, mem_i)
            new_mems.append(new_mem)
            outs.append(new_mem)

        # paper: average the outputs
        final = sum(outs) / len(outs)
        return final, new_mems
