In [2]:
import numpy as np
import torch
from transformers import BertTokenizer, BertForSequenceClassification

In [3]:
class FrankenBert:
    """
    Implements BertForSequenceClassification and BertTokenizer
    for binary classification from a saved model
    """

    def __init__(self, path: str):
        """
        If there's a GPU available, tell PyTorch to use the GPU.
        Loads model and tokenizer from saved model directory (path)
        """
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        self.model = BertForSequenceClassification.from_pretrained(path)
        self.tokenizer = BertTokenizer.from_pretrained(path)
        self.model.to(self.device)

    def predict(self, text: str):
        """
        Makes a binary classification prediction based on saved model
        """
        inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=280,
            return_tensors='pt',
        ).to(self.device)
        output = self.model(**inputs)
        prediction = output[0].softmax(1)
        tensors = prediction.detach().cpu().numpy()
        result = np.argmax(tensors)
        confidence = tensors[0][result]
        return f"Rank: {result}, {100 * confidence:.2f}%"

In [5]:
FrankenBert('saved_model')

<__main__.FrankenBert at 0x1ffa9d15550>

In [7]:
model = FrankenBert('saved_model')

In [8]:
model.predict("Mickey Mouse is in the house")

'Rank: 0, 98.96%'

In [9]:
isText = model.predict("Mickey Mouse is in the house")

In [11]:
type(isText)

str