In [169]:
# https://www.oreilly.com/catalog/errataunconfirmed.csp?isbn=9780596529321
# https://resources.oreilly.com/examples/9780596529321/tree/master

from math import tanh
import sqlite3

def dtanh(y):
    return 1.0 - y*y

class Searchnet:
    def __init__(self, dbname):
        self.conn = sqlite3.connect(dbname) 
    
    def __del__(self):
        self.conn.close()
    
    def make_tables(self):
        c = self.conn.cursor()
        c.execute('create table if not exists hiddennode(create_key)')
        c.execute('create table if not exists wordhidden(fromid, toid, strength)')
        c.execute('create table if not exists hiddenurl(fromid, toid, strength)')
        self.conn.commit()
        
    def get_strength(self, fromid, toid, layer):
        table = 'wordhidden' if layer == 0 else 'hiddenurl'
        c = self.conn.cursor()
        res = c.execute(f'select strength from {table} where fromid = ? and toid = ?', (fromid, toid)).fetchone()
        if res == None:
            return -0.2 if layer == 0 else 0
        return res[0]
    
    def set_strength(self, fromid, toid, layer, strength):
        table = 'wordhidden' if layer == 0 else 'hiddenurl'
        c = self.conn.cursor()
        res = c.execute(f'select rowid from {table} where fromid = ? and toid = ?', (fromid, toid)).fetchone()
        if res == None:
            c.execute(f'insert into {table} (fromid, toid, strength) values (?,?,?)', (fromid, toid, strength))
        else:
            rowid = res[0]
            c.execute(f'update {table} set strength = ? where rowid = ?', (strength, rowid))
        self.conn.commit()
    
    def generate_hidden_node(self, wordids, urls):
        if len(wordids) > 3: return None
        
        # Check if we already created a node for this set of words.
        createkey = '_'.join(sorted([str(wi) for wi in wordids]))
        c = self.conn.cursor()
        res = c.execute('select rowid from hiddennode where create_key = ?', (createkey,)).fetchone()
        
        # If not, create it.
        if res == None:
            cur = c.execute('insert into hiddennode (create_key) values (?)', (createkey,))
            hiddenid = cur.lastrowid
            
            # Put in some default weights.
            for wordid in wordids:
                self.set_strength(fromid=wordid, 
                                  toid=hiddenid, 
                                  layer=0, 
                                  strength=1.0/len(wordids))
            
            for urlid in urls:
                self.set_strength(fromid=hiddenid, 
                                  toid=urlid, 
                                  layer=1, 
                                  strength=0.1)

            self.conn.commit()

    def getall_hiddenids(self, wordids, urlids):
        """
        Feedforward. Query all the nodes and connections in the database, and build in memory the portion 
        of the network that is relevant to a specific query."""
        # Layer 1.
        l1 = {}
        c = self.conn.cursor()
        
        for wordid in wordids:
            cur = c.execute(
            'select toid from wordhidden where fromid = ?', (wordid,))
            for row in cur:
                l1[row[0]] = 1
            
        for urlid in urlids:
            cur = c.execute(
            'select fromid from hiddenurl where toid = ?', (urlid,))
            for row in cur:
                l1[row[0]] = 1

        return list(l1.keys())
            
            
    def setup_network(self, wordids, urlids):
        # Value lists.
        self.wordids = wordids
        self.hiddenids = self.getall_hiddenids(wordids, urlids)
        self.urlids = urlids
        
        # Node outputs. a: algorithm?
        self.ai = [1.0] * len(self.wordids) # i: Input layer.
        self.ah = [1.0] * len(self.hiddenids) # h: Hidden layer.
        self.ao = [1.0] * len(self.urlids) # o: Output layer.
        
        # Create weight matrixs. w: weight, i: input, o: output.
        self.wi = [[self.get_strength(wordid, hiddenid, layer=0)
                   for hiddenid in self.hiddenids]
                   for wordid in self.wordids]
        self.wo = [[self.get_strength(hiddenid, urlid, layer=1)
                   for urlid in self.urlids]
                   for hiddenid in self.hiddenids]
    
    def feedforward(self):
        # The only inputs are the query word.
        for i in range(len(self.wordids)):
            self.ai[i] = 1.0
        
        # Hidden activations.
        for j in range(len(self.hiddenids)):
            sum = 0.0
            for i in range(len(self.wordids)):
                sum = sum + self.ai[i] * self.wi[i][j]
            self.ah[j] = tanh(sum)
        
        # Output activations.
        for k in range(len(self.urlids)):
            sum = 0.0
            for j in range(len(self.hiddenids)):
                sum = sum + self.ah[j] * self.wo[j][k]
            self.ao[k] = tanh(sum)
        
        return self.ao[:]
    
    def get_result(self, wordids, urlids):
        self.setup_network(wordids, urlids)
        return self.feedforward()
    
    def backpropagate(self, targets, N = 0.5):
        # Calculate errors for output.
        output_deltas = [0.0] * len(self.urlids)
        for k in range(len(self.urlids)):
            error = targets[k] - self.ao[k]
            output_deltas[k] = dtanh(self.ao[k]) * error
        
        # Calculate errors for hidden layer.
        hidden_deltas = [0.0] * len(self.hiddenids)
        for j in range(len(self.hiddenids)):
            error = 0.0
            for k in range(len(self.urlids)):
                error = error + output_deltas[k] * self.wo[j][k]
            hidden_deltas[j] = dtanh(self.ah[j]) * error
        
        # Update output weights.
        for j in range(len(self.hiddenids)):
            for k in range(len(self.urlids)):
                change = output_deltas[k] * self.ah[j]
                self.wo[j][k] = self.wo[j][k] + N * change
        
        # Update input weights.
        for i in range(len(self.wordids)):
            for j in range(len(self.hiddenids)):
                change = hidden_deltas[j] * self.ai[i]
                self.wi[i][j] = self.wi[i][j] + N * change
    
    def train_query(self, wordids, urlids, selectedurl):
        # Generate a hidden node if necessary.
        self.generate_hidden_node(wordids, urlids)
        
        self.setup_network(wordids, urlids)
        self.feedforward()
        targets = [0.0] * len(urlids)
        targets[urlids.index(selectedurl)] = 1.0
        error = self.backpropagate(targets)
        self.update_database()
    
    def update_database(self):
        # Set them the database values.
        for i in range(len(self.wordids)):
            for j in range(len(self.hiddenids)):
                self.set_strength(self.wordids[i], self.hiddenids[j], 0, self.wi[i][j])
        
        for j in range(len(self.hiddenids)):
            for k in range(len(self.urlids)):
                self.set_strength(self.hiddenids[j], self.urlids[k], 1, self.wo[j][k])
        self.conn.commit()

