In [8]:
import torch

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim


torch.manual_seed(123);

In [6]:
class Config:
    def __init__(self, emb_dim):
        self.emb_dim = emb_dim

# Design long-term memory module in titan

In [183]:
class LongTermMemory(nn.Module):

    def __init__(self, config: Config):
        super().__init__()

        # these gonna update during test time
        self.Wk = nn.Linear(config.emb_dim, config.emb_dim)
        self.Wv = nn.Linear(config.emb_dim, config.emb_dim)

        # this gonna update during training the hole architecture
        self.Wq = nn.Linear(config.emb_dim, config.emb_dim)

        # this is the memory part of our model to save input information in it's parameters
        self.Mt = nn.Linear(config.emb_dim, config.emb_dim)
        """
        in begining we just set this to zero. the reason is we don't have any information about 
        it in initial step and just want to make it to be equal to loss value
        """
        self.St = 0        

        # adaptive learnable coefficients
        self.alpha = nn.Parameter(torch.rand(1)[0], requires_grad=True)
        self.etha = nn.Parameter(torch.rand(1)[0], requires_grad=True)
        self.tetha = nn.Parameter(torch.tensor([0.0])[0], requires_grad=True)

        # optimizer and loss function that is going to use for updating memory in test time
        self.loss_func = nn.MSELoss()
        self.optimizer = optim.AdamW(self.parameters())
        
    
    def memorize(self, X):
        """
        This function is used for saving input data into ``memory`` parameters
        for this to be end first we define previous ``memory`` and previous ``surprise``
        after that for each token we update the ``memory`` parameters. 
        """
        with torch.autocast(device_type="cpu", dtype=torch.float16):
            for Xt in X:
                # calculate the Key and Value matrices
                Kt = self.Wk(Xt)
                Vt = self.Wv(Xt)
                # Use MeanSquaredError(MSE) for calculating the loss function
                loss = self.loss_func(self.Mt(Kt), Vt)
                # Calculate the surprising value
                St = self.St * self.etha + (self.tetha * loss)
                # Updating the memory weights based on loss and surprise value 
                self.Mt.weight.data = ((1 - self.alpha) * self.Mt.weight.data) + St
                # optimize the model based on loss function to make Key and Value close to each other
                loss.backward()
                self.optimizer.step()
        
    def forward(self, Qt):
        Qt = self.Wq(Qt)
        return self.Mt(Qt)
        

class Core(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.net = nn.Sequential()

    def forward(self, x):
        ...


class PersistantMemory(nn.Module):
    def __init__(self, config: Config):
        super().__init__()

    def forward(self, x):
        ...


class Titan(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.memory = None

    def forward(self, x):
        ...

In [192]:
config = Config(emb_dim=128)

long_term_memory = LongTermMemory(config)

example_input = torch.randn(5, 128)


In [193]:
# update memory
long_term_memory.memorize(example_input)

In [194]:
long_term_memory(example_input).shape

torch.Size([5, 128])