In [6]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
import networkx as nx
import cPickle
import sklearn.cluster

# Create Toy Language

In [7]:
alphabet = ["(", ")"]
def random_sentence():
    sentence = "("
    depth = 1
    for j in range(500):
        if depth == 0:
            if np.random.binomial(1, 0.8) == 1:
                sentence += "("
                depth += 1
            else:
                sentence += "E"
                break
        elif depth > 0:
            if np.random.binomial(1, 0.4) == 1:
                sentence += "("
                depth +=1
            else:
                sentence += ")"
                depth -= 1
    return sentence

random_sentence()

def trainingExample():
    text = random_sentence()
    return Variable(text2input(text[:-1])), Variable(text2target(text))


def text2input(text):
    dict_trad = {"(" : 0, ")":1, "E":2}
    train_vec = torch.zeros(len(text) , 1, 3)
    for i, char in enumerate(text):
        train_vec[i][0][dict_trad[char]] = 1

        
    return train_vec

def text2target(text):
    dict_trad = {"(" : 0, ")":1, "E":2}
    target_list = [dict_trad[char] for char in text[1:]]
    return torch.LongTensor(target_list)

a = trainingExample()
print(len(a[0]), len(a[1]))

(12, 12)


In [8]:
def is_correct(text):
    if text[-1] != "E":
        print("No EOS")
        return False
    else:
        max_depth = 0
        depth = 0
        for char in text[:-1]:
            if char == "(":
                depth +=1
            elif char == ")":
                depth -= 1
            else:
                print("Unknow character")
                return False
            if depth < 0:
                print("Negative depth")
                return False
            if depth > max_depth:
                max_depth = depth
    
        if depth == 0:
            print("Maximum depth : {}".format(max_depth))
            return True
        else:
            print("Maximum depth : {}".format(max_depth))
            print("Missing {} closing brakets".format(depth))
            return False

# Importing model

In [9]:
#from model import RNN
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        #self.encoder = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(input_size, hidden_size, n_layers, dropout=0.1)
        self.lin = nn.Linear(hidden_size, output_size)
        #self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)
        #self.o2o = nn.Linear(hidden_size + output_size, output_size)
        #self.softmax = nn.LogSoftmax()

    def forward(self, input, hidden):
        #hidden = self.i2h(input_combined)
        #output = self.i2o(input_combined)
        #output_combined = torch.cat((hidden, output), 1)
        output, hidden = self.lstm(input.view(len(input), 1, -1), hidden)
        output = self.lin(output.view(1, -1))
        #output = self.softmax(output)
        return output, hidden


    def initHidden(self):
        # Before we've done anything, we dont have any hidden state.
        # Refer to the Pytorch documentation to see exactly
        # why they have this dimensionality.
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (Variable(torch.zeros(self.n_layers, 1, self.hidden_size)),
                Variable(torch.zeros(self.n_layers, 1, self.hidden_size)))

# Retrieving the network

In [10]:
tab = (3, 20, 3)

saved_to = open("rnn_toy_language_torch", "rb")

rnn = RNN(tab[0], tab[1], tab[2])

rnn.load_state_dict(torch.load(saved_to))

rnn.eval()

RNN (
  (lstm): LSTM(3, 20, dropout=0.1)
  (lin): Linear (20 -> 3)
)

# Retrieving saved states

In [11]:
read_hidden_state_file = open("/pickle2/hidden_state_file_toy_language", "r")


brute_list = read_hidden_state_file.readlines()

picklised_tuple_list = list()
current_string = ""
for line in brute_list:
    if line != "NEW_TUPLE\n":
        current_string += line
    else:
        picklised_tuple_list.append(current_string)
        current_string = ""

input_size = 3
hidden_size = 20

tuple_list = [cPickle.loads(a) for a in picklised_tuple_list[1:]]
print(len(tuple_list))
state_list = [(Variable(torch.from_numpy(tuple[:input_size]).view(1, 1, -1)), (Variable(torch.from_numpy(tuple[input_size: input_size + hidden_size]).view(1, 1, -1)), Variable(torch.from_numpy(tuple[input_size + hidden_size:]).view(1, 1, -1)))) for tuple in tuple_list]

