In [10]:
import torch
import torch.nn as nn
from torch.functional import F

#### get data

In [11]:
# # uncomment if not downloaded previously
# !wget https://download.pytorch.org/tutorial/data.zip
# !unzip data.zip

#### format data

In [12]:
import string
all_letters = string.ascii_letters + " .,;'"

In [13]:
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
import unicodedata

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

print(unicodeToAscii('Ślusàrski'))

Slusarski


In [14]:
import os
category_lines = {}
categories = []
path = '02-data/names/'
for fname in os.listdir(path):
    lang = fname.split('.')[0]
    categories.append(lang)
    category_lines[lang] = []
    with open(path+fname, 'r') as f:
      for line in f:
        category_lines[lang].append(unicodeToAscii(line.strip()))
', '.join(categories)

'Korean, Irish, Portuguese, Vietnamese, Czech, Russian, Scottish, German, Polish, Spanish, English, French, Japanese, Dutch, Greek, Chinese, Italian, Arabic'

In [15]:
categories

['Korean',
 'Irish',
 'Portuguese',
 'Vietnamese',
 'Czech',
 'Russian',
 'Scottish',
 'German',
 'Polish',
 'Spanish',
 'English',
 'French',
 'Japanese',
 'Dutch',
 'Greek',
 'Chinese',
 'Italian',
 'Arabic']

In [16]:
print(category_lines['Italian'][:5])

['Abandonato', 'Abatangelo', 'Abatantuono', 'Abate', 'Abategiovanni']


#### helpers

In [17]:
num_letters = len(all_letters);num_letters

57

In [18]:
def letterToIndex(letter):
    return all_letters.find(letter)
letterToIndex('n')

13

In [19]:
def char_to_one_hot(char):
    zeros = torch.zeros(1, num_letters)
    zeros[0][letterToIndex(char)] = 1
    return zeros

In [20]:
char_to_one_hot('n')

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0.]])

In [21]:
def seq_onehot(word):
    zeros = torch.zeros(len(word), 1, num_letters)
    for i, ch in enumerate(word):
        zeros[i][0][letterToIndex(ch)] = 1
    return zeros

In [22]:
seq_onehot('nan'), seq_onehot('nan').shape

(tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0.]],
 
         [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0.]]]),
 torch.Size([3, 1, 57]))

In [23]:
print(seq_onehot('Jones').size())

torch.Size([5, 1, 57])


In [24]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.input_and_hidden_to_next_hidden = nn.Linear(input_size + hidden_size, hidden_size)
        self.new_hidden_to_output = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input_tensor, hidden):
        next_hidden = self.input_and_hidden_to_next_hidden(torch.cat((input_tensor, hidden), 1))
        output = self.new_hidden_to_output(next_hidden)
        output = self.softmax(output)
        

        return output, next_hidden

In [25]:
rnn = RNN(num_letters, 50, len(categories))

In [26]:
hidden = torch.zeros((1, 50))

In [27]:
for char in seq_onehot('nan'):
    output, hidden = rnn(char, hidden)
    print(output.shape)

torch.Size([1, 18])
torch.Size([1, 18])
torch.Size([1, 18])


In [28]:
def categoryFromOutput(output):
    top_n, top_i = output.topk(1)
    category_i = top_i[0].item()
    return categories[category_i], category_i

In [29]:
categoryFromOutput(output)

('Greek', 14)

In [30]:
torch.cat((seq_onehot('nan')[0], hidden), 1).shape

torch.Size([1, 107])

In [31]:
loss_fn = nn.NLLLoss()

In [32]:
import random

def get_random_sample():
    def rand(collection):
        rand_int = random.randint(0, len(collection)-1)
        return collection[rand_int]
    cat = rand(categories)
    item = rand(category_lines[cat])
    cat_tensor = torch.tensor([categories.index(cat)])
    item_tensor = seq_onehot(item)
    return cat, item, cat_tensor, item_tensor

In [33]:
get_random_sample()

('Vietnamese',
 'Mai',
 tensor([3]),
 tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0.]],
 
         [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0.]]]))

In [34]:
cat, item, cten, itens = get_random_sample()

In [35]:
for ch in itens:
    out, hidden = rnn(ch, hidden)

In [36]:
out[0][torch.argmax(out).item()]

tensor(-2.6402, grad_fn=<SelectBackward0>)

In [37]:
# loss_fn(out[0][torch.argmax(out).item()], cten.float()).backward()

In [38]:
hidden

