In [75]:
import numpy as np
import string
from data import get_train_test_split
from models.utils import damerau_levenshtein_distance

In [5]:
train,test = get_train_test_split("10_ports.csv")

In [15]:
all_characters = string.printable

'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'

In [95]:
class Vocabulary:
    def __init__(self):
        self.char_to_int = {}
        self.int_to_char = {}
        self.build_vocab()
    
    def build_vocab(self):
        all_characters = string.printable
        self.char_to_int[""] = 0
        self.int_to_char[0] = ""
        idx = 1
        for c in all_characters:
            self.char_to_int[c] = idx
            self.int_to_char[idx] = c
            idx += 1
        
    def __getitem__(self, idx):
        return self.int_to_char[idx]
    
    def encode(self, text):
        return [self.char_to_int[c] for c in text]
    
    def decode(self, x):
        return "".join([
            self.int_to_char[i] for i in x
        ])

In [92]:
class Kmeans:
    def __init__(self, k, centroids, vocab):
        """
        Parameters:
            k (int): number of centroids
            centroids (list<str>): list of initial centroids (text)
            vocab (Vocabulary):
        """
        assert k == len(centroids)
        self.k = k
        self.centroids = centroids
        self.vocab = vocab
    
    def distance(self, s1, s2):
        """Compute Damereau distance
        """
        if type(s1) != str: s1 = self.vocab.decode(s1)
        if type(s2) != str: s2 = self.vocab.decode(s2)
        
        return damerau_levenshtein_distance(s1,s2)
            
    
    def __call__(self, text):
        """Finds the closes centroid
        
        Parameters:
            text (str):
        
        Returns:
            str: closes centroid
        """
        closest_centroid = self.centroids[0]
        closest_dist = self.distance(text, closest_centroid)
        
        for i in range(1, self.k):
            if self.distance(text, self.centroids[i]) < closest_dist:
                closest_dist = self.distance(text, self.centroids[i])
                closest_centroid = self.centroids[i]
        
        return closest_centroid
    
    def init_clusters(self):
        clusters = {}
        for i in range(self.k): 
            clusters[i] = [self.centroids[i]]
        return clusters
    
    def fit(self, dataset, iterations):
        n = len(dataset)
        max_length = max([
            len(text) for text in dataset["destination"].tolist()
        ])
        
        for i in range(iterations):
            clusters = self.init_clusters()
            
            # create new clusters
            for i in range(n):
                x = dataset["destination"][i]
                
                centroid = self(x)
                for j in range(self.k):
                    if centroid == clusters[j][0]: clusters[j].append(x)
            
            # compute new centroids
            for j in range(self.k):
                cluster = clusters[j]
                centroid = []
                for i in range(len(cluster)):
                    idx = 0
                    while True:
                        if idx > max([len(text) for text in cluster]): break
                        
                        
                        centroid[idx] = np.sum([self.vocab.encode(text[idx]) 
                                                if idx < len(text) else 0 
                                                for text in cluster], axis=1)
                        
                        for i in range(len(centroid)):
                            idx = np.argmax(centroid[i])
                            centroid[i] = np.zeros(len(self.vocab))
                            centroid[i][idx] = 1
                            centroid = self.vocab.encode(centroid)

In [93]:
vocab = Vocabulary()
model = Kmeans(k=10, centroids=train.loc[np.random.randint(1,len(train), 10)]["code"].tolist(), 
               vocab=vocab)

In [94]:
model.fit(train, 10)

IndexError: list assignment index out of range