# Makemore y wavenet
## Código inicial...

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
import idna

# def utf8_to_punycode(text: str) -> str:
#     """Encodes a UTF-8 string to its Punycode representation."""
#     return idna.encode(text).decode('ascii')

def punyencode(text: str) -> str:
    """Encodes a UTF-8 string to its Punycode representation, handling spaces by encoding each word separately."""
    
    return " ".join([idna.encode(word).decode('ascii') for word in text.split()])
    
def punydecode(punycode: str) -> str:
    """Decodes a Punycode string back to UTF-8."""
    #return idna.decode(punycode)
    return " ".join([idna.decode(word) for word in punycode.split()])

def process_name(name):
    name = name.lower()
    for n in name.split():
        if len(n) < 2:
            return ''
    try:
        return punyencode(name)
    except:
        #print(f'Cant convert {name}')
        return ''

dataset = open("data/city_names_full.txt", 'r').read().split('\n')
with open('data/city_names_puny.txt', 'w') as f:
    for n in dataset:
        name = process_name(n)
        if name != '':
            f.write(name+'\n')
dataset = open("data/city_names_puny.txt", 'r').read().split('\n')
puny = [x for x in dataset if 'xn--' in x]
nopuny = [x for x in dataset if 'xn--' not in x]
np.random.seed(42)
dataset = [x.item() for x in np.random.choice(nopuny, 100000,replace=False)]

In [None]:
charset = ['*'] + sorted(list(set([y for x in dataset for y in x])))
ctoi = {c:i for i, c in enumerate(charset)}
itoc = {i:c for i, c in enumerate(charset)}
charset_size = len(charset)

In [None]:
def build_dataset(dataset: list):
    X, Y  = [], []
    for d in dataset:
        example = list(d) + ['*']
        context = [0] * context_size
        for c in example:
            X.append(context)
            Y.append(ctoi[c])
            context = context[1:] + [ctoi[c]] 
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

# build the dataset
context_size = 3
np.random.seed(42)
np.random.shuffle(dataset)
n1 = int(.8 * len(dataset))  # límite para el 80% del dataset
n2 = int(.9 * len(dataset))  # límite para el 90% del dataset
Xtr, Ytr = build_dataset(dataset[:n1])    # 80%
Xva, Yva = build_dataset(dataset[n1:n2])  # 10%
Xte, Yte = build_dataset(dataset[n2:])    # 10%

In [None]:
class Linear:
    def __init__(self, input_dim, output_dim, bias=True, generator=torch.Generator().manual_seed(42)):
        self.W = torch.randn(input_dim, output_dim, generator=generator)/(input_dim ** 0.5)
        self.b = torch.zeros(output_dim) if bias else None

    def __call__(self, x):
        self.out = x @ self.W
        if self.b is not None:
            self.out += self.b
        return self.out

    def parameters(self):
        return [self.W] + ([] if self.b is None else [self.b])

class Tanh:
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out

    def parameters(self):
        return []

class BatchNorm1d:
    def __init__(self, input_size, momentum=0.001, eps=0.0005):
        self.momentum = momentum
        self.eps = eps
        self.training_mode_on = True
        # los parametros
        self.gamma = torch.ones(input_size)
        self.beta = torch.zeros(input_size)
        self.running_mean = torch.zeros(input_size)
        self.runnint_std = torch.ones(input_size)

    def __call__(self, x):
        if self.training_mode_on:
            xmean = x.mean(0, keepdims=True)
            xstd = x.std(0, keepdims=True)
            with torch.no_grad():
                self.running_mean = self.running_mean * (1 - self.momentum) + xmean * self.momentum
                self.runnint_std = self.runnint_std * (1 - self.momentum) + xstd * self.momentum
        else:
            xmean = self.running_mean
            xstd = self.runnint_std
        # normalizamos x para que tenga distribución N(0, 1)
        xhat = (x - xmean)/ (xstd + self.eps)
        self.out = self.gamma * xhat + self.beta
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]


class Model:
    def __init__(self, charset_size, context_size, emb_size, hidden_size, g=torch.Generator().manual_seed(42)):
        self.charset_size = charset_size
        self.context_size = context_size
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        self.C = torch.randn(self.charset_size, self.emb_size, generator=g)
        self.layers = [Linear(self.emb_size*self.context_size, self.hidden_size, bias=False, generator=g), BatchNorm1d(self.hidden_size), Tanh(),
                       Linear(self.hidden_size, self.hidden_size, bias=False), BatchNorm1d(self.hidden_size), Tanh(),
                       Linear(self.hidden_size, self.hidden_size, bias=False), BatchNorm1d(self.hidden_size), Tanh(),
                       Linear(self.hidden_size, self.hidden_size, bias=False), BatchNorm1d(self.hidden_size), Tanh(),
                       Linear(self.hidden_size, self.hidden_size, bias=False), BatchNorm1d(self.hidden_size), Tanh(),
                       Linear(self.hidden_size, self.charset_size)
                      ]

        # Kaiming init para todas las capas menos la última
        for l in self.layers[:-1]:
            if isinstance(l, Linear):
                l.W *= 5/3
        self.layers[-1].W *= 0.1  # La última capa es menos confianzuda

        # require_grad para todos los parámetros
        for p in self.parameters():
            p.requires_grad = True
    
    def parameters(self):
        return [self.C] + [p for l in self.layers for p in l.parameters()]

    def to(self, device):
        for p in self.parameters():
            p.to(device)

    def count_parameters(self):
        return sum([p.nelement() for p in self.parameters()])

    def __call__(self, x):
        self.emb = self.C[x]
        x = self.emb.view(-1, self.emb_size*self.context_size)
        for l in self.layers:
            x = l(x)
        return x