## 1. Import Modules and Data

In [None]:
from data import load_data

tokenizer, *_ = load_data("sst2")

## 2. Load Trained Model

Load our fine-tuned checkpoint.

In [None]:
import torch

import config
from modules.bert import BertForSequenceClassification

device = torch.device("cuda")
bert_clf = BertForSequenceClassification.from_pretrained(config.pretrained_path)

bert_clf.load_state_dict(
    torch.load(config.checkpoint_dir / "bert_clf_3.pth")["model"]
)
bert_clf = bert_clf.to(device).eval()

## 3. Inference
Here we consider text classification task with simple example as below, where '0' represents negative and '1' represents positive.

In [14]:
@torch.no_grad()
def text_classification(text):
    if isinstance(text, str):
        text = [text]
    inputs = tokenizer(
        text,
        padding="longest",
        truncation=True,
        max_length=config.max_len,
        return_tensors="pt",
    ).to(device)

    logits = bert_clf(
        input_ids=inputs.input_ids, attention_mask=inputs.attention_mask.bool()
    )
    predicted_class = torch.argmax(logits, dim=1)
    predicted_class = predicted_class.cpu().numpy()
    return predicted_class


text = "I was beaten by you!"
predicted_class = text_classification(text)
print(f"Predicted class index: {predicted_class.item()}")

Predicted class index: 0


In [15]:
text = "I love the LLM world!"
predicted_class = text_classification(text)
print(f"Predicted class index: {predicted_class.item()}")

Predicted class index: 1
