In [1]:
# Initialize roberta-base
from openprompt.plms import get_model_class
model_class = get_model_class(plm_type = "roberta-base")
model_path = 'roberta-base'
config = model_class.config.from_pretrained(model_path)
tokenizer = model_class.tokenizer.from_pretrained(model_path)
model = model_class.model.from_pretrained(model_path)

In [2]:
# Initialize prompt template
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
    text = ["<text_a>", "It", "was", "<mask>"],
    tokenizer = tokenizer,
)

In [3]:
# Initialize verbalizer
from openprompt.prompts import ManualVerbalizer
classes = ["negative", "positive"]

promptVerbalizer = ManualVerbalizer(
    classes = classes,
    label_words = {
        "negative": ["bad"],
        "positive": ["great"],
    },
    tokenizer = tokenizer,
)

In [4]:
# Create prompt model
from openprompt import PromptForClassification
promptModel = PromptForClassification(
    template = promptTemplate,
    model = model,
    verbalizer = promptVerbalizer,
)

In [5]:
# Data Processor
import os
import pandas as pd

from openprompt.data_utils import InputExample
from openprompt.data_utils.data_processor import DataProcessor

class SST2Processor(DataProcessor):
    
    def __init__(self):
        super().__init__()
        self.labels = ["negative", "positive"]

    def get_examples(self, data_dir, split):
        examples = []
        path = os.path.join(data_dir, "{}.tsv".format(split))
        df = pd.read_csv(path, sep='\\t', header = 0)
        sentences = df['sentence']
        labels = df['label']
        for idx in range(len(sentences)):
            sentence, label = sentences[idx], labels[idx]
            example = InputExample(
                guid = idx, text_a = sentence, label = int(label))
            examples.append(example)
                
        return examples

In [6]:
# Load dataset
path_to_train = './custom_data'
dataset = SST2Processor().get_examples(path_to_train, 'train')

  return func(*args, **kwargs)


In [7]:
# Create Data Loader
from openprompt import PromptDataLoader
data_loader = PromptDataLoader(
    dataset = dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
)

tokenizing: 10000it [00:06, 1453.09it/s]


In [8]:
# Accuracy of train set
import torch
correctly_classified = 0
samples = len(dataset)

promptModel.eval()
with torch.no_grad():
    for data in data_loader:
        logits = promptModel(data)
        pred = torch.argmax(logits, dim = -1)
        if pred == data["label"]:
            correctly_classified += 1

print(f'Accuracy : {correctly_classified / samples}')

Accuracy : 0.7368
