# DeepLerning最終課題

松尾研DeepLerning深層学習Springの最終課題のマスターコードである。

## 0.概要

#### 概要
被験者が画像を見ているときの脳波から，その画像がどのカテゴリに属するかを分類するタスク．
- サンプル数: 訓練 118,800 サンプル，検証 59,400 サンプル，テスト 59,400 サンプル
- クラス数: 5
- 入力: 脳波データ（チャンネル数 x 系列長）
- 出力: 対応する画像のクラス
- 評価指標: Top-1 accuracy

#### 修了要件
- ベースラインモデルのbest test accuracyは38.7%となります．**これを超えた提出のみ，修了要件として認めます**．
- ベースラインから改善を加えることで，55%までは性能向上することを運営で確認しています．こちらを 1 つの指標として取り組んでみてください．

#### 注意点
- 学習するモデルについて制限はありませんが，必ず訓練データで学習したモデルで予測してください．
    - 事前学習済みモデルを利用して，訓練データを fine-tuning しても構いません．
    - 埋め込み抽出モデルなど，モデルの一部を訓練しないケースは構いません．
    - 学習を一切せずに，ChatGPT などの基盤モデルを利用することは禁止とします．

In [9]:
#インポート
import os, sys
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from einops.layers.torch import Rearrange
from einops import repeat
from glob import glob
from termcolor import cprint
from tqdm.notebook import tqdm
from torchaudio.models import Conformer
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from PIL import Image
import open_clip
from torchvision import transforms

is_jupyter = False
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x1a0ff7b20f0>

In [None]:
#Install gdown if not already installed
if is_jupyter:
    !pip install gdown
    !pip install --upgrade gdown
    folder_id = '1qkxOgiD1Z0U5DnaYhs3RSfWdQHvzE4HX'
    !gdown --folder {folder_id}

    file_id = '1QNZv_4s89Lq8bwc6XNsqKKpxYZlLPWHT'
    !gdown https://drive.google.com/uc?id={file_id} -O downloaded.tar.gz
    !tar -xzvf downloaded.tar.gz -C data/
else:
    pass

# BaseCord

## 1.Preparation

In [3]:
WORK_DIR = r"C:\Users\sudok\OneDrive\ドキュメント\0B-DeepLerning\DeepLerning-FinalSubmission\最終課題"
os.makedirs(WORK_DIR, exist_ok=True)
%cd {WORK_DIR}

C:\Users\sudok\OneDrive\ドキュメント\0B-DeepLerning\DeepLerning-FinalSubmission\最終課題


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


## 2.Data

In [4]:
class ThingsEEGDataset(torch.utils.data.Dataset):
    def __init__(self, split: str) -> None:
        super().__init__()

        assert split in ["train", "val", "test"], f"Invalid split: {split}"
        self.split = split
        self.num_classes = 5
        self.num_subjects = 10

        self.X = np.load(fr"C:\Users\sudok\OneDrive\ドキュメント\0B-DeepLerning\DeepLerning-FinalSubmission\FinalSubmission\data\{split}\eeg.npy")
        self.X = torch.from_numpy(self.X).to(torch.float32)
        self.subject_idxs = np.load(fr"C:\Users\sudok\OneDrive\ドキュメント\0B-DeepLerning\DeepLerning-FinalSubmission\FinalSubmission\data\{split}\subject_idxs.npy")
        self.subject_idxs = torch.from_numpy(self.subject_idxs)

        if split in ["train", "val"]:
            self.y = np.load(fr"C:\Users\sudok\OneDrive\ドキュメント\0B-DeepLerning\DeepLerning-FinalSubmission\FinalSubmission\data\{split}\labels.npy")
            self.y = torch.from_numpy(self.y)

        print(f"EEG: {self.X.shape}, labels: {self.y.shape if hasattr(self, 'y') else None}, subject indices: {self.subject_idxs.shape}")

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, i):
        if hasattr(self, "y"):
            return self.X[i], self.y[i], self.subject_idxs[i]
        else:
            return self.X[i], self.subject_idxs[i]

    @property
    def num_channels(self) -> int:
        return self.X.shape[1]

    @property
    def seq_len(self) -> int:
        return self.X.shape[2]

