In [None]:
import torch
import torch.nn as nn
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)
from datasets import load_dataset

In [2]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model_name = "bert-base-uncased"

### Data Preprocessing

In [4]:
dataset = load_dataset("imdb")

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [6]:
# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding=True, truncation=True)


tokenized_dataset = dataset.map(
    tokenize_function, remove_columns=["text"], batched=True
)

tokenized_dataset = tokenized_dataset.shuffle(seed=42)

### Model

In [7]:
class CustomBERT(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super(CustomBERT, self).__init__()

        # Load pretrained BERT model for sequence classification
        self.bert = AutoModelForSequenceClassification.from_pretrained(model_name)

        # Custom layer before classification head
        self.custom_layer = nn.Sequential(
            nn.Linear(in_features=768, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=1024),
            nn.Dropout(0.5),
        )

        # Custom BERT classifier head
        self.bert.classifier = nn.Linear(in_features=1024, out_features=2)

    def forward(self, input_ids, attention_mask, labels=None):

        # BERT output(ignoring classification head)
        bert_output = self.bert.bert(input_ids=input_ids, attention_mask=attention_mask)

        # Extract last hidden state
        last_hidden_state = bert_output.last_hidden_state

        # Pass the representation of [CLS] token through custom layer
        custom_layer_output = self.custom_layer(last_hidden_state[:, 0, :])

        # Get the logits from the final classifier
        logits = self.bert.classifier(custom_layer_output)

        # Compute loss and backpropagation
        loss = None
        if labels is not None:
            loss_func = nn.CrossEntropyLoss()  # Loss function
            loss = loss_func(logits, labels)

        return {"loss": loss, "logits": logits}

In [None]:
# Initialize model
model = CustomBERT()

### Training

In [9]:
# Pytorch Trainer
training_args = TrainingArguments(
    output_dir="./results",
    save_steps=25,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    learning_rate=5e-5,
    lr_scheduler_type="linear",
    fp16=True,
    gradient_accumulation_steps=2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"].select(range(1000)),
    eval_dataset=tokenized_dataset["test"].select(range(1000)),
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)

In [None]:
# Training
# trainer.train()
trainer.train(resume_from_checkpoint=True)
# "resume_from_checkpoint=True" only works when output_dir already has saved chechpoint. Else throws error.

### Prediction

In [11]:
def predict(text):
    model.eval()  # Set the model to evaluation mode

    # Tokenize input text and move it to the device (GPU/CPU)
    input = tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(
        device
    )

    with torch.no_grad():

        # Forward pass
        output = model(
            input_ids=input["input_ids"],
            attention_mask=input["attention_mask"],
        )

    # Extract logits from output
    logits = output["logits"]

    # Apply sigmoid to get probabilities and use argmax to get the predicted class
    probabilities = torch.sigmoid(logits)
    predicted_class = torch.argmax(probabilities, dim=-1)

    # Assign sentiment based on the predicted class
    sentiment = "positive" if predicted_class == 1 else "negative"

    return sentiment

In [None]:
# Prediction
text = "The movie was good."
sentiment = predict(text)
print(f"The sentiment of the text is: {sentiment}")