In [1]:
import os, sys
import torch
sys.path.append('../../')
from IPython.display import display
from models.core.diffusion.pipe import Pipe
from models.core.diffusion.custom_pipeline import Generator4Embeds
from models.core.diffusion.diffusion_prior import DiffusionPriorUNet
from utils.data_modules.diffusion_embedding import DiffusionEmbeddingDataModule

In [2]:
data_module = DiffusionEmbeddingDataModule(
    eeg_embeddings_file="/workspace/eeg-image-decoding/data/all-joined-1/eeg/embeddings/650ms-250Hz/subj01_session1_eeg_embeddings.npy",
    subject=1,
    session=1,
    batch_size=1024,
    num_workers=4,
    val_split=0.1,
    test='default'
)

data_module.setup()
train_loader = data_module.train_dataloader()
test_loader = data_module.test_dataloader()

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)
pipe = Pipe(diffusion_prior, device=device)

In [5]:
model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'
pipe.train(train_loader, num_epochs=150, learning_rate=1e-3)

epoch: 0, loss: 1.2761273781458538




epoch: 1, loss: 1.273489236831665
epoch: 2, loss: 1.261702338854472
epoch: 3, loss: 1.246284047762553
epoch: 4, loss: 1.2350181341171265
epoch: 5, loss: 1.214526891708374
epoch: 6, loss: 1.2075467507044475
epoch: 7, loss: 1.1927531957626343
epoch: 8, loss: 1.1781030495961506
epoch: 9, loss: 1.1611844698588054
epoch: 10, loss: 1.1535508632659912
epoch: 11, loss: 1.139157732327779
epoch: 12, loss: 1.1270852486292522
epoch: 13, loss: 1.1121830940246582
epoch: 14, loss: 1.098787784576416
epoch: 15, loss: 1.0851916472117107
epoch: 16, loss: 1.0732665061950684
epoch: 17, loss: 1.0606656074523926
epoch: 18, loss: 1.0487064520517986
epoch: 19, loss: 1.037054419517517
epoch: 20, loss: 1.0283338228861492
epoch: 21, loss: 1.0167561372121174
epoch: 22, loss: 1.0056676467259724
epoch: 23, loss: 0.9955943822860718
epoch: 24, loss: 0.9825796484947205
epoch: 25, loss: 0.9726263483365377
epoch: 26, loss: 0.9590728878974915
epoch: 27, loss: 0.9506507515907288
epoch: 28, loss: 0.9379800955454508
epoch: 2

In [6]:
save_path = f"/workspace/eeg-image-decoding/code/models/check_points/diffusion_prior/subj01_session1.pt"

# Save the trained diffusion prior
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(pipe.diffusion_prior.state_dict(), save_path)
print(f"Diffusion prior saved to {save_path}")

Diffusion prior saved to /workspace/eeg-image-decoding/code/models/check_points/diffusion_prior/subj01_session1.pt