## 3.Code

In [5]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        kernel_size: int = 3,
        p_drop: float = 0.1,
    ) -> None:
        super().__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim

        self.conv0 = nn.Conv1d(in_dim, out_dim, kernel_size, padding="same")
        self.conv1 = nn.Conv1d(out_dim, out_dim, kernel_size, padding="same")
        # self.conv2 = nn.Conv1d(out_dim, out_dim, kernel_size) # , padding="same")

        self.batchnorm0 = nn.BatchNorm1d(num_features=out_dim)
        self.batchnorm1 = nn.BatchNorm1d(num_features=out_dim)

        self.dropout = nn.Dropout(p_drop)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        if self.in_dim == self.out_dim:
            X = self.conv0(X) + X  # skip connection
        else:
            X = self.conv0(X)

        X = F.gelu(self.batchnorm0(X))

        X = self.conv1(X) + X  # skip connection
        X = F.gelu(self.batchnorm1(X))

        # X = self.conv2(X)
        # X = F.glu(X, dim=-2)

        return self.dropout(X)


class BasicConvClassifier(nn.Module):
    def __init__(
        self,
        num_classes: int,
        seq_len: int,
        in_channels: int,
        hid_dim: int = 128
    ) -> None:
        super().__init__()

        self.blocks = nn.Sequential(
            ConvBlock(in_channels, hid_dim),
            ConvBlock(hid_dim, hid_dim),
        )

        self.head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            Rearrange("b d 1 -> b d"),
            nn.Linear(hid_dim, num_classes),
        )

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        """_summary_
        Args:
            X ( b, c, t ): _description_
        Returns:
            X ( b, num_classes ): _description_
        """
        X = self.blocks(X)

        return self.head(X)

## 4.Traning

In [10]:
# ハイパラ
lr = 0.001
batch_size = 512
epochs = 10
is_cuda = False

# ------------------
#    Dataloader
# ------------------
train_set = ThingsEEGDataset("train") # ThingsMEGDataset("train")
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)
val_set = ThingsEEGDataset("val") # ThingsMEGDataset("val")
val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=batch_size, shuffle=False
)

# ------------------
#       Model
# ------------------
if is_cuda:
    model = BasicConvClassifier(
        train_set.num_classes, train_set.seq_len, train_set.num_channels
    ).to("cuda")
else:
    model = BasicConvClassifier(
        train_set.num_classes, train_set.seq_len, train_set.num_channels
    )

# ------------------
#     Optimizer
# ------------------
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# ------------------
#   Start training
# ------------------
max_val_acc = 0
def accuracy(y_pred, y):
    return (y_pred.argmax(dim=-1) == y).float().mean()

writer = SummaryWriter("tensorboard")

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")

    train_loss, train_acc, val_loss, val_acc = [], [], [], []

    model.train()
    for X, y, subject_idxs in tqdm(train_loader, desc="Train"):
        if is_cuda:
            X, y = X.to("cuda"), y.to("cuda")
        else:
            X, y = X.to("cpu"), y.to("cpu")

        y_pred = model(X)

        loss = F.cross_entropy(y_pred, y)
        train_loss.append(loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = accuracy(y_pred, y)
        train_acc.append(acc.item())

    model.eval()
    for X, y, subject_idxs in tqdm(val_loader, desc="Validation"):
        if is_cuda:
            X, y = X.to("cuda"), y.to("cuda")
        else:
            X, y = X.to("cpu"), y.to("cpu")

        with torch.no_grad():
            y_pred = model(X)

        val_loss.append(F.cross_entropy(y_pred, y).item())
        val_acc.append(accuracy(y_pred, y).item())

    print(f"Epoch {epoch+1}/{epochs} | \
        train loss: {np.mean(train_loss):.3f} | \
        train acc: {np.mean(train_acc):.3f} | \
        val loss: {np.mean(val_loss):.3f} | \
        val acc: {np.mean(val_acc):.3f}")

    writer.add_scalar("train_loss", np.mean(train_loss), epoch)
    writer.add_scalar("train_acc", np.mean(train_acc), epoch)
    writer.add_scalar("val_loss", np.mean(val_loss), epoch)
    writer.add_scalar("val_acc", np.mean(val_acc), epoch)

    torch.save(model.state_dict(), "model_last.pt")

    if np.mean(val_acc) > max_val_acc:
        cprint("New best. Saving the model.", "cyan")
        torch.save(model.state_dict(), "model_best.pt")
        max_val_acc = np.mean(val_acc)

EOFError: No data left in file

In [None]:
%load_ext tensorboard
%tensorboard --logdir tensorboard

## 5.Predict

In [None]:
# ------------------
#    Dataloader
# ------------------
test_set = ThingsEEGDataset("test")
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=batch_size, shuffle=False
)