In [170]:
net = Searchnet('nn.db')
net.make_tables()

In [171]:
# Words are prefixed with w_.
w_world, w_river, w_bank = 101, 102, 103

# Urls are prefixed with u_.
u_worldbank, u_river, u_earth = 201, 202, 203

wordids = [w_world, w_bank]
urls = [u_worldbank, u_river, u_earth]
net.generate_hidden_node(wordids, urls)
[c for c in net.conn.execute('select * from wordhidden')]

[(101, 1, 2254.848039955704), (103, 1, 2254.848039955704)]

In [172]:
[c for c in net.conn.execute('select * from hiddenurl')]

[(1, 201, 1.6469085770077627),
 (1, 202, 0.0009942320893848347),
 (1, 203, 0.0009942320893848347)]

In [173]:
net.get_result(wordids, urls)

[0.9284321852314656, 0.0009942317617863366, 0.0009942317617863366]

In [176]:
net.train_query(wordids, urls, u_worldbank)
net.get_result(wordids, urls)

[0.9304126862240898, 0.00012427922551992506, 0.00012427922551992506]

In [179]:
for i in range(30):
    net.train_query([w_world, w_bank], urls, u_worldbank)
    net.train_query([w_river, w_bank], urls, u_river)
    net.train_query([w_world], urls, u_earth)

In [180]:
net.get_result([w_world, w_bank], urls)

[0.8387430775215277, 0.0122721330323893, 0.013456929658331038]

In [181]:
net.get_result([w_river, w_bank], urls)

[-0.029686235126554737, 0.908858940781442, 0.01656004208434448]

In [182]:
net.get_result([w_bank], urls)

[0.8601970829129808, 0.2492609006509291, -0.2620614237632982]