Prerequisite per [pytorch](https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html#prerequisites)

In [8]:
!pip install -U portalocker>=2.0.0

Import all the necessary libraries and set the device

In [9]:
import torch
import torch.nn as nn
from torchtext.datasets import AmazonReviewPolarity
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator
import torch.nn.functional as F
import time
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [10]:
train_iter = AmazonReviewPolarity(split="train")

In [14]:
next(iter(train_iter))

(2,
 'Stuning even for the non-gamer This sound track was beautiful! It paints the senery in your mind so well I would recomend it even to people who hate vid. game music! I have played the game Chrono Cross but out of all of the games I have ever played it has the best music! It backs away from crude keyboarding and takes a fresher step with grate guitars and soulful orchestras. It would impress anyone who cares to listen! ^_^')

In [15]:
train_iter = AmazonReviewPolarity(split="train")

tokenizer = get_tokenizer("basic_english")

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

# Add <pad> to your vocabulary
vocab.append_token("<pad>")
pad_index = vocab["<pad>"]

# Add <eos> to your vocabulary
vocab.append_token("<eos>")
eos_index = vocab["<eos>"]


def collate_batch(batch):
    label_list, text_list = [], []
    # max_length = max(len(text_pipeline(entry[1])) for entry in batch)  # Find the longest sequence in the batch
    max_length = 128
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = text_pipeline(_text)

        # Padding
        while len(processed_text) < max_length - 1:
            processed_text.append(pad_index)

        if len(processed_text) >= 128:
            processed_text = processed_text[:127]

        processed_text.append(eos_index)

        # Append the processed text to text_list
        text_list.append(processed_text)

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.stack([torch.tensor(t, dtype=torch.int64) for t in text_list])  # Convert list of lists to a tensor
    return label_list, text_list

dataloader = DataLoader(
    train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch
)

In [16]:
# The Embedding Bag helps with the efficiency, but it makes the words lose
# positional importance. While it helps with the training process and removes
# the need for padding/truncating the input, it 1-dimensionalizes the data.
class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)


class SelfAttentionHead(nn.Module):
    def __init__(self, input_dim, head_size):
        super(SelfAttentionHead, self).__init__()
        self.key = nn.Linear(input_dim, head_size, bias=False)
        self.query = nn.Linear(input_dim, head_size, bias=False)
        self.value = nn.Linear(input_dim, head_size, bias=False)

        self.mask = torch.tril(torch.ones(block_size, block_size, device=device))

        self.dropout = nn.Dropout(0.2)


    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        x = q @ k.transpose(-2, -1)
        x = x.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        x = F.softmax(x, dim=-1)
        x = self.dropout(x)

        x = x @ v
        return x

        # I was missing dropout
        # key = self.key(x)
        # query = self.query(x)
        # kq = key @ query.transpose(-2, -1)
        # value = self.value(x)

        # kq = F.softmax(kq, dim=-1)
        # x = kq @ value
        # return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, input_dim, head_size, num_head):
        super(MultiHeadSelfAttention, self).__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(input_dim, head_size) for _ in range(num_head)])
        self.proj = nn.Linear(head_size * num_head, input_dim)
        self.dropout = nn.Dropout(.2)

    def forward(self, x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.dropout(self.proj(x))
        return x

class Block(nn.Module):
    def __init__(self, input_dim, num_head):
        super().__init__()
        head_size = input_dim // num_head
        self.self_attention = MultiHeadSelfAttention(input_dim, head_size, num_head)
        self.dense1 = nn.Linear(input_dim, 4 * input_dim)
        self.dense2 = nn.Linear(4 * input_dim, input_dim)
        self.dropout = nn.Dropout(.2)
        self.ln1 = nn.LayerNorm(input_dim)
        self.ln2 = nn.LayerNorm(input_dim)

    def forward(self, x):
        x = x + self.self_attention(self.ln1(x))
        x = self.ln2(x)
        x = F.relu(self.dense1(x))
        x = self.dropout(self.dense2(x))
        return x

class TextClassificationModelWithAttention(nn.Module):
    def __init__(self, vocab_size, block_size, embed_dim, num_head, num_class):
        super(TextClassificationModelWithAttention, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim, device=device)
        self.positional_embedding = nn.Embedding(block_size, embed_dim, device=device)
        self.blocks = nn.Sequential(*[Block(embed_dim, num_head) for _ in range(4)])
        self.ln = nn.LayerNorm(embed_dim)
        self.dense = nn.Linear(embed_dim, num_class)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, y=None):
        B, T = x.shape

        token_embed = self.token_embedding(x)
        position_embed = self.positional_embedding(torch.arange(T, device=device))
        x = token_embed + position_embed
        x = self.blocks(x)
        x = self.ln(x)
        x = x.mean(dim=1)
        x = self.dense(x)
        x = F.softmax(x, dim=1)


        loss = None
        if y is not None:
            loss = F.cross_entropy(x, y)

        return x, loss

