# Cross Attention Module

In [9]:
import torch
from torch import nn, tanh, einsum
from torchtyping import TensorType
from einops import rearrange, repeat
from einops_exts import rearrange_many


# class CrossAttentionTransformerBlock(nn.Module):
#     def __init__(
#         self,
#         lm_block: nn.Module,
#         config: dict,
#         token_dim: int = 4096,
#         **kwargs
#     ):
#         super().__init__()
#         self.lm_block = lm_block
#         self.cross_x_block = GatedCrossAttentionBlock(
#             config, token_dim, **kwargs)
#         self.media_locations = None
#         self.visual_features = None

#     def forward(self, embs, **kwargs):
#         logits = self.cross_x_block(
#             embs, self.media_locations, self.visual_features)
#         out = self.lm_block(logits, use_cache=False, **kwargs)
#         return out


class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim),
        )

    def forward(self, x: TensorType["Batch", "Sequence", "Dim"]):
        return self.net(x)


class MaskedCrossAttention(nn.Module):
    def __init__(
        self,
        text_token_dim: int = 4096,
        visual_token_dim: int = 2048,
        num_heads: int = 8,
        n_latents: int = 64, 
        head_dim: int = 64, 
    ):
        super().__init__()
        self.num_heads = num_heads
        self.temp = 1 / (head_dim ** -0.5)
        self.softmax = nn.Softmax(dim=-1)
        self.device = device

        self.v_k_w = nn.Linear(
            visual_token_dim, head_dim * num_heads * 2, bias=False)
        self.q_w = nn.Linear(
            text_token_dim, head_dim * num_heads, bias=False)
        self.out = nn.Linear(head_dim *num_heads, text_token_dim, bias=False)

    def forward(self,
                latent: TensorType["Batch", "Sequence", "TokenDim"],
                y: TensorType["Batch", "Sequence Length", "TokenDim"],
                media_mask: TensorType["Batch", "Sequence Length"]
                ):
        visual_features = rearrange(latent, 'b t n d -> b (t n) d')
        print("vs_shape", visual_features.shape)
        k, v = self.v_k_w(visual_features).chunk(2, dim=-1)
        q = self.q_w(y)
        q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h=8)
        sim = einsum('... i d, ... j d -> ... i j', q, k)
        media_time = torch.arange(3).to(self.device) + 1
        text_to_media_mask = rearrange(media_mask, 'b i -> b 1 i 1') == repeat(media_time, 'j -> 1 1 1 (j m)', m=64)
        sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
        logits = sim.softmax(dim=-1)
        text_without_media_mask = media_mask == 0
        text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1')
        logits = logits.masked_fill(text_without_media_mask, 0.)
        logits = einsum('... i j, ... j d -> ... i d', logits, v)
        logits = rearrange(logits, 'b h n d -> b n (h d)')
        y = self.out(logits)
        return y


class GatedCrossAttentionBlock(nn.Module):
    def __init__(self, config, text_token_dim):
        super().__init__()
        self.x_attn = MaskedCrossAttention(text_token_dim = text_token_dim, n_latents = config['n_latents'])
        self.tanh1 = nn.Parameter(torch.tensor([0.]))
        self.ffw = FeedForward(dim=text_token_dim)
        self.tanh2 = nn.Parameter(torch.tensor([0.]))

    def perceiver_pipe(self, visual_features, media_mask):
        self.media_mask = media_mask
        self.visual_features = visual_features

    def forward(self, embs: TensorType["Batch", "Sequence Length", "TokenDim"]):
        x_attn = self.x_attn(self.visual_features, embs, self.media_mask)
        attn_out = embs + tanh(self.tanh1) * x_attn
        x_ffw = attn_out + self.ffw(x_attn) * tanh(self.tanh2)

        return x_ffw


PERCEIVER 

In [2]:
import torch
from torch import nn
from torchtyping import TensorType
from einops import rearrange
from einops_exts import rearrange_many