output_list = [state[:input_size] for state in tuple_list]

output_matrix = np.matrix(output_list)

read_hidden_state_file.close()

brute_list = 0, 0

picklised_tuple_list = 0

192359


# Clustering the states

In [30]:
import time

In [100]:
def mean_distance_between_states(tuple_list, input_size, distance, n_iters=1000):
    l_all = list()
    l_output = list()
    l_hidden = list()
    nb_hidden_inf_min = 0
    nb_output_inf_min = 0
    n_both = 0
    for i in range(n_iters):
        index1 = int(np.random.random() * (len(tuple_list) - 1))
        index2 = int(np.random.random() * (len(tuple_list) - 1))
        l_all.append(distance(tuple_list[index1], tuple_list[index2]))
        d_output = distance(tuple_list[index1][:input_size], tuple_list[index2][:input_size])
        if d_output > 0:
            l_output.append(d_output)
        d_hidden = distance(tuple_list[index1][input_size:], tuple_list[index2][input_size:])
        if d_hidden > 0:
            l_hidden.append(d_hidden)
        if d_hidden < 0.5:
            nb_hidden_inf_min += 1
        if d_output < 1E-3:
            nb_output_inf_min += 1
        if d_hidden < 0.5 and d_output < 1E-3:
            n_both += 1
        if d_hidden == 0:
            print("hidden")
        if d_output == 0:
            print("output")
    print("means")
    print(np.mean(l_all), np.mean(l_output), np.mean(l_hidden))
    print("std")
    print(np.std(l_all), np.std(l_output), np.std(l_hidden))
    print("max")
    print(np.max(l_all), np.max(l_output), np.max(l_hidden))
    print("min")
    print(np.min(l_all), np.min(l_output), np.min(l_hidden))
    print(nb_output_inf_min , nb_hidden_inf_min, n_both)

In [101]:
%%time
mean_distance_between_states(tuple_list, 3, euclidian_distance, 100000)

means
(8.0055669665052118, 0.21050451613499907, 7.9958041511291205)
std
(5.8772480498122039, 0.28331196535164876, 5.8799392216231672)
max
(48.198883558521494, 0.98629855753994533, 48.190545378422833)
min
(0.0059229443408932462, 1.430872736563283e-05, 0.0059227772021624974)
(441, 803, 47)
CPU times: user 7.06 s, sys: 80 ms, total: 7.14 s
Wall time: 7.14 s


In [36]:
 def hellinger(p, q):
    return np.sqrt(np.sum((np.sqrt(p) - np.sqrt(q))**2)) / np.sqrt(2)

def KL_divergence(p, q):
    return np.sum(p * np.log(p / q.astype(float)))

def JS_divergence(p, q):
    """ turn KL divergence into a metric (symetric and triangular inequality)"""
    m = (p + q) / float(2)
    return 0.5 * KL_divergence(p, m) + 0.5 * KL_divergence(q, m)

In [15]:
def euclidian_distance(state1, state2):
    """on vectorized states"""
    return (torch.from_numpy(state1).float() - torch.from_numpy(state2).float()).norm(2)

In [16]:
def log_euclidian_distance(state1, state2):
    epsilon_array = np.array([1E-10 for _ in range(len(state1))])
    state1 = np.log(state1 + epsilon_array)
    state2 = np.log(state2 + epsilon_array)
    return euclidian_distance(state1, state2)

In [9]:
import sklearn.cluster

In [10]:
import multiprocessing as mp

