In [4]:
import torch.nn as nn

import torch

import math

from mamba_ssm import Mamba

In [73]:
from muse_maskgit_pytorch import LayerNorm

In [30]:
from mamba_ssm.ops.triton.layernorm import layer_norm_fn
from einops import repeat

In [2]:
import sys
sys.path.append("..")

from masked_model import MaskedModel, SequenceModelWrapper

  from .autonotebook import tqdm as notebook_tqdm


In [91]:
class MambaIT(nn.Module):
  def __init__(
      self,
      token_size,
      depth,
      d_state = 16,
      d_conv = 4,
      expand = 2,
  ):
    super().__init__()
    self.token_size = token_size
    self.mamba_layers = nn.ModuleList([Mamba(d_model = token_size, d_state = d_state, d_conv = d_conv, expand = expand) for _ in range(depth)])
    self.norm = LayerNorm(token_size)

  def forward(
      self,
      x,
      context,
      context_mask
  ):

    seq_len = x.shape[1]

    context_mask = repeat(context_mask, 'b t -> b t s', s = 512)
    context = torch.where(context_mask, context, torch.zeros_like(context))

    x = torch.cat((context, x), dim = -2)

    for mamba_layer in self.mamba_layers:
      x = mamba_layer(x)
      x = self.norm(x)

    return x[:, :seq_len, :]

In [97]:
def cosine_schedule(t):
    return torch.cos(t * math.pi * 0.5)

In [98]:
model = MaskedModel(
  SequenceModelWrapper(
    MambaIT(
      token_size = 512,
      depth = 8,
    ).cuda(),
    8192,
    64,
    False
  ).cuda(),
  cosine_schedule,
  0.1,
).cuda()

In [99]:
model(image_ids = torch.rand(1, 10).long().cuda(), text_embeds = torch.rand(1, 5, 768).float().cuda())

tensor(11.9251, device='cuda:0', grad_fn=<AddBackward0>)