tensor([[ 0.0091, -0.0891,  0.0167, -0.1053,  0.0541,  0.0241,  0.1082,  0.0874,
          0.0062, -0.1450,  0.1753, -0.0888,  0.0708,  0.0205, -0.0243,  0.1996,
         -0.1346, -0.0241, -0.0224, -0.0857,  0.0518, -0.0554,  0.0642,  0.1156,
         -0.0324, -0.0037, -0.0876,  0.1573,  0.1526,  0.0521,  0.1785, -0.0898,
          0.0717, -0.0123, -0.1349,  0.0199, -0.1606,  0.0172, -0.1679,  0.0666,
         -0.1905, -0.1292, -0.1296, -0.0537,  0.1229,  0.1395,  0.0612, -0.0770,
          0.1526, -0.0916]], grad_fn=<AddmmBackward0>)

In [39]:
n_hidden = 50
rnn = RNN(num_letters, n_hidden, len(categories))
optimizer = torch.optim.SGD(rnn.parameters(), lr=0.001)


def train(input_tensor, cat_tensor):
    hidden = torch.zeros((1, n_hidden))

    for ch in input_tensor:
        out, hidden = rnn(ch, hidden)
    
    loss = loss_fn(out, cat_tensor)
    # # hidden.detach_()
    # # out.detach_()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return output, loss.item()


In [40]:
for i in range(1000):
    cat, item, cten, itens = get_random_sample()
    output, loss = train(itens, cten)
    
    if i % 1000 == 0:
        print(f'Loss: {loss}')
        print([torch.argmax(output).item()] == cten.item())


Loss: 2.78035569190979
False


In [41]:
cten.float().item()

6.0

In [42]:
output.dtype

torch.float32

In [43]:
for gg in rnn.parameters():
    print(gg.grad.zero_())
    break

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [44]:
randomTrainingExample = get_random_sample

In [45]:
learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn

def train(category_tensor, line_tensor):
    hidden = torch.zeros((1, n_hidden))

    rnn.zero_grad()

    for i in range(line_tensor.size()[0]):
        output, hidden = rnn(line_tensor[i], hidden)

    loss = loss_fn(output, category_tensor.float())
    loss.backward()

    # Add parameters' gradients to their values, multiplied by learning rate
    for p in rnn.parameters():
        p.data.add_(p.grad.data, alpha=-learning_rate)

    return output, loss.item()

In [46]:
import time
import math

n_iters = 10000
print_every = 5000
plot_every = 1000



# Keep track of losses for plotting
current_loss = 0
all_losses = []

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

start = time.time()

for iter in range(1, n_iters + 1):
    category, line, category_tensor, line_tensor = randomTrainingExample()
    output, loss = train(category_tensor, line_tensor)
    current_loss += loss

    # Print ``iter`` number, loss, name and guess
    if iter % print_every == 0:
        guess, guess_i = categoryFromOutput(output)
        correct = '✓' if guess == category else '✗ (%s)' % category
        print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))

    # Add current loss avg to list of losses
    if iter % plot_every == 0:
        all_losses.append(current_loss / plot_every)
        current_loss = 0

RuntimeError: expected scalar type Long but found Float

In [47]:
categories[:5], "_________",{k: category_lines[k][:2] for k in category_lines}

(['Korean', 'Irish', 'Portuguese', 'Vietnamese', 'Czech'],
 '_________',
 {'Korean': ['Ahn', 'Baik'],
  'Irish': ['Adam', 'Ahearn'],
  'Portuguese': ['Abreu', 'Albuquerque'],
  'Vietnamese': ['Nguyen', 'Tron'],
  'Czech': ['Abl', 'Adsit'],
  'Russian': ['Ababko', 'Abaev'],
  'Scottish': ['Smith', 'Brown'],
  'German': ['Abbing', 'Abel'],
  'Polish': ['Adamczak', 'Adamczyk'],
  'Spanish': ['Abana', 'Abano'],
  'English': ['Abbas', 'Abbey'],
  'French': ['Abel', 'Abraham'],
  'Japanese': ['Abe', 'Abukara'],
  'Dutch': ['Aalsburg', 'Aalst'],
  'Greek': ['Adamidis', 'Adamou'],
  'Chinese': ['Ang', 'AuYong'],
  'Italian': ['Abandonato', 'Abatangelo'],
  'Arabic': ['Khoury', 'Nahas']})

In [48]:
allLetters = all_letters