In [11]:
def kmeans_clustering(tuple_list, n_clusters_output, size_clusters_hidden):
    """the network "states" to be clustered are composed of a probability distribution (for the next character)
    and of a memory state (hidden and cell state). We first cluster states based on their probability distrib
    and then we subcluster each cluster obtained according to their memory state"""
    
    state_list = [(Variable(torch.from_numpy(tuple[:input_size]).view(1, 1, -1)), (Variable(torch.from_numpy(tuple[input_size: input_size + hidden_size]).view(1, 1, -1)), Variable(torch.from_numpy(tuple[input_size + hidden_size:]).view(1, 1, -1)))) for tuple in tuple_list]
    output_list = [state[:input_size] for state in tuple_list]
    output_matrix = np.matrix(output_list)
    # Cluster according to output
    kmeans = sklearn.cluster.KMeans(n_clusters = n_clusters_output, n_jobs=-1).fit(output_matrix) 
    list_indices_first_clusters = [[i for i in range(len(tuple_list)) if kmeans.labels_[i]==cluster] for cluster in np.unique(kmeans.labels_)]
    hidden_states_per_output_cluster_list = [[tuple_list[i][input_size:] for i in cluster] for cluster in list_indices_first_clusters]
    #list_indices_final_clusters = list() #save in which cluster are each original state (mostly to check clustering correctness)
    #list_centers_final_clusters = list() #save the cluster centers (i.e the states of the finite automaton)
    list_centers_first_clusters = kmeans.cluster_centers_
    list_subclusters_centers = list()
    # Subcluster according to memory 
    for i, output_cluster in enumerate(hidden_states_per_output_cluster_list):
        matrix = np.matrix(output_cluster)
        kmeans_i = sklearn.cluster.KMeans(n_clusters = max(len(output_cluster) / size_clusters_hidden, 1), n_jobs=-1).fit(matrix)
        list_subclusters_centers.append(kmeans_i.cluster_centers_)
        #list_indices_hidden_cluster_in_output_cluster = [[j for j in range(len(output_cluster)) if kmeans_i.labels_[j]==cluster] for cluster in np.unique(kmeans_i.labels_)]
        #list_indices_final_clusters.extend([[list_indices_first_clusters[i][j] for j in hidden_cluster] for hidden_cluster in list_indices_hidden_cluster_in_output_cluster])
        #output_cluster_centers = [np.mean([tuple_list[list_indices_first_clusters[i][k]][:input_size] for k in list_indices_hidden_cluster_in_output_cluster[l]], axis=0) for l in range(len(list_indices_hidden_cluster_in_output_cluster))]
        #list_centers_final_clusters.extend([(output_cluster_centers[k], kmeans_i.cluster_centers_[k]) for k in range(len(output_cluster_centers))])
    return  list_centers_first_clusters, list_subclusters_centers

In [28]:
def get_mini_batch_from_database(file_to_read, batch_size):
    """
    Get a batch from the file_to_read (in the order of the file)
    Returns : a sklearn_compatible matrix of the batch
    """
    batch = list()
    for _ in range(batch_size):
        picklised_example = ""
        for i in range(1000):
            picklised_line = file_to_read.readline()
            if picklised_line == "":
                print("END OF FILE")
                raise EOFError
                break
            if picklised_line != "NEW_TUPLE\n":
                picklised_example += picklised_line
            else:
                batch.append(picklised_example)
                break
    batch = np.matrix([cPickle.loads(a) for a in batch])
    return batch
    
    

In [29]:
def estimator_nb_clusters_mini_batch_kmeans(kmeans_estimator, file_to_read, n_clusters_output, batch_size, n_batchs, input_size):
    """
    estimate the density of the output clusters
    """
    clusters_density = np.zeros(n_clusters_output)
    kmeans = kmeans_estimator
    for i in range(n_batchs):
        batch_matrix = get_mini_batch_from_database(file_to_read, batch_size)
        output_batch_matrix = batch_matrix[:, : 3]
        kmeans.partial_fit(output_batch_matrix)
        labels = kmeans.labels_
        for k in range(n_clusters_output):
            clusters_density[k] += sum(labels == k)
    return clusters_density / float(n_batchs * batch_size)
    

In [30]:
file_to_read = open("/pickle2/hidden_state_file_toy_language", "r")

