###**Encoder-only Transformer vs. LSTM + Attention for Text Classification**


###**AG News Dataset**
The AG News dataset is a collection of over one million news articles categorized into four classes:

World

Sports

Business

Science/Technology

For this implementation, we'll use the version provided by the torchtext library, which includes:

Training set: 120,000 samples (30,000 per class)

Test set: 7,600 samples (1,900 per class)

In [8]:
!pip install numpy==1.24.4 --force-reinstall

Collecting numpy==1.24.4
  Downloading numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Downloading numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m72.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.2.4
    Uninstalling numpy-2.2.4:
      Successfully uninstalled numpy-2.2.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.24.4 which is incompatible.
jax 0.5.2 requires numpy>=1.25, but you have numpy 1.24.4 which is incompatible.
blosc2 3.3.0 requires numpy>=1.26, but you have numpy 1.24.4 which is incompatible.
jaxlib 0.5.1 requires numpy>=1.

In [13]:
!pip install torch==2.0.1+cu118 torchtext==0.15.2 -f https://download.pytorch.org/whl/torch_stable.html

Looking in links: https://download.pytorch.org/whl/torch_stable.html


In [14]:
import torch
import torchtext

print("Torch Version:", torch.__version__)
print("TorchText Version:", torchtext.__version__)

Torch Version: 2.0.1+cu118
TorchText Version: 0.15.2+cpu


In [15]:
!pip uninstall portalocker -y
!pip install portalocker==2.7.0

Found existing installation: portalocker 2.7.0
Uninstalling portalocker-2.7.0:
  Successfully uninstalled portalocker-2.7.0
Collecting portalocker==2.7.0
  Using cached portalocker-2.7.0-py2.py3-none-any.whl.metadata (6.8 kB)
Using cached portalocker-2.7.0-py2.py3-none-any.whl (15 kB)
Installing collected packages: portalocker
Successfully installed portalocker-2.7.0


####**Data Preparation**

In [1]:
import torch
import portalocker
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [2]:
# Load the dataset
train_iter, test_iter = AG_NEWS()

# Tokenizer
tokenizer = get_tokenizer('basic_english')

# Build vocabulary
def yield_tokens(data_iter):
    for label, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

# Define text and label pipelines
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

####**DataLoader Preparation**

In [3]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

def collate_batch(batch):
    label_list, text_list = [], []
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
    text_list = pad_sequence(text_list, batch_first=True)
    label_list = torch.tensor(label_list, dtype=torch.int64)
    return text_list, label_list

# Create DataLoaders
from torch.utils.data import IterableDataset

class AGNewsDataset(IterableDataset):
    def __init__(self, data_iter):
        self.data_iter = list(data_iter)

    def __iter__(self):
        return iter(self.data_iter)

train_dataset = AGNewsDataset(AG_NEWS(split='train'))
test_dataset = AGNewsDataset(AG_NEWS(split='test'))

batch_size = 64
# Remove shuffle=True for train_dataloader as it's not supported with IterableDataset
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_batch) # Change here
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

###**Model Implementations**

####**1. Encoder-only Transformer**

In [4]:
import torch.nn as nn

class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, num_classes, max_len=512):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_embedding[:, :seq_len, :]
        x = x.permute(1, 0, 2)  # Transformer expects input of shape (seq_len, batch_size, embed_dim)
        x = self.transformer_encoder(x)
        x = x.mean(dim=0)  # Global average pooling
        return self.fc(x)

####**2. LSTM with Attention**

In [5]:
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        attn_weights = torch.tanh(self.attn(x))
        attn_weights = torch.softmax(attn_weights, dim=1)
        context = torch.sum(attn_weights * x, dim=1)
        return context

class LSTMWithAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
        super(LSTMWithAttention, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.attention = Attention(hidden_dim * 2)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        lstm_out, _ = self.lstm(x)
        attn_out = self.attention(lstm_out)
        return self.fc(attn_out)

####**Training and Evaluation**

In [6]:
import torch.optim as optim

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_acc, total_count = 0, 0
    for text, labels in dataloader:
        text, labels = text.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(text)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        total_acc += (output.argmax(1) == labels).sum().item()
        total_count += labels.size(0)
    return total_acc / total_count

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_acc, total_count = 0, 0
    with torch.no_grad():
        for text, labels in dataloader:
            text, labels = text.to(device), labels.to(device)
            output = model(text)
            loss = criterion(output, labels)
            total_acc += (output.argmax(1) == labels).sum().item()
            total_count += labels.size(0)
    return total_acc / total_count

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
vocab_size = len(vocab)
embed_dim = 128
hidden_dim = 256
num_heads = 4
num_layers = 2
num_classes = 4
epochs = 5
learning_rate = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
vocab_size = len(vocab)
embed_dim = 128
hidden_dim = 256
num_heads = 4
num_layers = 2
num_classes = 4
epochs = 5
learning_rate = 0.001

# Initialize models
transformer_model = TransformerClassifier(vocab_size, embed_dim, num_heads, hidden_dim, num_layers, num_classes).to(device)
lstm_model = LSTMWithAttention(vocab_size, embed_dim, hidden_dim, num_classes).to(device)

In [8]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
transformer_optimizer = optim.Adam(transformer_model.parameters(), lr=learning_rate)
lstm_optimizer = optim.Adam(lstm_model.parameters(), lr=learning_rate)

In [10]:
print("Training Transformer Model")
train(transformer_model, train_dataloader, criterion, transformer_optimizer, device)

print("\nTraining LSTM with Attention Model")
train(lstm_model, train_dataloader, criterion, lstm_optimizer, device)
evaluate(lstm_model, test_dataloader, criterion, device)

Training Transformer Model

Training LSTM with Attention Model


0.8967105263157895

###**Complexity Comparison**

**Transformer (Encoder Only):**  

---
1: Time Complexity	 O(n² * d) due to self-attention

2: Space Complexity	O(n²) for attention matrix

3: Parallelization	High (parallelizable across sequence positions)

4: Scalability	Better for long sequences

**LSTM + Attention**

1: O(n * d²) where d is hidden size

2: O(n * d) for hidden states

3: Low (sequential computation)

4: Slower with increasing sequence length




---


###**Language Understanding Capabilities**

**Transformer:**

1: Captures global dependencies regardless of token distance.

2: Better at understanding contextual relationships.

3: More parallelizable and scalable.

**LSTM + Attention:**

1: Captures temporal patterns well.

2: Attention helps with long-term dependencies but still limited by sequential LSTM structure.

3: Struggles with longer contexts and vanishing gradients.



---

**Note:**

Use Transformer when you have enough data and compute power — great for long sequences and global context.

Use LSTM + Attention when working with shorter texts or on resource-constrained devices.