In [68]:
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchtext.datasets import text_classification
from torch.utils.data.dataset import random_split

In [2]:
NGRAMS=2
if not os.path.isdir('./__data__'):
    os.mkdir('./__data__')

train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](root='./__data__', ngrams=NGRAMS, vocab=None)

__data__\ag_news_csv.tar.gz: 11.8MB [00:24, 484kB/s]
120000lines [00:11, 10711.54lines/s]
120000lines [00:23, 5087.34lines/s]
7600lines [00:01, 5124.75lines/s]


In [7]:
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_classes):
        super(TextSentiment, self).__init__()
        
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_classes)
        self.init_weights()
        
    def init_weights(self):
        init_range = 0.5
        
        self.embedding.weight.data.uniform_(-init_range, init_range)
        self.fc.weight.data.uniform_(-init_range, init_range)
        self.fc.bias.data.zero_()
        
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

In [72]:
# hyperparameters
vocab_size = len(train_dataset.get_vocab())
embed_dim = 32
batch_size = 16
epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(train_dataset.get_labels())

In [8]:
# initialize model
model = TextSentiment(vocab_size, embed_dim, num_classes).to(device)

In [77]:
# generate the batch
def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1] for entry in batch]
    offsets = [0] + [len(entry) for entry in batch]
    
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
    return text, offsets, label

In [78]:
def train(train_data):
    t_loss, t_acc = 0, 0
    
    # get data
    data = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
    
    for i, (text, offsets, cls) in enumerate(data):
        optimizer.zero_grad()
        
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        
        # get output from models
        output = model(text, offsets)
        
        loss = criterion(output, cls)
        t_loss += loss.item()
        loss.backward()
        
        optimizer.step()
        
        acc = (output.argmax(1) == cls).sum().item()
        t_acc += acc
        
    scheduler.step()
    
    return t_loss / len(train_data), t_acc / len(train_data)

In [79]:
def test(test_data):
    t_loss, t_acc = 0, 0
    
    data = DataLoader(test_data, batch_size=batch_size, collate_fn=generate_batch)
    
    for text, offsets, cls in data:
        
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        
        with torch.no_grad():
            output = model(text, offsets)
            
            loss = criterion(output, cls)
            t_loss += loss.item()
            
            t_acc += (output.argmax(1) == cls).sum().item()
            
    return t_loss / len(test_data), t_acc / len(test_data)

In [82]:
# train the model
val_loss = float('inf')

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)

train_data, valid_data = random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(epochs):
    # start time for each epoch
    start = time.time()
    
    train_loss, train_acc = train(train_data)
    val_loss, val_acc = test(valid_data)
    
    print(f'Epoch {epoch+1}/{epochs}')
    print(f'{round(time.time() - start)}s - loss: {train_loss:.2f} - accuracy: {train_acc*100:.2f}% - '+\
          f'val_loss: {val_loss:.2f} - val_accuracy: {val_acc*100:.2f}%')

Epoch 1/5
71s - loss: 0.10 - accuracy: 0.25% - val_loss: 0.10 - val_accuracy: 0.24%


KeyboardInterrupt: 