In [31]:
def mini_batch_kmeans_clustering(filename, n_clusters_output, size_clusters_hidden, batch_size, n_batchs, input_size, hidden_size):
    """
    Cluster the saved states of the network streamingly through the MiniBatchKmeans algorithm.
    The file_to_read should be randomized first.
    """
    # we have to specify the number of subclusters in each output cluster, but we cannot store 
    # the members of each cluster, and we cannot store their indices (states are only available as streams)
    # so we estimate the size of each output cluster to compute the size if the subclusters
    kmeans_output = sklearn.cluster.MiniBatchKMeans(n_clusters_output, batch_size=batch_size)
    print("estimating subcluster size")
    file_to_read = open(filename, "r")
    cluster_density = estimator_nb_clusters_mini_batch_kmeans(kmeans_output, file_to_read, n_clusters_output, batch_size, 50, input_size) # TODO : n_batchs ?                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
    print("estimation over, initiating clustering")
    file_to_read = open(filename, "r") # reset readline
    subclusters_size = (cluster_density * batch_size * n_batchs / size_clusters_hidden).astype(int)
    print(subclusters_size)
    kmeans_subclusters = [sklearn.cluster.MiniBatchKMeans(subclusters_size[i], batch_size=batch_size) for i in range(n_clusters_output)] # TODO : change batch size with subcluster size
    subclusters_batches = [list() for _ in range(n_clusters_output)] # store the hidden_vector for each cluster while waiting for subcluster
    for i in range(n_batchs):
        try:
            batch_matrix = get_mini_batch_from_database(file_to_read, batch_size)
        except EOFError:
            break
        output_batch_matrix = batch_matrix[:, : input_size]
        hidden_batch_matrix = batch_matrix[:, input_size:]
        kmeans_output.partial_fit(output_batch_matrix)
        labels = kmeans_output.labels_
        for cluster in range(n_clusters_output):
            for j in range(batch_size):
                if labels[j] == cluster:
                    subclusters_batches[cluster].append(np.array(hidden_batch_matrix[j])[0])
        for k, subcluster_batch in enumerate(subclusters_batches):
            if len(subcluster_batch) >= batch_size:
                subcluster_matrix = np.array(subcluster_batch)
                subclusters_batches[k] = list()
                kmeans_subclusters[k].partial_fit(subcluster_matrix)
        
    return kmeans_output.cluster_centers_, [kmeans.cluster_centers_ for kmeans in kmeans_subclusters]
    

In [34]:
%%time
filename = "/pickle2/hidden_state_file_toy_language"
list_centers_first_clusters, list_subclusters_centers = mini_batch_kmeans_clustering(filename, 10, 10000, 300, 3500, 3, 40) 

estimating subcluster size
estimation over, initiating clustering
[27  9  3  2  1  9 26  9 14  1]
END OF FILE
CPU times: user 17.7 s, sys: 36 ms, total: 17.7 s
Wall time: 17.7 s


In [231]:
%%time
list_centers_first_clusters, list_subclusters_centers = kmeans_clustering(tuple_list, 10, 400)

CPU times: user 20.8 s, sys: 4.74 s, total: 25.5 s
Wall time: 54.4 s


In [35]:
sum([len(a) for a in list_subclusters_centers])

101

In [36]:
len(list_subclusters_centers)

10

In [37]:
np.mean([len(a) for a in list_subclusters_centers]) 

10.1

In [266]:
print(list_centers_first_clusters)

[[  4.12901253e-01   5.85594296e-01   1.50380284e-03]
 [  8.10950279e-01   1.37669109e-02   1.75282404e-01]
 [  3.44245911e-01   6.54674530e-01   1.07952754e-03]
 [  5.28011739e-01   4.44499761e-01   2.74883285e-02]
 [  6.25752747e-01   3.01131785e-01   7.31153563e-02]
 [  3.84899676e-01   6.13780260e-01   1.32000411e-03]
 [  2.93613851e-01   7.05602169e-01   7.83839962e-04]
 [  4.53209281e-01   5.43887377e-01   2.90350337e-03]
 [  7.74459839e-01   6.23917580e-02   1.63148463e-01]
 [  7.15725541e-01   1.58109263e-01   1.26165524e-01]]


# Make the automaton

In [233]:
def compute_log_loss_rnn(rnn, n_examples):
    """
    return the mean of the log_loss of the RNN when trying to predict the 
    next character of random correct sentences
    """
    criterion = nn.NLLLoss()
    total_loss = 0
    nb_loss = 0
    rnn.eval()
    for i in range(n_examples):
        text = random_sentence()
        target = Variable(text2target(text))
        input = Variable(text2input(text[:-1]))
        hidden = rnn.initHidden()
        for i in range(len(input)):
            output, hidden = rnn(input[i], hidden)
            output = F.log_softmax(output)
            total_loss += criterion(output, target[i])
            nb_loss += 1
    return total_loss / float(nb_loss)

