In [None]:
!pip install transformers[torch] datasets evaluate sacrebleu accelerate

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd
import numpy as np
from datasets import Dataset
from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
import torch
import accelerate

In [None]:
df = pd.read_csv('labeled_data.csv')

In [None]:
df = df.drop(columns=['Aspect'])

In [None]:
dataset = Dataset.from_pandas(df)

In [None]:
dataset = dataset.train_test_split(test_size=0.2, shuffle=True)

In [None]:
dataset

In [None]:
checkpoint = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
prefix = "find span in the sentence: "
def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["review"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
    labels = tokenizer(text_target=examples["span"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [None]:
import evaluate

metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    # print(preds[200],labels[200])
    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # print(preds[200],labels[200])
    def compute_jaccard(pred, label):
      pred_set = set(pred.split())
      label_set = set(label.split())
      intersection = len(pred_set.intersection(label_set))
      union = len(pred_set.union(label_set))
      if union == 0:
        return 0  # Avoid division by zero
      return intersection / union
    def compute_confidence(pred,label):
      correct_predictions = pred == label

      # Ignore padding tokens (typically 0) for accuracy calculation
      valid_tokens = label != 0
      correct_predictions = correct_predictions & valid_tokens

      # Compute confidence score as the ratio of correct predictions to valid tokens
      confidence_score = np.sum(correct_predictions) / np.sum(valid_tokens)

      return confidence_score

    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    # probabilities = softmax(decoded_preds)
    confidence_scores = [compute_confidence(pred,label) for pred,label in zip(preds,labels)]# np.max(probabilities, axis=-1)
    jaccard_scores = [compute_jaccard(pred, label[0]) for pred, label in zip(decoded_preds, decoded_labels)]
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}
    result["jaccard"] = np.mean(jaccard_scores)
    result["confidence_score"] = np.mean(confidence_scores)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)

    result = {k: round(v, 4) for k, v in result.items()}

    return result

In [None]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="get_key_intent",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.02,
    save_total_limit=3,
    num_train_epochs=15,
    predict_with_generate=True,
    fp16=True,
    warmup_steps=200
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
from transformers import pipeline

translator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)

In [None]:
text = "As others have mentioned, this keyboard is excellent in all ways but one - the spacebar"
translator(text)

In [None]:
text = "Unfortunately, I am sending them back for 2 different reasons: (1) When wearing the earbuds the  cord/wires somehow \"create\" noise  that interferes with your music listening."
translator(text)

In [None]:
text = "I can easily see myself ripping the cord accidentally in the future, thus creating a bigger gap and a huge mess along with that."
translator(text)

In [None]:
text = "Many KVM vendors will tell you that if you can not control your computer past POST and before boot, it is because the keyboard has too many keys and it is affecting the KVM"
translator(text)

In [None]:
text = "The straight audio jack awkwardly sticks out from whatever device you use, making it more likely that it could get damaged by any little thing that may press up against the audio jack and cable."
translator(text)

In [None]:
text = "The volume controls and extra function keys are nice, but I wish it wasn't so big and bulky"
translator(text)

In [None]:
text = "The volume controls and extra function keys are nice, but I wish it wasn't so big and bulky"
translator(text)

In [None]:
text = "The spacebar is awful "
translator(text)

In [None]:
text = "This keyboard has a lot of great features (scroll control, etc) but the spacebar makes this keyboard unusable "
translator(text)

In [None]:
text = "The first two were refurbished, and the screen would black out in 5-10 seconds after startup "
translator(text)

In [None]:
test_df = pd.read_csv('unlabeled_data.csv')
preds = []
for review in test_df['review']:
    preds.append(translator(review)[0]['generated_text'])
test_df['preds'] = preds

In [None]:
test_df.to_csv('/content/drive/MyDrive/output_BART_BASE.csv')

In [None]:
test_df