# Simple repeat networks

In [109]:
import torch
import parameters_will
import torch.optim as optim
import RNN_Will as _model_
import numpy as np
import matplotlib.pyplot as plt
import pickle
import utils
import copy

import torch.nn as nn

%load_ext autoreload
%autoreload 2

params = parameters_will.default_params()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [151]:
# First let's create the new type of data. There is a freq choice that determines the initialisation, and a desired output
freqs = np.array([2, 5])
trial_len = params.data.min_length + np.random.randint(params.data.max_length - params.data.min_length)

outputs = np.zeros([trial_len, len(freqs)])
for (freq_counter, freq) in enumerate(freqs):
    outputs[freq-1::freq, freq_counter] = 1

input_dict = parameters_will.DotDict()
input_dict.freq = torch.from_numpy(freqs)
input_dict.outputs = torch.from_numpy(outputs)

params.model.num_freqs = 2

In [152]:
utils.generate_osc_data(params)

{'freq': tensor([2, 3]),
 'outputs': tensor([[0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
          0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 

In [153]:
params.model

{'h_size': 20,
 'hidden_act': 'relu',
 'hidden_init_learn': True,
 'hidden_init_std': 1,
 'transition_init': 'orthogonal',
 'output_act': 'sigmoid',
 'i_size': 1,
 't_size': 1,
 'linear_std': 1,
 'batch_size': 5,
 'num_inits': 2,
 'num_freqs': 2,
 '__class__': parameters_will.DotDict}

In [162]:
class Oscillator(_model_.VanillaRNN):
    def __init__(self, par):
        super().__init__(par)
        self.hidden_init = nn.Parameter(torch.zeros((self.par.num_freqs, self.par.h_size), dtype=torch.float32), requires_grad=self.par.hidden_init_learn)
        
    def forward_old(self, inputs, device='cpu'):
        T = inputs.outputs.size()[1]
        hs, preds, preactivations = torch.zeros([self.par.num_freqs, T, self.par.h_size]), torch.zeros([self.par.num_freqs, T]), torch.zeros([self.par.num_freqs, T, self.par.h_size])
        preactivations[:,0,:] = self.hidden_init
        hs[:,0,:] = self.activation(preactivations[:,0,:])
        preds[:,0:1] = self.out_activation(self.predict(hs[:,0,:]))
        
        for t in range(1,T):
            preactivations[:,t,:] = self.transition(hs[:,t-1,:])
            hs[:,t,:] = self.activation(preactivations[:,t,:])
            preds[:,t:t+1] = self.out_activation(self.predict(hs[:,0,:]))
        
        variable_dict = parameters_will.DotDict(
            {'hidden': hs,
             'pred': preds,
             'preactivations':preactivations
             })
        
        return variable_dict
        
    def forward(self, inputs, device='cpu'):
            T = inputs.outputs.size()[0]
            pre = self.hidden_init
            h  = self.activation(pre)
            pred = self.out_activation(self.predict(h))
            hs, preds, pres = [], [], []


            for t in range(T):
                pre = self.transition(h)
                h = self.activation(pre)
                pred = self.out_activation(self.predict(h))

                pres.append(pre)
                hs.append(h)
                preds.append(pred)

            variable_dict = parameters_will.DotDict(
                {'hidden': hs,
                 'pred': preds,
                 'preactivations':preactivations
                 })

            return variable_dict
        
model = Oscillator(params.model)
variables = model.forward(input_dict)

In [163]:
torch.stack(variables.pred).shape

torch.Size([204, 2, 1])

In [164]:
input_dict.outputs[:,:,None].shape

torch.Size([204, 2, 1])

In [166]:
def compute_losses_osc(model_in, model_out, model, par, device='cpu'):
    loss_fit = torch.sum(torch.pow(input_dict.outputs[:,:,None] - torch.stack(variables.pred), 2))
    loss_act = torch.sum(torch.pow((torch.stack(variables.hidden)), 2))
    return (loss_fit + par.act_weight*loss_act, loss_fit)

In [131]:
torch.sum(torch.pow((variables.hidden), 2))

tensor(0., grad_fn=<SumBackward0>)

In [123]:
input_dict.outputs.shape

torch.Size([2, 266])

In [125]:
variables.pred.shape

torch.Size([2, 266])

In [167]:
compute_losses_osc(input_dict, variables, model, params)

AttributeError: 'DotDict' object has no attribute 'act_weight'

# Now debugging stage

In [169]:
torch.autograd.set_detect_anomaly(True)

print('Booting up parameters')
# Set up our parameters
params = parameters_will.default_params()

print_iters = 100
save_iters = 10000

# make instance of model
print('Making model')
if params.data.oscillators:
    model = _model_.Oscillator(params.model)
else:
    model = _model_.VanillaRNN(params.model)
# put model to gpu (if available)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Make an ADAM optimizer
optimizer = optim.Adam(model.parameters(), lr=params.train.learning_rate, weight_decay=params.train.weight_decay)
min_loss = np.infty

if params.data.oscillators:
    generator = utils.generate_osc_data
else:
    generator = utils.generate_data
loss_func = _model_.compute_losses_torch

print('Starting Training')



Booting up parameters
Making model
Starting Training


In [178]:
model.transition.state_dict()

OrderedDict([('weight',
              tensor([[ 3.1030e-01,  2.4609e-01, -3.4927e-01, -4.1020e-02,  1.2722e-01,
                        2.0545e-01, -2.2556e-01, -2.8088e-01,  1.3100e-01, -1.4480e-01,
                       -9.6985e-02, -2.0294e-01,  1.4690e-01, -5.2792e-02,  4.1907e-01,
                        2.8068e-01, -2.7462e-01, -1.0354e-01,  1.6196e-01, -2.2684e-01],
                      [ 1.3890e-01, -4.5774e-01, -2.6264e-01,  1.3764e-02, -2.5479e-01,
                       -5.7501e-02,  5.3634e-01,  1.1849e-01,  6.4628e-02,  5.2944e-02,
                       -5.5103e-04, -1.9896e-01,  2.5179e-01, -5.9136e-02,  3.1987e-01,
                        1.0078e-01,  2.0040e-01,  1.1037e-01, -1.8953e-01, -1.3378e-01],
                      [ 5.5905e-01, -1.4606e-02, -4.6917e-03,  1.6568e-01,  2.0129e-01,
                        3.5144e-01, -9.1793e-02,  1.6667e-01,  1.4420e-01,  2.3949e-01,
                        3.2777e-01, -9.5097e-02,  1.8704e-01, -2.4320e-01, -3.3018e-01,
      