In [234]:
def compute_log_loss_automata(automata, n_examples):
    """
    return the mean of the log_loss of the automata when trying to predict the 
    next character of random correct sentences
    """
    criterion = nn.NLLLoss()
    total_loss = 0
    nb_loss = 0
    #print(list(automata.neighbors(-1)))
    for i in range(n_examples):
        text = random_sentence()
        target = Variable(text2target(text)) # RNN don't predict the first "(" so we don't make the automata predict it so we can compare results
        current_node = move_state_automata(automata, -1, "(")        
        for k, char in enumerate(text[1:]):
            predicted_distrib = Variable(torch.from_numpy(return_proba_distrib(automata, current_node))).log() #log proba to use NLLLoss
            total_loss += criterion( predicted_distrib, target[k])
            nb_loss += 1
            current_node = move_state_automata(automata, current_node, char)
    return total_loss / float(nb_loss)
        

In [235]:
def move_state_automata(automata, node, char):
    """
    auxiliary function to move the automaton from one state to the next after one character (return the next node)
    """
    neighbors = list(automata.neighbors(node))
    for neighbor in neighbors:
        if char in [d["label"] for d in automata.get_edge_data(node, neighbor).values()]:
            return neighbor

In [236]:
def return_proba_distrib(automata, node):
    """
    return the proba distrib for the next character for a given automaton state
    """
    distrib_proba = list()
    neighbors = list(automata.neighbors(node))
    for char in char_vect:
        for neighbor in neighbors:
            for key in automata.get_edge_data(node, neighbor).keys():
                if automata.get_edge_data(node, neighbor)[key]["label"] == char:
                    distrib_proba.append(automata.get_edge_data(node, neighbor)[key]["weight"])
    return np.array(distrib_proba)

In [237]:
def generate_from_automata(automata, start_node=-1, temperature=1, print_nb_visited_nodes=False):
    """
    Generate words according to the automaton transitions probability
    """
    visited_nodes = list()
    result=""
    current_node = start_node
    for i in range(5000):
        if current_node not in visited_nodes:
            visited_nodes.append(current_node)
        neighbors = list(automata.neighbors(current_node))
        attrib_list = [automata.get_edge_data(current_node, neighbor)for neighbor in neighbors]
        label_list = np.concatenate([np.array([attrib[key]["label"] for key in attrib.keys()])for attrib in attrib_list])
        proba_list = np.concatenate([np.array([attrib[key]["weight"] for key in attrib.keys()]) for attrib in attrib_list])
        proba_list /= sum(proba_list)
        next_index = np.where(np.random.multinomial(1, proba_list))[0][0]
        #next_index = np.argmax(proba_list)
        result += label_list[next_index]
        if label_list[next_index] == "E":
            if print_nb_visited_nodes:
                print("nb visited nodes : {}".format(len(visited_nodes)))
            return result
        else:
            s = 0
            for k in range(len(neighbors)):
                s += len(attrib_list[k].keys())
                if next_index < s:
                    current_node = neighbors[k]
                    break
    if print_nb_visited_nodes:
        print("nb visited nodes : {}".format(len(visited_nodes)))
    return result

In [267]:
automata2 = nx.MultiDiGraph()

In [268]:
for i in range(len(list_subclusters_centers)):
    for j in range(len(list_subclusters_centers[i])):
        automata2.add_node((i, j))

In [269]:
automata2.add_node(-1)

In [270]:
char_vect = ['(', ')', 'E']

In [271]:
def closest_to(vect, list_vect, distance=euclidian_distance):
    return np.argmin([distance(vect, vect2) for vect2 in list_vect])

In [272]:
def distances_closest(vect, list_vect, n_closest, distance=euclidian_distance):
    distance_list = [distance(vect, vect2) for vect2 in list_vect]
    indices_closest = np.argsort(distance_list)[:n_closest]
    distances_closest = np.array(distance_list)[indices_closest]
    return indices_closest, distances_closest
    