# ------------------
#       Model
# ------------------
model = BasicConvClassifier(
    test_set.num_classes, test_set.seq_len, test_set.num_channels
).to("cuda")
model.load_state_dict(torch.load("model_best.pt", map_location="cuda"))

# ------------------
#  Start evaluation
# ------------------
preds = []
model.eval()
for X, subject_idxs in tqdm(test_loader, desc="Evaluation"):
    preds.append(model(X.to("cuda")).detach().cpu())

preds = torch.cat(preds, dim=0).numpy()
np.save("submission", preds)
print(f"Submission {preds.shape} saved.")

# OriginalCode

## 1.Preparation

In [3]:
class ThingsEEGDataset(torch.utils.data.Dataset):
    def __init__(self, split: str) -> None:
        super().__init__()

        assert split in ["train", "val", "test"], f"Invalid split: {split}"
        self.split = split
        self.num_classes = 5
        self.num_subjects = 10

        self.X = np.load(fr"C:\Users\sudok\OneDrive\ドキュメント\0B-DeepLerning\DeepLerning-FinalSubmission\FinalSubmission\data\{split}\eeg.npy")
        self.X = torch.from_numpy(self.X).to(torch.float32)
        self.subject_idxs = np.load(fr"C:\Users\sudok\OneDrive\ドキュメント\0B-DeepLerning\DeepLerning-FinalSubmission\FinalSubmission\data\{split}\subject_idxs.npy")
        self.subject_idxs = torch.from_numpy(self.subject_idxs)

        if split in ["train", "val"]:
            self.y = np.load(fr"C:\Users\sudok\OneDrive\ドキュメント\0B-DeepLerning\DeepLerning-FinalSubmission\FinalSubmission\data\{split}\labels.npy")
            self.y = torch.from_numpy(self.y)

        print(f"EEG: {self.X.shape}, labels: {self.y.shape if hasattr(self, 'y') else None}, subject indices: {self.subject_idxs.shape}")

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, i):
        if hasattr(self, "y"):
            return self.X[i], self.y[i], self.subject_idxs[i]
        else:
            return self.X[i], self.subject_idxs[i]

    @property
    def num_channels(self) -> int:
        return self.X.shape[1]

    @property
    def seq_len(self) -> int:
        return self.X.shape[2]

In [None]:
train_set = ThingsEEGDataset("train")
val_set = ThingsEEGDataset("val")
test_set = ThingsEEGDataset("test")

EEG: torch.Size([118800, 17, 100]), labels: torch.Size([118800]), subject indices: torch.Size([118800])
EEG: torch.Size([59400, 17, 100]), labels: torch.Size([59400]), subject indices: torch.Size([59400])


## 2.精度向上

### **精度向上への施策**

**・Conformerによる音声認識**
>音声認識のConformerを脳波予測に応用し、精度向上を図る

## 3.Conformer

**Conformerによる音声認識のTransfromerの予測**


