## Load model


In [1]:
import yaml
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from IPython.display import display, HTML
import ipywidgets as widgets
from diffusers.models import AutoencoderKL

from diffusion import create_diffusion
from isolated_nwm_infer import model_forward_wrapper
from misc import transform
from models import CDiT_models
from datasets import TrainingDataset

EXP_NAME = 'nwm_cdit_xl'
MODEL_PATH = f'logs/{EXP_NAME}/checkpoints/0100000.pth.tar'

with open("config/data_config.yaml", "r") as f:
    default_config = yaml.safe_load(f)
config = default_config

with open(f'config/{EXP_NAME}.yaml', "r") as f:
    user_config = yaml.safe_load(f)
config.update(user_config)
latent_size = config['image_size'] // 8

print("loading model")
model = CDiT_models[config['model']](input_size=latent_size, context_size=config['context_size'])
ckp = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
print(model.load_state_dict(ckp["ema"], strict=True))
model.eval()
device = 'cuda:0'
model.to(device)
model = torch.compile(model)

diffusion = create_diffusion(str(250))
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device)
latent_size = config['image_size'] // 8



loading model
<All keys matched successfully>


## Choose starting image

In [2]:
def url_to_pil_image(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert('RGB')
    return img

def load_internet_image(url):
    from torchvision import transforms
    _transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
    ])
    img = url_to_pil_image(url)
    x_start = _transform(img)
    return x_start.unsqueeze(0).expand(config['context_size'], x_start.shape[0], x_start.shape[1], x_start.shape[2])

# Jupyter Notebook Cell


# List of image links
image_links = [
    #'https://raw.githubusercontent.com/amirbar/amirbar.github.io/refs/heads/master/images/recon.png',
    #'https://raw.githubusercontent.com/amirbar/amirbar.github.io/refs/heads/master/images/scand.png',
    #'https://raw.githubusercontent.com/amirbar/amirbar.github.io/refs/heads/master/images/sacson.png',
    #'https://raw.githubusercontent.com/amirbar/amirbar.github.io/refs/heads/master/images/tartan.png',
    
	#add customized images 
    #known images different from official website
    #'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/frame.png',
    
    # unknown warehouse (collapse)
    'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/realistic-warehouse-photo1.png',
    'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/warehouse2.png',
	#'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/obstracle+warehouse3.jpg',
    #'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/Resized.jpg',
    #'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/realwarehouse.png',
    #'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/Low.png',
	
	#unknown enviroment from internet
    #'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/livingroom.jpg',
   
	# unknown environment from offcial website
    #'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/chateau.png',
    #unknown From roboflow
    #'https://raw.githubusercontent.com/FBI-openup/nwm/main/warehouse%20Images/cornel.png'
]

# Output widget to hold the selected link
output = widgets.Output()
x_start_link = None  # This will hold the selected link

# Function to handle image click
def on_image_click(link):
    global x_start_link
    x_start_link = link
    with output:
        output.clear_output()
        print(f"Selected image link:\n{x_start_link}")

# Create HBox of images
image_buttons = []
for link in image_links:
    img = widgets.Button(
        description='click',
        layout=widgets.Layout(width='150px', height='20px', padding='0'),
        style={'button_color': 'lightgray'}
    )

    img._dom_classes += ('image-button',)
    img_link = link  # capture current link in closure

    def on_click(b, link=img_link):
        on_image_click(link)

    img.on_click(on_click)

    # Embed image using HTML style
    img_html = f'<img src="{link}" width="150px" height="150px">'
    img_html_widget = widgets.HTML(value=img_html)
    image_buttons.append(widgets.VBox([img_html_widget, img]))

# Display the gallery
display(widgets.HBox(image_buttons))
display(output)


