In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [62]:
class PositionalEncoder(nn.Module):
    def __init__(self,d_model,dropout = 0.1,max_seq_len = 200):
        super(PositionalEncoder,self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        positional_encodings = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term =  torch.pow(1000,2*torch.arange(0, d_model, 2).float()/d_model) 
        positional_encodings[:, 0::2] = torch.sin(position / div_term)
        positional_encodings[:, 1::2] = torch.cos(position / div_term)
        positional_encodings = positional_encodings.unsqueeze(0).transpose(0, 1) #shape is max_seq_len,1,d_model
        self.register_buffer('positional_encodings', positional_encodings)

    def forward(self, x):
        x = x + self.positional_encodings[:x.size(0), :]
        return self.dropout(x)

class Embedder(nn.Module):
    def __init__(self,vocab_size,d_model,dropout=0.1,max_seq_len=200):
        super(Embedder,self).__init__()
        if d_model%2!=0:
            d_model+=1 #ensures positional embeddings have both sine and cosine component for all indices.
        self.d_model = d_model #model embedding dimension
        self.embed = nn.Embedding(vocab_size,d_model)
        self.positional_embedder = PositionalEncoder(d_model,dropout,max_seq_len)
        
    def forward(self, x):
        embedded=self.embed(x)
        return self.positional_embedder(embedded)

In [154]:
class TransformerBlock(nn.Module):
    def __init__(self,d_model,num_heads=8,ff_hidden=4):
        super(TransformerBlock,self).__init__()
        self.attentions = [nn.MultiheadAttention(d_model, num_heads) for _ in range(2)]
        self.layer_norms = [nn.LayerNorm(d_model) for _ in range(4)]
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model,ff_hidden*d_model),
            nn.ReLU(),
            nn.Linear(ff_hidden*d_model,d_model),
        )
        

    def forward(self,x):
        normed = self.layer_norms[0](x)
        attended,_ = self.attentions[0](normed,normed,normed,need_weights=False)
        normed = self.layer_norms[1](attended+x)
        forwarded = self.feed_forward(normed)
        
        normed = self.layer_norms[2](forwarded+attended)
        attended,_ = self.attentions[1](normed,normed,normed,need_weights=False)
        normed = self.layer_norms[3](forwarded+attended)
        forwarded = self.feed_forward(normed)
        
        return forwarded+attended

In [203]:
class ClassificationTransformer(nn.Module):
    def __init__(self,d_model,vocab_size,num_classes,num_heads=8,max_seq_len =200,dropout=0.1,max_pool=True):
        super(ClassificationTransformer,self).__init__()
        self.max_pool=max_pool
        self.embedder = Embedder(vocab_size,d_model,dropout,max_seq_len)
        self.transformer_block = TransformerBlock(d_model,num_heads)
        self.to_probability = nn.Linear(d_model,num_classes)
        
    def forward(self,x):
        x=self.embedder(x)
        x=self.transformer_block(x)
        x=self.to_probability(x)
        x = x.max(dim=1)[0] if self.max_pool else x.mean(dim=1) # pool over the time dimension
        return F.log_softmax(x,dim=1)

In [69]:
import torchtext
from torchtext.datasets import text_classification
NGRAMS = 2
import os
if not os.path.isdir('./.data'):
    os.mkdir('./.data')
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](root='./.data', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ag_news_csv.tar.gz: 11.8MB [13:58, 14.1kB/s]
120000lines [00:12, 9906.52lines/s] 
120000lines [00:23, 5103.99lines/s]
7600lines [00:01, 5495.75lines/s]


In [70]:
device

device(type='cpu')

In [209]:
VOCAB_SIZE = len(train_dataset.get_vocab())
EMBED_DIM = 32
MAX_SEQ_LEN=200
NUM_CLASS = len(train_dataset.get_labels())
model = ClassificationTransformer(EMBED_DIM,VOCAB_SIZE,NUM_CLASS,max_seq_len=MAX_SEQ_LEN).to(device)

200


In [212]:
import torch.nn.functional as F
def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text=[F.pad(entry[1], (0,200-entry[1].shape[0]), mode='constant', value=0) for entry in batch]
    text=[torch.unsqueeze(entry,0) for entry in text]
    text = torch.cat(text)
    return text, label

from torch.utils.data import DataLoader

def train_func(sub_train_):
    # Train the model
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,collate_fn=generate_batch)
    for i, (text, cls) in enumerate(tqdm(data)):
        optimizer.zero_grad()
        text, cls = text.to(device), cls.to(device)
        output = model(text)
        loss = criterion(output, cls)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == cls).sum().item()

    # Adjust the learning rate
    scheduler.step()

    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, cls in tqdm(data):
        text, cls = text.to(device), cls.to(device)
        with torch.no_grad():
            output = model(text)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

In [213]:
import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss, train_acc = train_func(sub_train_)
    print('Trained for epoch',epoch,', now validating.')
    valid_loss, valid_acc = test(sub_valid_)
    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
    print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
    print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

  0%|          | 20/7125 [00:07<41:27,  2.86it/s] 

KeyboardInterrupt: 