In [18]:
num_class = 2 #len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
embed_dim = 64
block_size = 128
head_size = 64
num_head = 4
#model = TextClassificationModel(vocab_size, embed_dim, num_class)
# vocab_size, block_size, embed_dim, num_head, num_class
model = TextClassificationModelWithAttention(vocab_size, block_size, embed_dim, num_head, num_class).to(device)
# model = LSTM_Cell(emsize, emsize)

In [19]:
#old size 12567365
print("Model size: ", sum(p.numel() for p in model.parameters()))

Model size:  100437378


In [37]:
def train(dataloader, force_limit_batch=(None, None)):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    limit_batch, count = force_limit_batch
    if limit_batch is not None:
        current = 0

    for idx, (label, text) in enumerate(dataloader):
        optimizer.zero_grad()
        text = text.to(device)
        label = label.to(device)
        predicted_label, loss = model(text, label)
        #predicted_label = torch.mean(predicted_label, dim=1)  # Shape becomes [64, 4]
        predicted_label = predicted_label.argmax(dim=1).to(device)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        #print(predicted_label.shape, label.shape)
        #total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_acc += (predicted_label == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f}".format(
                    epoch, idx, len(dataloader), total_acc / total_count
                )
            )
            total_acc, total_count = 0, 0
            start_time = time.time()

        if limit_batch is not None:
            current += 1
            if current >= count:
                return


def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text) in enumerate(dataloader):
            label = label.to(device)
            text = text.to(device)
            #predicted_label = model(text)
            logits, loss = model(text)
            predicted_label = logits.argmax(dim=1).to(device)
            #predicted_label = torch.mean(logits, dim=1).to(device)
            #predicted_label = torch.mean(predicted_label, dim=1).to(device)
            # loss = criterion(predicted_label, label)
            total_acc += (predicted_label == label).sum().item()
            total_count += label.size(0)
            if idx == 500:
                break
    return total_acc / total_count

In [23]:

# Hyperparameters
EPOCHS = 10  # epoch
LR = 3e-4  # learning rate
BATCH_SIZE = 64  # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.1)
total_accu = None
train_iter, test_iter = AmazonReviewPolarity()
train_dataset = to_map_style_dataset(train_iter)
test_dataset = to_map_style_dataset(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split(
    train_dataset, [num_train, len(train_dataset) - num_train]
)

train_dataloader = DataLoader(
    split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
valid_dataloader = DataLoader(
    split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
test_dataloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)

In [24]:
x, y = next(iter(train_dataloader))

print(x.shape, y.shape)

torch.Size([64]) torch.Size([64, 128])


In [38]:
for epoch in range(EPOCHS):
    epoch_start_time = time.time()
    train(train_dataloader, (True, 1000))
    accu_val = evaluate(valid_dataloader)
    total_accu = accu_val
    print("-" * 59)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(
            epoch, time.time() - epoch_start_time, accu_val
        )
    )
    print("-" * 59)

| epoch   0 |   500/53438 batches | accuracy    0.854
-----------------------------------------------------------
| end of epoch   0 | time: 83.66s | valid accuracy    0.862 
-----------------------------------------------------------
| epoch   1 |   500/53438 batches | accuracy    0.862


KeyboardInterrupt: ignored

In [42]:
# ag_news_label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
label_lookup = {0: "Negative", 1: "Positive"}
def predict(text, text_pipeline, give_all_preds=False):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(ex_text_str)).long().unsqueeze(0).to(device)
        output, _ = model(text)
        if give_all_preds:
          return output
        return output.argmax(1).item() + 1

ex_text_str = "I was not dissatisfied with the result."

model = model.to(device)

torch.set_printoptions(sci_mode=False)
prediction = predict(ex_text_str, text_pipeline, True)
label = prediction.argmax(1).item()
print("This is a {label} review with {confidence:.2f}% confidence.".format(label = label_lookup[label], confidence = prediction[0][label] * 100))
torch.set_printoptions(sci_mode=True)

This is a Negative review with 99.15% confidence.
