In [None]:
# Initialize roberta-base
from openprompt.plms import get_model_class
model_class = get_model_class(plm_type = "roberta-base")
model_path = 'roberta-base'
config = model_class.config.from_pretrained(model_path)
tokenizer = model_class.tokenizer.from_pretrained(model_path)
model = model_class.model.from_pretrained(model_path)

In [None]:
# Initialize prompt template
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
    text = ["<text_a>", "It", "was", "<mask>"],
    tokenizer = tokenizer,
)

In [None]:
# Initialize verbalizer
from openprompt.prompts import ManualVerbalizer
classes = ["negative", "positive"]

promptVerbalizer = ManualVerbalizer(
    classes = classes,
    label_words = {
        "negative": ["bad"],
        "positive": ["great"],
    },
    tokenizer = tokenizer,
)

In [None]:
# Create prompt model
from openprompt import PromptForClassification
promptModel = PromptForClassification(
    template = promptTemplate,
    model = model,
    verbalizer = promptVerbalizer,
)

In [None]:
# Data Processor
import os
import pandas as pd

from openprompt.data_utils import InputExample
from openprompt.data_utils.data_processor import DataProcessor

class SST2Processor(DataProcessor):
    
    def __init__(self):
        super().__init__()
        self.labels = ["negative", "positive"]

    def get_examples(self, data_dir, split):
        examples = []
        path = os.path.join(data_dir, "{}.tsv".format(split))
        df = pd.read_csv(path, sep='\\t', header = 0)
        sentences = df['sentence']
        labels = df['label']
        for idx in range(len(sentences)):
            sentence, label = sentences[idx], labels[idx]
            example = InputExample(
                guid = idx, text_a = sentence, label = int(label))
            examples.append(example)
                
        return examples

In [None]:
# Load dataset
import random
path_to_train = './custom_data'
train_dataset = SST2Processor().get_examples(path_to_train, 'train')
random.Random(0).shuffle(train_dataset)
val_dataset = train_dataset[:1000]
train_dataset = train_dataset[1000:]
test_dataset = SST2Processor().get_examples(path_to_train, 'dev')

In [None]:
# Create Data Loader
from openprompt import PromptDataLoader
train_loader = PromptDataLoader(
    dataset = train_dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
)
val_loader = PromptDataLoader(
    dataset = val_dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
)
test_loader = PromptDataLoader(
    dataset = test_dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
)

In [None]:
trainer.run()

In [None]:
# After training
trainer.evaluate(test_loader, 'test')