In [118]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from util import randomChoice, lineToTensor
from util import categoryFromOutput

import string
import time
import math
import random


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)
    
    def getCategory(self, word):
        lineTensor = lineToTensor(word)
        hidden = rnn.initHidden()
        
        for i in range(lineTensor.size()[0]):
            output, hidden = rnn(lineTensor[i], hidden)
            
        category = categoryFromOutput(output)
        return category

In [119]:
n_hidden = 128

all_letters = string.ascii_letters + " .,;'" + "äÄüÜöÖ"
n_letters = len(all_letters)

rnn = RNN(n_letters, n_hidden, 3)

learning_rate = 0.005

criterion = nn.NLLLoss()

In [120]:
training_data = ["Käse Butter Milch Joghurt Buttermilch Erdbeermilch Schokomilch".upper().split(), "Schokolade Eis Schokocreme Gummibärchen".upper().split()]

In [121]:
catergories = ["Milchprodukte", "Süßwaren"]

In [122]:
def randomTrainingExample():
    category = random.randint(0,1)
    line = randomChoice(training_data[category])

    category_tensor = torch.tensor([category], dtype=torch.long)

    line_tensor = lineToTensor(line)
    return category, line, category_tensor, line_tensor

In [123]:
def train(category_tensor, line_tensor):
    hidden = rnn.initHidden()
    rnn.zero_grad()

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)
        
    loss = criterion(output, category_tensor)
    loss.backward()

    # Add parameters' gradients to their values, multiplied by learning rate
    for p in rnn.parameters():
        p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, loss.item()

In [124]:
def evaluate(line_tensor):
    hidden = rnn.initHidden()

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    return output

In [125]:
n_iters = 100000
print_every = 5000
plot_every = 1000

current_loss = 0
all_losses = []

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

start = time.time()

for iter in range(1, n_iters + 1):
    category, line, category_tensor, line_tensor = randomTrainingExample()
    output, loss = train(category_tensor, line_tensor)
    current_loss += loss

    # Print iter number, loss, name and guess
    if iter % print_every == 0:
        guess_i = categoryFromOutput(output)
        correct = '✓' if guess_i == category else '✗ (%s)' % category
        print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess_i, correct))

    # Add current loss avg to list of losses
    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0

5000 5% (0m 4s) 0.0100 EIS / 1 ✓
10000 10% (0m 8s) 0.0038 EIS / 1 ✓
15000 15% (0m 12s) 0.0002 JOGHURT / 0 ✓
20000 20% (0m 17s) 0.0016 EIS / 1 ✓
25000 25% (0m 21s) 0.0001 SCHOKOLADE / 1 ✓
30000 30% (0m 25s) 0.0001 GUMMIBÄRCHEN / 1 ✓
35000 35% (0m 30s) 0.0000 SCHOKOLADE / 1 ✓
40000 40% (0m 34s) 0.0000 SCHOKOCREME / 1 ✓
45000 45% (0m 38s) 0.0001 SCHOKOMILCH / 0 ✓
50000 50% (0m 43s) 0.0000 SCHOKOLADE / 1 ✓
55000 55% (0m 47s) 0.0000 BUTTERMILCH / 0 ✓
60000 60% (0m 50s) 0.0000 ERDBEERMILCH / 0 ✓
65000 65% (0m 54s) 0.0002 KÄSE / 0 ✓
70000 70% (0m 58s) 0.0000 GUMMIBÄRCHEN / 1 ✓
75000 75% (1m 2s) 0.0000 MILCH / 0 ✓
80000 80% (1m 6s) 0.0003 EIS / 1 ✓
85000 85% (1m 10s) 0.0000 SCHOKOLADE / 1 ✓
90000 90% (1m 14s) 0.0000 BUTTERMILCH / 0 ✓
95000 95% (1m 18s) 0.0000 MILCH / 0 ✓
100000 100% (1m 22s) 0.0000 BUTTER / 0 ✓


In [138]:
word = "Schokoladenmilch"
cat = rnn.getCategory(word.upper())
print(word + " is in catergory " + str(cat) + " (" + str(catergories[cat]) + ")")

Schokoladenmilch is in catergory 0 (Milchprodukte)
