In [1]:
from lib.DataGenerator import DataGenerator


In [2]:
memory_size = 1024
batch_size = 128
device = 'cuda'
history_size = 1000

generator = DataGenerator(
    '../../datasets/bas_clean.txt',
    batch_size=batch_size,
    history_size=history_size,
    device=device
)

In [3]:
x, _ = generator.get_batch()
print(x.shape) # (batch_size, history_size, 256)

torch.Size([128, 1000, 256])


In [ ]:
class Net(torch.nn.Module):
    def __init__(self,batch_size,input_size,memory_size,device):
        super(Net, self).__init__()
        self.batch_size = batch_size
        self.input_size = input_size
        self.memory_size = memory_size
        self.device = device
        hsize = input_size + memory_size
        self.swish = torch.nn.SiLU()
        self.l1_f = torch.nn.Linear(input_size+memory_size,hsize).to(device)
        self.l2_f = torch.nn.Linear(hsize,memory_size).to(device).to(device)
        self.g = 
        self.memory = torch.zeros((batch_size,memory_size),device=device)

    def reset(self):
        self.memory = torch.zeros((self.batch_size,self.memory_size),device=self.device)

    def forward(self,x):
        self.memory = self.memory.detach() #detach the memory to avoid backpropagating through the memory in time
        x = torch.cat((x,self.memory),1)
        self.memory = torch.tanh(self.f(x))
        return self.memory

    def reconstruct(self, noise=0):
        x = self.g(self.memory + noise * torch.randn_like(self.memory))
        pred_memory = x[:,:self.memory_size]
        pred_memory = torch.tanh(pred_memory)
        pred_input = x[:,self.memory_size:]
        return pred_memory, pred_input


In [10]:

net = Net(
    batch_size=batch_size,
    input_size=256,
    memory_size=memory_size,
    device=device
)
import torch
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

cross_entropy = torch.nn.CrossEntropyLoss()

for it in range(1000):
    x, _ = generator.get_batch()
    for i in range(history_size): #min(it,history_size)): #maybe curriculum learning by increasing this number slowly over time because of the issue that m0 is not trained to be resilient against noise and can only be indirectly trained by the memory
        optimizer.zero_grad()
        current_x = x[:, i, :] # (batch_size, 256) #target value for pred_input
        old_mem = net.memory.clone().detach()
        mem = net(current_x) #this is the current memory that has been produced by the network
        pred_memory, pred_input = net.reconstruct(noise = 0.01) #pred_memory is in -1, 1 and pred_input is logits
        loss_memory = torch.mean((old_mem - pred_memory) ** 2) #for the memory we use mse loss
        loss_input = cross_entropy(pred_input, current_x.argmax(dim=1)) #for the input we use cross entropy loss
        loss = loss_memory + loss_input
        loss.backward()
        optimizer.step()
        print(f'it: {it}, i: {i}, loss: {loss.item()}, loss_memory: {loss_memory.item()}, loss_input: {loss_input.item()}')
        
    

layer 0, input size 1280, output size 1406
layer 1, input size 1406, output size 1544
layer 2, input size 1544, output size 1697
layer 3, input size 1697, output size 1864
layer 0, input size 1024, output size 1024
layer 1, input size 1024, output size 1024
layer 2, input size 1024, output size 1024
layer 3, input size 1024, output size 1024
layer 0, input size 1024, output size 776
layer 1, input size 776, output size 588
layer 2, input size 588, output size 445
layer 3, input size 445, output size 337
it: 0, i: 0, loss: 5.541074752807617, loss_memory: 0.00034740008413791656, loss_input: 5.540727138519287
it: 0, i: 1, loss: 5.530224323272705, loss_memory: 0.00016194673662539572, loss_input: 5.530062198638916
it: 0, i: 2, loss: 5.5073394775390625, loss_memory: 0.00015232546138577163, loss_input: 5.507187366485596
it: 0, i: 3, loss: 5.467586040496826, loss_memory: 0.0003743020643014461, loss_input: 5.467211723327637
it: 0, i: 4, loss: 4.6797051429748535, loss_memory: 0.00163646100554615

KeyboardInterrupt: 

In [ ]:
#Problems we are running into right now:
#-sometimes, the input reconstruction is not learned in favor of learning to reconstruct the memory. maybe we need to add a minimal change loss to the memory.
#-the memory does not seem to be able to recover from noise. purely theoretically, the memory is only indirectly trained to be resilient against noise:
#    -imagine the memory has a complex nonzero state
#idea for a new trial for this kind of network:
#input size is a number x
#memory size is a number y=x*n
#in theory, the memory should be able to store n inputs. we can test that by showing the network n inputs and then see if it can reconstruct them.
#this should give us a good idea of how well the memory is able to store information.
#it is also a much simpler task than natural language processing and should be easier to debug.

#another problem with the memory is that if it is left unbounded, it will grow indefinitely. this is why we added tanh as the activation function for the memory.

#another idea is to add a timing thing to the memory, so that the memory can evolve with timing instead of with reading the input, which should be more easy to reverse. ah. this is a good idea. we can add a timing network that takes the old memory state and
#advances the time for the memory. then another network integrates the input to the memory, with a gating mechanism that is controlled by the integration network. 
#then this new memory state needs to reconstruct the old memory state.
#this way, it is observable whether the memory also learns to perform sequence prediction implicitly.
#it could also maybe nudge the network more towards doing sequence prediciton, as the temporal evolution of the memory is easier to reconstruct than the input, shifting the memory more towards finding patterns in the input.

#to summarize, these are the new ideas from all the problems we have encountered:
#-add a minimal change loss to the memory
#-add a timing network that advances the memory state:
#    -the timing network should take the old memory state and advance the time
#    -the integration network should integrate the input to the memory, with a gating mechanism that is controlled by the integration network
#    -the new memory state should reconstruct the old memory state
#-add a loss that nudges the network more towards doing sequence prediction, as the temporal evolution of the memory is easier to reconstruct than the input, shifting the memory more towards finding patterns in the input

#Question: is the memory currently trying to find patterns in the input? like, it needs to compress the previous memory state and the current input into a vector of the same size as the memory state. does it implictly learn to find patterns in the input this way? if it does find patterns in the input sequence, it doesn't need to integrate new information but rather tell the decoding network where it is, so that the decoding network can deconstruct the memory state into the previous memory state and the current input. 

#how do we make this requirement more explicit? the memory state should be a combination of "where we are" for implict knowledge in the decoder, as well as explicit knowledge, like variables. 
#the implict knowledge might be represented by a multi-state automaton, and the explicit knowledge might be represented by a set of variables whose types and number are a function of the current state of the multi-state automaton.
#"where we are" is both influenced by the time step and by the input. the variables are timeless and might just change in interpretation based on the "where we are" part.
#since the number of variables is not fixed, but the memory size is fixed, we will inevitably run into the problem that the memory size is a problem. my hope is that up to a certain degree, this will allow for pattern learning, as the network will have to compress its data into a fixed size.