In [5]:
from esm.models.esmc import ESMC
client = ESMC.from_pretrained("esmc_300m").to("cpu")
client


ESMC(
  (embed): Embedding(64, 960)
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0-29): 30 x UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=960, out_features=2880, bias=False)
          )
          (out_proj): Linear(in_features=960, out_features=960, bias=False)
          (q_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (k_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (rotary): RotaryEmbedding()
        )
        (ffn): Sequential(
          (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=960, out_features=5120, bias=False)
          (2): SwiGLU()
          (3): Linear(in_features=2560, out_features=960, bias=False)
        )
      )
    )
    (norm): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
  )
  (sequ

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtune.modules import RotaryPositionalEmbeddings

class SwiGLU(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=-1)
        return F.silu(x1) * x2

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, max_seq_len=4096):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        self.layernorm_qkv = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim * 3, bias=False),
        )
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.q_ln = nn.LayerNorm(embed_dim)
        self.k_ln = nn.LayerNorm(embed_dim)
        self.rotary = RotaryPositionalEmbeddings(dim=self.head_dim, max_seq_len=max_seq_len)

    def forward(self, x, input_pos=None):
        B, T, C = x.shape  # Batch size, sequence length, embed dim
        qkv = self.layernorm_qkv(x).view(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # Apply rotary embeddings to queries and keys
        q = self.rotary(q, input_pos=input_pos)
        k = self.rotary(k, input_pos=input_pos)

        # Scaled dot-product attention
        attn_weights = torch.einsum("bnqd,bnkd->bnqk", q, k) / (self.head_dim ** 0.5)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = torch.einsum("bnqk,bnvd->bnqd", attn_weights, v)
        attn_output = attn_output.contiguous().view(B, T, C)

        return self.out_proj(attn_output)

class UnifiedTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ffn_hidden_dim, max_seq_len=4096):
        super().__init__()
        self.attn = MultiHeadAttention(embed_dim, num_heads, max_seq_len)
        self.ffn = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, ffn_hidden_dim, bias=False),
            SwiGLU(),
            nn.Linear(ffn_hidden_dim // 2, embed_dim, bias=False),  # SwiGLU halves hidden dim
        )

    def forward(self, x, input_pos=None):
        x = x + self.attn(x, input_pos=input_pos)
        x = x + self.ffn(x)
        return x

class TransformerStack(nn.Module):
    def __init__(self, num_blocks, embed_dim, num_heads, ffn_hidden_dim, max_seq_len=4096):
        super().__init__()
        self.blocks = nn.ModuleList(
            [UnifiedTransformerBlock(embed_dim, num_heads, ffn_hidden_dim, max_seq_len) for _ in range(num_blocks)]
        )
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x, input_pos=None):
        for block in self.blocks:
            x = block(x, input_pos=input_pos)
        return self.norm(x)

class Custom_ESMC(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_blocks, num_heads, ffn_hidden_dim, output_dim, max_seq_len=4096):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.transformer = TransformerStack(num_blocks, embed_dim, num_heads, ffn_hidden_dim, max_seq_len)
        self.sequence_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim, bias=True),
            nn.GELU(),
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, output_dim, bias=True),
        )

    def forward(self, x, input_pos=None):
        x = self.embed(x)
        x = self.transformer(x, input_pos=input_pos)
        return self.sequence_head(x)

# Instantiate the model
model = Custom_ESMC(
    vocab_size=64,
    embed_dim=960,
    num_blocks=30,
    num_heads=8,  # Example: Adjust num_heads as needed
    ffn_hidden_dim=5120,
    output_dim=64,
    max_seq_len=4096,  # Ensure this matches the maximum sequence length expected
)

print(model)

Custom_ESMC(
  (embed): Embedding(64, 960)
  (transformer): TransformerStack(
    (blocks): ModuleList(
      (0-29): 30 x UnifiedTransformerBlock(
        (attn): MultiHeadAttention(
          (layernorm_qkv): Sequential(
            (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=960, out_features=2880, bias=False)
          )
          (out_proj): Linear(in_features=960, out_features=960, bias=False)
          (q_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (k_ln): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (rotary): RotaryPositionalEmbeddings()
        )
        (ffn): Sequential(
          (0): LayerNorm((960,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=960, out_features=5120, bias=False)
          (2): SwiGLU()
          (3): Linear(in_features=2560, out_features=960, bias=False)
        )
      )
    )
    (norm): LayerNorm((960,), eps=1e-05, elementwise_affine

In [4]:
# make a fake tensor input to the model
x = torch.randint(0, 64, (1, 4096))
out = model(x)
print(out.shape)  # torch.Size([1, 4096, 64])


torch.Size([1, 4096, 64])


In [5]:
out

tensor([[[-0.4917, -0.0814, -0.7978,  ..., -0.4080, -0.4066, -1.9724],
         [ 0.0868,  0.2095,  1.0585,  ...,  0.0946, -0.0400,  0.5337],
         [ 0.0096,  0.3399, -0.2859,  ...,  0.2592, -0.2552,  0.2571],
         ...,
         [-0.1556, -0.0589, -0.8261,  ...,  0.3990, -0.7987,  0.3712],
         [-0.6930,  0.4637, -0.2504,  ...,  0.0775,  0.0601, -0.9776],
         [-0.3557,  0.2114,  0.5776,  ...,  0.8534,  0.0169,  0.0795]]],
       grad_fn=<ViewBackward0>)

In [6]:
# count model parameters
sum(p.numel() for p in model.parameters() if p.requires_grad)

333055744