In [33]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from datasets import load_dataset
from sklearn.metrics import auc, accuracy_score, confusion_matrix, precision_recall_fscore_support, roc_curve
from sklearn.preprocessing import label_binarize
from tqdm import tqdm
from typing import TypedDict

In [34]:
dataset = load_dataset("ag_news")

Found cached dataset ag_news (/Users/jaypark/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


  0%|          | 0/2 [00:00<?, ?it/s]

In [35]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # 모델 입력에 사용되는 데이터를 바로 사용할 수 없고 토큰화를 해서 입력 그떄 사용하는 토크나이저
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=4) 


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [36]:
class DatasetItem(TypedDict):
    text: str
    label: str
    
def preprocess_date(dataset_item: DatasetItem)->dict[str, torch.Tensor]:
    return tokenizer(dataset_item["text"], padding="max_length", truncation=True, return_tensors='pt')
    
    

In [37]:
# 데이터 전처리

train_dataset = dataset["train"].select(range(1200)).map(preprocess_date, batched=True)
test_dataset = dataset["test"].select(range(1200)).map(preprocess_date, batched=True)

train_dataset.set_format("torch", columns=["input_ids", "attention_mask","label"])
test_dataset.set_format("torch", columns=["input_ids", "attention_mask","label"])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

Loading cached processed dataset at /Users/jaypark/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-6aefbb7bb7939cd6.arrow
Loading cached processed dataset at /Users/jaypark/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-59c08fe492775d92.arrow


In [38]:
train_dataset.to_pandas()

Unnamed: 0,text,label,input_ids,token_type_ids,attention_mask
0,Wall St. Bears Claw Back Into the Black (Reute...,2,"[101, 2813, 2358, 1012, 6468, 15020, 2067, 204...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,Carlyle Looks Toward Commercial Aerospace (Reu...,2,"[101, 18431, 2571, 3504, 2646, 3293, 13395, 10...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,Oil and Economy Cloud Stocks' Outlook (Reuters...,2,"[101, 3514, 1998, 4610, 6112, 15768, 1005, 176...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,Iraq Halts Oil Exports from Main Southern Pipe...,2,"[101, 5712, 9190, 2015, 3514, 14338, 2013, 236...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,"Oil prices soar to all-time record, posing new...",2,"[101, 3514, 7597, 2061, 2906, 2000, 2035, 1011...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...,...,...
1195,Airline Looks to Delay Its Pension Payments In...,2,"[101, 8582, 3504, 2000, 8536, 2049, 11550, 105...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1196,McDonald's Goes for Gold For a four-year inves...,2,"[101, 9383, 1005, 1055, 3632, 2005, 2751, 2005...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1197,Auction for Shares In Google's IPO May End Tod...,3,"[101, 10470, 2005, 6661, 1999, 8224, 1005, 105...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1198,"Four Firms Buy Intelsat Intelsat Ltd., the pio...",3,"[101, 2176, 9786, 4965, 13420, 16846, 13420, 1...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [39]:
optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()



In [40]:
print(f"MPS 장치를 지원하도록 build가 되었는가? {torch.backends.mps.is_built()}")
print(f"MPS 장치가 사용 가능한가? {torch.backends.mps.is_available()}") 

MPS 장치를 지원하도록 build가 되었는가? True
MPS 장치가 사용 가능한가? True


In [41]:
# 모델 학습

device = torch.device("mps")
model.to(device)

num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    total_loss = 0 
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        inputs = {key: batch[key].to(device) for key in batch}
        labels = inputs.pop("label")
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} average loss: {average_loss}")

Epoch 1:   3%|▎         | 5/150 [00:52<25:22, 10.50s/it]


KeyboardInterrupt: 