In [14]:
from __future__ import print_function
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from torchviz import make_dot
import pdb
import numpy as np

class G(nn.Module):
    def __init__(self, input_shape, output_shape, hidden_size=10):
        super(G, self).__init__()
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.g1 = nn.Linear(np.prod(input_shape), hidden_size, bias=True)
        self.g2 = nn.Linear(hidden_size, np.prod(output_shape), bias=True)
        self.act = nn.Tanh()

    def forward(self, x):
        assert x.shape == self.input_shape
        x = x.flatten()
        x = F.relu(self.g1(x))
        x = self.g2(x)
        x = self.act(x)
        return x
    
class Unit(nn.Module):
    def __init__(self, name, layer_index, mu_shape, mu_next_shape, device):
        super(Unit, self).__init__()
        self.device = device
        self.name = name
        self.mu_shape = mu_shape
        self.mu_next_shape = mu_next_shape

        self.g = G(input_shape=mu_next_shape, output_shape=mu_shape)
        self.f = G(input_shape=mu_shape, output_shape=mu_shape)
        self.loss_g = 0
        self.loss_f = 0

        lr = 0.2
        self.optimizer_g = optim.SGD(self.g.parameters(), lr=lr, momentum=0)
        self.optimizer_f = optim.SGD(self.f.parameters(), lr=lr, momentum=0)

        # this unit's state
        self.previous_mu = torch.zeros(mu_shape).to(device)
        self.mu = torch.zeros(mu_shape).to(device)
        self.mu_bar = torch.zeros(mu_shape).to(device)
        self.mu_hat = torch.zeros(mu_shape).to(device)

        # next unit's mu
        self.mu_next = torch.zeros(mu_next_shape).to(device)

        # buffer for collecting mu
        self.mu_buffer = SlidingWindowBuffer(mu_shape[0])

#     def pool(self, buffer):
#         x = np.reshape(np.array([b.detach().numpy() for b in buffer]), (1, 1, self.t_sample * self.temporal_pooling_size))
#         x = torch.nn.functional.avg_pool1d(torch.tensor(x), kernel_size=self.temporal_pooling_size)
#         return x.reshape(self.t_sample)

    def add_mu_item(self, mu_item):
        x = self.mu_buffer.append_item(mu_item)
        if x is not None:
            self.mu = torch.tensor(mu_buffer.buffer)
        
    def set_mu_next(self, mu_next):
        self.mu_next = mu_next

    def before_step(self):
        pass

    def compute_predictions(self):
        self.mu_bar = self.f(self.previous_mu)
        self.mu_hat = self.g(self.mu_next)

    def train(self):
        self.g.train()
        self.optimizer_g.zero_grad()
        self.mu_hat = self.g(self.mu_next)
        self.loss_g = F.mse_loss(self.mu_hat, self.mu - self.mu_bar)
        self.loss_g.backward()
        self.optimizer_g.step()

        self.f.train()
        self.optimizer_f.zero_grad()
        self.mu_bar = self.f(self.previous_mu)
        self.loss_f = F.mse_loss(self.mu_bar, self.mu)
        self.loss_f.backward()
        self.optimizer_f.step()

        self.previous_mu = self.mu.detach()

    def history(self):
        return [self.loss_g, self.loss_f, self.mu_bar[-1], self.mu_hat[-1], self.mu[-1], self.mu_bar[-1] + self.mu_hat[-1]]


class UnitStack(nn.Module):
    def __init__(self, units):
        super(UnitStack, self).__init__()
        self.units = units
        self.unit_count = len(units)

    def step(self, mu_item, mu_awareness, train=True):
        # before step initialization
        [self.units[layer].before_step() for layer in range(self.unit_count)]
        
        # forward error propagation
        self.units[0].add_mu_item(mu_item)
        self.units[-1].add_mu_item(mu_awareness)

        for layer in range(1, self.unit_count):
            # compute mu using previous layer's predictions
            # mu is part of the signal the previous layer could not predict
            self.units[layer].add_mu_item((self.units[layer - 1].mu - (self.units[layer - 1].mu_hat + self.units[layer - 1].mu_bar)).detach()[-1])
                                                               
            self.units[layer - 1].set_mu_next(self.units[layer].mu)

        # backward flow of predictions
        [self.units[layer].compute_predictions() for layer in range(self.unit_count - 1, -1, -1)]

        # train
        if train:
            for layer in range(self.unit_count):
                self.units[layer].train()
        
        # return stats
        return [self.units[layer].history() for layer in range(self.unit_count)]
        