In [49]:
numLetters = len(all_letters)

In [50]:
numLetters

57

In [51]:
class RNN2(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN2, self).__init__()
        self.hidden_size = hidden_size
        self.input_and_hidden_to_new_hidden = nn.Linear(input_size+hidden_size, hidden_size)
        self.new_hidden_to_output = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input_tensor, hidden_tensor):
        concatted = torch.cat((input_tensor, hidden_tensor), dim=1)
        next_hidden = self.input_and_hidden_to_new_hidden(concatted)
        output = self.new_hidden_to_output(next_hidden)
        # print(output)
        # print(output.argmax())
        # print("--------------")
        # print(nn.Softmax(dim=1)(output))
        output = self.softmax(output)
        # print(output)
        # print(output.argmax())
        return output, next_hidden
    
    def init_hidden(self):
        return torch.zeros(1, self.hidden_size)
        

In [52]:
def init_hidden(hidden_size):
    return torch.zeros(1, hidden_size)

In [53]:
init_hidden(128)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.]])

In [54]:
get_random_sample()

('Portuguese',
 'Cardozo',
 tensor([2]),
 tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0.]],
 
         [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [55]:
label, name, labelTensor, nameTensor = get_random_sample()

In [56]:
numLetters = num_letters

In [57]:
numCats = len(categories)

In [58]:
numHidden = 128

In [59]:
rnn2 = RNN2(numLetters, numHidden, numCats)

In [60]:
hidden = init_hidden(numHidden)

In [61]:
for i in range(nameTensor.size()[0]):
    output, hidden = rnn2(nameTensor[i], hidden)


In [62]:
tt = torch.tensor([[0.0503, 0.0554, 0.0539, 0.0609, 0.0575, 0.0518, 0.0531, 0.0571, 0.0536,
         0.0602, 0.0603, 0.0549, 0.0533, 0.0536, 0.0547, 0.0514, 0.0572, 0.0607]])

In [63]:
learning_rate = 0.005

In [64]:
a = torch.log(tt)

In [65]:
nn.Softmax(dim=1)(a)

tensor([[0.0503, 0.0554, 0.0539, 0.0609, 0.0575, 0.0518, 0.0531, 0.0571, 0.0536,
         0.0602, 0.0603, 0.0549, 0.0533, 0.0536, 0.0547, 0.0514, 0.0572, 0.0607]])

In [66]:
a.topk(3, 1, False)

torch.return_types.topk(
values=tensor([[-2.9898, -2.9681, -2.9604]]),
indices=tensor([[ 0, 15,  5]]))

In [67]:
a.topk??

[0;31mDocstring:[0m
topk(k, dim=None, largest=True, sorted=True) -> (Tensor, LongTensor)

See :func:`torch.topk`
[0;31mType:[0m      builtin_function_or_method


In [68]:
def train(model, optimizer, loss_fn, sampleTensor, labelTensor):
    model.train()
    
    model.zero_grad()
    hidden = model.init_hidden()
    for i in range(nameTensor.size()[0]):
        output, hidden = rnn2(nameTensor[i], hidden)
    
    loss = loss_fn(output, labelTensor)
    loss.backward()
    
    # Add parameters' gradients to their values, multiplied by learning rate
    for p in model.parameters():
        p.data.add_(p.grad.data, alpha=-learning_rate)
    
    return loss.item()


In [69]:
rnn2 = RNN2(numLetters, numHidden, numCats)
optimizer = torch.optim.SGD(rnn2.parameters(), lr=0.0001)
loss_fn = torch.nn.NLLLoss()

In [70]:

%time
for i in range(1):
    label, name, labelTensor, nameTensor = get_random_sample()
    loss = train(rnn2,optimizer,loss_fn,nameTensor, labelTensor)
    
    if i % 1000 == 0:
        print(f"{i} Loss: {loss}")

CPU times: user 3 µs, sys: 1 µs, total: 4 µs
Wall time: 7.87 µs
0 Loss: 2.7582786083221436


In [71]:
def evaluate():
    hidden = rnn2.init_hidden()
    label, name, labelTensor, nameTensor = get_random_sample()
    for i in range(nameTensor.shape[0]):
        output, hidden = rnn2(nameTensor[i], hidden)
    print(f"{output.argmax() == labelTensor} name: {name}; predicted: {categories[output.argmax().item()]}; Actual: {label}")
    return (output.argmax() == labelTensor).float().item()
    

In [72]:
    label, name, labelTensor, nameTensor = get_random_sample()
    for i in range(nameTensor.shape[0]):
        output, hidden = rnn2(nameTensor[i], hidden)

In [73]:
(output.argmax() == labelTensor).float().item()

0.0

In [74]:
categories[output.argmax().item()]

'German'

In [75]:
correct = 0
for i in range(1):
    correct += evaluate()

correct / 100
    

tensor([False]) name: Dao; predicted: Korean; Actual: Vietnamese


0.0

In [76]:
str.

SyntaxError: invalid syntax (2101759970.py, line 1)

In [77]:
import string

In [78]:
allLetters = string.ascii_letters

In [79]:
# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
import unicodedata

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in allLetters
    )

print(unicodeToAscii('Ślusàrski'))

Slusarski


In [80]:
[c for c in unicodedata.normalize('NFD', 'Ślusàrski') if c in allLetters]

['S', 'l', 'u', 's', 'a', 'r', 's', 'k', 'i']

In [81]:
unicodedata.normalize('NFD', 'Ślusàrski')

'Ślusàrski'

In [31]:
unicodedata.category

<function unicodedata.category(chr, /)>

In [33]:
Path('02-data')

NameError: name 'Path' is not defined

In [35]:
import os

In [47]:
os.listdir('02-data/names/')

['Korean.txt',
 'Irish.txt',
 'Portuguese.txt',
 'Vietnamese.txt',
 'Czech.txt',
 'Russian.txt',
 'Scottish.txt',
 'German.txt',
 'Polish.txt',
 'Spanish.txt',
 'English.txt',
 'French.txt',
 'Japanese.txt',
 'Dutch.txt',
 'Greek.txt',
 'Chinese.txt',
 'Italian.txt',
 'Arabic.txt']

In [46]:
with open('02-data/names/Korean.txt', 'r') as f:
    for l in f:
        print(l)
        break

Ahn



In [44]:
os.open?

[0;31mSignature:[0m [0mos[0m[0;34m.[0m[0mopen[0m[0;34m([0m[0mpath[0m[0;34m,[0m [0mflags[0m[0;34m,[0m [0mmode[0m[0;34m=[0m[0;36m511[0m[0;34m,[0m [0;34m*[0m[0;34m,[0m [0mdir_fd[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Open a file for low level IO.  Returns a file descriptor (integer).

If dir_fd is not None, it should be a file descriptor open to a directory,
  and path should be relative; path will then be relative to that directory.
dir_fd may not be implemented on your platform.
  If it is unavailable, using it will raise a NotImplementedError.
[0;31mType:[0m      builtin_function_or_method


In [97]:
cats = []
dict = {}
path = '02-data/names/'
for fname in os.listdir('02-data/names/'):
    lang = fname.split('.')[0]
    cats.append(lang)
    dict[lang] = []
    with open(path+fname) as f:
        for line in f:
            dict[lang].append(unicodeToAscii(line.strip()))

In [98]:
dict['Korean'][:5]

['Ahn', 'Baik', 'Bang', 'Byon', 'Cha']

In [99]:
[f"{k}: {len(dict[k])}" for k in dict.keys()]

['Korean: 94',
 'Irish: 232',
 'Portuguese: 74',
 'Vietnamese: 73',
 'Czech: 519',
 'Russian: 9408',
 'Scottish: 100',
 'German: 724',
 'Polish: 139',
 'Spanish: 298',
 'English: 3668',
 'French: 277',
 'Japanese: 991',
 'Dutch: 297',
 'Greek: 203',
 'Chinese: 268',
 'Italian: 709',
 'Arabic: 2000']

In [100]:
import torch
torch.zeros(1, len(allLetters))

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.]])

In [101]:
def sToI(ch):
    return allLetters.index(ch)

In [102]:
def onehotChar(char):
    zeros = torch.zeros(1, len(allLetters))
    zeros[0, sToI(char)] = 1
    return zeros

In [103]:
def stringToTensor(string):
    return torch.stack([onehotChar(ch) for ch in string])

In [104]:
import random

In [107]:
def randomSample():
    cat = random.choice(cats)
    name = random.choice(dict[cat])
    return name, cat

In [119]:
def sampleAsTensor(sample):
    name, cat = sample
    name = stringToTensor(name)
    cat = torch.tensor([cats.index(cat)])
    return name, cat

In [120]:
sampleAsTensor(('Pettigrew', 'French'))

(tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0.]],
 
         [[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0.]],
 
         [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,