# Metalearned Neural Memory

<img src="img/MNM.png" width="400" />

In [2]:
import numpy as np
import torch
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
# define model controller
class Controller(torch.nn.Module):
    
    '''
    The controller uses an LSTM architecture
    '''
    
    def __init__(self, in_size, hidden_size, num_layers):
        
        self.lstm = torch.nn.LSTMCell(
            input_size=in_size,
            hidden_size=hidden_size,
            num_layers = num_layers
        )
        
        
    
    def forward(self, x, v_r):
        
        """
        input:
            p_r - prev read value from memory function
            x - next input
            
        output:
            beta - weight to decide size of memory update
            v_w - value that the memory function learns
            k_r - key used to read from the memory function
            k_w  - key used to write to the memory function
            y - controller output
        """
        
        return beta, v_w, k_r, k_w, y

In [None]:
# define memory function

class MemoryFunction(torch.nn.Module):
    
    def __init__(self):
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        self.loss = torch.nn.MSELoss()
        
    def update(self, beta, pred_v_w, v_w)):
        
        
    def forward(self, k_r, k_w):
        
        """
        input: 
            k_r - read key given to read from memory function
            k_w - write key given to write to the memory function
            
        output:
            
            v_w - write value returned as to update the network via back prop
            v_r - value to be passed for output
        """
        
        return v_w, v_r
    

In [None]:
# define metalearned neural memory module
class MNM(torch.nn.Module):
    
    def __init__(self):
        
        self.controller = Controller()
        self.memory_function = MemoryFunction()
        
        # controller hidden state
        self.h_t = torch.randn()
        
        # previous output
        self.v_t = torch.randn()
        
    def forward(self):
        
        # forward pass on controller
        beta, v_w, k_r, k_w, y = self.controller(x, v_r)
        
        # forward pass on memory function
        pred_v_w, v_r = self.memory_function(k_r, k_w)
        
        # update memory function
        self.memory_function.update(beta, pred_v_w, v_w)
        
        return v_r
    
    def optimize(self):
        pass

In [None]:
# define synthetic dictionary inference task
class SDIT(torch.utils.data.Dataset):
    
    """Synthetic Dictionary Inference Task"""

    def __init__(self, num_sequences=1000):
        self.alphabet = ["abcdefghijklmnopqrstuvwxyz"]
        self.generate_data(num_sequences)
    
    def generate_data(self, num_sequences):
        
        self.partitioned_alphabet = self.partition_alphabet()
        self.mappings = self.generate_mappings()
        self.sample_sequences = self.generate_sample_sequences(num_sequences)
        
    def partition_alphabet(self):
        
        # split the alphabet into two groups
        
        pass
    
    def generate_mappings(self):
        
        # generate a one to one mapping of alphabet partitions
        
        pass
    
    def generate_sample_sequences(self, num_sequences):
        
        # pick random pairs of letters from the alphabet
        # with a particular sequence length
        # and generate their corresponding mappings
        
        pass

    def __len__(self):
        return len(self.sample_sequences)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        return self.sample_sequences[idx]