In [1]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    
    def __init__(self, in_channels=1, embed_dim=512, patch_size=16):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size)

    def forward(self, x):
        out = self.proj(x)
        out = out.flatten(2)
        return out

class HRMVision(nn.Module):
    
    def __init__(self, output_size,in_channels=4, sequence_length = 16, patch_size=64, embed_dim=16, h_cycle = 4, l_cycle = 8, device='cpu'):
        super().__init__()
        
        
        self.h_cycle = h_cycle
        self.l_cycle = l_cycle
        self.context_length =  16
        self.patchify = PatchEmbedding(in_channels, sequence_length, patch_size)
        
        # self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(self.context_length, embed_dim)
        self.low = nn.GRUCell(input_size=embed_dim*embed_dim, hidden_size=embed_dim*embed_dim, device=device,)
        self.high = nn.GRUCell(input_size=embed_dim*embed_dim, hidden_size=embed_dim*embed_dim, device=device)
        
        # self.low = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device,)
        # self.high = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim*embed_dim, embed_dim*embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim*embed_dim, output_size)
        )
        
    def forward(self, image):
        
        token_embs = self.patchify(image)
        
        pos_embs = self.pos_embed(torch.arange(0, self.context_length).to(image.device))
        embs = token_embs+pos_embs
        embs = embs.view(image.shape[0], -1)
        # hx, cx = torch.zeros((tokens.shape[0],pos_embs.shape[-1])), torch.zeros((tokens.shape[0],pos_embs.shape[-1]))
        z_l = torch.zeros((image.shape[0],embs.shape[-1]))
        # print(z_l.shape)
        for i in range(self.h_cycle*self.l_cycle):
            z_l = self.low(embs, z_l)
            if i%self.h_cycle == 0: 
                # print(f"at {i}")
                z_h = self.high(embs, z_l)
                z_l = z_h
        # print('here')
        out = self.mlp(z_h)
        return out

model = HRMVision(output_size=10, in_channels=1)

x = torch.randn((1,1, 256, 256))
out = model(x)
print(out.shape)


torch.Size([1, 10])


In [1]:
from transformers import PreTrainedModel, PretrainedConfig
import torch
import torch.nn as nn

# 1. Create a config class
class HRMConfig(PretrainedConfig):
    model_type = "hrm"

    def __init__(self, 
                 in_channels=1, 
                 embed_dim=16, 
                 sequence_length=16, 
                 output_size=10, 
                 h_cycle=4, 
                 l_cycle=8,
                 patch_size=16, 
                 **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = in_channels
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.sequence_length = sequence_length
        self.output_size = output_size
        self.h_cycle = h_cycle
        self.l_cycle = l_cycle
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 2. Wrap your model inside a PreTrainedModel
class HRMForClassification(PreTrainedModel):
    config_class = HRMConfig

    def __init__(self, config):
        super().__init__(config)

        self.h_cycle = config.h_cycle
        self.l_cycle = config.l_cycle
        self.context_length =  16
        self.patchify = PatchEmbedding(config.in_channels, config.sequence_length, config.patch_size)
        
        # self.token_embed = nn.Embedding(vocab_size, hidden_size)
        self.pos_embed = nn.Embedding(self.context_length, config.embed_dim)
        self.low = nn.GRUCell(input_size=config.embed_dim*config.embed_dim, hidden_size=config.embed_dim*config.embed_dim, device=config.device,)
        self.high = nn.GRUCell(input_size=config.embed_dim*config.embed_dim, hidden_size=config.embed_dim*config.embed_dim, device=config.device)
        
        # self.low = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device,)
        # self.high = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size, device=device)
        
        self.mlp = nn.Sequential(
            nn.Linear(config.embed_dim*config.embed_dim, config.embed_dim*config.embed_dim),
            nn.ReLU(),
            nn.Linear(config.embed_dim*config.embed_dim, config.output_size)
        )

        # Initialize weights the Transformers way
        self.post_init()

    def forward(self, input_ids=None, labels=None, **kwargs):
        tokens = input_ids
        token_embs = self.token_embed(tokens)
        pos_embs = self.pos_embed(
            torch.arange(0, tokens.shape[-1], device=tokens.device)
        )
        embs = token_embs + pos_embs
        embs = embs.view(tokens.shape[0], -1)

        z_l = torch.zeros((tokens.shape[0], embs.shape[-1]), device=tokens.device)
        for i in range(self.h_cycle * self.l_cycle):
            z_l = self.low(embs, z_l)
            if i % self.h_cycle == 0:
                z_h = self.high(embs, z_l)
                z_l = z_h
        logits = self.mlp(z_h)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.output_size), labels.view(-1))

        return {"loss": loss, "logits": logits}


  from .autonotebook import tqdm as notebook_tqdm
