In [39]:
import numpy as np

import torch
import torch.nn as nn

from transformers import AutoTokenizer, AutoModelForSequenceClassification, T5EncoderModel, T5Tokenizer, T5ForConditionalGeneration, DebertaTokenizer, Trainer, TrainingArguments

import copy

from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map

from typing import Any, Dict, List, Optional, Union, Callable
from torch import Tensor
from itertools import count 
import torch.nn.functional as F

from transformers import T5Tokenizer

from torch.utils.data import DataLoader, Dataset

import itertools

from sklearn.metrics import (
    precision_score,
    recall_score,
    make_scorer,
    f1_score,
    confusion_matrix,
    accuracy_score
)

from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

import time
import pandas as pd

import re

from sklearn.model_selection import train_test_split

Load Model

In [2]:
# Loads best model from our training results
model = AutoModelForSequenceClassification.from_pretrained('best_model/checkpoint-16000')
tokenizer = AutoTokenizer.from_pretrained("debertav3-base_tokenizer", model_max_length=512)

Load and Process Dataset

In [26]:
df = pd.read_csv("satd-comments-manual-subclass.xlsx")
df.head()

Unnamed: 0,comment
0,# assume curl will handle
1,# Is this cacheable?
2,"# What impact should any(c(""public"", ""private""..."
3,# Requires validation
4,# TODO might need to put some params before an...


Binary

In [31]:
X = df['comment'].copy()

# Remove unhelpful punctuation
X = X.apply(lambda x: re.sub('[()#{}<>//=.*:-]', ' ', x).replace('[', ' ').replace(']', ' '))
X = X.apply(lambda x: re.sub('\n', '', x).replace('[', ' ').replace(']', ' '))
# Fix double spaces
X = X.apply(lambda x: x.replace('  ', ' '))

In [32]:
X.head()

0                              assume curl will handle
1                                   Is this cacheable?
2     What impact should any c "public", "private" ...
3                                  Requires validation
4     TODO might need to put some params before and...
Name: comment, dtype: object

In [33]:
y = np.array([1] * len(X))

In [34]:
test_encodings = tokenizer(X.values.tolist(), truncation=True, padding=True, max_length=512)

class MakeTorchData(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

# convert our tokenized data into a torch Dataset
test_dataset = MakeTorchData(test_encodings, y.ravel())

In [35]:
# Use GPU and set seed to ensure reproducability of results. Edit: T5-11b is too big to fit in my GPU :(
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
torch.manual_seed(42)

<torch._C.Generator at 0x2a62f56c210>

In [41]:
def compute_metrics(eval_pred):

  predictions, labels = eval_pred
  
  predictions = np.argmax(predictions, axis=1)

  score_f1 = f1_score(labels, predictions, pos_label=1, average="binary")
  score_acc = accuracy_score(labels, predictions)
  score_pre = precision_score(labels, predictions, pos_label=1 , average="binary", sample_weight=None)
  score_rec = recall_score(labels, predictions, pos_label=1 , average="binary", sample_weight=None)

  return {"accuracy": float(score_acc), "precision": float(score_pre), "recall": float(score_rec), "f1": float(score_f1)}

In [40]:
training_args = TrainingArguments(
    output_dir='./results2',          # output directory
    num_train_epochs=10,     # total number of training epochs
    per_device_train_batch_size=4,   # batch size per device during training
    per_device_eval_batch_size=4,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    load_best_model_at_end=True,     # load the best model when finished training (default metric is loss)    # select the base metrics
    logging_steps=1000,               # log & save weights each logging_steps
    save_steps=1000,
    evaluation_strategy="steps",     # evaluate each `logging_steps`
) 

In [42]:
trainer = Trainer(
    model=model,                         # the instantiated Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=test_dataset,         # training dataset
    eval_dataset=test_dataset,          # evaluation dataset
    compute_metrics=compute_metrics,     # the callback that computes metrics of interest
    #optimizers= (torch.optim.AdamW(model.parameters(), lr=1e-3), None)
)

In [43]:
y_pred = trainer.predict(test_dataset)

***** Running Prediction *****
  Num examples = 4961
  Batch size = 4
  attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
  score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
  score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)


  0%|          | 0/1241 [00:00<?, ?it/s]

In [44]:
y_pred2 = list([np.argmax(x) for x in y_pred[0]])

In [45]:
from sklearn.metrics import classification_report
print(classification_report(y, y_pred2, digits=6))

              precision    recall  f1-score   support

           0   0.000000  0.000000  0.000000         0
           1   1.000000  0.509978  0.675477      4961

    accuracy                       0.509978      4961
   macro avg   0.500000  0.254989  0.337739      4961
weighted avg   1.000000  0.509978  0.675477      4961



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
