In [1]:
%matplotlib inline

## Recursive Network Exercise

In this exercise, you should train a recursive neural network which can estimate the _free energy_ of an _RNA secondary structure_. In biology, RNA sequences fold to so-called secondary structures and it is assumed that secondary structures are preferred which have little free energy. Free energy is minimized if base pairs are joined in stable pairs.

For this task, though, you do not need to know anything about the actual biological specifics. You can just train a recursive neural net which infers the correct energy (a simple scalar) from a given tree.

### Report

For the report, please describe the architecture that you used to solve the task and generate the following plot. After training the network, generate 100 further trees and record for each tree the size using the `recursive_oracle.tree_size` function and the error `abs(y - y_predicted)`. Plot error against tree size in a scatter plot.

<strong>Note:</strong> Please use the `exercise_sheet_template.tex` to generate your report. Your report is due on *Friday, March 15th, 10am* as single-page PDF to [aschulz@techfak.uni-bielefeld.de](mailto:aschulz@techfak.uni-bielefeld.de). Please start your e-mail subject with the words *[Deep Learning]*.

### Advice

Do not try to map directly to the energy because a one-dimensional encoding may carry too little information. Rather, apply a recursive neural network to a low-dimensional encoding space first and then another neural network which predicts from the encoding the free energy.

Further, this predictive task is not super easy, so do not try to achieve perfect error values. If you manage to stay consistently below an error of 1 this is already a good result.

In [2]:
# For this exercise, we already provide data generation function (an 'oracle')
# which we can use
from recursive_oracle import generate_rna_tree

# let's have a look at an example tree and its energy value.
# Executing this cell multiple times will yield different trees.
x, y = generate_rna_tree()
print('the tree %s has energy value %g' % (str(x), y))

the tree pair(g, pair(g, pair(c, hairpin(a, hairpin(a, hairpin(g, hairpin_end(g)))), g), u), u) has energy value 2.06355


In [67]:
import torch
from torch.autograd import Variable

class RecursiveNet(torch.nn.Module):
    def __init__(self, dim, arity_alphabet):
        
        super(RecursiveNet, self).__init__()
        self.dim = int(dim)
        self.arity_alphabet = arity_alphabet
        self.embed = torch.nn.Embedding(len(arity_alphabet), dim)
        self.noise = torch.nn.Embedding(len(arity_alphabet), dim)
        self.constants = torch.nn.ParameterDict()
        self.layers = torch.nn.ModuleDict()
        self.l1 = torch.nn.Linear(2 * dim, dim)

        for symbol, arity in self.arity_alphabet.items():
            if(arity == 0):
                self.constants[symbol] = torch.nn.Parameter(torch.randn(self.dim))
            else:
                self.layers[symbol] = torch.nn.Linear(arity * self.dim, self.dim)
        self.sigmoid = torch.nn.Sigmoid()
        self.softmax = torch.nn.Softmax()

    def forward(self, T):
        
        arity = self.arity_alphabet[T.label]
        if(len(T.children) != arity):
            raise ValueError('Expected %s children for a node with label %s but got %d children.' % (
                arity, T.label, len(T.children)))
        if(arity == 0):
            return self.sigmoid(self.constants[T.label])
        child_encodings = []
        for child in T.children:
            child_encodings.append(self.forward(child))
        child_encodings = torch.cat(child_encodings)
        encoding = self.layers[T.label](child_encodings)
        #encoding = self.embed(Variable(torch.LongTensor([self.arity_alphabet[T.label]])))
        #encoding = self.sigmoid(encoding)
        encoding = self.sigmoid(encoding)
        return encoding
    
k = 1
rna_arity_alphabet = {'dangle_end' : 1, 'dangle' : 2, 'split' : 2,'branch' : 2,'pair' : 3,'hairpin' : 2,
                  'hairpin_end' : 1,'c' : 0, 'g' : 0, 'a' : 0, 'u' : 0}
model = RecursiveNet(k, rna_arity_alphabet)

loss_function = torch.nn.L1Loss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
#torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [68]:
# The oracle also provides us with the arity alphabet for the RNA trees
from recursive_oracle import rna_arity_alphabet

print(rna_arity_alphabet)

{'dangle': 2, 'dangle_end': 1, 'split': 2, 'pair': 3, 'branch': 2, 'hairpin': 2, 'hairpin_end': 1, 'c': 0, 'g': 0, 'a': 0, 'u': 0}


In [None]:
loss_threshold = 1
learning_curve = []

minibatch_size = 50

while(not learning_curve or learning_curve[-1] > loss_threshold):
    optimizer.zero_grad()
    
    loss_batch = 0
    for i in range(minibatch_size):
        (x, y) = generate_rna_tree()
        #print(x)
        y_predicted = model(x)
        loss_object = loss_function(y, y_predicted)
        loss_batch += loss_object.item()
        loss_object.backward()

    learning_curve.append(loss_batch / minibatch_size)
    if(len(learning_curve) % 20 == 0):
        print('loss after {} batches: {}'.format(len(learning_curve), learning_curve[-1]))
        optimizer.step()

loss after 20 batches: 3.0897516340017317
loss after 40 batches: 3.20043983399868
loss after 60 batches: 3.1519632995128632
loss after 80 batches: 2.974375388622284
loss after 100 batches: 3.3466975355148314
loss after 120 batches: 2.699699021577835
loss after 140 batches: 2.7163452661037444


In [None]:
import matplotlib.pyplot as plt

plt.plot(list(range(len(learning_curve))), learning_curve)
plt.xlabel('gradient step')
plt.ylabel('loss')
plt.show()

In [47]:
for symbol, arity in rna_arity_alphabet.items():
    if(arity > 0):
        w = model.layers[symbol].weight.data[0]
        b = model.layers[symbol].bias.data[0]
        print('weights for symbol \'%s\' = %s ; bias = %s' % (symbol, str(w), str(b)))
    else:
        print('encoding for symbol \'%s\' = %s' % (symbol, str(model.constants[symbol])))

weights for symbol 'dangle' = tensor([0.7311, 0.3511]) ; bias = tensor(0.0624)
weights for symbol 'dangle_end' = tensor([-0.8864]) ; bias = tensor(-0.9605)
weights for symbol 'split' = tensor([0.0690, 0.2204]) ; bias = tensor(0.0242)
weights for symbol 'pair' = tensor([-0.0332,  0.3772, -0.1551]) ; bias = tensor(-0.6845)
weights for symbol 'branch' = tensor([ 0.9058, -0.1659]) ; bias = tensor(0.4337)
weights for symbol 'hairpin' = tensor([0.2573, 1.1539]) ; bias = tensor(0.7086)
weights for symbol 'hairpin_end' = tensor([0.1667]) ; bias = tensor(1.0580)
encoding for symbol 'c' = Parameter containing:
tensor([2.2922], requires_grad=True)
encoding for symbol 'g' = Parameter containing:
tensor([-0.2355], requires_grad=True)
encoding for symbol 'a' = Parameter containing:
tensor([1.5517], requires_grad=True)
encoding for symbol 'u' = Parameter containing:
tensor([0.0582], requires_grad=True)
