In [None]:
import os
import pickle
import mne
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import matplotlib.pyplot as plt
import wandb
from simpleconv_diffusion import SimpleConv
import torch.optim as optim
from classes import MEGDataset
from utils import soft_clip_loss, hard_clip_loss, calculate_params, train_modified, val_modified, test_model
from mindeye import VersatileDiffusionPriorNetwork, BrainDiffusionPrior
from diffusers import VersatileDiffusionDualGuidedPipeline, UniPCMultistepScheduler
from diffusers.models import DualTransformer2DModel
import time

from diffusion_utils import train_diffusion, val_diffusion

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%reload_ext autoreload

In [None]:
class Args:
    seed = 42
    epochs = 100
    batch_size = 128
    lr = 3e-4
    early_stopping = 15
    lr_schedule = "linear"
    warmup_lr = 1e-5
    warmup_interval = 1000
    loss_func = "soft_clip_loss"
    output_dir = "./output"
    wandb_project = "MEG_Diffusion_Testing"
    wandb_run_name = None
    save_interval = 10
    print_interval = 50
    dilation_type = "expo"
    dropout = 0.2
    embeddings_type = "vit"
    dataset_type = "small"
    preprocessing_type = "raj"
    scheduler_type = "constant"

args = Args()

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using DEVICE:", DEVICE)

now = time.strftime("%Y-%m-%d_%H-%M-%S")
run_name = (
    args.wandb_run_name
    if args.wandb_run_name
    else f"Dfsn_Pre{args.preprocessing_type}_Drpt{args.dropout}_Diln{args.dilation_type}_ClipEmb{args.embeddings_type}_Loss[{args.loss_func}]_B{args.batch_size}_LR{args.lr}_S{args.seed}_E{args.epochs}_{now}"
)
wandb_run_name = run_name

In [None]:
if args.dataset_type == "small" and args.preprocessing_type == "raj":
    with open('./valid_epochs/valid_epochs_adjusted_train_redid.pickle', 'rb') as f:
        valid_epochs = pickle.load(f)
    with open('./valid_epochs/valid_epochs_small_test_redid.pickle', 'rb') as f:
        valid_epochs_test = pickle.load(f)

In [None]:
if args.embeddings_type == "vit":
    if args.dataset_type == "small":
        embeddings = np.load('./embeddings/image_embeddings_vit_hidden.npy', allow_pickle=True).item()
        # embeddings_test = np.load('./embeddings/image_embeddings_vit_small_test_redid.npy', allow_pickle=True).item()
        embeddings_val = np.array([embeddings[filename] for filename in valid_epochs.metadata['image_path']])
        # embeddings_val_test = np.array([embeddings_test[filename] for filename in valid_epochs_test.metadata['image_path']])
        embeddings_val_test = np.array([embeddings[filename] for filename in valid_epochs_test.metadata['image_path']])

In [None]:
del embeddings

In [None]:
# filenames_set1 = set(valid_epochs.metadata['image_path'])
# filenames_set2 = set(embeddings.keys())
# if filenames_set1 == filenames_set2:
#     print("All filenames match!")
# else:
#     print("Filenames do not exactly match.")

layout = mne.channels.find_layout(valid_epochs.info, ch_type="meg")
layout_test = mne.channels.find_layout(valid_epochs_test.info, ch_type='meg')

print("Train Valid Epochs Shape:",valid_epochs.get_data().shape)
print("Test Valid Epochs Shape:",valid_epochs_test.get_data().shape)

In [None]:
print("Working on Data Loaders...")
dataset = MEGDataset(valid_epochs, embeddings_val, layout)
val_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=1)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=1)

# Test Dataset
test_dataset = MEGDataset(valid_epochs_test, embeddings_val_test, layout_test)
test_loader = DataLoader(test_dataset, batch_size=800, shuffle=False, drop_last=True, num_workers=1)



small_train_dataset = torch.utils.data.Subset(
    train_dataset, torch.linspace(0, len(train_dataset) - 1, steps=1024).long()) # only 512 data points

# 4 batches
small_train_loader = DataLoader(
    small_train_dataset, batch_size=128, pin_memory=True, shuffle=True)

small_val_dataset = torch.utils.data.Subset(
    val_dataset, torch.linspace(0, len(val_dataset) - 1, steps=512).long())

small_val_loader = DataLoader(
    small_val_dataset, batch_size=128, pin_memory=True, shuffle=False)
print("Data Loaders Ready!")

In [None]:
clip_size = 768 # "ViT-L/14": 768
norm_embs = True
hidden = True
prior = True
vd_cache_dir = "./vd_cache/"  #Where is cached Versatile Diffusion model; if not cached will download to this path
n_samples_save = 1
lr_scheduler_type = 'linear'

args.clip_size = clip_size
args.norm_embs = norm_embs
args.hidden = hidden
args.prior = prior
args.vd_cache_dir = vd_cache_dir
args.n_samples_save = n_samples_save
args.lr_scheduler_type = lr_scheduler_type
args.v2c = True

if args.hidden:
    args.prior_mult = 30.0
    args.nce_mult = 0
else:
    args.prior_mult = .03  # WHY?
    



