In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

from datasets import load_dataset, Dataset, DatasetDict

from sklearn.metrics import classification_report, f1_score, precision_recall_fscore_support, accuracy_score
import numpy as np
import random

from transformers import AutoTokenizer
import torch

from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding
from gliclass.training import TrainingArguments, Trainer

device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')

In [None]:
def get_gliclass_predictions(pipeline, test_texts, classes, batch_size=8):
    results = pipeline(test_texts, classes, batch_size=batch_size)#, labels_chunk_size=1)
    predicts = [result[0]['label'] for result in results]
    return predicts

def evaluate(predicts, true_labels):
    micro = f1_score(true_labels, predicts, average="micro")
    macro = f1_score(true_labels, predicts, average="macro")
    weighted = f1_score(true_labels, predicts, average="weighted")
    return {"micro": micro, "macro": macro, "weighted": weighted}

def get_train_dataset(dataset, N, label_column='label'):
    ids = []
    label2count = {}
    train_dataset = dataset.shuffle(seed=41)
    for id, example in enumerate(train_dataset):
        if example[label_column] not in label2count:
            label2count[example[label_column]]=1
        elif label2count[example[label_column]]>=N:
            continue
        else:
            label2count[example[label_column]]+=1
        ids.append(id)
    return train_dataset.select(ids)

def prepare_dataset(dataset, classes = None, text_column = 'text', label_column = "label", split=None):
    if 'test' in dataset:
        test_dataset = dataset['test']
    elif isinstance(dataset, Dataset):
        test_dataset = dataset
    else:
        test_dataset = dataset['train']
    
    if classes is None:
        classes = test_dataset.features[label_column].names
        if split is not None:
            classes = [' '.join(class_.split(split)) for class_ in classes]

    texts = test_dataset[text_column]

    true_labels = test_dataset[label_column]

    print(classes)
    if type(test_dataset[label_column][0]) == int:
        true_labels = [classes[label] for label in true_labels]

    return texts, classes, true_labels


def prepare_dataset_for_training(train_dataset, classes, text_column='text', label_column='label'):
    id2class = {id: class_ for id, class_ in enumerate(classes)}
    dataset = []
    for example in train_dataset:
        label = example[label_column]
        if type(label)==int:
            label = id2class[label]
        item = {'text': example[text_column], 'all_labels': classes, 'true_labels': [label]}
        dataset.append(item)
    random.shuffle(dataset)
    return dataset


In [None]:
emotions = load_dataset('dair-ai/emotion')

train_data = get_train_dataset(emotions['train'], N=64)

test_texts, classes, true_labels = prepare_dataset(emotions)

train_data = prepare_dataset_for_training(train_data, classes)


In [None]:
ag_news = load_dataset('ag_news')

train_data = get_train_dataset(ag_news['train'], N=64)

test_texts, classes, true_labels = prepare_dataset(ag_news)

train_data = prepare_dataset_for_training(train_data, classes)


In [None]:
sst5 = load_dataset('SetFit/sst5')

train_data = get_train_dataset(sst5['train'], N=64)

classes = ['very negative', 'negative', 'neutral', 'positive', 'very positive']

test_texts, classes, true_labels = prepare_dataset(sst5, classes=classes)

train_data = prepare_dataset_for_training(train_data, classes)


In [None]:
banking = load_dataset('PolyAI/banking77')

train_data = get_train_dataset(banking['train'], N=32)

test_texts, classes, true_labels = prepare_dataset(banking)

train_data = prepare_dataset_for_training(train_data, classes)


In [None]:
massive = load_dataset("AmazonScience/massive", "en-US")

train_data = get_train_dataset(massive['train'], N=32, label_column='intent')

test_texts, classes, true_labels = prepare_dataset(massive, text_column='utt', label_column='intent')

train_data = prepare_dataset_for_training(train_data, classes,  text_column='utt', label_column='intent')

In [None]:
model_name = 'knowledgator/gliclass-base-v1.0'

model = GLiClassModel.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
max_length = 1024
problem_type = "multi_label_classification"
architecture_type = model.config.architecture_type
prompt_first = model.config.prompt_first

train_dataset = GLiClassDataset(train_data, tokenizer, max_length, problem_type, architecture_type, prompt_first)
test_dataset = GLiClassDataset(train_data[:int(len(train_data)*0.1)], tokenizer, max_length, problem_type, architecture_type, prompt_first)

data_collator = DataCollatorWithPadding(device=device)

training_args = TrainingArguments(
    output_dir='models/test',
    learning_rate=1e-5,
    weight_decay=0.01,
    others_lr=1e-5,
    others_weight_decay=0.01,
    lr_scheduler_type='linear',
    warmup_ratio=0.0,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=8,
    evaluation_strategy="epoch",
    save_steps = 1000,
    save_total_limit=10,
    dataloader_num_workers=8,
    logging_steps=10,
    use_cpu = False,
    report_to="none",
    fp16=False,
    )

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)
trainer.train()

In [None]:
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='single-label', device='cuda:0')

predicts = get_gliclass_predictions(pipeline, test_texts, classes, batch_size=8)

results = evaluate(predicts, true_labels)
print(results)