In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from collections import Counter
import random
import time
import string
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Set seed for reproducibility
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# Load dataset with proper configuration
try:
    # Load TREC dataset with coarse-grained labels
    trec = load_dataset('trec', 'coarse')
    label_col = 'label-coarse'
except Exception as e:
    print(f"Error loading TREC dataset: {e}")
    print("Falling back to IMDB dataset")
    trec = load_dataset('imdb')
    label_col = 'label'

# Dataset split handling
train_data = trec['train']
test_valid_split = trec['test'].train_test_split(test_size=0.5, seed=SEED)
valid_data = test_valid_split['train']
test_data = test_valid_split['test']

MAX_VOCAB_SIZE = 25_000
PAD_IDX, UNK_IDX = 0, 1
MAX_LENGTH = 50

# Text processing functions
def tokenize(text):
    text = text.lower().translate(str.maketrans('', '', string.punctuation))
    return text.split()

# Vocabulary building
def build_vocab(dataset, max_size):
    counter = Counter()
    for example in dataset:
        tokens = tokenize(example['text'])
        counter.update(tokens)
    vocab = ['<pad>', '<unk>'] + [word for word, _ in counter.most_common(max_size-2)]
    return {word: idx for idx, word in enumerate(vocab)}

word2idx = build_vocab(train_data, MAX_VOCAB_SIZE)

# Label handling
label2idx = {label: idx for idx, label in enumerate(set(train_data[label_col]))}
idx2label = {v: k for k, v in label2idx.items()}

# Dataset class
class TextClassificationDataset(Dataset):
    def __init__(self, data, word2idx, label2idx, label_column):
        self.data = data
        self.word2idx = word2idx
        self.label2idx = label2idx
        self.label_column = label_column

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]['text']
        label = self.label2idx[self.data[idx][self.label_column]]
        
        tokens = tokenize(text)[:MAX_LENGTH]
        indices = [self.word2idx.get(token, UNK_IDX) for token in tokens]
        
        # Padding
        if len(indices) < MAX_LENGTH:
            indices += [PAD_IDX] * (MAX_LENGTH - len(indices))
            
        return torch.tensor(indices, dtype=torch.long), torch.tensor(label, dtype=torch.long)

# DataLoaders
def collate_batch(batch):
    texts, labels = zip(*batch)
    return torch.stack(texts), torch.tensor(labels)

BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = TextClassificationDataset(train_data, word2idx, label2idx, label_col)
valid_dataset = TextClassificationDataset(valid_data, word2idx, label2idx, label_col)
test_dataset = TextClassificationDataset(test_data, word2idx, label2idx, label_col)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_batch)

# CNN Model
class CNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        self.convs = nn.ModuleList([
            nn.Conv2d(1, n_filters, (fs, embedding_dim)) 
            for fs in filter_sizes
        ])
        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):
        embedded = self.embedding(text).unsqueeze(1)
        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
        cat = self.dropout(torch.cat(pooled, dim=1))
        return self.fc(cat)

# Model configuration
INPUT_DIM = len(word2idx)
EMBEDDING_DIM = 100
N_FILTERS = 100
FILTER_SIZES = [2, 3, 4]
OUTPUT_DIM = len(label2idx)
DROPOUT = 0.5
PAD_IDX = word2idx['<pad>']

model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)

# Training setup
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
model = model.to(device)
criterion = criterion.to(device)

# Training functions
def categorical_accuracy(preds, y):
    top_pred = preds.argmax(1, keepdim=True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    return correct.float() / y.shape[0]

def train(model, iterator, optimizer, criterion):
    epoch_loss = 0
    epoch_acc = 0
    model.train()
    for text, labels in iterator:
        optimizer.zero_grad()
        text, labels = text.to(device), labels.to(device)
        predictions = model(text)
        loss = criterion(predictions, labels)
        acc = categorical_accuracy(predictions, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()
    with torch.no_grad():
        for text, labels in iterator:
            text, labels = text.to(device), labels.to(device)
            predictions = model(text)
            loss = criterion(predictions, labels)
            acc = categorical_accuracy(predictions, labels)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

# Training loop
N_EPOCHS = 5
best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_loader, criterion)
    end_time = time.time()
    
    epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'text-classification-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Time: {int(epoch_mins)}m {int(epoch_secs)}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

# Final evaluation
model.load_state_dict(torch.load('text-classification-model.pt'))
test_loss, test_acc = evaluate(model, test_loader, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

# Prediction function
def predict_class(model, sentence, word2idx, idx2label, device):
    model.eval()
    tokens = tokenize(sentence)[:MAX_LENGTH]
    indices = [word2idx.get(token, UNK_IDX) for token in tokens]
    
    if len(indices) < MAX_LENGTH:
        indices += [PAD_IDX] * (MAX_LENGTH - len(indices))
    
    tensor = torch.LongTensor(indices).unsqueeze(0).to(device)
    with torch.no_grad():
        preds = model(tensor)
    return idx2label[preds.argmax(dim=1).item()]

# Example usage
test_sentence = "What is the capital of France?"
predicted_class = predict_class(model, test_sentence, word2idx, idx2label, device)
print(f"Predicted class: {predicted_class}")

Error loading TREC dataset: https://cogcomp.seas.upenn.edu/Data/QA/QC/train_5500.label
Falling back to IMDB dataset
Epoch: 01 | Time: 0m 15s
	Train Loss: 0.695 | Train Acc: 58.50%
	 Val. Loss: 0.612 |  Val. Acc: 64.00%
Epoch: 02 | Time: 0m 16s
	Train Loss: 0.586 | Train Acc: 67.90%
	 Val. Loss: 0.532 |  Val. Acc: 72.98%
Epoch: 03 | Time: 0m 16s
	Train Loss: 0.534 | Train Acc: 72.66%
	 Val. Loss: 0.509 |  Val. Acc: 74.33%
Epoch: 04 | Time: 0m 17s
	Train Loss: 0.476 | Train Acc: 76.97%
	 Val. Loss: 0.495 |  Val. Acc: 75.18%
Epoch: 05 | Time: 0m 18s
	Train Loss: 0.418 | Train Acc: 80.60%
	 Val. Loss: 0.498 |  Val. Acc: 75.28%
Test Loss: 0.498 | Test Acc: 75.12%
Predicted class: 1
