In [None]:
!pip install transformers torch ipywidgets

In [None]:
import os
import torch
import zipfile
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from transformers import AlbertTokenizer, AlbertForSequenceClassification
from transformers import DebertaTokenizer, DebertaForSequenceClassification
import ipywidgets as widgets
from IPython.display import display

# Unzip the model file
with zipfile.ZipFile('albert-base-v2.zip', 'r') as zip_ref:
    zip_ref.extractall('albert-base-v2')

# 加载保存的模型和tokenizer
model_path = 'albert-base-v2'
model = AlbertForSequenceClassification.from_pretrained(model_path)
tokenizer = AlbertTokenizer.from_pretrained(model_path)

# 将模型移动到GPU（如果可用）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def predict(text):
    inputs = tokenizer.encode_plus(
        text,
        None,
        add_special_tokens=True,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_token_type_ids=True,
        return_tensors='pt'
    )
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        pred_label = torch.argmax(logits, dim=1).cpu().numpy()[0]
        pred_prob = probs[0][pred_label]

    return pred_label, pred_prob



In [None]:
text_input = widgets.Textarea(
    value='',
    placeholder='Enter text here...',
    description='Text:',
    disabled=False,
    layout=widgets.Layout(width='50%', height='100px')
)

output_label = widgets.Label(value="")
output_prob = widgets.Label(value="")

def on_button_click(b):
    text = text_input.value
    label, prob = predict(text)
    output_label.value = f'Label: {label}'
    output_prob.value = f'Probability: {prob:.4f}'

button = widgets.Button(
    description='Predict',
    disabled=False,
    button_style='',
    tooltip='Click to predict',
    icon='check'
)

button.on_click(on_button_click)


display(text_input, button, output_label, output_prob)