HBox(children=(VBox(children=(HTML(value='<img src="https://raw.githubusercontent.com/FBI-openup/nwm/main/ware…

Output()

## Visualize navigation commands

In [3]:
x_start = load_internet_image(x_start_link)
print(x_start_link)
commands = {
    'Forward': [1,0,0],
    'Rotate Right': [0,0,-0.5],
    'Rotate Left': [0,0,0.5],
}
preds = {}

def reset():
    x_cond_pixels = x_start
    reconstructed_image=x_cond_pixels.to(device)
    preds['x_cond_pixels_display'] = (reconstructed_image[-1] * 127.5 + 127.5).clamp(0, 255).permute(1, 2, 0).to("cpu", dtype=torch.uint8).numpy()
    preds['x_cond_pixels'] = x_cond_pixels
    preds['video'] = [preds['x_cond_pixels_display']]


reset()
display(Image.fromarray(preds['x_cond_pixels_display']))
Image.fromarray(preds['x_cond_pixels_display']).save('sacson.png')

output = widgets.Output()
display(output)
rel_t = (torch.ones(1)*0.0078125).to(device)

@output.capture()
def update_image(b):

    if b.description == 'Reset':
        print("Reset clicked!")
        output.clear_output(wait=False)
        reset()
        return

    print("Button clicked!")
    y = commands[b.description]
    y = torch.tensor(y).to(device).unsqueeze(0)

    print("You entered:", b.description)
    x_cond_pixels = preds['x_cond_pixels'][-4:].unsqueeze(0).to(device)
    samples = model_forward_wrapper((model, diffusion, vae), x_cond_pixels, y, None, latent_size, device, config["context_size"], num_goals=1, rel_t=rel_t, progress=True)
    x_cond_pixels = samples # torch.clip(samples, -1., 1.)
    preds['x_cond_pixels'] = torch.cat([preds['x_cond_pixels'].to(x_cond_pixels), x_cond_pixels], dim=0)
    samples = (samples * 127.5 + 127.5).permute(0, 2, 3, 1).clamp(0,255).to("cpu", dtype=torch.uint8).numpy()
    display(Image.fromarray(samples[0]))
    preds['video'].append(samples[0])

buttons = []
for o in ["Forward", "Rotate Left", "Rotate Right", "Reset"]:
    b = widgets.Button(description=o)
    b.on_click(update_image)
    display(b)
    buttons.append(b)

MissingSchema: Invalid URL 'None': No scheme supplied. Perhaps you meant https://None?

# Generate a video

In [None]:

%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import numpy as np
# np array with shape (frames, height, width, channels)
video = np.array(preds['video'])

fig = plt.figure()
im = plt.imshow(video[0,:,:,:])

plt.close() # this is required to not display the generated image

def init():
    im.set_data(video[0,:,:,:])

def animate(i):
    im.set_data(video[i,:,:,:])
    return im

anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],
                               interval=500)
HTML(anim.to_html5_video())


In [None]:
# optional: load from dataset
# dataloaders = {}

# for dataset_name in config["datasets"]:
#     data_config = config["datasets"][dataset_name]
#     for data_split_type in ["test"]: #["test"]:
#         dataset = TrainingDataset(
#             data_folder=data_config["data_folder"],
#             data_split_folder=data_config[data_split_type],
#             dataset_name=dataset_name,
#             image_size=config["image_size"],
#             min_dist_cat=config["distance"]["min_dist_cat"],
#             max_dist_cat=config["distance"]["max_dist_cat"],
#             len_traj_pred=config["len_traj_pred"],
#             context_size=config["context_size"],
#             normalize=config["normalize"],
#             goals_per_obs=1,
#             transform=transform,
#             predefined_index=None,
#             traj_stride=1,
#         )
#         dataloaders[f"{dataset_name}_{data_split_type}"] = dataset
#         print(f"Dataset: {dataset_name} ({data_split_type}), size: {len(dataset)}")

# load from dataset
# ds = dataloaders['recon_test'] # scand_test,
# x, _, _ = ds[np.random.randint(len(ds))]
# x_start = x[:config["context_size"]]
