In [1]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys
import json
import pickle
import argparse
import numpy as np
import math
from einops import rearrange
import time
import random
import string
import h5py
from tqdm.auto import tqdm
import webdataset as wds

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator

from PIL import Image, ImageDraw, ImageFont

# SDXL unCLIP requires code from https://github.com/Stability-AI/generative-models/tree/main
sys.path.append('generative_models/')
import sgm
from generative_models.sgm.modules.encoders.modules import FrozenOpenCLIPImageEmbedder, FrozenOpenCLIPEmbedder2
from generative_models.sgm.models.diffusion import DiffusionEngine
from generative_models.sgm.util import append_dims
from omegaconf import OmegaConf

from versatile_diffusion import Reconstructor
# tf32 data type is faster than standard float32
torch.backends.cuda.matmul.allow_tf32 = True

# custom functions #
import utils
from models import *

### Multi-GPU config ###
local_rank = os.getenv('RANK')
if local_rank is None: 
    local_rank = 0
else:
    local_rank = int(local_rank)
print("LOCAL RANK ", local_rank)  

accelerator = Accelerator(split_batches=False, mixed_precision="fp16")
device = accelerator.device
print("device:",device)

  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.4.1+cu121)
    Python  3.11.6 (you have 3.11.10)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details
  @custom_fwd(cast_inputs=torch.float16 if _triton_softmax_fp16_enabled else None)
  @custom_bwd
  @torch.cuda.amp.custom_fwd
  @torch.cuda.amp.custom_bwd
  torch.utils._pytree._register_pytree_node(


LOCAL RANK  0
device: cuda


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [2]:
# if running this interactively, can specify jupyter_args here for argparser to use
if utils.is_interactive():
    # model_name = "final_subj01_pretrained_40sess_24bs"
    model_name = "p_trained_subj01_40sess_hypatia_new_vd_dual_proj"
    print("model_name:", model_name)

    # other variables can be specified in the following string:
    jupyter_args = f"--data_path=/weka/proj-medarc/shared/umn-imagery \
                    --cache_dir=/weka/proj-medarc/shared/cache \
                    --model_name={model_name} --subj=1 \
                    --hidden_dim=1024 --n_blocks=4 --mode vision --blurry_recon \
                    --imagery_data_path=/weka/proj-medarc/shared/umn-imagery --dual_guidance"
    print(jupyter_args)
    jupyter_args = jupyter_args.split()
    
    from IPython.display import clear_output # function to clear print outputs in cell
    %load_ext autoreload 
    # this allows you to change functions in models.py or utils.py and have this notebook automatically update with your revisions
    %autoreload 2 

model_name: p_trained_subj01_40sess_hypatia_new_vd_dual_proj
--data_path=/weka/proj-medarc/shared/umn-imagery                     --cache_dir=/weka/proj-medarc/shared/cache                     --model_name=p_trained_subj01_40sess_hypatia_new_vd_dual_proj --subj=1                     --hidden_dim=1024 --n_blocks=4 --mode vision --blurry_recon                     --imagery_data_path=/weka/proj-medarc/shared/umn-imagery --dual_guidance


In [3]:
parser = argparse.ArgumentParser(description="Model Training Configuration")
parser.add_argument(
    "--model_name", type=str, default="testing",
    help="will load ckpt for model found in ../train_logs/model_name",
)
parser.add_argument(
    "--data_path", type=str, default=os.getcwd(),
    help="Path to where NSD data is stored / where to download it to",
)
parser.add_argument(
    "--cache_dir", type=str, default=os.getcwd(),
    help="Path to where misc. files downloaded from huggingface are stored. Defaults to current src directory.",
)
parser.add_argument(
    "--subj",type=int, default=1, choices=[1,2,3,4,5,6,7,8,9,10,11],
    help="Validate on which subject?",
)
parser.add_argument(
    "--blurry_recon",action=argparse.BooleanOptionalAction,default=True,
)
parser.add_argument(
    "--n_blocks",type=int,default=4,
)
parser.add_argument(
    "--hidden_dim",type=int,default=2048,
)
parser.add_argument(
    "--seq_len",type=int,default=1,
)
parser.add_argument(
    "--seed",type=int,default=42,
)
parser.add_argument(
    "--mode",type=str,default="vision",
)
parser.add_argument(
    "--gen_rep",type=int,default=10,
)
parser.add_argument(
    "--dual_guidance",action=argparse.BooleanOptionalAction,default=False,
)
parser.add_argument(
    "--snr",type=float,default=-1,
)
parser.add_argument(
    "--imagery_data_path",type=str, default=None
)

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

# create global variables without the args prefix
for attribute_name in vars(args).keys():
    globals()[attribute_name] = getattr(args, attribute_name)


if seed > 0 and gen_rep == 1:
    # seed all random functions, but only if doing 1 rep
    utils.seed_everything(seed)

if not imagery_data_path:
    imagery_data_path = data_path

# make output directory
os.makedirs("evals",exist_ok=True)
os.makedirs(f"evals/{model_name}",exist_ok=True)

In [4]:
if mode == "synthetic":
    voxels, all_images = utils.load_nsd_synthetic(subject=subj, average=False, nest=True, data_root=imagery_data_path)
elif subj > 8:
    _, _, voxels, all_images = utils.load_imageryrf(subject=subj-8, mode=mode, stimtype="object", average=False, nest=True, split=True, data_root=imagery_data_path)
else:
    voxels, all_images = utils.load_nsd_mental_imagery(subject=subj, mode=mode, stimtype="all", average=False, nest=True, data_root=imagery_data_path)
num_voxels = voxels.shape[-1]

torch.Size([18, 8, 15724]) torch.Size([18, 3, 425, 425])


  x = torch.load(f"{data_root}/preprocessed_data/subject{subject}/nsd_imagery.pt").requires_grad_(False).to("cpu")
  y = torch.load(f"{data_root}/nsddata_stimuli/stimuli/imagery_stimuli_18.pt").requires_grad_(False).to("cpu")


In [5]:
clip_emb_dim = 768
clip_seq_dim = 257
clip_text_seq_dim=77
reconstructor = Reconstructor(device=device, cache_dir=f'{cache_dir}/versatile_diffusion_ckpts')
clip_extractor = reconstructor
clip_variant = "ViT-L-14"


if blurry_recon:
    from diffusers import AutoencoderKL
    autoenc = AutoencoderKL(
        down_block_types=['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
        up_block_types=['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
        block_out_channels=[128, 256, 512, 512],
        layers_per_block=2,
        sample_size=256,
    )
    ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')
    
    autoenc.load_state_dict(ckpt)
    autoenc.eval()
    autoenc.requires_grad_(False)
    autoenc.to(device)
    utils.count_params(autoenc)
    
class MindEyeModule(nn.Module):
    def __init__(self):
        super(MindEyeModule, self).__init__()
    def forward(self, x):
        return x
        
model = MindEyeModule()

class RidgeRegression(torch.nn.Module):
    # make sure to add weight_decay when initializing optimizer
    def __init__(self, input_sizes, out_features, seq_len): 
        super(RidgeRegression, self).__init__()
        self.out_features = out_features
        self.linears = torch.nn.ModuleList([
                torch.nn.Linear(input_size, out_features) for input_size in input_sizes
            ])
    def forward(self, x, subj_idx):
        out = torch.cat([self.linears[subj_idx](x[:,seq]).unsqueeze(1) for seq in range(seq_len)], dim=1)
        return out
        
model.ridge = RidgeRegression([num_voxels], out_features=hidden_dim, seq_len=seq_len)

from diffusers.models.vae import Decoder
from models import BrainNetwork
model.backbone = BrainNetwork(h=hidden_dim, in_dim=hidden_dim, seq_len=seq_len, n_blocks=n_blocks,
                          clip_size=clip_emb_dim, out_dim=clip_emb_dim*clip_seq_dim, 
                          blurry_recon=blurry_recon, text_clip=dual_guidance) 
utils.count_params(model.ridge)
utils.count_params(model.backbone)
utils.count_params(model)

# setup diffusion prior network
out_dim = clip_emb_dim
depth = 6
dim_head = 64
heads = clip_emb_dim//64 # heads * dim_head = clip_emb_dim
timesteps = 100

prior_network = PriorNetwork(
        dim=out_dim,
        depth=depth,
        dim_head=dim_head,
        heads=heads,
        causal=False,
        num_tokens = clip_seq_dim,
        learned_query_mode="pos_emb"
    )
model.diffusion_prior = BrainDiffusionPrior(
    net=prior_network,
    image_embed_dim=out_dim,
    condition_on_text_encodings=False,
    timesteps=timesteps,
    cond_drop_prob=0.2,
    image_embed_scale=None,
)
if dual_guidance:
    prior_network_txt = PriorNetwork(
            dim=out_dim,
            depth=depth,
            dim_head=dim_head,
            heads=heads,
            causal=False,
            num_tokens = 77,
            learned_query_mode="pos_emb"
        )


    model.diffusion_prior_txt = BrainDiffusionPrior(
        net=prior_network_txt,
        image_embed_dim=out_dim,
        condition_on_text_encodings=False,
        timesteps=timesteps,
        cond_drop_prob=0.2,
        image_embed_scale=None,
    )
model.to(device)

utils.count_params(model.diffusion_prior)
if dual_guidance:
    utils.count_params(model.diffusion_prior_txt)
utils.count_params(model)

# Load pretrained model ckpt
tag='last'
outdir = os.path.abspath(f'../train_logs/{model_name}')
print(f"\n---loading {outdir}/{tag}.pth ckpt---\n")
try:
    checkpoint = torch.load(outdir+f'/{tag}.pth', map_location='cpu')
    state_dict = checkpoint['model_state_dict']
    layer_mapping = {
        "backbone.bupsampler.mid_block.attentions.0.to_q.weight": "backbone.bupsampler.mid_block.attentions.0.query.weight",
        "backbone.bupsampler.mid_block.attentions.0.to_q.bias": "backbone.bupsampler.mid_block.attentions.0.query.bias",
        "backbone.bupsampler.mid_block.attentions.0.to_k.weight": "backbone.bupsampler.mid_block.attentions.0.key.weight",
        "backbone.bupsampler.mid_block.attentions.0.to_k.bias": "backbone.bupsampler.mid_block.attentions.0.key.bias",
        "backbone.bupsampler.mid_block.attentions.0.to_v.weight": "backbone.bupsampler.mid_block.attentions.0.value.weight",
        "backbone.bupsampler.mid_block.attentions.0.to_v.bias": "backbone.bupsampler.mid_block.attentions.0.value.bias",
        "backbone.bupsampler.mid_block.attentions.0.to_out.0.weight": "backbone.bupsampler.mid_block.attentions.0.proj_attn.weight",
        "backbone.bupsampler.mid_block.attentions.0.to_out.0.bias": "backbone.bupsampler.mid_block.attentions.0.proj_attn.bias"
    }
    new_ckpt = {}
    for old_key, value in state_dict.items():
        new_key = layer_mapping.get(old_key, old_key)  # Get the new key, or use the old key if not in mapping
        new_ckpt[new_key] = value
    model = torch.compile(model)
    model.load_state_dict(new_ckpt, strict=True)
    del checkpoint
except: # probably ckpt is saved using deepspeed format
    import deepspeed
    state_dict = deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir=outdir, tag=tag)
    model.load_state_dict(state_dict, strict=False)
    del state_dict
print("ckpt loaded!")

Reconstructor: Loading model... fp16: True
Taking new code 2.

#######################
# Running in eps mode #
#######################

making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


  sd = torch.load(cfg.pth, map_location=map_location)


Load pth from /weka/proj-medarc/shared/cache/versatile_diffusion_ckpts/kl-f8.pth
Load autoencoderkl with total 83653863 parameters,72921.759 parameter sum.
Load optimus_bert_connector with total 109489920 parameters,19107.967 parameter sum.
Load optimus_gpt2_connector with total 132109824 parameters,19036.291 parameter sum.
Load pth from /weka/proj-medarc/shared/cache/versatile_diffusion_ckpts/optimus-vae.pth
Load optimus_vae_next with total 241599744 parameters,-344611.688 parameter sum.
Load clip_image_context_encoder with total 427616513 parameters,64007.510 parameter sum.
Load clip_text_context_encoder with total 427616513 parameters,64007.510 parameter sum.
Load openai_unet_2d_next with total 859520964 parameters,99818.335 parameter sum.
Load openai_unet_0d_next with total 1706797888 parameters,249893.201 parameter sum.
Load vd_v2_0 with total 3746805485 parameters,206036.626 parameter sum.


  sd = torch.load(f'{cache_dir}/vd-four-flow-v1-0-fp16.pth', map_location='cpu')
  ckpt = torch.load(f'{cache_dir}/sd_image_var_autoenc.pth')


param counts:
83,653,863 total
0 trainable
param counts:
16,102,400 total
16,102,400 trainable
param counts:
280,407,420 total
280,407,420 trainable
param counts:
296,509,820 total
296,509,820 trainable
param counts:
56,055,184 total
56,055,168 trainable
param counts:
55,640,464 total
55,640,448 trainable
param counts:
408,205,468 total
408,205,436 trainable

---loading /weka/proj-fmri/ckadirt/MindEye_Imagery/train_logs/p_trained_subj01_40sess_hypatia_new_vd_dual_proj/last.pth ckpt---



  checkpoint = torch.load(outdir+f'/{tag}.pth', map_location='cpu')


ckpt loaded!


In [6]:
# print('Creating versatile diffusion reconstruction pipeline...')
# from diffusers import VersatileDiffusionDualGuidedPipeline, UniPCMultistepScheduler
# from diffusers.models import DualTransformer2DModel
# # vd_cache_dir = "/home/naxos2-raid25/kneel027/home/kneel027/fMRI-reconstruction-NSD/versatile_diffusion"
# # try:
# #     vd_pipe =  VersatileDiffusionDualGuidedPipeline.from_pretrained(cache_dir).to(device)
# # except:
# print("Downloading Versatile Diffusion to", cache_dir)
# vd_pipe =  VersatileDiffusionDualGuidedPipeline.from_pretrained(
#         "shi-labs/versatile-diffusion",
#         torch_dtype=torch.float16,
#         cache_dir = cache_dir).to(device)
# vd_pipe.remove_unused_weights()
# vd_pipe.image_unet.eval()
# vd_pipe.vae.eval()
# vd_pipe.image_unet.requires_grad_(False)
# vd_pipe.vae.requires_grad_(False)

# vd_pipe.scheduler = UniPCMultistepScheduler.from_pretrained(cache_dir + "/models--shi-labs--versatile-diffusion/snapshots/2926f8e11ea526b562cd592b099fcf9c2985d0b7", subfolder="scheduler")
# num_inference_steps = 20

# # Set weighting of Dual-Guidance 
# if dual_guidance:
#     text_image_ratio = .4 # .5 means equally weight text and image, 0 means use only image
# else:
#     text_image_ratio = 0.
# for name, module in vd_pipe.image_unet.named_modules():
#     if isinstance(module, DualTransformer2DModel):
#         module.mix_ratio = text_image_ratio
#         for i, type in enumerate(("text", "image")):
#             if type == "text":
#                 module.condition_lengths[i] = 77
#                 module.transformer_index_for_condition[i] = 1  # use the second (text) transformer
#             else:
#                 module.condition_lengths[i] = 257
#                 module.transformer_index_for_condition[i] = 0  # use the first (image) transformer

# unet = vd_pipe.image_unet
# vae = vd_pipe.vae
# noise_scheduler = vd_pipe.scheduler

In [7]:
# num_reconstructors = 4
# reconstructors = []
# reconstructors.append(reconstructor)
# for i in range(num_reconstructors):
#     if i == 0:
#         reconstructors.append(reconstructor)
#     else:
#         reconstructors.append(Reconstructor(device=device, cache_dir=f'{cache_dir}/versatile_diffusion_ckpts'))

In [8]:
# import concurrent.futures
# import time
# def reconstruct_task(i, reconstructor, blurred_image, prior_out, prior_out_txt, seed):
#     image_pil = transforms.ToPILImage()(torch.Tensor(blurred_image[0]))
#     return reconstructor.reconstruct(
#         image=image_pil,
#         c_i=prior_out[i],
#         c_t=prior_out_txt[i],
#         n_samples=1,
#         textstrength=0.4,
#         strength=0.85,
#         seed=seed
#     )

In [None]:
final_recons = None
final_predcaptions = None
final_clipvoxels = None
final_blurryrecons = None



recons_per_sample = 16


for rep in tqdm(range(gen_rep)):
    utils.seed_everything(seed = random.randint(0,10000000))
    # get all reconstructions    
    # all_images = None
    all_blurryrecons = None
    all_recons = None
    all_predcaptions = []
    all_clipvoxels = None
    
    minibatch_size = 1
    num_samples_per_image = 1
    plotting = False
    
    with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):
        for idx in tqdm(range(0,voxels.shape[0]), desc="sample loop"):
            voxel = voxels[idx]
            voxel = torch.mean(voxel, dim=0).to(device).unsqueeze(0).unsqueeze(0)
            print(voxel.shape)
            voxel_ridge = model.ridge(voxel,0) # 0th index of subj_list
            backbone, backbone_txt, clip_voxels, blurry_image_enc = model.backbone(voxel_ridge)
            blurry_image_enc = blurry_image_enc[0]
                
                    
            # Save retrieval submodule outputs
            if all_clipvoxels is None:
                all_clipvoxels = clip_voxels.to('cpu')
            else:
                all_clipvoxels = torch.vstack((all_clipvoxels, clip_voxels.to('cpu')))
            
            # Feed voxels through versatile diffusion diffusion prior
            backbone = backbone.repeat(recons_per_sample, 1, 1)
            prior_out = model.diffusion_prior.p_sample_loop(backbone.shape, 
                            text_cond = dict(text_embed = backbone), 
                            cond_scale = 1., timesteps = 20)
            prior_out_txt = None
            if dual_guidance:
                backbone_txt = backbone_txt.repeat(recons_per_sample, 1, 1)
                prior_out_txt = model.diffusion_prior_txt.p_sample_loop(backbone_txt.shape, 
                                text_cond = dict(text_embed = backbone_txt), 
                                cond_scale = 1., timesteps = 20)
            # pred_caption_emb = clip_convert(prior_out)
            # generated_ids = clip_text_model.generate(pixel_values=pred_caption_emb, max_length=20)
            # generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
            # all_predcaptions = np.hstack((all_predcaptions, generated_caption))
            
            if blurry_recon:
                blurred_image = (autoenc.decode(blurry_image_enc/0.18215).sample/ 2 + 0.5).clamp(0,1)
                
                im = torch.Tensor(blurred_image)
                if all_blurryrecons is None:
                    all_blurryrecons = im.cpu()
                else:
                    all_blurryrecons = torch.vstack((all_blurryrecons, im.cpu()))
                if plotting:
                    plt.figure(figsize=(2,2))
                    plt.imshow(transforms.ToPILImage()(im))
                    plt.axis('off')
                    plt.show()
            
            # Feed outputs through versatile diffusion
            start_time = time.time()
            # samples_multi = [reconstructor.reconstruct(
            #                     image=transforms.ToPILImage()(torch.Tensor(blurred_image[0])),
            #                     c_i=prior_out[i],
            #                     c_t=prior_out_txt[i],
            #                     n_samples=1,
            #                     textstrength=0.4,
            #                     strength=0.85,
            #                     seed=seed) for i in range(recons_per_sample)]
            samples_multi = reconstructor.reconstruct_batch(
                                image=transforms.ToPILImage()(torch.Tensor(blurred_image[0])),
                                c_i=prior_out,
                                c_t=prior_out_txt,
                                # n_samples=1,
                                textstrength=0.4,
                                strength=0.85,
                                seed=seed)
            # samples_multi = []
            # with concurrent.futures.ThreadPoolExecutor(max_workers=num_reconstructors) as executor:
            #     futures = []
            #     for i in range(recons_per_sample):
            #         # Assign each task to a reconstructor in a round-robin fashion
            #         reconstructor = reconstructors[i % num_reconstructors]
            #         futures.append(executor.submit(reconstruct_task, i, reconstructor, blurred_image, prior_out, prior_out_txt, seed))
                
            #     # Collect the results as they complete
            #     for future in concurrent.futures.as_completed(futures):
            #         samples_multi.append(future.result())
            print(time.time()-start_time)
            samples = utils.pick_best_recon(samples_multi, clip_voxels, clip_extractor)
            if isinstance(samples, PIL.Image.Image):
                samples = transforms.ToTensor()(samples)
            samples = samples.unsqueeze(0)
            
            if all_recons is None:
                all_recons = samples.cpu()
            else:
                all_recons = torch.vstack((all_recons, samples.cpu()))
            if plotting:
                for s in range(num_samples_per_image):
                    plt.figure(figsize=(2,2))
                    plt.imshow(transforms.ToPILImage()(samples[s]))
                    plt.axis('off')
                    plt.show()
                    
            if plotting: 
                print(model_name)
                err # dont actually want to run the whole thing with plotting=True

            
    
        # resize outputs before saving
        imsize = 256
        # saving
        # print(all_recons.shape)
        # torch.save(all_images,"evals/all_images.pt")
        if final_recons is None:
            final_recons = all_recons.unsqueeze(1)
            # final_predcaptions = np.expand_dims(all_predcaptions, axis=1)
            final_clipvoxels = all_clipvoxels.unsqueeze(1)
            if blurry_recon:
                final_blurryrecons = all_blurryrecons.unsqueeze(1)
        else:
            final_recons = torch.cat((final_recons, all_recons.unsqueeze(1)), dim=1)
            # final_predcaptions = np.concatenate((final_predcaptions, np.expand_dims(all_predcaptions, axis=1)), axis=1)
            final_clipvoxels = torch.cat((final_clipvoxels, all_clipvoxels.unsqueeze(1)), dim=1)
            if blurry_recon:
                final_blurryrecons = torch.cat((all_blurryrecons.unsqueeze(1),final_blurryrecons), dim = 1)
        
if blurry_recon:
    torch.save(final_blurryrecons,f"evals/{model_name}/{model_name}_all_blurryrecons_{mode}.pt")
torch.save(final_recons,f"evals/{model_name}/{model_name}_all_recons_{mode}.pt")
# torch.save(final_predcaptions,f"evals/{model_name}/{model_name}_all_predcaptions_{mode}.pt")
torch.save(final_clipvoxels,f"evals/{model_name}/{model_name}_all_clipvoxels_{mode}.pt")
print(f"saved {model_name} mi outputs!")

# if not utils.is_interactive():
#     sys.exit(0)


  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float16):

