In [None]:
from typing import Dict, Any, Optional

from orbax.checkpoint import PyTreeCheckpointer
import jax
import jax.numpy as jnp
import flax.linen as nn

from models.dynamics import DynamicsMaskGIT
from models.lam import LatentActionModel
from models.tokenizer import TokenizerVQVAE

class Genie(nn.Module):
    """Genie model"""
    # --- Tokenizer ---
    in_dim: int
    tokenizer_dim: int
    latent_patch_dim: int
    num_patch_latents: int
    patch_size: int
    tokenizer_num_blocks: int
    tokenizer_num_heads: int
    # --- LAM ---
    lam_dim: int
    latent_action_dim: int
    num_latent_actions: int
    lam_patch_size: int
    lam_num_blocks: int
    lam_num_heads: int
    # --- Dynamics ---
    dyna_dim: int
    dyna_num_blocks: int
    dyna_num_heads: int
    dropout: float
    mask_limit: float

    def sample(self, batch):
        temp = 1.0
        generation_steps = 25
        # Tokenize initial frame
        token_idxs = self.tokenizer.vq_encode(batch["videos"], training=False)['indices']
        token_idxs = jnp.concatenate
        lam_codes = self.lam.get_codebook[(batch["latent_actions"],)]

        # --- MASKGIT ---
        def _maskgit_step(carry, step):
            rng, seq, mask = carry

            # --- Mask videos ---
            vid_embed = self.patch_embed(batch["video_tokens"])
            if training:
                rng1, rng2 = jax.random.split(batch["mask_rng"])
                mask_prob = jax.random.uniform(rng1, minval=self.mask_limit)
                mask = jax.random.bernoulli(rng2, mask_prob, vid_embed.shape[:-1])
                mask = mask.at[:, 0].set(False)
                vid_embed = jnp.where(jnp.expand_dims(mask, -1), self.mask_token, vid_embed)
            else:
                mask = None

            # --- Predict transition ---
            act_embed = self.action_up(batch["latent_actions"])
            vid_embed += jnp.pad(act_embed, ((0, 0), (1, 0), (0, 0), (0, 0)))
            logits = self.dynamics(vid_embed)
            return dict(token_logits=logits, mask=mask)



            return carry, None

        jax.lax.scan

        # Predict denoised frame

        # Update tokens

        vid_gen = genie.tokenizer.decode(
            dyna_outputs["vid_gen"],
            video_hw=batch['videos'].shape[2:4],
        )
        return vid_gen



    def setup(self):
        self.tokenizer = TokenizerVQVAE(
            in_dim=self.in_dim,
            model_dim=self.tokenizer_dim,
            latent_dim=self.latent_patch_dim,
            num_latents=self.num_patch_latents,
            patch_size=self.patch_size,
            num_blocks=self.tokenizer_num_blocks,
            num_heads=self.tokenizer_num_heads,
            dropout=0.0,
            codebook_dropout=0.0,
        )
        self.lam = LatentActionModel(
            in_dim=self.in_dim,
            model_dim=self.lam_dim,
            latent_dim=self.latent_patch_dim,
            num_latents=self.num_latent_actions,
            patch_size=self.lam_patch_size,
            num_blocks=self.lam_num_blocks,
            num_heads=self.lam_num_heads,
            dropout=0.0,
            codebook_dropout=0.0,
        )
        self.dynamics = DynamicsMaskGIT(
            model_dim=self.dyna_dim,
            num_latents=self.num_patch_latents,
            num_blocks=self.dyna_num_blocks,
            num_heads=self.dyna_num_heads,
            dropout=self.dropout,
            mask_limit=self.mask_limit,
        )

    def __call__(self, batch: Dict[str, Any], training: bool = True) -> Dict[str, Any]:
        tokenizer_outputs = self.tokenizer.vq_encode(batch["videos"], training=False)
        lam_outputs = self.lam.vq_encode(batch["videos"], training=False)
        outputs = dict(
            video_tokens=jax.lax.stop_gradient(tokenizer_outputs["indices"]),
            latent_actions=jax.lax.stop_gradient(lam_outputs["z_q"]),
        )
        outputs["mask_rng"] = batch["mask_rng"]
        dyna_outputs = self.dynamics(outputs, training)
        outputs.update(dyna_outputs)
        mle_indices = jnp.argmax(outputs["token_logits"], axis=-1)
        outputs["recon"] = self.tokenizer.decode(mle_indices, batch["videos"].shape[2:4])
        return outputs