class SlidingWindowBuffer(object):
    def __init__(self, item_count):
        self.item_count = item_count
        self.buffer = []
        
    def append_item(self, item):
        # return None while gathering initial items
        if len(self.buffer) < self.item_count - 1:
            self.buffer.append(item)
            return None
        
        # once enough items, convert to np.array
        elif len(self.buffer) == self.item_count - 1:
            self.buffer.append(item)
            self.buffer = np.array(self.buffer)
            
        else:
            self.buffer = np.roll(self.buffer, -1, axis=0)
            self.buffer[-1] = item

        return self.buffer

class SampleDataPointsGenerator(object):
    def __init__(self, shape=(1,)):
        self.index = 0
        self.count = np.empty(shape).size
        
    def __next__(self):
        self.index += 1
        if self.count == 1:
            return np.sin(self.index/10.0 + np.random.random_sample() * 0.0) * np.cos(self.index/25.0) + np.random.random_sample() * 0.0
        elif self.count == 2:
            return [
                np.cos(self.index/10.0 + np.random.random_sample() * 0.2) * np.sin(self.index/5.0),
                np.sin(self.index/10.0 + np.random.random_sample() * 0.2) * np.cos(self.index/20.0)
            ]

def plot_history(loss_history, title=None):
    loss_history = np.array(loss_history)
    fig = plt.figure(figsize=(15,12))
    fig.suptitle(title, fontsize=16)
    plt.plot(loss_history[:, 0],"--",label='loss_g')
    plt.plot(loss_history[:, 1],"--",label='loss_f')
    plt.plot(loss_history[:, 2],"-",label='mu_bar',linewidth=1,alpha = 0.3)
    plt.plot(loss_history[:, 3],"-",label='mu_hat',linewidth=1,alpha = 0.3)
    plt.plot(loss_history[:, 4],"-",label='mu', linewidth=2)
    plt.plot(loss_history[:, 5],"-",label='mu_pred',linewidth=2, alpha = 0.5)

    plt.legend()
    plt.show()


In [15]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.set_num_threads(12)

t_sample = 10
mu_shape = (t_sample, 1)
unit_count = 10
units = [Unit(name="unit{}".format(i), layer_index=i, mu_shape=mu_shape, mu_next_shape=mu_shape, device=device) for i in range(unit_count)]

network = UnitStack(units)

loss_history = []
mu_awareness = np.ones((t_sample,))

mu_awareness = torch.tensor(mu_awareness, requires_grad=True).float().to(device)
data_generator = SampleDataPointsGenerator()

In [16]:
# TRAIN

for i in range(5000):
    loss = network.step(mu_item=next(data_generator), mu_awareness=mu_awareness, train=True)
    # print(loss)

    loss_history.append(loss)
        
    if (i+1) % 500 == 0:
        [plot_history(np.array(loss_history)[:, i, :], title='unit {}'.format(i)) for i in range(unit_count)]

        loss_history = []

print("==================")

RuntimeError: input and target shapes do not match: input [10], target [10 x 1] at /Users/soumith/miniconda2/conda-bld/pytorch_1532623076075/work/aten/src/THNN/generic/MSECriterion.c:12

In [1]:
import numpy as np
s = torch.ones(10,2,3)
s.view(10, 6)

NameError: name 'torch' is not defined

In [1]:
%connect_info

{
  "shell_port": 53880,
  "iopub_port": 53881,
  "stdin_port": 53882,
  "control_port": 53883,
  "hb_port": 53884,
  "ip": "127.0.0.1",
  "key": "8ec79334-75b11a739f32f5ee52a950e2",
  "transport": "tcp",
  "signature_scheme": "hmac-sha256",
  "kernel_name": ""
}

Paste the above JSON into a file, and connect with:
    $> jupyter <app> --existing <file>
or, if you are local, you can connect with just:
    $> jupyter <app> --existing kernel-e5d06227-3a46-4532-af3b-cc0b7fd5cc2d.json
or even just:
    $> jupyter <app> --existing
if this is the most recent Jupyter kernel you have started.


In [7]:
exit