# 埋め込みの平均ベクトルでTransformerを学習＆5分割交差検証
- モデル：`esm-2`

## 1. 必要ライブラリのインポート

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

model_name = "esm2"

### 1.1 データセットの作成

In [5]:
class MeanEmbeddingDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx].unsqueeze(0), self.y[idx]

### 1.2 Transformer分類モデル

In [6]:
import torch.nn as nn

class MeanEmbeddingTransformerClassifier(nn.Module):
    def __init__(self, input_dim=2560, hidden_dim=512, num_classes=5, num_heads=8, num_layers=2, dropout=0.1):
        super().__init__()
        self.linear_proj = nn.Linear(input_dim, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads, batch_first=True, dropout=dropout
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.linear_proj(x)
        x = self.transformer(x)
        x = x.mean(dim=1)
        return self.fc(x)

## 2. データ読み込み

In [8]:
X = np.load(f"../data/embedding-vectors/{model_name}/mean_vectors.npy")
y = np.load(f"../data/embedding-vectors/{model_name}/labels.npy")

print("X.shape:", X.shape)
print("y.shape:", y.shape)

X.shape: (7716, 2560)
y.shape: (7716,)


## 3. 学習と評価（5分割交差検証）

In [9]:
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import torch.optim as optim

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [11]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)  # 5分割交差検証

accuracies, precisions, recalls, f1_scores = [], [], [], []

In [12]:
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y), 1):
    print(f"\033[34m[Fold {fold} / 5]\033[0m")

    train_ds = MeanEmbeddingDataset(X[train_idx], y[train_idx])
    val_ds = MeanEmbeddingDataset(X[val_idx], y[val_idx])

    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=32)

    model = MeanEmbeddingTransformerClassifier().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # --- 学習ループ ---
    for epoch in range(100):  # 調整可能
        model.train()
        total_loss = 0

        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            optimizer.zero_grad()
            out = model(xb)
            loss = criterion(out, yb)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"\033[36mEpoch {epoch + 1}, loss={total_loss/len(train_loader):.6f}\033[0m")

    # --- 評価 ---
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            out = model(xb)
            preds = out.argmax(dim=1).cpu().numpy()
            y_pred.extend(preds)
            y_true.extend(yb.numpy())

    # --- 指標の計算 ---
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average="macro", zero_division=0)
    rec = recall_score(y_true, y_pred, average="macro", zero_division=0)
    f1 = f1_score(y_true, y_pred, average="macro")
    cm = confusion_matrix(y_true, y_pred)

    # --- 保存 ---
    accuracies.append(acc)
    precisions.append(prec)
    recalls.append(rec)
    f1_scores.append(f1)

    # --- 出力 ---
    print(f"\n\033[32mFold {fold} result:\033[0m")
    print(f"\033[92mAccuracy : {acc:.4f}\033[0m")
    print(f"\033[92mPrecision: {prec:.4f}\033[0m")
    print(f"\033[92mRecall   : {rec:.4f}\033[0m")
    print(f"\033[92mF1-score : {f1:.4f}\033[0m")
    print("\nConfusion Matrix:")
    print(cm)
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, zero_division=0, digits=4))

# === 平均結果 ===
print("\n\033[35m===== Cross-validation Summary =====\033[0m")
print(f"\033[95mMean Accuracy : {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}\033[0m")
print(f"\033[95mMean Precision: {np.mean(precisions):.4f} ± {np.std(precisions):.4f}\033[0m")
print(f"\033[95mMean Recall   : {np.mean(recalls):.4f} ± {np.std(recalls):.4f}\033[0m")
print(f"\033[95mMean F1-score : {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}\033[0m")

[34m[Fold 1 / 5][0m
[36mEpoch 1, loss=0.610349[0m
[36mEpoch 2, loss=0.474201[0m
[36mEpoch 3, loss=0.425489[0m
[36mEpoch 4, loss=0.399520[0m
[36mEpoch 5, loss=0.387225[0m
[36mEpoch 6, loss=0.369748[0m
[36mEpoch 7, loss=0.368266[0m
[36mEpoch 8, loss=0.352238[0m
[36mEpoch 9, loss=0.355089[0m
[36mEpoch 10, loss=0.347014[0m
[36mEpoch 11, loss=0.339620[0m
[36mEpoch 12, loss=0.334018[0m
[36mEpoch 13, loss=0.332189[0m
[36mEpoch 14, loss=0.328486[0m
[36mEpoch 15, loss=0.333157[0m
[36mEpoch 16, loss=0.321970[0m
[36mEpoch 17, loss=0.324054[0m
[36mEpoch 18, loss=0.327203[0m
[36mEpoch 19, loss=0.321642[0m
[36mEpoch 20, loss=0.329254[0m
[36mEpoch 21, loss=0.314089[0m
[36mEpoch 22, loss=0.309765[0m
[36mEpoch 23, loss=0.311684[0m
[36mEpoch 24, loss=0.310329[0m
[36mEpoch 25, loss=0.313401[0m
[36mEpoch 26, loss=0.312693[0m
[36mEpoch 27, loss=0.307557[0m
[36mEpoch 28, loss=0.307513[0m
[36mEpoch 29, loss=0.310170[0m
[36mEpoch 30, loss=0.304019[