In [273]:
len(list(automata2.nodes()))

432

In [274]:
import sys
sys.setrecursionlimit(10000)

In [275]:
def add_transition_closest(automata, rnn, start_node, start_output, start_hidden, list_subclusters_centers, list_first_clusters_centers,n_rec=0):
    """
    Create the transitions between the finite automaton states according to the rnn.
    (the rnn tells from wich memory state to which memory state we should go after a character).
    """
    n_rec +=1
    if n_rec > 7000:
        return
    result_list = list()

    for i, proba in enumerate(start_output):
        char_as_vect = Variable(text2input(char_vect[i]))
        hidden_state = (Variable(torch.from_numpy(start_hidden[:hidden_size]).view(1, 1, -1)).float(), Variable(torch.from_numpy(start_hidden[hidden_size:]).view(1, 1, -1).float()))
        output, new_hidden = rnn(char_as_vect, hidden_state)
        output = F.softmax(output)
        output = output.data.numpy()
        new_hidden = np.concatenate((new_hidden[0].data.numpy(), new_hidden[1].data.numpy()))
        index_closest_output = closest_to(output, list_first_clusters_centers, euclidian_distance)
        index_closest_hidden = closest_to(new_hidden, list_subclusters_centers[index_closest_output], euclidian_distance)
        cluster_output = list_first_clusters_centers[index_closest_output]
        cluster_hidden = list_subclusters_centers[index_closest_output][index_closest_hidden]
        new_node = (index_closest_output, index_closest_hidden)
        add_node = True
        data = automata.get_edge_data(start_node, new_node)
        if data:
            if char_vect[i] in [d["label"] for d in data.values()]:
                add_node = False
        
        if add_node:
            automata.add_edge(start_node, new_node, label=char_vect[i], weight=float(proba))
            add_transition_closest(automata, rnn, new_node, cluster_output, cluster_hidden, list_subclusters_centers, list_first_clusters_centers, n_rec)


            
        
        

In [276]:
add_transition_closest(automata2, rnn, -1, np.array((1., 0, 0)), np.zeros(40).astype(float), list_subclusters_centers, list_centers_first_clusters)

In [277]:
print(nx.info(automata2))

Name: 
Type: MultiDiGraph
Number of nodes: 432
Number of edges: 1065
Average in degree:   2.4653
Average out degree:   2.4653


In [278]:
l = [len(list(automata2.neighbors((node)))) for node in automata2.nodes()]

In [279]:
l = np.array(l)

In [280]:
sum(l==0)

77

In [293]:
res = generate_from_automata(automata2, print_nb_visited_nodes=True)
print("longueur : {}".format(len(res)))
print(res)
print(is_correct(res))

nb visited nodes : 3
longueur : 3
()E
Maximum depth : 1
True


In [296]:
%%time
compute_log_loss_automata(automata2, 2000)

CPU times: user 8.75 s, sys: 188 ms, total: 8.94 s
Wall time: 8.94 s


Variable containing:
 0.6570
[torch.DoubleTensor of size 1]

In [297]:
%%time
compute_log_loss_rnn(rnn, 400)

CPU times: user 5min 1s, sys: 420 ms, total: 5min 2s
Wall time: 14.2 s


Variable containing:
 0.6608
[torch.FloatTensor of size 1]

In [297]:
def export_parallel_edges_gephi_compatible(automata, filename):
    new_automata = nx.DiGraph()
    for node in automata.nodes():
        new_automata.add_node(node)
    for node in automata.nodes():
        neighbors = automata.neighbors(node)
        for neighbor in neighbors:
            data = automata.get_edge_data(node, neighbor)
            weight = sum([v["weight"] for v in data.values()])
            label = "".join([v["label"] + " : " + str(v["weight"]) for v in data.values()])
            new_automata.add_edge(node, neighbor, label=label, weight=weight)
    print(nx.info(new_automata))
    nx.write_graphml(new_automata, filename)

In [298]:
export_parallel_edges_gephi_compatible(automata2, "automata_0.4.graphml")

Name: 
Type: DiGraph
Number of nodes: 19
Number of edges: 18
Average in degree:   0.9474
Average out degree:   0.9474
