In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7"

In [2]:
# notebooks/4_classifier.ipynb
import sys
import os

# 设置项目根目录为当前目录
project_root = os.path.abspath("..")
sys.path.append(project_root)

from src.data_utils import load_datasets, create_balanced_tasks
from src.model_utils import load_model_and_tokenizer, get_embeddings
from src.classifier_utils import train_classifier, evaluate_classifier
import random
import numpy as np
import torch

In [3]:
# 测试随机数种子
SEED = 42

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

In [4]:
MODEL_NAME = "Llama-2-7b-hf"
MODELS_ROOT_PATH = "/mnt/data102_d2/huggingface/models"

In [5]:
# Load the pre-trained model and tokenizer
model_name = MODEL_NAME
root_path = MODELS_ROOT_PATH

# 拼接完整的存储路径
model_path = os.path.join(root_path, model_name)

model, tokenizer = load_model_and_tokenizer(model_path)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
# Load datasets
data_dir = '../data/processed'
datasets = load_datasets(data_dir, split='train')

# Generate balanced tasks
tasks = create_balanced_tasks(datasets, balanced=True, seed=42)


Loaded hellaswag dataset from train split, shape: (39905, 4)
Loaded gsm8k dataset from train split, shape: (7473, 4)
Loaded winogrande dataset from train split, shape: (2558, 4)
Loaded piqa dataset from train split, shape: (16113, 4)
Loaded mmlu dataset from train split, shape: (99842, 4)
Loaded ai2_arc dataset from train split, shape: (1119, 4)


In [7]:
# 将 tasks 分解为 inputs 和 task_types
inputs = [t['input'] for t in tasks]
task_types = [t['task_type'] for t in tasks]

In [8]:
# Generate embeddings and labels
embeddings, labels, task_type_to_label = get_embeddings(inputs, task_types, model, tokenizer)



Generating Embeddings: 100%|██████████| 6714/6714 [00:46<00:00, 143.07input/s]


In [9]:
# Train the classifier
clf, X_test, y_test = train_classifier(embeddings, labels, test_size=0.2, random_state=42)


In [10]:
# Evaluate the classifier
evaluate_classifier(clf, X_test, y_test, task_type_to_label)

              precision    recall  f1-score   support

   hellaswag       0.94      0.89      0.91       236
       gsm8k       0.92      0.99      0.95       233
  winogrande       0.98      0.99      0.98       218
        piqa       0.88      0.91      0.89       201
        mmlu       0.95      0.87      0.91       233
     ai2_arc       0.91      0.95      0.93       222

    accuracy                           0.93      1343
   macro avg       0.93      0.93      0.93      1343
weighted avg       0.93      0.93      0.93      1343

