In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import BertTokenizer
from sklearn.metrics import confusion_matrix, classification_report


In [3]:

# 1. 데이터셋 로드 및 전처리
dataset = load_dataset('ag_news')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
MAX_LENGTH = 128

def preprocess_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=MAX_LENGTH)

encoded_dataset = dataset.map(preprocess_function, batched=True)
encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

# 2. 데이터셋 및 데이터로더 준비
class AGNewsDataset(Dataset):
    def __init__(self, encodings):
        self.input_ids = encodings['input_ids']
        self.attention_mask = encodings['attention_mask']
        self.labels = encodings['label']
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx],
            'labels': self.labels[idx]
        }

train_dataset = AGNewsDataset(encoded_dataset['train'])
test_dataset = AGNewsDataset(encoded_dataset['test'])

BATCH_SIZE = 32

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [4]:
# 3. 모델 정의 및 인스턴스화
class CrossAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super(CrossAttention, self).__init__()
        self.key_proj = nn.Linear(d_in, d_out_kq)
        self.query_proj = nn.Linear(d_in, d_out_kq)
        self.value_proj = nn.Linear(d_in, d_out_v)
        self.softmax = nn.Softmax(dim=-1)  # attention 확률로 변환

    def forward(self, x, latent):
        keys = self.key_proj(x)
        queries = self.query_proj(latent)
        values = self.value_proj(x)

        attention_scores = torch.matmul(queries, keys.transpose(-2, -1))
        attention_probs = self.softmax(attention_scores)

        attended_values = torch.matmul(attention_probs, values)
        return attended_values

class LatentTransformer(nn.Module):
    def __init__(self, latent_dim, num_heads, num_layers, embed_dim):
        super(LatentTransformer, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, latent):
        latent = latent.permute(1,0,2)  # Transformer는 (seq_len, batch_size, latent_dim) 형식으로 데이터 받음.
        latent = self.transformer(latent)
        return latent.permute(1,0,2)    # 다시 (batch_size, latent_len, latent_dim) 형식으로 변환

class Averaging(nn.Module):
    def forward(self, latent):
        return latent.mean(dim=1)   # latent vector를 평균내서 최종 logits 계산

class Perceiver(nn.Module):
    def __init__(self, vocab_size, embed_dim, latent_dim, num_heads, num_layers, num_classes):
        super(Perceiver, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)  # 임베딩 레이어 추가
        self.input_proj = nn.Linear(embed_dim, embed_dim)

        self.latents = nn.Parameter(torch.randn(1, latent_dim, embed_dim))

        self.cross_attention = CrossAttention(d_in=embed_dim, d_out_kq=embed_dim, d_out_v=embed_dim)
        self.latent_transformer = LatentTransformer(latent_dim=latent_dim, num_heads=num_heads, 
                                                    num_layers=num_layers, embed_dim=embed_dim)
        
        self.averaging = Averaging()
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids)  # (batch_size, seq_length, embed_dim)
        x = self.input_proj(x)  # 동일한 임베딩 차원 유지

        batch_size = x.size(0)
        latent = self.latents.repeat(batch_size, 1, 1)    # (batch_size, latent_dim, embed_dim)
        latent = self.cross_attention(x, latent)          # Cross Attention
        latent = self.latent_transformer(latent)          # Transformer
        latent_avg = self.averaging(latent)               # Averaging
        logits = self.classifier(latent_avg)             # 최종 분류
        return logits

# 모델 인스턴스화
VOCAB_SIZE = tokenizer.vocab_size
EMBED_DIM = 128
LATENT_DIM = 64
NUM_HEADS = 8
NUM_LAYERS = 4
NUM_CLASSES = 4  # AG News는 4개의 클래스

model = Perceiver(vocab_size=VOCAB_SIZE, embed_dim=EMBED_DIM, latent_dim=LATENT_DIM,
                 num_heads=NUM_HEADS, num_layers=NUM_LAYERS, num_classes=NUM_CLASSES)

# 4. 손실 함수 및 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)




In [5]:
# 5. 훈련 루프 구현
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy

def eval_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)

            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct / total
    return avg_loss, accuracy

EPOCHS = 5

for epoch in range(EPOCHS):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = eval_epoch(model, test_loader, criterion, device)
    print(f'Epoch {epoch+1}/{EPOCHS}:')
    print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
    print(f'  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

Epoch 1/5:
  Train Loss: 0.5646, Train Acc: 0.7868
  Test Loss: 0.4011, Test Acc: 0.8553
Epoch 2/5:
  Train Loss: 0.3683, Train Acc: 0.8708
  Test Loss: 0.3545, Test Acc: 0.8745
Epoch 3/5:
  Train Loss: 0.3145, Train Acc: 0.8890
  Test Loss: 0.3119, Test Acc: 0.8939
Epoch 4/5:
  Train Loss: 0.2779, Train Acc: 0.9023
  Test Loss: 0.3005, Test Acc: 0.8964
Epoch 5/5:
  Train Loss: 0.2498, Train Acc: 0.9121
  Test Loss: 0.3014, Test Acc: 0.8976


In [6]:
def detailed_evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    print('Confusion Matrix:')
    print(cm)

    report = classification_report(all_labels, all_preds, target_names=dataset['train'].features['label'].names)
    print('Classification Report:')
    print(report)

detailed_evaluate(model, test_loader, device)

Confusion Matrix:
[[1727   59   66   48]
 [  27 1849    9   15]
 [  85   48 1572  195]
 [  75   46  105 1674]]
Classification Report:
              precision    recall  f1-score   support

       World       0.90      0.91      0.91      1900
      Sports       0.92      0.97      0.95      1900
    Business       0.90      0.83      0.86      1900
    Sci/Tech       0.87      0.88      0.87      1900

    accuracy                           0.90      7600
   macro avg       0.90      0.90      0.90      7600
weighted avg       0.90      0.90      0.90      7600



In [7]:
import matplotlib.pyplot as plt
import seaborn as sns


In [12]:
def detailed_evaluate(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # 혼동 행렬 계산
    cm = confusion_matrix(all_labels, all_preds)
    print('Confusion Matrix:')
    print(cm)

    # 분류 리포트 출력
    report = classification_report(all_labels, all_preds, target_names=dataset['train'].features['label'].names)
    print('Classification Report:')
    print(report)

    # 혼동 행렬 시각화
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=dataset['train'].features['label'].names,
                yticklabels=dataset['train'].features['label'].names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    detailed_evaluate(model, dataloader, device)