In [1]:
from dataclasses import dataclass
import os
import time

import einops
from flax.training import orbax_utils
from flax.training.train_state import TrainState
import optax
import orbax
import numpy as np
import jax
import jax.numpy as jnp
import wandb
import tyro

from data.dataloader import get_dataloader

ts = int(time.time())

@dataclass
class Args:
    # Experiment
    num_steps: int = 200_000
    seed: int = 0
    seq_len: int = 16
    image_channels: int = 3
    image_resolution: int = 64
    file_path: str = "/home/duser/jafar/data/coinrun.npy"
    # Optimization
    batch_size: int = 36
    min_lr: float = 3e-6
    max_lr: float = 3e-5
    warmup_steps: int = 5000
    # Tokenizer
    tokenizer_dim: int = 512
    latent_patch_dim: int = 32
    num_patch_latents: int = 1024
    patch_size: int = 4
    tokenizer_num_blocks: int = 8
    tokenizer_num_heads: int = 8
    tokenizer_checkpoint: str = "/home/duser/jafar/checkpoints/tokenizer_1721468116_50000"
    # LAM
    lam_dim: int = 512
    latent_action_dim: int = 32
    num_latent_actions: int = 6
    lam_patch_size: int = 16
    lam_num_blocks: int = 8
    lam_num_heads: int = 8
    lam_checkpoint: str = "/home/duser/jafar/checkpoints/lam_1721469076_175000"
    # Dynamics
    dyna_dim: int = 512
    dyna_num_blocks: int = 12
    dyna_num_heads: int = 8
    dropout: float = 0.0
    mask_limit: float = 0.5
    # Logging
    log: bool = False
    entity: str = "flair"
    project: str = "jafari"
    log_interval: int = 5
    log_image_interval: int = 250
    ckpt_dir: str = "/home/duser/jafar/checkpoints"
    log_checkpoint_interval: int = 25000
    log_gradients: bool = False
    # Sampling
    checkpoint: str = "/home/duser/jafar/checkpoints/genie_1721738387_200000"

# args = tyro.cli(Args)
args = Args()
rng = jax.random.PRNGKey(args.seed)

# --- Construct train state ---
genie = Genie(
    # Tokenizer
    in_dim=args.image_channels,
    tokenizer_dim=args.tokenizer_dim,
    latent_patch_dim=args.latent_patch_dim,
    num_patch_latents=args.num_patch_latents,
    patch_size=args.patch_size,
    tokenizer_num_blocks=args.tokenizer_num_blocks,
    tokenizer_num_heads=args.tokenizer_num_heads,
    # LAM
    lam_dim=args.lam_dim,
    latent_action_dim=args.latent_action_dim,
    num_latent_actions=args.num_latent_actions,
    lam_patch_size=args.lam_patch_size,
    lam_num_blocks=args.lam_num_blocks,
    lam_num_heads=args.lam_num_heads,
    # Dynamics
    dyna_dim=args.dyna_dim,
    dyna_num_blocks=args.dyna_num_blocks,
    dyna_num_heads=args.dyna_num_heads,
    dropout=args.dropout,
    mask_limit=args.mask_limit,
)
rng, _rng = jax.random.split(rng)
image_shape = (args.image_resolution, args.image_resolution, args.image_channels)
dummy_inputs = dict(
    videos=jnp.zeros((args.batch_size, args.seq_len, *image_shape), dtype=jnp.float32),
    mask_rng=_rng
)
rng, _rng = jax.random.split(rng)
params = genie.init(_rng, dummy_inputs)
from orbax.checkpoint import PyTreeCheckpointer
params["params"].update(
    PyTreeCheckpointer().restore(args.checkpoint)["model"]["params"]["params"])


2024-08-18 09:56:37.847935: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
dataloader = get_dataloader(args.file_path, args.seq_len, args.batch_size)