In [None]:
if args.hidden:
        print("Using hidden layer CLIP space (Versatile Diffusion)")
        if not args.norm_embs:
            print("WARNING: YOU WANT NORMED EMBEDDINGS FOR VERSATILE DIFFUSION!")
        # clip_extractor = Clipper(clip_variant, device=device, hidden_state=True, norm_embs=norm_embs)
        args.out_dim = 257 * args.clip_size
        # OVERWRITE FOR NOW: REMOVE LATER
        # args.out_dim = args.clip_size
        
        print("Output Dimension:", args.out_dim)
        
        

In [None]:
print("Instantiating Voxel2Clip Model.. BRAIN NETWORK")
voxel2clip = SimpleConv(
    in_channels=272,
    out_channels=2048,
    merger_dropout=0.2,
    hidden_channels=320,
    n_subjects=4,
    merger=True,
    merger_pos_dim=2048,
    subject_layers=True,
    subject_layers_dim="input",
    gelu=True,
    device=DEVICE,
    dilation_type=args.dilation_type,
    extra_dropout=args.dropout,
    use_mse_projector=True, ## ADD THIS EXTRA MLP PROJECTOR FOR CLIP AND DIFUSSION PRIOR TO LEARN SEPARATELY
    clip_size=args.clip_size,
    # projector_dim=2048,
    # out_dim=args.out_dim,
).to(DEVICE) 

In [None]:
v2c_params, v2_trainable_params = calculate_params(voxel2clip)
print("Voxel2Clip Model Instantiated! Total Parameters:", v2c_params)
print("Trainable Parameters:", v2_trainable_params)

In [None]:
# setup prior network for diffusion prior
out_dim = clip_size
depth = 6
dim_head = 64
heads = clip_size//64 # heads * dim_head = 12 * 64 = 768

if hidden:
    args.guidance_scale = 3.5 #NEED TO BE PASSED FOR RECONSTRUCTION
    args.timesteps = 100 #NEED TO BE PASSED FOR RECONSTRUCTION
    prior_network = VersatileDiffusionPriorNetwork(
            dim=out_dim,
            depth=depth,
            dim_head=dim_head,
            heads=heads,
            causal=False,
            num_tokens = 257,
            learned_query_mode="pos_emb"
        ).to(DEVICE)
    print("prior_network loaded")

    # custom version that can fix seeds
    diffusion_prior = BrainDiffusionPrior(
        net=prior_network,
        image_embed_dim=out_dim,
        condition_on_text_encodings=False,
        timesteps=args.timesteps,
        cond_drop_prob=0.2,
        image_embed_scale=None,
        voxel2clip=voxel2clip, # THIS IS OUR BRAIN MODULE , TRAINED END TO END WITH DIFFUSION PRIOR
    ).to(DEVICE)
    

if not prior:
    diffusion_prior = diffusion_prior.requires_grad_(False)
    diffusion_prior.voxel2clip.requires_grad_(True)

In [None]:
print("Diffusion Prior  Instantiated! Total Parameters:", calculate_params(diffusion_prior))

In [None]:
if n_samples_save > 0 and args.hidden:
        print('Creating versatile diffusion reconstruction pipeline...')
        try:
            vd_pipe =  VersatileDiffusionDualGuidedPipeline.from_pretrained(vd_cache_dir).to('cpu')
        except:
            print("Downloading Versatile Diffusion to", vd_cache_dir)
            vd_pipe =  VersatileDiffusionDualGuidedPipeline.from_pretrained(
                    "shi-labs/versatile-diffusion",
                    cache_dir = vd_cache_dir).to('cpu')
        vd_pipe.image_unet.eval()
        vd_pipe.vae.eval()
        vd_pipe.image_unet.requires_grad_(False)
        vd_pipe.vae.requires_grad_(False)
        path_scheduler = "./vd_cache/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7/scheduler/scheduler_config.json"
        vd_pipe.scheduler = UniPCMultistepScheduler.from_pretrained(path_scheduler)
        args.num_inference_steps = 20 #NEED TO BE PASSED FOR RECONSTRUCTION

        # Set weighting of Dual-Guidance 
        text_image_ratio = .0 # .5 means equally weight text and image, 0 means use only image
        for name, module in vd_pipe.image_unet.named_modules():
            if isinstance(module, DualTransformer2DModel):
                module.mix_ratio = text_image_ratio
                for i, type in enumerate(("text", "image")):
                    if type == "text":
                        module.condition_lengths[i] = 77
                        module.transformer_index_for_condition[i] = 1  # use the second (text) transformer
                    else:
                        module.condition_lengths[i] = 257
                        module.transformer_index_for_condition[i] = 0  # use the first (image) transformer
                        
        # args.unet = vd_pipe.image_unet #NEED TO BE PASSED FOR RECONSTRUCTION
        # args.vae = vd_pipe.vae #NEED TO BE PASSED FOR RECONSTRUCTION
        # args.noise_scheduler = vd_pipe.scheduler #NEED TO BE PASSED FOR RECONSTRUCTION
        # args.vd_pipe = vd_pipe #NEED TO BE PASSED FOR RECONSTRUCTION

In [None]:
print("Model Instantiated!, Model device:", diffusion_prior.device)

In [None]:
trained_model, train_loss, val_loss, best_model_path = train_diffusion(
        diffusion_prior, train_loader, val_loader, args, DEVICE, vd_pipe
    )

In [None]:
torch.cuda.empty_cache()