# 使用扩散模型对大脑解码

In [1]:
import h5py
from PIL import Image
import scipy.io
import argparse, os
import pandas as pd
import PIL
import torch
import numpy as np
from omegaconf import OmegaConf
from tqdm import trange
from einops import rearrange
from torch import autocast
from contextlib import nullcontext
from pytorch_lightning import seed_everything
import sys
from nsd_access.nsda import NSDAccess
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
# from ldm.data.util import AddMiDaS

In [2]:
def load_model_from_config(config, ckpt, gpu, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)
    model.cuda(f"cuda:{gpu}")
    model.eval()
    return model

In [3]:
def load_img_from_arr(img_arr):
    image = Image.fromarray(img_arr).convert("RGB")
    w, h = 512, 512
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

In [4]:
# Set parameters
seed = 42
seed_everything(seed)
imgidx = 0
gpu = 0
method = "bright"
subject= "subj01"
# gandir = f'../data/decoded/gan_recon_img/all_layers/{subject}/streams/'
# captdir = f'../data/decoded/{subject}/captions/'

Global seed set to 42


In [5]:
# Load NSD information
nsd_expdesign = scipy.io.loadmat('../data/nsd/nsddata/experiments/nsd/nsd_expdesign.mat')

# Note that mos of them are 1-base index!
# This is why I subtract 1
sharedix = nsd_expdesign['sharedix'] -1 

nsda = NSDAccess('../data/nsd/')
sf = h5py.File(nsda.stimuli_file, 'r')
sdataset = sf.get('imgBrick')

stims_ave = np.load(f'../data/stim/{subject}/{subject}_stims_ave.npy')

In [6]:
tr_idx = np.zeros_like(stims_ave)
for idx, s in enumerate(stims_ave):
    if s in sharedix:
        tr_idx[idx] = 0
    else:
        tr_idx[idx] = 1

In [7]:
# Load Stable Diffusion Model
config = '../stable-diffusion_v1/configs/stable-diffusion/v1-inference.yaml'
ckpt = '../stable-diffusion_v1/models/ldm/stable-diffusion-v1/sd-v1-4.ckpt'
config = OmegaConf.load(f"{config}")
torch.cuda.set_device(gpu)
model = load_model_from_config(config, f"{ckpt}", gpu)

Loading model from ../stable-diffusion_v1/models/ldm/stable-diffusion-v1/sd-v1-4.ckpt


  pl_sd = torch.load(ckpt, map_location="cpu")


Global Step: 470000
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


In [8]:
n_samples = 1
ddim_steps = 50
ddim_eta = 0.0
strength = 0.8
scale = 5.0
n_iter = 5
precision = 'autocast'
precision_scope = autocast if precision == "autocast" else nullcontext
batch_size = n_samples
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [9]:
# 解码结果输出路径
outdir = f'../data/output/image-{method}/{subject}/'
os.makedirs(outdir, exist_ok=True)
sample_path = os.path.join(outdir, f"samples")
os.makedirs(sample_path, exist_ok=True)

In [10]:
precision = 'autocast'
device = torch.device(f"cuda:{gpu}") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)

sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)

assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(strength * ddim_steps)
print(f"target t_enc is {t_enc} steps")

target t_enc is 40 steps


In [15]:
# Load Prediction (C, InitLatent, Depth(cc))
captdir = f'../../Brain-Decoded/{subject}'
dptdir = f'../../data/decoded/{subject}/dpt_fromemb/'

# C
captions = pd.read_csv(f'{captdir}/captions_brain.csv', sep='\t',header=None)

In [139]:
# Save Directories
outdir = f'../data/output/image-bright/{subject}/'
os.makedirs(outdir, exist_ok=True)
sample_path = os.path.join(outdir, f"samples")
os.makedirs(sample_path, exist_ok=True)

In [166]:
imgidx = 5

In [167]:
# Load z (Image)
imgidx_te = np.where(tr_idx==0)[0][imgidx] # Extract test image index
idx73k= stims_ave[imgidx_te]
Image.fromarray(np.squeeze(sdataset[idx73k,:,:,:]).astype(np.uint8)).save(
    os.path.join(sample_path, f"{imgidx:05}_org.png"))    

if method in ['init','text']:
    roi_latent = 'early'
    init_latent = np.load(f'../data/decoded/{subject}/{subject}_{roi_latent}_brain_embs_init_latent.npy')
    imgarr = torch.Tensor(init_latent[imgidx,:].reshape(4,40,40)).unsqueeze(0).to('cuda')

    # Generate image from Z
    precision_scope = autocast if precision == "autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                x_samples = model.decode_first_stage(imgarr)
                x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                for x_sample in x_samples:
                    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
    im = Image.fromarray(x_sample.astype(np.uint8)).resize((512,512))
    im = np.array(im)