sample loop:   0%|                                                                                                                                   | 0/18 [00:00<?, ?it/s][A

torch.Size([1, 1, 15724])


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

31.22403907775879



sample loop:   6%|██████▊                                                                                                                    | 1/18 [00:32<09:08, 32.24s/it][A

torch.Size([1, 1, 15724])


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

31.149746656417847



sample loop:  11%|█████████████▋                                                                                                             | 2/18 [01:04<08:35, 32.20s/it][A

torch.Size([1, 1, 15724])


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

31.18331527709961



sample loop:  17%|████████████████████▌                                                                                                      | 3/18 [01:36<08:03, 32.21s/it][A

torch.Size([1, 1, 15724])


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

31.140665769577026



sample loop:  22%|███████████████████████████▎                                                                                               | 4/18 [02:08<07:30, 32.19s/it][A

torch.Size([1, 1, 15724])


sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

sampling loop time step:   0%|          | 0/19 [00:00<?, ?it/s]

In [None]:
# final_recons.shape

In [2]:
# import torch
# import matplotlib.pyplot as plt
# import numpy as np

# # Assuming final_recons is your tensor of shape [18, 1, 3, 512, 512]
# # final_recons = ... (your tensor)

# # Select the first image in the batch and the first reconstruction
# # Shape after selection: [3, 512, 512]
# image_tensor = final_recons[0, 0]

# # If the tensor is on a GPU, move it to CPU
# if image_tensor.is_cuda:
#     image_tensor = image_tensor.cpu()

# # Detach the tensor from the computation graph and convert to NumPy
# image_np = image_tensor.detach().numpy()

# # Transpose the tensor to have shape [512, 512, 3] for plotting
# image_np = np.transpose(image_np, (1, 2, 0))

# # Optional: Normalize the image to [0, 1] if it's not already
# # This step depends on how your data is scaled
# # Uncomment the following lines if normalization is needed
# # min_val = image_np.min()
# # max_val = image_np.max()
# # image_np = (image_np - min_val) / (max_val - min_val)

# # Ensure the image has valid pixel values
# image_np = np.clip(image_np, 0, 1)

# # Plot the image
# plt.figure(figsize=(6, 6))
# plt.imshow(image_np)
# plt.title("First Image from final_recons")
# plt.axis('off')  # Hide axis
# plt.show()


In [3]:

if not utils.is_interactive():
    sys.exit(0)

NameError: name 'utils' is not defined

In [None]:
print(prior_out.shape, prior_out_txt.shape)