>end-to-endの音声認識システムは、RNNをベースとしたモデルによって発展してきた。また、最近では、self-attentionによるTransformerベースのモデルが、様々な
>領域へと展開されてきている。一方で、音声認識においてはlocalのcontextを捉えることができるCNNが有用であり、いまだによく使われている。
>Transformerは、globalの相互作用を捉えやすいが、localの特徴を抽出しにくい。また、CNNはkernelの位置と移動距離によって局所的な特徴量を抽出していくため、
>globalの相互作用を捉えるのは難しい。ContextNetは、CNNベースでありながら、残差blockを活用することでglobalの特徴を抽出していたが十分とは言えない。
>ただ最近では、CNNとself-attentionを組み合わせることによって、それぞれを個別に扱うよりも性能が向上することが様々な研究によってわかってきている。
>本論では、音声認識モデルにおいて、CNNにself-attentionを結合した独自の方法を紹介する。また、globalとlocalの相互作用を扱うことで、パラメータの効率化も
>可能であると想定している。


![Transformer structure diagram](https://storage.googleapis.com/zenn-user-upload/1fc351a21052-20230206.png "Transformer")

[【論文紹介】Conformer](https://zenn.dev/nudibranch/articles/577a0bb1fab32a)

In [None]:
class EEGConformer(nn.Module):
    def __init__(self, num_classes, input_dim, num_layers=4, hidden_dim=144, num_heads=4, ffmult=4, dropout=0.1):
        super(EEGConformer, self).__init__()

        """
        EEG向けConformerモデル

        Args:
            num_classes (int): 分類クラス数
            input_dim (int): 各時刻の特徴量数（EEGならチャンネル数）
            num_layers (int): Conformer層の数
            hidden_dim (int): 内部の隠れ層サイズ
            num_heads (int): Self-Attentionのヘッド数
            ff_mult (int): Feed Forwardの拡張倍率
            dropout (float): ドロップアウト率
        """
        self.input_proj = nn.Linear(input_dim, hidden_dim)

        self.conformer = Conformer(
            input_dim=hidden_dim,
            num_heads=num_heads,
            ffn_dim=hidden_dim * ffmult,
            num_layers=num_layers,
            dropout=dropout,
            depthwise_conv_kernel_size=31  # 例として31を指定
        )
        
        self.classifier = nn.Linear(hidden_dim, num_classes)
        

    def forward(self, x):
        """
        x.shape = [batch_size, seq_len, input_dim]
        """
        x = self.input_proj(x)
        
        batch_size, seq_len, _ = x.shape
        lengths = torch.full((batch_size,), seq_len, dtype=torch.int64, device=x.device)
        
        x, _ = self.conformer(x, lengths)

        x = x.mean(dim=1)

        logits = self.classifier(x)

        return logits

In [None]:
barch_size = 32
lr = 0.01
num_epochs = 10

device = "cuda" if torch.cuda.is_available() else "cpu"

train_loader = DataLoader(train_set, batch_size=barch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=barch_size, shuffle=False)

model_conformer = EEGConformer(input_dim=train_set.num_channels,
                               num_classes=train_set.num_classes, 
                               num_layers=4,
                               hidden_dim=144,
                               num_heads=4,
                               ffmult=4,
                               dropout=0.1).to(device)

optimizer = torch.optim.Adam(model_conformer.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3, verbose=True
)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    model_conformer.train()
    train_loss = 0.0
    train_acc = 0.0
    correct = 0
    total = 0
    
    for X, y, subject_idxs in tqdm(train_loader, f"Epoch {epoch+1}/{num_epochs} - Train"):
        X, y = X.to(device), y.to(device)
        
        X = X.permute(0, 2, 1) 

        optimizer.zero_grad()
        outputs = model_conformer(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
    
    avg_train_loss = train_loss / len(train_loader)
    train_acc = correct / total
    
    model_conformer.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for X, y, subject_idxs in tqdm(val_loader, f"Epoch {epoch+1}/{num_epochs} - Validation"):
            X, y = X.to(device), y.to(device)
            
            X = X.permute(0, 2, 1) 

            outputs = model_conformer(X)
            loss = criterion(outputs, y)
            
            val_loss += loss.item()
            preds = outputs.argmax(dim=1)
            val_correct += (preds == y).sum().item()
            val_total += y.size(0)
            
    avg_val_loss = val_loss / len(val_loader)
    val_acc = val_correct / val_total

    scheduler.step(val_acc)
    current_lr = optimizer.param_groups[0]['lr']

    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f},"
          f"LR: {current_lr:.6f}")

In [None]:
#推論
test_set = ThingsEEGDataset("test")
test_loader = DataLoader(test_set, batch_size=barch_size, shuffle=False)

all_outputs = []
all_labels = []
device = "cuda" if torch.cuda.is_available() else "cpu"

model_conformer = EEGConformer(input_dim=train_set.num_channels,
                               num_classes=train_set.num_classes, 
                               num_layers=4,
                               hidden_dim=144,
                               num_heads=4,
                               ffmult=4,
                               dropout=0.1).to(device)
model_conformer.load_state_dict(torch.load("conformer_weights.pth"))
model_conformer.eval()

with torch.no_grad():
    for X, y in test_loader:
        X, y = X.to(device), y.to(device)
        X = X.permute(0, 2, 1)

        outputs = model_conformer(X)
        all_outputs.append(outputs.cpu())
        all_labels.append(y.cpu())

# 結合
all_outputs = torch.cat(all_outputs, dim=0).numpy()
all_labels = torch.cat(all_labels, dim=0).numpy()

np.save("submission.npy", all_outputs)
print(f"Submission {all_outputs.shape} saved.")

行った修正：

・StepLRやReduceLROnPlateauなどのスケジューラーの導入
>
>``StepLR``：Epochsごとに学習率を変動させる
>```python
>scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_epochs, gamma=lrning_late)
>```

>``ReduceLROPlateau``：性能の停滞時、学習率を変動させる
>```python
>scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=decrease_late, patience=step_epochs, verbose=True)
>```

## 4.CLIP

**CLIPによる応用的な脳波画像分類**


>CLIPの基本的なアイデアは，言語情報（テキスト）とペアになっている画像の対照学習によって，よい言語と画像の表現（embedding）を学習する点にあります
>上記のように事前学習されたCLIPは，画像とテキストが対応するような表現（共有表現）が学習されていると考えられます． このモデルを画像分類に活用する場合
>を考えてみましょう（図2）．
>飛行機（plane），自動車（car），…の分類器を作る場合は，まずそれぞれの画像を説明するテキスト（a photo of plane, a photo of car, …）のembeddingを作成
>します（図2上部）．
>そして，分類したい画像を画像のエンコーダに入力した際のembeddingと，それぞれのテキストとのコサイン類似度を計算し，最大になるクラスを選択することで，
>分類ができます（図2下部）．
>このようにして，対象となるdownstreamの分類タスクに関するデータセットを集めてfine-tuningすることなしに，zero-shotで目的のデータセットに関する分類器を
>（理想的には）構成できます。

![CLIP's Pre-study mechanism](https://trail.t.u-tokyo.ac.jp/ja/blog/22-12-02-clip/overview-a.svg "CLIP's prestudy")

[CLIP：言語と画像のマルチモーダル基盤モデル](https://trail.t.u-tokyo.ac.jp/ja/blog/22-12-02-clip/)

In [None]:
class EEGWithImageDataset(Dataset):
    def __init__(self, eeg_dataset, image_dir, transform=None):
        self.eeg_dataset = eeg_dataset
        self.image_dir = image_dir
        self.transform = transform
        self.image_folders = sorted(os.listdir(self.image_dir))

    def __len__(self):
        return len(self.eeg_dataset)

    def __getitem__(self, idx):
        eeg, label, subj = self.eeg_dataset[idx]

        folder_name = self.image_folders[label.item()]
        folder_path = os.path.join(self.image_dir, folder_name)

        image_files = sorted(os.listdir(folder_path))
        image_name = image_files[0]

        image_path = os.path.join(folder_path, image_name)
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)
            
        return eeg, image, label

In [None]:
preprocess = transforms.Compose([
    transforms.Resize(224),              # 画像サイズを224x224にリサイズ
    transforms.CenterCrop(224),          # 中心部分をクロップ
    transforms.ToTensor(),               # Tensor型に変換
    transforms.Normalize(                # 標準化 (CLIPと同じ値を使う場合はCLIPの値を使う)
        mean=[0.48145466, 0.4578275, 0.40821073],
        std=[0.26862954, 0.26130258, 0.27577711]
    ),
])

train_dataset = EEGWithImageDataset(train_set, image_dir="data/images", transform=preprocess)
val_dataset = EEGWithImageDataset(val_set, image_dir="data/images", transform=preprocess)
test_dataset = EEGWithImageDataset(test_set, image_dir="data/images", transform=preprocess)

EEG: torch.Size([118800, 17, 100]), labels: torch.Size([118800]), subject indices: torch.Size([118800])
EEG: torch.Size([59400, 17, 100]), labels: torch.Size([59400]), subject indices: torch.Size([59400])
EEG: torch.Size([118800, 17, 100]), labels: None, subject indices: torch.Size([59400])


In [None]:
class EEGFeatureExtractor(nn.Module):
    def __init__(self, eeg_model):
        super().__init__()
        self.features = nn.Sequential(
            *eeg_model.conformer_layers[:-1]
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)
        for module in self.features:
            x = module(x, key_padding_mask=None)
        x = torch.mean(x, dim=1)
        return x

In [15]:
#CLIP
device = "cuda" if torch.cuda.is_available() else "cpu"

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)

model_clip, _, _ = open_clip.create_model_and_transforms(
    "ViT-B-32", pretrained="openai", force_quick_gelu=True
)
model_clip.eval()
model_clip = model_clip.to(device)

In [16]:
#全結合
class EEG_CLIP_ConcatModel(nn.Module):
    def __init__(self, eeg_feature_dim, clip_feature_dim, num_classes):
        super().__init__()
        self.eeg_feature_dim = eeg_feature_dim
        self.clip_feature_dim = clip_feature_dim

        self.classifier = nn.Sequential(
            nn.Linear(eeg_feature_dim + clip_feature_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, eeg_feat, clip_feat):
        combined = torch.cat([eeg_feat, clip_feat], dim=1)
        out = self.classifier(combined)
        return out

In [None]:
num_epochs = 10
criterion = nn.CrossEntropyLoss()

concat_model = EEG_CLIP_ConcatModel(
    eeg_feature_dim=17,
    clip_feature_dim=512,
    num_classes=train_dataset.eeg_dataset.num_classes
).to(device)

eeg_model = Conformer(
    input_dim=train_set.num_channels, # num_channelsプロパティを使用
    num_heads=4,
    ffn_dim=128,
    num_layers=4,
    depthwise_conv_kernel_size=31,
    dropout=0.1,
).to(device)

eeg_feature_extractor = EEGFeatureExtractor(eeg_model).to(device)

optimizer = torch.optim.Adam(
    list(eeg_feature_extractor.parameters()) + list(concat_model.parameters()), 
    lr=1e-4
).to(device)

for epoch in range(num_epochs):
    concat_model.train()
    total_loss = 0
    correct = 0
    total = 0

    for eeg, image, label in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        eeg = eeg.to(device)
        image = image.to(device)
        label = label.to(device)

        eeg_feat = eeg_feature_extractor.to(device)
        clip_feat = model_clip.encode_image(image)

        output = concat_model(eeg_feat, clip_feat)

        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        preds = output.argmax(dim=1)  # 予測ラベル
        correct += (preds == label).sum().item()  # 正解数を加算
        total += label.size(0)  # バッチサイズを加算

    accuracy = correct / total
    print(f"Epoch {epoch+1} Loss: {total_loss / len(train_loader):.4f} Accuracy: {accuracy:.4f}")