elif method == 'gan':
    ganpath = f'{gandir}/recon_image_normalized-VGG19-fc8-{subject}-streams-{imgidx:06}.tiff'
    im = Image.open(ganpath).resize((512,512))
    im = np.array(im)

init_image = load_img_from_arr(im).to('cuda')
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space

In [168]:
# Load c (Semantics)
if method == 'init':
    roi_c = 'ventral'
    c_embs = np.load(f'../data/decoded/{subject}/{subject}_{roi_c}_brain_embs_c.npy')
    carr = c_embs[imgidx,:].reshape(77,768)
    c = torch.Tensor(carr).unsqueeze(0).to('cuda')
elif method in ['text','depth','bright']:
    captions = pd.read_csv(f'{captdir}/captions_brain.csv', sep='\t',header=None)
    c = model.get_learned_conditioning(captions.iloc[imgidx][0]).to('cuda')

In [169]:
# Load cc(depth/bright)
cc = torch.Tensor(np.load(f'{dptdir}/{imgidx:06}.npy')).to('cuda')

In [170]:
# Generate image from Z (image) + C (semantics) + cc(bright)
base_count = 0
with torch.no_grad():
    with precision_scope("cuda"):
        with model.ema_scope():
            for n in trange(n_iter, desc="Sampling"):
                # c = model.cond_stage_model.encode(prompt).mean(axis=0).unsqueeze(0)

                uc = model.get_learned_conditioning(str(cc))

                # encode (scaled latent)
                z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                # decode it
                samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale,
                                        unconditional_conditioning=uc,)

                x_samples = model.decode_first_stage(samples)
                x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                for x_sample in x_samples:
                    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                Image.fromarray(x_sample.astype(np.uint8)).save(
                    os.path.join(sample_path, f"{imgidx:05}_{base_count:03}.png"))    
                base_count += 1

Sampling:   0%|          | 0/5 [00:00<?, ?it/s]

Running DDIM Sampling with 40 timesteps



