# Qformer Library

In [1]:
import torch
from qformer import QFormer

In [None]:
x = torch.randn((1, 32, 512))

img = torch.randn((1,3,224,224))

# Create an instance of the QFormer model with the following parameters:
# - input_size: 512
# - num_heads: 8
# - num_layers: 8
# - dropout: 0.1
# - num_classes: 2
# - num_patches: 2

q_former = QFormer(512, 8, 8, 0.1, 2, 2)
y = q_former(x, img)
print(y.shape)

torch.Size([1, 32, 512])


# Surprise Mechanism

In [17]:
import torch
import torch.nn as nn


class SurpriseMechanism(nn.Module):
    
    def __init__(self, input_dim, num_slots, slot_dim, learning_rate = 0.6):
        super().__init__()
        
        self.slots = torch.randn((1, num_slots, slot_dim))
        
        self.fc1 = nn.Linear(input_dim, slot_dim)
        
        self.lr = learning_rate
    
    def forward(self, x):
        
        proj = self.fc1(x)
    
        
        surprise = nn.functional.mse_loss(self.slots, proj.expand_as(self.slots), reduction='none').sum(dim=-1)
        
        min_index = torch.argmin(surprise, dim=-1)
        print(min_index)
        
        self.slots.data[0][min_index] = (1-self.lr)*self.slots.data[0][min_index] + self.lr*proj
        
        return surprise        
        


model = SurpriseMechanism(4,2,8)

x_t = torch.randn((1, 4))

out = model(x_t)
print(out)

tensor([1])
tensor([[11.3411,  6.9629]], grad_fn=<SumBackward1>)


In [15]:
model.slots.data.shape

torch.Size([1, 2, 8])

In [20]:
slots = torch.randn((1, 2, 4))
slots

tensor([[[ 0.2288,  0.0145,  0.2790,  1.2779],
         [ 1.9375, -0.7367, -2.1079, -0.0094]]])

In [19]:
x_t.expand_as(slots).shape

torch.Size([1, 2, 4])

In [23]:
x_t

tensor([[-0.1775,  1.4299,  0.2057, -0.6213]])

In [22]:
x_t.expand_as(slots)

tensor([[[-0.1775,  1.4299,  0.2057, -0.6213],
         [-0.1775,  1.4299,  0.2057, -0.6213]]])

# HRM Model

In [None]:
import torch
import torch.nn as nn

class HRM(nn.Module):
    
    def __init__(self, hidden_size, vocab_size, context_length, output_size, h_cycle = 4, l_cycle = 8, device='cpu'):
        super().__init__()
        
        self.h_cycle = h_cycle
        self.l_cycle = l_cycle
        
        self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(context_length, hidden_size)
        self.low = nn.GRUCell(input_size=hidden_size*4, hidden_size=hidden_size*4, device=device,)
        self.high = nn.GRUCell(input_size=hidden_size*4, hidden_size=hidden_size*4, device=device)
        
        # self.low = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device,)
        # self.high = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device)
        
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size*4, hidden_size*2),
            nn.ReLU(),
            nn.Linear(hidden_size*2, output_size)
        )
        
    def forward(self, tokens):
        
        token_embs = self.token_embed(tokens)
        pos_embs = self.pos_embed(torch.arange(0, tokens.shape[-1]).to(tokens.device))
        
        embs = token_embs+pos_embs
        embs = embs.view(tokens.shape[0], -1)
        # print(embs.shape)
        # hx, cx = torch.zeros((tokens.shape[0],pos_embs.shape[-1])), torch.zeros((tokens.shape[0],pos_embs.shape[-1]))
        z_l = torch.zeros((tokens.shape[0],embs.shape[-1]))
        # print(z_l.shape)
        for i in range(self.h_cycle*self.l_cycle):
            z_l = self.low(embs, z_l)
            if i%self.h_cycle == 0: 
                # print(f"at {i}")
                z_h = self.high(embs, z_l)
                z_l = z_h
        # print('here')
        out = self.mlp(z_h)
        return out
    

model = HRM(32, 8, 4, 8)

x = torch.randint(0,8, (1,4))
out = model(x)
print(out.shape)


at 0
at 4
at 8
at 12
at 16
at 20
at 24
at 28
torch.Size([1, 8])


In [None]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    
    def __init__(self, in_channels=1, embed_dim=512, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size)

    def forward(self, x):
        out = self.proj(x)
        out = out.flatten(2)
        return out

class HRMVision(nn.Module):
    
    def __init__(self, output_size,in_channels=4, sequence_length = 16, patch_size=64, embed_dim=16, h_cycle = 4, l_cycle = 8, device='cpu'):
        super().__init__()
        
        
        self.h_cycle = h_cycle
        self.l_cycle = l_cycle
        self.context_length =  16
        self.patchify = PatchEmbedding(in_channels, sequence_length, patch_size)
        
        # self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(self.context_length, embed_dim)
        self.low = nn.GRUCell(input_size=embed_dim*embed_dim, hidden_size=embed_dim*embed_dim, device=device,)
        self.high = nn.GRUCell(input_size=embed_dim*embed_dim, hidden_size=embed_dim*embed_dim, device=device)
        
        # self.low = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device,)
        # self.high = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim*embed_dim, embed_dim*embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim*embed_dim, output_size)
        )
        
    def forward(self, image):
        
        token_embs = self.patchify(image)
        
        pos_embs = self.pos_embed(torch.arange(0, self.context_length).to(image.device))
        embs = token_embs+pos_embs
        embs = embs.view(image.shape[0], -1)
        # hx, cx = torch.zeros((tokens.shape[0],pos_embs.shape[-1])), torch.zeros((tokens.shape[0],pos_embs.shape[-1]))
        z_l = torch.zeros((image.shape[0],embs.shape[-1]))
        # print(z_l.shape)
        for i in range(self.h_cycle*self.l_cycle):
            z_l = self.low(embs, z_l)
            if i%self.h_cycle == 0: 
                # print(f"at {i}")
                z_h = self.high(embs, z_l)
                z_l = z_h
        # print('here')
        out = self.mlp(z_h)
        return out

model = HRMVision(output_size=10, in_channels=1)

x = torch.randn((1,1, 256, 256))
out = model(x)
print(out.shape)


torch.Size([1, 10])


In [10]:
model.named_modules

<bound method Module.named_modules of HRMVision(
  (patchify): PatchEmbedding(
    (proj): Conv2d(1, 16, kernel_size=(64, 64), stride=(64, 64))
  )
  (pos_embed): Embedding(16, 16)
  (low): GRUCell(256, 256)
  (high): GRUCell(256, 256)
  (mlp): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=10, bias=True)
  )
)>