In [1]:
using Flux, CuArrays
using Flux: onehot, crossentropy, throttle, argmax
using Flux.Data: Tree, children, isleaf

## Recursive Networks

This notebook is about tree-structured data, such as parse trees used to represent sentences of english.

Here's an example of a tree, containing letters of the string 'abc'.

In [2]:
('a', ('b', 'c'))

('a', ('b', 'c'))

The tree has two branches; one ends in `'a'`, and the other is another tree containing `'b'` and `'c'`. In a neural network, we would represent the characters as numeric arrays, perhaps by one-hot encoding them.

In [3]:
alphabet = ['a', 'b', 'c']
a, b, c = [onehot(char, alphabet) for char in ('a', 'b', 'c')]
tree = (a, (b, c))

(Bool[true, false, false], (Bool[false, true, false], Bool[false, false, true]))

Simple neural networks deal with data of *fixed size*, like images. But more complex models that deal with language must handle both very long and very short sentences, meaning different sizes and shapes of trees.

How can we deal with trees of any size? For simplicity, let's just look at the sub-tree `(b, c)`. One way to simplify the problem is to combine `b` and `c` into a single vector of length 3.

In [4]:
shrink = Dense(6, 3)
combine(x, y) = shrink([a; b])
combine(b, c)

Tracked 3-element Array{Float64,1}:
 -0.147115
  1.31421 
 -0.35475 

We've effectively compressed the tree `(b, c)` into a simpler object that looks a bit like a character encoding; but it stores information about a tree, rather than a single character. Let's put this back into our original tree.

In [5]:
bc = combine(b, c)
(a, bc)

(Bool[true, false, false], param([-0.147115, 1.31421, -0.35475]))

You might notice that our original tree now looks a lot simpler, as it only has two leaves -- and because of that, we can apply the same operation again!

In [6]:
combine(a, bc)

Tracked 3-element Array{Float64,1}:
 -0.147115
  1.31421 
 -0.35475 

We've reduce our tree to a simple vector. You can see how the same logic might apply to trees of any size or shape; indeed, we can create a function to do exactly that, called a *recursive neural network*.

In [7]:
net(a) = a
net(a::Tuple) = combine(net(a[1]), net(a[2]))

net(tree)

Tracked 3-element Array{Float64,1}:
 -0.147115
  1.31421 
 -0.35475 

The `net` function walks over the tree and does what we did above automatically. It works over trees of any shape.

In [8]:
net(((b, c), (a, (c, b))))

Tracked 3-element Array{Float64,1}:
 -0.147115
  1.31421 
 -0.35475 

Now we just need to map the tree description created by `net` into a useful form. For example, we might want to carry out a sentiment analysis and ask whether `'abc'` sounds positive or negative. In that case we can follow the recursive net with a simple two-class logistic regression as usual.

In [9]:
model = Chain(net, Dense(3, 2), softmax)
model(tree)

Tracked 2-element Array{Float64,1}:
 0.487091
 0.512909

## The Stanford Sentiment Treebank

In [10]:
include("data.jl");

The Stanford Sentiment Treebank stores parse trees for sentences scraped from movie reviews, like the following. Each sub-tree is assigned a sentiment score from 1 to 5 marking how positive it is. Scores can very a lot within the tree, and a model needs to show sensitivity to things like negation; for example, a tree like `("not", X)` will probably have the opposite sentiment of `X` (consider `X = ("very", "good")`).

In [11]:
traintrees[1]

Tree{Any}
(3, nothing)
├─ (2, nothing)
│  ├─ (2, "The")
│  └─ (2, "Rock")
└─ (4, nothing)
   ├─ (3, nothing)
   │  ├─ (2, "is")
   │  └─ (4, nothing)
   │     ├─ (2, "destined")
   │     └─ (2, nothing)
   │        ├─ (2, nothing)
   │        │  ├─ (2, nothing)
   │        │  │  ├─ (2, nothing)
   │        │  │  │  ├─ (2, "to")
   │        │  │  │  └─ (2, nothing)
   │        │  │  │     ├─ (2, "be")
   │        │  │  │     └─ (2, nothing)
   │        │  │  │        ├─ (2, "the")
   │        │  │  │        └─ (2, nothing)
   │        │  │  │           ├─ (2, "21st")
   │        │  │  │           └─ (2, nothing)
   │        │  │  │              ├─ (2, nothing)
   │        │  │  │              │  ├─ (2, "Century")
   │        │  │  │              │  └─ (2, "'s")
   │        │  │  │              └─ (2, nothing)
   │        │  │  │                 ├─ (3, "new")
   │        │  │  │                 └─ (2, nothing)
   │        │  │  │                    ├─ (2, "``")
   │        │  │  │       

We'll one-hot-encode each word that comes up, with respect to all other words in the training set.

In [12]:
really = onehot("really", alphabet)

8737-element Flux.OneHotVector:
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
     ⋮
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false
 false

Here's the sentence above encoded this way.

In [13]:
train[1]

