In [1]:
import os, sys
import torch
sys.path.append('../../')
from IPython.display import display
from models.core.diffusion.pipe import Pipe
from models.core.diffusion.custom_pipeline import Generator4Embeds
from models.core.diffusion.diffusion_prior import DiffusionPriorUNet
from utils.data_modules.diffusion_embedding import DiffusionEmbeddingDataModule

In [2]:
data_module = DiffusionEmbeddingDataModule(
    eeg_embeddings_file="/workspace/eeg-image-decoding/data/all-joined-1/eeg/lo-res-embeddings/650ms-250Hz/subj01_session1_eeg_embeddings.npy",
    subject=1,
    session=1,
    batch_size=1024,
    num_workers=4,
    val_split=0.1,
    test='default'
)

data_module.setup()
train_loader = data_module.train_dataloader()
test_loader = data_module.test_dataloader()

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)
pipe = Pipe(diffusion_prior, device=device)

In [5]:
model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'
pipe.train(train_loader, num_epochs=150, learning_rate=1e-3)

epoch: 0, loss: 1.2761089007059734




epoch: 1, loss: 1.2736213604609172
epoch: 2, loss: 1.26202388604482
epoch: 3, loss: 1.2499698003133137
epoch: 4, loss: 1.2361989418665569
epoch: 5, loss: 1.219718098640442
epoch: 6, loss: 1.2001156409581502
epoch: 7, loss: 1.1918713251749675
epoch: 8, loss: 1.1756192048390706
epoch: 9, loss: 1.1681668361028035
epoch: 10, loss: 1.1542164087295532
epoch: 11, loss: 1.1410696109135945
epoch: 12, loss: 1.1178914705912273
epoch: 13, loss: 1.1140015920003254
epoch: 14, loss: 1.101389487584432
epoch: 15, loss: 1.0879757006963093
epoch: 16, loss: 1.0739066203435261
epoch: 17, loss: 1.0571256875991821
epoch: 18, loss: 1.0500188668568928
epoch: 19, loss: 1.0394944349924724
epoch: 20, loss: 1.0299759308497112
epoch: 21, loss: 1.0183583498001099
epoch: 22, loss: 1.0064880053202312
epoch: 23, loss: 0.9947097301483154
epoch: 24, loss: 0.984795093536377
epoch: 25, loss: 0.9728595415751139
epoch: 26, loss: 0.9626705249150594
epoch: 27, loss: 0.9518924752871195
epoch: 28, loss: 0.9413004318873087
epoch:

In [6]:
save_path = f"/workspace/eeg-image-decoding/code/models/check_points/diffusion_prior/subj01_session1.pt"

# Save the trained diffusion prior
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(pipe.diffusion_prior.state_dict(), save_path)
print(f"Diffusion prior saved to {save_path}")

Diffusion prior saved to /workspace/eeg-image-decoding/code/models/check_points/diffusion_prior/subj01_session1.pt
