# Simple two-layer neural network to classify documents into three classes

In [None]:
# import dependencies, set a random set, and define the dataset

import re, torch, torch.nn as nn

torch.manual_seed(42)

docs = [
    "Movies are fun for everyone",
    "Watching movies is great fun",
    "Enjoy a great movie tonight.",
    "Research is interesting and important",
    "Learning math is very important",
    "Science discovery is interesting",
    "Rock is great to listen to",
    "Listen to music for fun",
    "Music is fun for everyone.",
    "Listen to folk music!."
]

labels = [1,1,1,3,3,3,2,2,2,2]
num_classes = len(set(labels))

In [4]:
# convert documents to a bag of words using tokenize and get_vocabulary

def tokenize(text):
    return re.findall(r"\w+", text.lower())

def get_vocabulary(texts):
    tokens = {token for text in texts for token in tokenize(text)}
    return {word: idx for idx, word in enumerate(sorted(tokens))}

vocabulary = get_vocabulary(docs)

In [5]:
vocabulary

{'a': 0,
 'and': 1,
 'are': 2,
 'discovery': 3,
 'enjoy': 4,
 'everyone': 5,
 'folk': 6,
 'for': 7,
 'fun': 8,
 'great': 9,
 'important': 10,
 'interesting': 11,
 'is': 12,
 'learning': 13,
 'listen': 14,
 'math': 15,
 'movie': 16,
 'movies': 17,
 'music': 18,
 'research': 19,
 'rock': 20,
 'science': 21,
 'to': 22,
 'tonight': 23,
 'very': 24,
 'watching': 25}

In [6]:
# feature extraction function that converts a document into a feature vector

def doc_to_bow(doc, vocabulary):
    tokens = set(tokenize(doc))
    bow = [0]*len(vocabulary)
    for token in tokens:
        if token in vocabulary:
            bow[vocabulary[token]] = 1
    return bow

In [7]:
# transform documents and labels into numbers
vectors = torch.tensor(
    [doc_to_bow(doc, vocabulary) for doc in docs],
    dtype=torch.float32
)
labels = torch.tensor(labels, dtype=torch.long) - 1

In [9]:
print("vectors: \n{}\n\nlabels: \n{}".format(vectors, labels))

vectors: 
tensor([[0., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 1.,
         0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,
         0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0.,
         0., 0., 1., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         1., 0., 0., 0., 1., 0., 0., 0.],
      