Tree{Any}
(nothing, Bool[false, false, false, true, false])
├─ (nothing, Bool[false, false, true, false, false])
│  ├─ (Bool[true, false, false, false, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false], Bool[false, false, true, false, false])
│  └─ (Bool[false, true, false, false, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false], Bool[false, false, true, false, false])
└─ (nothing, Bool[false, false, false, false, true])
   ├─ (nothing, Bool[false, false, false, true, false])
   │  ├─ (Bool[false, false, true, false, false, false, false, false, false, false  …  false, false, false, false, false, false, false, false, false, false], Bool[false, false, true, false, false])
   │  └─ (nothing, Bool[false, false, false, false, true])
   │     ├─ (Bool[false, false, false, true, false, false, false, false, false, false  …  false, false, false, false, false, false,

To avoid working with 8,000-dimensional vectors, we'll immediately multiple by an *embedding matrix* that simply compresses the word vector.

In [14]:
N = 300
embedding = cu(param(randn(N, length(alphabet))))

really = embedding*really

Tracked 300-element CuArray{Float32,1}:
 -0.331456 
  0.0503743
  1.4798   
  0.538446 
 -0.868414 
 -0.126501 
 -0.396232 
 -0.131381 
  0.393741 
  0.924209 
 -0.383779 
 -0.846226 
 -0.606682 
  ⋮        
  1.83946  
 -0.358464 
  0.870228 
  0.425165 
  0.7949   
  0.267819 
  1.12521  
 -0.103925 
 -1.26151  
 -0.0358396
 -0.462601 
  0.068873 

We need a way to combine tokens / phrases, just as in the original example.

In [15]:
W = cu(Dense(2N, N, relu))
combine(a, b) = W([a; b])

combine(really, really)

Tracked 300-element CuArray{Float32,1}:
 1.21657 
 0.876515
 0.0     
 1.00717 
 0.0     
 0.0     
 0.0     
 0.169884
 0.832975
 1.13765 
 0.036399
 1.21097 
 0.0     
 ⋮       
 0.0     
 2.02336 
 1.44976 
 0.232272
 0.0     
 0.0     
 0.473249
 0.0     
 0.0     
 0.0     
 1.35685 
 0.987563

Once we've got the embedding (300D vector) for a word or phrase, we want to analyse its sentiment. We'll do that with a simple linear layer.

In [16]:
sentiment = cu(Chain(Dense(N, 5), softmax))

sentiment(really)

Tracked 5-element CuArray{Float32,1}:
 0.0973401
 0.286922 
 0.207826 
 0.123449 
 0.284463 

Now we can define our forward pass. It's a little more complex than the original example; because we want to predict the sentiment for each sub-tree, we need to carry the loss forward and sum it over the whole tree. So `forward` returns both an embedding an a loss for each subtree.

In [17]:
function forward(tree)
  if isleaf(tree)
    token, sent = tree.value
    phrase = embedding * token
    phrase, crossentropy(sentiment(phrase), cu(collect(sent)))
  else
    _, sent = tree.value
    c1, l1 = forward(tree[1])
    c2, l2 = forward(tree[2])
    phrase = combine(c1, c2)
    phrase, l1 + l2 + crossentropy(sentiment(phrase), cu(collect(sent)))
  end
end

loss(tree) = forward(map(cu, tree))[2]

loss(train[1])

186.12698f0 (tracked)

In [18]:
opt = ADAM(params(embedding, W, sentiment))
evalcb = () -> @show loss(train[1])

(::#13) (generic function with 1 method)

In [None]:
Flux.train!(loss, zip(train), opt,
            cb = throttle(evalcb, 10))

loss(train[1]) = 138.38686f0 (tracked)
loss(train[1]) = 97.96852f0 (tracked)
loss(train[1]) = 88.88923f0 (tracked)
loss(train[1]) = 88.09165f0 (tracked)
loss(train[1]) = 83.77608f0 (tracked)
loss(train[1]) = 79.10224f0 (tracked)
loss(train[1]) = 75.20916f0 (tracked)
loss(train[1]) = 70.89377f0 (tracked)
loss(train[1]) = 71.374725f0 (tracked)
loss(train[1]) = 70.74858f0 (tracked)
loss(train[1]) = 70.881355f0 (tracked)
loss(train[1]) = 70.56669f0 (tracked)
loss(train[1]) = 69.135155f0 (tracked)
loss(train[1]) = 66.57009f0 (tracked)
loss(train[1]) = 66.10199f0 (tracked)
loss(train[1]) = 66.79412f0 (tracked)
loss(train[1]) = 67.255424f0 (tracked)
loss(train[1]) = 68.62694f0 (tracked)
loss(train[1]) = 63.694664f0 (tracked)
loss(train[1]) = 65.513145f0 (tracked)
loss(train[1]) = 67.66066f0 (tracked)
loss(train[1]) = 66.7731f0 (tracked)
loss(train[1]) = 63.33839f0 (tracked)
loss(train[1]) = 62.359272f0 (tracked)
loss(train[1]) = 61.217785f0 (tracked)
loss(train[1]) = 62.65793f0 (tracked)
loss

In [None]:
# open(io -> serialize(io, (alphabet, embedding, W)), "model-39.jls", "w")

Our prediction function is very simple, modelled on the original example.

In [None]:
phrase(x) = embedding*onehot(x, alphabet)
phrase(x::Tuple) = combine(phrase(x[1]), phrase(x[2]))
predict = Chain(phrase, sentiment, argmax)

As predicted, the network is able to learn that "not" negates, and can match our intuition about how positive or negative a phrase sounds.

In [None]:
predict(("very", "good"))

In [None]:
predict(("not", ("very", "good")))

In [None]:
predict(("utterly", "awful"))

In [None]:
predict(("not", ("utterly", "awful")))