class PerceiverAttentionBlock(nn.Module):
    def __init__(
        self,
        token_dim,
        output_dim,
        num_heads=8,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.temp = 1 / (output_dim ** -0.5)
        self.softmax = nn.Softmax(dim=-1)

        self.v_k_w = torch.nn.Linear(
            token_dim, output_dim * num_heads * 2, bias=False)
        self.q_w = torch.nn.Linear(
            token_dim, output_dim * num_heads, bias=False)
        self.out = nn.Linear(output_dim*num_heads, token_dim, bias=False)

    def forward(self, q: TensorType["Batch", "Number of Images", "OutputDim", "TokenDim"], k_v: TensorType["Batch", "Number of Images", "Sequence", "TokenDim"]):

        # (batch, n_images, sequence_length + output_dim, embedding_dim)
        k_v = torch.cat((k_v, q), dim=-2)
        # (batch, n_images, sequence_length + output_dim, embedding_dim)
        k, v = self.v_k_w(k_v).tensor_split(2, dim=-1)
        q = self.q_w(q)  # (batch, n_images, output_dim, embedding_dim)

        q, k, v = rearrange_many(
            (q, k, v), 'b n s (h d) -> b n h s d', h=self.num_heads)

        q = q * self.temp

        # k = Image Input Sequence Dimension / q = Output Dimension
        sm = torch.einsum("b n h q d,b n h k d->b n h q k", q, k)
        attn = sm.softmax(dim=-1)

        ff_input_per_head = torch.einsum(
            "b n h q k,b n h k d -> b n h q d", attn, v)
        ff_input = rearrange(
            ff_input_per_head, "b n h s d -> b n s (h d)", h=self.num_heads)
        return self.out(ff_input)


class PerceiverResampler(nn.Module):
    def __init__(
        self,
        token_dim,
        num_layers:  int = 2,
        latent_size: int = 64,
        time: int = 1,
    ):
        super().__init__()
        self.learned_latents = nn.Parameter(
            torch.randn(latent_size, token_dim))
        self.time_embeddings = nn.Parameter(torch.rand(time, 1, token_dim))
        self.flatten = torch.nn.Flatten()
        self.perceiver_attention_layers = nn.ModuleList([])
        for _ in range(num_layers):
            self.perceiver_attention_layers.append(nn.ModuleList([
                PerceiverAttentionBlock(
                    token_dim=token_dim,  output_dim=latent_size),
                PerceiverFeedForwardLayer(token_dim=token_dim)
            ]))
        self.normalize = nn.LayerNorm(token_dim)

    def forward(self, x: TensorType["Batch", "Number of Images", "Time", "Sequence", "Token Dimesion"]):
        if x.ndim == 3:
            x = x[:, None, None, :, :]

        batch_size, number_of_images, time, sequence_length, token_dim = x.size()
        print(x.shape)
        x = rearrange(x, "b n t s d -> b n (t s) d")
        latents = self.learned_latents.repeat(
            batch_size, number_of_images, 1, 1)
        x = x + self.time_embeddings[:number_of_images]

        for att_module, ff_layer in self.perceiver_attention_layers:
            latents = att_module(latents, x)
            latents = ff_layer(latents)

        return self.normalize(latents)


class PerceiverFeedForwardLayer(nn.Module):
    def __init__(
        self,
        token_dim: int = 3027,
        mult: int = 4,
    ):
        super().__init__()
        self.norm = nn.LayerNorm(token_dim)
        self.inner = nn.Linear(token_dim, mult*token_dim, bias=False)
        self.act = nn.GELU()
        self.outer = nn.Linear(mult*token_dim, token_dim, bias=False)

    def forward(self, x):
        return self.outer(self.act(self.inner(self.norm(x))))


In [25]:
from magma.config import MultimodalConfig
from magma.utils import get_world_info
from magma.utils import get_tokenizer
from magma.datasets.dataset import ImgCptDataset
from magma.image_encoders import clip_encoder
from magma.transforms import get_transforms
from magma.language_model import get_gptj
import copy
from torch import nn
torch.set_default_dtype(torch.float16)

In [7]:
t

4096

In [26]:
encoder = clip_encoder(name="RN50")

lm = get_gptj(from_pretrained=True)

transforms = get_transforms(
            384, 
            'clip_resnet_large',
            input_resolution=encoder.input_resolution,
        )

perceiver = PerceiverResampler(2048)
tokenizer = get_tokenizer(
            "gpt2", sequence_length=lm.config.max_position_embeddings)
cross_attention_layers=[]
transformer = lm.transformer.h
config = MultimodalConfig.from_yml('/home/ml-mmeuer/adaptable_magma/fb20-dgx2-configs/Flamingo.yml')
for l in range(len(lm.transformer.h)):
    layer_norm = getattr(lm.transformer.h[l], "ln_1")
    x_attn_block = GatedCrossAttentionBlock(
        config=config.cross_attention_config,
        text_token_dim=lm.config.hidden_size,
    )
    cross_attention_layers.append(l)
    setattr(lm.transformer.h[l], 'ln_1', nn.Sequential(
        *[x_attn_block, layer_norm]))
tokenizer.add_special_tokens({"additional_special_tokens":["<|image|>"]})
img_token = tokenizer.encode("<|image|>")[0]
lm.resize_token_embeddings(len(tokenizer))

dataset = ImgCptDataset('/home/ml-mmeuer/adaptable_magma/magma/datasets/coco_train_val', tokenizer=tokenizer, transforms=transforms, few_shot=3)

Fetching GPTJ language model...
Done Fetching Model ... 


In [20]:
class Magma(nn.Module): 
    def __init__(self, lm, tokenizer, transforms, encoder, img_token, perceiver, cross_attention_layers):
        super().__init__()
        self.lm = lm
        self.tokenizer = tokenizer
        self.img_token = img_token
        self.transforms = transforms
        self.perceiver = perceiver
        self.encoder = encoder
        self.cross_attention_layers = cross_attention_layers

        
    def forward(self, images, captions): 
        input_embeddings = self.lm.transformer.wte(captions).half()
        images = images[:,:,None,:,:,:]
        media_pos = captions == img_token
        media_mask = media_pos.cumsum(dim=-1)
        flat_img = rearrange(images, "b n t c h w -> ( b n t ) c h w ")
        featurized_images = encoder(flat_img)
        visual_embeddings = rearrange(featurized_images, "(b n t) s d -> b n t s d", b = images.shape[0], n = images.shape[1], t=images.shape[2])
        # add cross attention
        print("Input_embeddings",visual_embeddings.shape)
        visual_features = self.perceiver(
            visual_embeddings
        )
        print("vf.dim", visual_features.dtype)
        for l in self.cross_attention_layers:
            x_attn_block = getattr(
                lm.transformer.h[l], 'ln_1')[0]

            x_attn_block.perceiver_pipe(
                visual_features, media_mask=media_mask)
            
        lm_outputs = self.lm(
            inputs_embeds=input_embeddings,
            output_hidden_states=False,
        )
        

In [21]:

magma = Magma(tokenizer=tokenizer,transforms=transforms, lm=lm, encoder=encoder,img_token=img_token, perceiver=perceiver, cross_attention_layers=cross_attention_layers)
magma = magma.to("cuda:3")
images, captions = dataset[1]
images, captions = images.to('cuda:3').half(), captions.to('cuda:3', dtype=torch.int64)
images = images[None, : , : , : , : ]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


In [22]:
captions.shape

torch.Size([1, 2048])

In [23]:
magma(images, captions)

Input_embeddings torch.Size([1, 3, 1, 49, 2048])
torch.Size([1, 3, 1, 49, 2048])
vf.dim torch.float32
vs_shape torch.Size([1, 192, 2048])


RuntimeError: expected scalar type Half but found Float

In [None]:
v_dim = 2048
t_dim = 4096
n_latents = 64 
num_heads = 8 
v_k_w = nn.Linear(
            v_dim, n_latents * num_heads * 2, bias=False)#.to('cuda:3',dtype=torch.float16)
q_w = nn.Linear(
    t_dim, n_latents * num_heads, bias=False)#.to('cuda:3',dtype=torch.float16)
out = nn.Linear(n_latents*num_heads, t_dim, bias=False)#.to('cuda:3',dtype=torch.float16)

In [None]:
y = lm.transformer.wte(captions).float()
media_pos = captions == img_token
media_mask = media_pos.cumsum(dim=-1)#.to('cuda:3')
latent = torch.randn(3,3,64,2048)#.to('cuda:3', dtype=torch.float16)

torch.float32

In [None]:
y.shape

torch.Size([1, 2048, 4096])

In [None]:
visual_features = rearrange(latent, 'b n s d -> b (n s) d')
print("vs_shape", visual_features.shape)
k, v = v_k_w(visual_features).chunk(2, dim=-1)
q = q_w(y)
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h=8)
sim = einsum('... i d, ... j d -> ... i j', q, k)
media_time = torch.arange(3) + 1 #.to('cuda:3') + 1
text_to_media_mask = rearrange(media_mask, 'b i -> b 1 i 1') == repeat(media_time, 'j -> 1 1 1 (j m)', m=64)
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
logits = sim.softmax(dim=-1)
text_without_media_mask = media_mask == 0
text_without_media_mask = rearrange(text_without_media_mask, 'b i -> b 1 i 1')
logits = logits.masked_fill(text_without_media_mask, 0.)
logits = einsum('... i j, ... j d -> ... i d', logits, v)
logits = rearrange(logits, 'b h n d -> b n (h d)')
y = out(logits)

vs_shape torch.Size([3, 192, 2048])


In [None]:
y.shape

torch.Size([3, 2048, 4096])

In [None]:
magma(images,captions)

Input_embeddings torch.Size([1, 3, 1, 49, 2048])
torch.Size([1, 3, 1, 49, 2048])
vf.dim torch.Size([1, 3, 64, 2048])
vs_shape torch.Size([1, 192, 2048])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (192x2048 and 64x1024)

In [None]:
featured_images = encoder(flat_img)

In [None]:
featured_images.shape

torch.Size([3, 49, 2048])

In [None]:
img_shape = images.shape
flat_img = images.view(-1, img_shape[-3], img_shape[-2], img_shape[-1])
featured_images = encoder(flat_img)
featured_images.view(img_shape[0] , img_shape[1], img_shape[3], -1).shape

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
(img_shape[0] , img_shape[1], img_shape[3], -1)

(1, 3, 224, -1)