In [1]:
import os, sys
import torch
sys.path.append('../../')
from IPython.display import display
from models.core.diffusion.pipe import Pipe
from models.core.diffusion.custom_pipeline import Generator4Embeds
from models.core.diffusion.diffusion_prior import DiffusionPriorUNet
from utils.data_modules.diffusion_embedding import DiffusionEmbeddingDataModule

In [2]:
data_module = DiffusionEmbeddingDataModule(
    eeg_embeddings_file="/workspace/eeg-image-decoding/data/all-joined-1/eeg/embeddings/650ms-250Hz/subj01_session1_eeg_embeddings.npy",
    subject=1,
    session=1,
    batch_size=1024,
    num_workers=4,
    val_split=0.1,
    test='default'
)

data_module.setup()
train_loader = data_module.train_dataloader()
test_loader = data_module.test_dataloader()

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [4]:
diffusion_prior = DiffusionPriorUNet(cond_dim=1024, dropout=0.1)
pipe = Pipe(diffusion_prior, device=device)

In [5]:
model_name = 'diffusion_prior' # 'diffusion_prior_vice_pre_imagenet' or 'diffusion_prior_vice_pre'
pipe.train(train_loader, num_epochs=150, learning_rate=1e-3)

epoch: 0, loss: 1.2819268306096394




epoch: 1, loss: 1.278226097424825
epoch: 2, loss: 1.2674747308095295
epoch: 3, loss: 1.2540096044540405
epoch: 4, loss: 1.2388110955556233
epoch: 5, loss: 1.2261294921239216
epoch: 6, loss: 1.2080537875493367
epoch: 7, loss: 1.195296287536621
epoch: 8, loss: 1.1820848782857258
epoch: 9, loss: 1.1660919189453125
epoch: 10, loss: 1.1551267703374226
epoch: 11, loss: 1.1433051824569702
epoch: 12, loss: 1.128611445426941
epoch: 13, loss: 1.1152995427449544
epoch: 14, loss: 1.1026936769485474
epoch: 15, loss: 1.0910864273707073
epoch: 16, loss: 1.0734502871831257
epoch: 17, loss: 1.0607703924179077
epoch: 18, loss: 1.0511807998021443
epoch: 19, loss: 1.0415691137313843
epoch: 20, loss: 1.0282485087712605
epoch: 21, loss: 1.0178369283676147
epoch: 22, loss: 1.0073200464248657
epoch: 23, loss: 0.9916874567667643
epoch: 24, loss: 0.9841941992441813
epoch: 25, loss: 0.9749188621838888
epoch: 26, loss: 0.9618188738822937
epoch: 27, loss: 0.949218730131785
epoch: 28, loss: 0.9401240348815918
epoch

In [6]:
save_path = f"/workspace/eeg-image-decoding/code/models/check_points/diffusion_prior/subj01_session1.pt"

# Save the trained diffusion prior
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(pipe.diffusion_prior.state_dict(), save_path)
print(f"Diffusion prior saved to {save_path}")

Diffusion prior saved to ../models/check_points/diffusion_prior/subj01_session1.pt
