In [1]:
import torch
import torchtext

torchtext.disable_torchtext_deprecation_warning()

from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from torchtext.datasets import AG_NEWS
from transformers import (
    AdamW,
    XLNetForSequenceClassification,
    XLNetTokenizer,
    get_linear_schedule_with_warmup,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Load AG News Dataset
train_datapip = AG_NEWS(split="train")  # type: ignore
test_datapip = AG_NEWS(split="test")  # type: ignore

# Define tokenizer and model
tokenizer: XLNetTokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
model = XLNetForSequenceClassification.from_pretrained(
    "xlnet-base-cased", num_labels=4
).to(DEVICE)
tokenizer.__call__


# Preprocessing and Tokenization function
def preprocess(batch):
    labels, texts = zip(*batch)
    inputs = tokenizer(
        list(texts), padding=True, truncation=True, return_tensors="pt", max_length=512
    )
    labels = torch.tensor(labels) - 1  # Label 0-indexed for PyTorch
    return inputs, labels


  from .autonotebook import tqdm as notebook_tqdm
################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################

Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.bias', 'logits_proj.weight', 'sequence_summary.summary.bias', 'sequence_summary.summary.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
next(iter(train_datapip))

In [13]:
# DataLoader creation
batch_size = 32

train_loader = DataLoader(
    train_datapip, shuffle=True, batch_size=batch_size, collate_fn=preprocess
)
test_loader = DataLoader(test_datapip, batch_size=batch_size, collate_fn=preprocess)

# Define Optimizer and Scheduler
optimizer = AdamW(model.parameters(), lr=2e-5)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=len(list(test_datapip))
)


# Training function
def train(model, loader, optimizer, scheduler):
    model.train()
    total_loss = 0
    for step, (inputs, labels) in enumerate(loader):
        optimizer.zero_grad()
        # outputs = model(**inputs, labels=labels)
        outputs = model(
            input_ids=inputs["input_ids"].to(DEVICE),
            attention_mask=inputs["attention_mask"].to(DEVICE),
            token_type_ids=inputs["token_type_ids"].to(DEVICE),
            labels=labels.to(DEVICE),
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
        if step % 10 == 9:
            print(f"Step {step}, Loss: {total_loss / (step + 1):.4f}")
    return total_loss


# Evaluation function
def evaluate(model, loader):
    model.eval()
    preds, true_labels = [], []
    with torch.no_grad():
        for step, (inputs, labels) in enumerate(loader):
            outputs = model(
                input_ids=inputs["input_ids"].to(DEVICE),
                attention_mask=inputs["attention_mask"].to(DEVICE),
                token_type_ids=inputs["token_type_ids"].to(DEVICE),
            )
            preds.extend(torch.argmax(outputs.logits, axis=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    return accuracy_score(true_labels, preds)


# Training loop
num_epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, scheduler)
    test_accuracy = evaluate(model, test_loader)
    print(
        f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Test Accuracy: {test_accuracy:.4f}"
    )

print("Training complete.")


Step 9, Loss: 0.0870
Step 19, Loss: 0.0803
Step 29, Loss: 0.0764
Step 39, Loss: 0.0662
Step 49, Loss: 0.0711
Step 59, Loss: 0.0685
Step 69, Loss: 0.0718
Step 79, Loss: 0.0686
Step 89, Loss: 0.0667
Step 99, Loss: 0.0659
Step 109, Loss: 0.0703
Step 119, Loss: 0.0697
Step 129, Loss: 0.0695
Step 139, Loss: 0.0736
Step 149, Loss: 0.0730
Step 159, Loss: 0.0751
Step 169, Loss: 0.0758
Step 179, Loss: 0.0782
Step 189, Loss: 0.0782
Step 199, Loss: 0.0805
Step 209, Loss: 0.0809
Step 219, Loss: 0.0812
Step 229, Loss: 0.0838
Step 239, Loss: 0.0847
Step 249, Loss: 0.0852
Step 259, Loss: 0.0859
Step 269, Loss: 0.0858
Step 279, Loss: 0.0883
Step 289, Loss: 0.0877
Step 299, Loss: 0.0872
Step 309, Loss: 0.0891
Step 319, Loss: 0.0893
Step 329, Loss: 0.0887
Step 339, Loss: 0.0898
Step 349, Loss: 0.0897
Step 359, Loss: 0.0889
Step 369, Loss: 0.0901
Step 379, Loss: 0.0905
Step 389, Loss: 0.0923
Step 399, Loss: 0.0933
Step 409, Loss: 0.0938
Step 419, Loss: 0.0932
Step 429, Loss: 0.0944
Step 439, Loss: 0.0950

In [15]:
next(iter(test_datapip))

(3,
 "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.")

In [18]:
# test the first 10 samples in the test dataset
for i, (label, text) in enumerate(test_datapip):
    if i == 10:
        break
    inputs = tokenizer(text, return_tensors="pt", max_length=512).to(DEVICE)
    outputs = model(**inputs)
    pred = torch.argmax(outputs.logits).item()
    print(f"True Label: {label-1}, Predicted Label: {pred}, Text: {text}")

True Label: 2, Predicted Label: 2, Text: Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.
True Label: 3, Predicted Label: 3, Text: The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket.
True Label: 3, Predicted Label: 3, Text: Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins.
True Label: 3, Predicted Label: 3, Text: Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick s

In [30]:
label, data = zip(*list(train_datapip))

In [31]:
label = list(label)
label.count(1), label.count(2), label.count(3), label.count(4)

(30000, 30000, 30000, 30000)

In [33]:
print(tokenizer)

XLNetTokenizer(name_or_path='xlnet-base-cased', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '<sep>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>', 'additional_special_tokens': ['<eop>', '<eod>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<cls>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("<sep>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	5: AddedToken("<pad>", rs