Decoding image:   0%|          | 0/40 [00:00<?, ?it/s][A
Decoding image:   5%|▌         | 2/40 [00:00<00:02, 15.16it/s][A
Decoding image:  10%|█         | 4/40 [00:00<00:02, 16.11it/s][A
Decoding image:  15%|█▌        | 6/40 [00:00<00:02, 16.45it/s][A
Decoding image:  20%|██        | 8/40 [00:00<00:01, 16.52it/s][A
Decoding image:  25%|██▌       | 10/40 [00:00<00:01, 16.62it/s][A
Decoding image:  30%|███       | 12/40 [00:00<00:01, 16.69it/s][A
Decoding image:  35%|███▌      | 14/40 [00:00<00:01, 16.67it/s][A
Decoding image:  40%|████      | 16/40 [00:00<00:01, 16.72it/s][A
Decoding image:  45%|████▌     | 18/40 [00:01<00:01, 16.75it/s][A
Decoding image:  50%|█████     | 20/40 [00:01<00:01, 16.78it/s][A
Decoding image:  55%|█████▌    | 22/40 [00:01<00:01, 16.80it/s][A
Decoding image:  60%|██████    | 24/40 [00:01<00:00, 16.17it/s][A
Decoding image:  65%|██████▌   | 26/40 [00:01<00:00, 16.37it/s][A
Decoding image:  70%|███████   | 28/40 [00:01<00:00, 16.53it/s][A
Decodin

Running DDIM Sampling with 40 timesteps



Decoding image:   0%|          | 0/40 [00:00<?, ?it/s][A
Decoding image:   5%|▌         | 2/40 [00:00<00:02, 16.96it/s][A
Decoding image:  10%|█         | 4/40 [00:00<00:02, 16.91it/s][A
Decoding image:  15%|█▌        | 6/40 [00:00<00:02, 16.88it/s][A
Decoding image:  20%|██        | 8/40 [00:00<00:01, 16.90it/s][A
Decoding image:  25%|██▌       | 10/40 [00:00<00:01, 16.90it/s][A
Decoding image:  30%|███       | 12/40 [00:00<00:01, 16.90it/s][A
Decoding image:  35%|███▌      | 14/40 [00:00<00:01, 16.91it/s][A
Decoding image:  40%|████      | 16/40 [00:00<00:01, 16.89it/s][A
Decoding image:  45%|████▌     | 18/40 [00:01<00:01, 16.90it/s][A
Decoding image:  50%|█████     | 20/40 [00:01<00:01, 16.91it/s][A
Decoding image:  55%|█████▌    | 22/40 [00:01<00:01, 16.92it/s][A
Decoding image:  60%|██████    | 24/40 [00:01<00:00, 16.93it/s][A
Decoding image:  65%|██████▌   | 26/40 [00:01<00:00, 16.92it/s][A
Decoding image:  70%|███████   | 28/40 [00:01<00:00, 16.93it/s][A
Decodin

Running DDIM Sampling with 40 timesteps



Decoding image:   0%|          | 0/40 [00:00<?, ?it/s][A
Decoding image:   5%|▌         | 2/40 [00:00<00:02, 16.97it/s][A
Decoding image:  10%|█         | 4/40 [00:00<00:02, 16.93it/s][A
Decoding image:  15%|█▌        | 6/40 [00:00<00:02, 16.92it/s][A
Decoding image:  20%|██        | 8/40 [00:00<00:01, 16.91it/s][A
Decoding image:  25%|██▌       | 10/40 [00:00<00:01, 16.90it/s][A
Decoding image:  30%|███       | 12/40 [00:00<00:01, 16.90it/s][A
Decoding image:  35%|███▌      | 14/40 [00:00<00:01, 16.91it/s][A
Decoding image:  40%|████      | 16/40 [00:00<00:01, 16.91it/s][A
Decoding image:  45%|████▌     | 18/40 [00:01<00:01, 16.89it/s][A
Decoding image:  50%|█████     | 20/40 [00:01<00:01, 16.88it/s][A
Decoding image:  55%|█████▌    | 22/40 [00:01<00:01, 16.91it/s][A
Decoding image:  60%|██████    | 24/40 [00:01<00:00, 16.89it/s][A
Decoding image:  65%|██████▌   | 26/40 [00:01<00:00, 16.91it/s][A
Decoding image:  70%|███████   | 28/40 [00:01<00:00, 16.93it/s][A
Decodin

Running DDIM Sampling with 40 timesteps



Decoding image:   0%|          | 0/40 [00:00<?, ?it/s][A
Decoding image:   5%|▌         | 2/40 [00:00<00:02, 16.91it/s][A
Decoding image:  10%|█         | 4/40 [00:00<00:02, 16.92it/s][A
Decoding image:  15%|█▌        | 6/40 [00:00<00:02, 16.90it/s][A
Decoding image:  20%|██        | 8/40 [00:00<00:01, 16.88it/s][A
Decoding image:  25%|██▌       | 10/40 [00:00<00:01, 16.89it/s][A
Decoding image:  30%|███       | 12/40 [00:00<00:01, 16.90it/s][A
Decoding image:  35%|███▌      | 14/40 [00:00<00:01, 16.91it/s][A
Decoding image:  40%|████      | 16/40 [00:00<00:01, 16.91it/s][A
Decoding image:  45%|████▌     | 18/40 [00:01<00:01, 16.91it/s][A
Decoding image:  50%|█████     | 20/40 [00:01<00:01, 16.91it/s][A
Decoding image:  55%|█████▌    | 22/40 [00:01<00:01, 16.91it/s][A
Decoding image:  60%|██████    | 24/40 [00:01<00:00, 16.89it/s][A
Decoding image:  65%|██████▌   | 26/40 [00:01<00:00, 16.92it/s][A
Decoding image:  70%|███████   | 28/40 [00:01<00:00, 16.93it/s][A
Decodin

Running DDIM Sampling with 40 timesteps



Decoding image:   0%|          | 0/40 [00:00<?, ?it/s][A
Decoding image:   5%|▌         | 2/40 [00:00<00:02, 16.90it/s][A
Decoding image:  10%|█         | 4/40 [00:00<00:02, 16.86it/s][A
Decoding image:  15%|█▌        | 6/40 [00:00<00:02, 16.85it/s][A
Decoding image:  20%|██        | 8/40 [00:00<00:01, 16.86it/s][A
Decoding image:  25%|██▌       | 10/40 [00:00<00:01, 16.87it/s][A
Decoding image:  30%|███       | 12/40 [00:00<00:01, 16.88it/s][A
Decoding image:  35%|███▌      | 14/40 [00:00<00:01, 16.87it/s][A
Decoding image:  40%|████      | 16/40 [00:00<00:01, 16.89it/s][A
Decoding image:  45%|████▌     | 18/40 [00:01<00:01, 16.88it/s][A
Decoding image:  50%|█████     | 20/40 [00:01<00:01, 16.82it/s][A
Decoding image:  55%|█████▌    | 22/40 [00:01<00:01, 16.85it/s][A
Decoding image:  60%|██████    | 24/40 [00:01<00:00, 16.88it/s][A
Decoding image:  65%|██████▌   | 26/40 [00:01<00:00, 16.90it/s][A
Decoding image:  70%|███████   | 28/40 [00:01<00:00, 16.92it/s][A
Decodin