In [1]:
import sys

sys.path.append("../")

import torch

In [2]:
from lib.brain_module import BrainModule
from lib.datasets import ThingsMEGDatasetWithImages
from lib.function import mse_loss
import os

import hydra
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import wandb
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from termcolor import cprint
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from tqdm import tqdm

from lib.datasets import ThingsMEGDataset
from lib.models import BasicConvClassifier
from lib.utils import set_seed
import matplotlib.pyplot as plt

In [3]:
data_dir = "../../data"
batch_size = 32
num_workers = 1
loader_args = {"batch_size": batch_size, "num_workers": num_workers}

pretrained_model2latent_dim = {
    "dinov2_vits14": 384,
}

train_set = ThingsMEGDatasetWithImages("train", data_dir)
train_loader = DataLoader(
    train_set, shuffle=True, **loader_args
)

valid_set = ThingsMEGDatasetWithImages("val", data_dir)
valid_loader = DataLoader(
    valid_set, shuffle=False, **loader_args
)

In [4]:
def train_brain_module(train_loader, valid_loader, n_epochs, loss_weight, device, image_module, brain_module, optimizer, scheduler)->None:
    image_module.to(device)
    brain_module.to(device)
    train_loss_list = []
    valid_loss_list = []

    for epoch in range(n_epochs):
        clip_losses_train = []  # 訓練誤差を格納しておくリスト
        clip_losses_valid = []  # 検証データの誤差を格納しておくリスト
        mse_losses_train = []  # 訓練誤差を格納しておくリスト
        mse_losses_valid = []  # 検証データの誤差を格納しておくリスト

        brain_module.train()  # 訓練モードにする
        for image_X, brain_X, y, subject_idx in tqdm(train_loader):
            optimizer.zero_grad()  # 勾配の初期化

            z = image_module(image_X.to(device))
            pred_z = brain_module(brain_X.to(device), subject_idx)

            # MSE loss
            mse_loss = mse_loss(z, pred_z)

            # clip loss
            clip_loss = 0

            # loss
            loss = loss_weight * clip_loss + (1.0 - loss_weight) * mse_loss

            loss.backward()  # 誤差の逆伝播
            optimizer.step()  # パラメータの更新

            clip_losses_train.append(clip_loss.tolist())
            mse_losses_train.append(mse_loss.tolist())

        brain_module.eval()  # 評価モードにする

        for image_X, brain_X, y, subject_idx in valid_loader:
            z = image_module(image_X.to(device))
            pred_z = brain_module(brain_X.to(device), subject_idx)

            # MSE loss
            mse_loss = mse_loss(z, pred_z)

            # clip loss
            clip_loss = 0

            # loss
            loss = loss_weight * clip_loss + (1.0 - loss_weight) * mse_loss

            clip_losses_valid.append(clip_loss.tolist())
            mse_losses_valid.append(mse_loss.tolist())

        losses_train = [loss_weight(clip_loss, mse_loss) for clip_loss, mse_loss in zip(clip_losses_train, mse_losses_train)]
        losses_valid = [loss_weight(clip_loss, mse_loss) for clip_loss, mse_loss in zip(clip_losses_valid, mse_losses_valid)]

        print(
            "EPOCH: {}, Train [Loss: {:.3f}], Valid [Loss: {:.3f}]".format(
                epoch,
                np.mean(losses_train),
                np.mean(losses_valid),
            )
        )
        train_loss_list.append(np.mean(losses_train))
        valid_loss_list.append(np.mean(losses_valid))

        if scheduler is not None:
            scheduler.step()

    plt.plot(train_loss_list, label="train loss")
    plt.plot(valid_loss_list, label="valid loss")
    # 凡例を表示
    plt.legend()
    # plt.savefig("drive/MyDrive/Colab Notebooks/DLBasics2023_colab/Lecture05/loss.png")
    plt.show()

    # plt.plot(train_acc_list, label="train acc")
    # plt.plot(valid_acc_list, label="valid acc")
    # plt.ylim(0.8, 1)
    # # 凡例を表示
    # plt.legend()
    # plt.savefig("drive/MyDrive/Colab Notebooks/DLBasics2023_colab/Lecture05/acc.png")
    # plt.show()


In [5]:
n_epochs = 200
device = torch.device("mps")

loss_weight = 0 # clip lossの重み
image_module = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
brain_module = BrainModule(out_dim=pretrained_model2latent_dim["dinov2_vits14"])
optimizer = optim.Adam(
    brain_module.parameters(), lr=3e-4
)
#scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [60, 120, 160], 0.2)

Using cache found in /Users/hayato/.cache/torch/hub/facebookresearch_dinov2_main


ds directory : /Users/hayato/mne_data/MNE-spm-face/MEG/spm/SPM_CTF_MEG_example_faces1_3D.ds
    res4 data read.
    hc data read.
    Separate EEG position data file not present.
    Quaternion matching (desired vs. transformed):
      -0.90   72.01    0.00 mm <->   -0.90   72.01   -0.00 mm (orig :  -43.09   61.46 -252.17 mm) diff =    0.000 mm
       0.90  -72.01    0.00 mm <->    0.90  -72.01   -0.00 mm (orig :   53.49  -45.24 -258.02 mm) diff =    0.000 mm
      98.30    0.00    0.00 mm <->   98.30   -0.00    0.00 mm (orig :   78.60   72.16 -241.87 mm) diff =    0.000 mm
    Coordinate transformations established.
    Polhemus data for 3 HPI coils added
    Device coordinate locations for 3 HPI coils added
    Measurement info composed.
Finding samples for /Users/hayato/mne_data/MNE-spm-face/MEG/spm/SPM_CTF_MEG_example_faces1_3D.ds/SPM_CTF_MEG_example_faces1_3D.meg4: 
    System clock channel is available, checking which samples are valid.
    1 x 324474 = 324474 samples from 340 ch

In [6]:
train_brain_module(train_loader=train_loader, valid_loader=valid_loader, n_epochs=n_epochs, loss_weight=loss_weight, device=device, image_module=image_module, brain_module=brain_module, optimizer=optimizer, scheduler=None)

  0%|          | 0/2054 [00:06<?, ?it/s]


NotImplementedError: The operator 'aten::upsample_bicubic2d.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.