# MIND Dataset + SBERT-style BERT 실험 노트북

- 본 노트북은 Microsoft MIND 데이터셋의 `news.tsv` 파일을 로컬에서 로드하여 뉴스 기사 분류 실험을 진행합니다.
- 입력 문장은 `title`과 `abstract`를 이어붙인 **rich context** 기반 문장입니다.
- 모델은 `bert-base-uncased`에 **Mean Pooling을 적용한 SBERT-style 분류기**를 사용합니다.
- 라벨 정보(`category`)는 정수로 매핑되며, 클래스 불균형을 완화하기 위해 샘플 수가 100개 미만인 카테고리는 제외됩니다.
- 데이터 다운로드 및 전처리에 대한 상세 내용은 `mind_sbert/README.md`에 정리되어 있습니다.


In [21]:
# Cell 1: 라이브러리 및 시드 고정
import torch
import torch.nn as nn
import numpy as np
import random
from transformers import BertModel, BertTokenizerFast, TrainingArguments, Trainer
from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed()

In [22]:
# Cell 2: 커스텀 SBERTClassifier (Trainer 호환)
class SBERTClassifier(nn.Module):
    def __init__(self, pretrained_model_name='bert-base-uncased', num_classes=2):
        super().__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)

    def mean_pooling(self, last_hidden, attention_mask):
        mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden.size()).float()
        sum_embeddings = torch.sum(last_hidden * mask_expanded, dim=1)
        sum_mask = mask_expanded.sum(dim=1)
        return sum_embeddings / sum_mask

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.mean_pooling(outputs.last_hidden_state, attention_mask)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
            return {"loss": loss, "logits": logits}
        return {"logits": logits}

In [23]:
# Cell 3: 데이터 로딩 및 전처리
import pandas as pd

def load_news(path):
    df = pd.read_csv(
        path,
        sep="\t",
        header=0,
        names=["newsID","category","subcategory","title","abstract"],
        quoting=3,            # pandas.QUOTE_NONE
        encoding="utf-8",
        usecols=[0,1,2,3,4],  # newsID,category,subcategory,title,abstract
    )
    df["text"] = df["title"].fillna("") + " " + df["abstract"].fillna("")
    return df[["text","category"]]

train_df = load_news("/home/elicer/MINDlarge_train/news.tsv")
val_df   = load_news("/home/elicer/MINDlarge_dev/news.tsv")
test_df  = load_news("/home/elicer/MINDlarge_test/news.tsv")

# 클래스 불균형 제거 (최소 100개)
min_count = 100
counts = train_df["category"].value_counts()
keep_categories = counts[counts >= min_count].index.tolist()

train_df = train_df[train_df["category"].isin(keep_categories)]
val_df   = val_df[val_df["category"].isin(keep_categories)]
test_df  = test_df[test_df["category"].isin(keep_categories)]

# 라벨 인코딩
labels   = sorted(train_df["category"].unique())
label2id = {c:i for i,c in enumerate(labels)}
id2label = {i:c for c,i in label2id.items()}
num_labels = len(labels)
for df in (train_df, val_df, test_df):
    df["label"] = df["category"].map(label2id).astype("int64")

In [24]:
# Cell 4: HF DatasetDict 생성 및 토크나이징
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
MAX_LEN = 256

dataset_dict = DatasetDict({
    "train":      Dataset.from_pandas(train_df[["text","label"]], preserve_index=False),
    "validation": Dataset.from_pandas(val_df[["text","label"]],   preserve_index=False),
    "test":       Dataset.from_pandas(test_df[["text","label"]],  preserve_index=False),
})

def preprocess_fn(batch):
    out = tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=MAX_LEN,
    )
    out["labels"] = batch["label"]
    return out

tokenized = dataset_dict.map(
    preprocess_fn,
    batched=True,
    remove_columns=["text","label"],
)

tokenized.set_format("torch", columns=["input_ids","attention_mask","labels"])
train_ds = tokenized["train"]
eval_ds  = tokenized["validation"]
test_ds  = tokenized["test"]

Map: 100%|██████████| 101522/101522 [00:21<00:00, 4831.40 examples/s]
Map: 100%|██████████| 72019/72019 [00:14<00:00, 4828.86 examples/s]
Map: 100%|██████████| 120956/120956 [00:25<00:00, 4750.18 examples/s]


In [25]:
# Cell 5: Trainer 및 TrainingArguments 설정
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, preds)}

training_args = TrainingArguments(
    output_dir="./sbert_trainer_mind",
    do_train=True,
    do_eval=True,
    eval_steps=500,
    save_steps=500,
    logging_steps=50,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    fp16=True,  # GPU에서 반드시 지원
)

In [26]:
# Cell 6: Trainer 인스턴스화 및 학습
model = SBERTClassifier(pretrained_model_name='bert-base-uncased', num_classes=num_labels)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    compute_metrics=compute_metrics,
)

trainer.train()

Step,Training Loss
50,1.7593
100,1.2294
150,1.0491
200,1.0867
250,1.0317
300,0.954
350,0.8995
400,0.8438
450,0.8122
500,0.8386


TrainOutput(global_step=19038, training_loss=0.494050298146383, metrics={'train_runtime': 2178.9272, 'train_samples_per_second': 139.778, 'train_steps_per_second': 8.737, 'total_flos': 0.0, 'train_loss': 0.494050298146383, 'epoch': 3.0})

In [29]:
# Cell 7: Validation/Test 평가
print("\n▶ Validation Accuracy:", trainer.evaluate(eval_ds))
print("▶ Test Accuracy:      ", trainer.evaluate(test_ds))



▶ Validation Accuracy: {'eval_loss': 0.21372860670089722, 'eval_accuracy': 0.9326844304975076, 'eval_runtime': 122.1179, 'eval_samples_per_second': 589.75, 'eval_steps_per_second': 18.433, 'epoch': 3.0}
▶ Test Accuracy:       {'eval_loss': 0.31558459997177124, 'eval_accuracy': 0.9032954132081087, 'eval_runtime': 204.7909, 'eval_samples_per_second': 590.632, 'eval_steps_per_second': 18.458, 'epoch': 3.0}
