In [1]:
import sys
sys.path.insert(0, "../..")

import torch
import pytorch_lightning as pl

from tqdm import tqdm
from torch.nn import Module, Linear
from torch.utils.data import DataLoader

from Utilities.Datasets import YahooDataset
from Utilities.Preprocessors import WordTokenizer, TokenIdPadding
from Utilities.ModelTrainers import SkipgramTrainerModule, MulticlassSentenceClassificationTrainerModule, WordEmbedding

In [2]:
dataset = YahooDataset(max_samples=10000, local_dir="small_yahoo_dataset")

In [5]:
dataset.train[4]

{'input': 'How do you stand regarding the Bush administration?\nFor me.... as far away as possible.\nThank God for George Bush!',
 'target': 10}

In [7]:
len(dataset.train), len(dataset.val), len(dataset.test)

(9000, 1000, 10000)

In [16]:
tokenizer = WordTokenizer(num_embeddings=20000, padding_idx=0)

max_token_length = 0
for sample in tqdm(dataset.train):
    tokens = tokenizer(sample["input"])
    if len(tokens) > max_token_length:
        max_token_length = len(tokens)
print(max_token_length)

100%|██████████| 9000/9000 [00:08<00:00, 1018.04it/s]

995





In [10]:
tokenizer = WordTokenizer(num_embeddings=20000, padding_idx=0)
padding = TokenIdPadding(padding_length=70, padding_idx=0)

token_ids = tokenizer(dataset.train[0]["input"])

output_dict = padding([token_ids])
output_dict

{'token_ids': array([[14522,  5751, 13488,  2349,  4308, 19015,  1972, 14434,  9681,
          4488,  5103, 11353, 17795,  1367, 18818,   785, 18818, 19833,
          8510,  8244, 18818, 11159,  2652, 14510,   199, 18818,  1237,
          4954,  2346, 14510, 12272,  6198,  4954,  9329, 14510,  3095,
         12738,  8510,  2224, 16522, 16891, 17498, 14606,  4970,  3793,
         18916,  5103,   302, 18915, 15787, 11291,  3799,  3912, 11375,
         13488, 12625, 18916, 14510,  2144,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]])}

In [8]:
class Preprocessor:
    def __init__(self):
        self.tokenizer = WordTokenizer(num_embeddings=20000, 
                                       padding_idx=0)
        
        self.padding = TokenIdPadding(padding_length=1000,
                                      padding_idx=0)
    
    def __call__(self, sample):
        # Process input
        token_ids = self.tokenizer(sample["input"])
        output_dict = self.padding([token_ids])
        token_ids = output_dict["token_ids"][0]
        
        # Process target
        class_id = sample["target"] - 1
        return {"input": token_ids, "target": class_id}

In [9]:
preprocessor = Preprocessor()

dataset.train.set_preprocessor(preprocessor)
dataset.val.set_preprocessor(preprocessor)
dataset.test.set_preprocessor(preprocessor)

In [10]:
dataset.train[0]

{'input': array([14522,  5751, 13488,  2349,  4308, 19015,  1972, 14434,  9681,
         4488,  5103, 11353, 17795,  1367, 18818,   785, 18818, 19833,
         8510,  8244, 18818, 11159,  2652, 14510,   199, 18818,  1237,
         4954,  2346, 14510, 12272,  6198,  4954,  9329, 14510,  3095,
        12738,  8510,  2224, 16522, 16891, 17498, 14606,  4970,  3793,
        18916,  5103,   302, 18915, 15787, 11291,  3799,  3912, 11375,
        13488, 12625, 18916, 14510,  2144,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     

In [11]:
dataloader_train = DataLoader(dataset.train, batch_size=32, shuffle=True)
dataloader_val = DataLoader(dataset.val, batch_size=32, shuffle=False)
dataloader_test = DataLoader(dataset.test, batch_size=32, shuffle=False)

In [12]:
for batch in dataloader_train:
    print(batch["input"].shape)
    print(batch["input"])
    print()
    print(batch["target"].shape)
    print(batch["target"])
    break

torch.Size([32, 1000])
tensor([[ 2652,  5328, 10728,  ...,     0,     0,     0],
        [  530, 17795,  1115,  ...,     0,     0,     0],
        [14480, 14555,  3912,  ...,     0,     0,     0],
        ...,
        [19901,  6968, 15787,  ...,     0,     0,     0],
        [ 6091,  3592,  1910,  ...,     0,     0,     0],
        [19580,  7887, 14510,  ...,     0,     0,     0]])

torch.Size([32])
tensor([1, 3, 4, 7, 7, 6, 5, 9, 4, 6, 6, 3, 7, 9, 6, 2, 6, 6, 4, 6, 5, 5, 6, 1,
        1, 3, 6, 8, 8, 1, 6, 4])


In [23]:
class YahooClassifier(Module):
    def __init__(self, word_embedding, embedding_dim, class_size):
        super().__init__()
        self.word_embedding = word_embedding
        self.linear_classifier = Linear(embedding_dim, class_size)
        
    def forward(self, token_ids):
        """
        token_ids: (batch_size, words_num)
        """
        # (batch_size, words_num, embedding_dim)
        outputs = self.word_embedding(token_ids)
        # (batch_size, embedding_dim)
        outputs = torch.max(outputs, dim=1)[0]
        # (batch_size, class_size)
        outputs = self.linear_classifier(outputs)
        return outputs

In [24]:
word_embedding = WordEmbedding(num_embeddings=20000, embedding_dim=300, padding_idx=0)
yahoo_classifier = YahooClassifier(word_embedding, embedding_dim=300, class_size=10)

In [25]:
for batch in dataloader_train:
    outputs = yahoo_classifier(batch["input"])
    print(outputs.shape)
    print(outputs)
    break

torch.Size([32, 10])
tensor([[-4.7722e-01, -9.3360e-02,  3.0805e+00,  6.2392e-02, -7.5214e-01,
          3.9694e-01, -1.0916e-01, -5.1078e-01, -1.0713e+00,  1.1232e+00],
        [-6.1630e-01, -3.6279e-01,  1.8831e+00,  3.2270e-01, -1.1273e+00,
          7.1028e-01,  2.6147e-01, -1.9018e-01, -9.6163e-01,  5.6575e-01],
        [-1.9175e-01, -4.6639e-01,  2.8217e+00,  1.5230e-01, -5.0953e-01,
          3.0075e-01,  1.3256e-01, -3.7502e-01, -9.4569e-01,  8.1514e-01],
        [ 5.1160e-02,  1.6008e-01,  2.5268e+00,  3.2699e-01, -2.5911e-01,
          2.2128e-01,  2.7860e-01, -5.4790e-01, -8.6681e-01,  3.8693e-01],
        [-4.6968e-01,  1.1829e-01,  3.2506e+00,  5.4877e-01, -4.4072e-01,
          6.4956e-01,  3.9342e-01, -9.4772e-01, -9.3286e-01,  1.1706e+00],
        [-3.2966e-01,  2.9893e-01,  2.5929e+00, -5.5727e-01, -7.3426e-01,
          2.3320e-01, -3.0079e-02, -6.3186e-01, -6.8745e-01,  6.6857e-01],
        [-4.2181e-02,  6.8372e-02,  2.2157e+00,  1.6892e-02, -3.6602e-01,
          6

In [None]:
classifier_trainer = MulticlassSentenceClassificationTrainerModule(yahoo_classifier)

In [None]:
trainer = pl.Trainer()
trainer.fit(classifier_trainer, dataset.train)