In [3]:
for i in dataloader:
    batch = jnp.array(i, dtype=jnp.float32) / 255.0
    break

  return torch.from_numpy(sequence).clone()


In [4]:
batch.shape
init_frame = batch[0, 0]

In [5]:
def imshow(img):
    import cv2
    import IPython
    _,ret = cv2.imencode('.jpg', img)
    i = IPython.display.Image(data=ret)
    IPython.display.display(i)

# for i in range(generated_frames.shape[1]):
#     imshow(np.asarray(generated_frames[0, i]*255.0))
#     imshow(np.asarray(videos[0, i]*255.0))


In [17]:
# def _sample_internal(self, 
#                         batch: Dict[str, Any], 
#                         rng: jax.random.PRNGKey,
#                         generation_steps: int = 25,
#                         temp: float = 1.0):

#     act_embed = self.action_up(batch["latent_actions"])
#     act_embed = jnp.pad(act_embed, ((0, 0), (1, 0), (0, 0), (0, 0)))

#     B, T, N = batch["video_tokens"].shape
#     vid_act_embed, gen_act_embed = act_embed[:, :T], act_embed[:, T:]

#     vid_embed = self.patch_embed(batch["video_tokens"]) + vid_act_embed

#     def gen_step(state, step):
#         gen, mask, rng = state

#         gen_embed = self.patch_embed(gen) + gen_act_embed
#         gen_embed = jnp.where(jnp.expand_dims(mask, -1), 0, gen_embed)
#         gen_embed = jnp.concatenate([vid_embed, gen_embed], axis=1)
#         logits = self.dynamics(gen_embed)[:, T:]

#         n_mask_toks = cosine_schedule(
#             step, logits, N,
#             generation_steps=generation_steps,
#         )

#         rng, rng_gen = jax.random.split(rng)
#         next_gen = jax.random.categorical(rng_gen, logits / temp)

#         p_tokens = jax.nn.softmax(logits)
#         p_tokens = jnp.take_along_axis(p_tokens, next_gen[..., None], axis=-1).squeeze(-1) + mask

#         def get_threshold(x, idx):
#             return jax.lax.dynamic_slice(x, (idx,), (1,))[0]

#         limit_indices = N - n_mask_toks
#         p_tokens_sorted = jnp.sort(p_tokens, axis=-1)
#         limit = jax.vmap(jax.vmap(get_threshold))(p_tokens_sorted, limit_indices)[..., None]
#         next_mask = (p_tokens >= limit) & ~mask

#         gen = jnp.where(next_mask, next_gen, gen)
#         mask = mask | next_mask

#         return (gen, mask, rng)

#     mask = jnp.zeros((B, 1, N), dtype=jnp.bool)        
#     generated_frame = jnp.zeros((B, 1 , N), dtype=jnp.int32)
#     rng, rng_run = jax.random.split(rng)
#     state = (generated_frame, mask, rng_run)

#     for i in range(0, generation_steps):
#         state = gen_step(state, i)

#     generated_frame = state[0]
#     vid_gen = jnp.concatenate([batch["video_tokens"], generated_frame], axis=1)

#     return dict(vid_gen=vid_gen)

# def sample(
#         self,
#         params,
#         batch,
#         rng,
#         generation_steps: int = 25,
#         temp: float = 1.0,
# ):
#     return self.apply(params, batch, rng, generation_steps, temp, method=self._sample_internal)

In [None]:
temp = 1.0
generation_steps = 25

# Tokenize initial frame
tokenizer_outputs = genie.tokenizer.vq_encode(batch["videos"], training=False)
lam_codes = genie.lam.get_codebook[(batch['latent_actions'],)]

# --- MASKGIT ---

# Construct mask frame
genie.sample(tokenizer_outputs)
dyna_outputs = genie.dynamics._sample_internal(
    batch=dict(video_tokens=tokenizer_outputs["indices"],
                latent_actions=lam_codes[:, :, None, :]),
    rng=rng,
    generation_steps=generation_steps,
    temp=temp,
)

vid_gen = genie.tokenizer.decode(
    dyna_outputs["vid_gen"],
    video_hw=batch['videos'].shape[2:4],
)

