In [20]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
from sklearn.metrics import classification_report
from datasets import load_dataset, Dataset
from torch.optim import AdamW
from tqdm import tqdm
import plotly.express as px
import torch.utils.data as data
import pandas as pd
import numpy as np
import torch

In [21]:
ag_news = pd.read_csv("/home/kyle/repos/Parameter-Free-LM-Editing/datasets/ag_news_twitter/shifted_test_set_gpt3.csv")
ag_news.rename(columns={"tweet summary": "text"}, inplace=True)
display(ag_news.head())
ag_news = Dataset.from_pandas(ag_news)[:1000]

Unnamed: 0,article summary,label,text,prompt
0,Fears for T N pension after talks Unions repre...,2,Unions express disappointment in talks with Fe...,V1
1,The Race is On: Second Private Team Sets Launc...,3,🚀👨‍🚀 #SpaceRace update: Second private team to...,V1
2,Ky. Company Wins Grant to Study Peptides (AP) ...,3,"""Chemistry researcher's Ky. startup wins grant...",V1
3,Prediction Unit Helps Forecast Wildfires (AP) ...,3,"🔥💨🌲Prediction Unit helps forecast wildfires, s...",V1
4,Calif. Aims to Limit Farm-Related Smog (AP) AP...,3,"""California takes on smog w emissions rules fo...",V1


In [22]:
class GenericDataset(data.Dataset):
    def __init__(self, in_dataset):
        self.dataset = in_dataset

    def __getitem__(self, index):
        return self.dataset["text"][index], self.dataset["label"][index]

    def __len__(self):
        return len(self.dataset["text"])

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
task_tokenizer = AutoTokenizer.from_pretrained("nateraw/bert-base-uncased-ag-news")
task_model = AutoModelForSequenceClassification.from_pretrained("nateraw/bert-base-uncased-ag-news").to(device)
optimizer = AdamW(task_model.parameters(), lr=2e-5)
criterion = torch.nn.CrossEntropyLoss()

In [24]:
formatted_dataset = GenericDataset(ag_news)
data_loader = data.DataLoader(formatted_dataset, batch_size=16)
formatted_dataset[0]

('Unions express disappointment in talks with Federal Mogul about Turner Newall pension fears. #TNPension #FederalMogul #UnionNegotiations #Disappointment',
 2)

In [32]:
batch_losses = []
for epoch in range(1):
    for batch in tqdm(data_loader):
        task_model.train()
        optimizer.zero_grad()
        tokenized_batch = task_tokenizer(batch[0], padding=True, truncation=True, return_tensors="pt").to(device)
        labels = batch[1].to(device)
        
        logits = task_model(**tokenized_batch).logits
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        batch_losses.append(loss.detach().item())    
    
    predicitons = []
    labels = []
    with torch.no_grad():
        task_model.eval()
        predicitons = []
        for index in tqdm(range(len(formatted_dataset))):
            eval_text = formatted_dataset[index][0]
            eval_label = formatted_dataset[index][1]
            
            tokenized_sample = task_tokenizer(eval_text, return_tensors="pt").to(device)
            logits = task_model(**tokenized_sample).logits
            eval_prediciton = torch.argmax(logits, dim=1).cpu().numpy()
            
            predicitons.append(eval_prediciton)
            labels.append(eval_label)
            
    print(classification_report(labels, predicitons))

100%|██████████| 1000/1000 [00:13<00:00, 72.93it/s]
 12%|█▎        | 1/8 [00:14<01:40, 14.29s/it]

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       268
           1       1.00      1.00      1.00       274
           2       1.00      1.00      1.00       205
           3       1.00      1.00      1.00       253

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



100%|██████████| 1000/1000 [00:12<00:00, 77.21it/s]
 25%|██▌       | 2/8 [00:27<01:22, 13.76s/it]

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       268
           1       1.00      1.00      1.00       274
           2       1.00      1.00      1.00       205
           3       1.00      1.00      1.00       253

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



100%|██████████| 1000/1000 [00:09<00:00, 106.28it/s]
 38%|███▊      | 3/8 [00:37<00:59, 11.99s/it]

              precision    recall  f1-score   support

           0       1.00      1.00      1.00       268
           1       1.00      1.00      1.00       274
           2       1.00      1.00      1.00       205
           3       1.00      1.00      1.00       253

    accuracy                           1.00      1000
   macro avg       1.00      1.00      1.00      1000
weighted avg       1.00      1.00      1.00      1000



 32%|███▏      | 320/1000 [00:03<00:06, 104.62it/s]
 38%|███▊      | 3/8 [00:40<01:08, 13.66s/it]


KeyboardInterrupt: 

In [None]:
# plit a line chart of the losses for each batch
px.line(batch_losses, title="Losses for each batch")
