In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

sequence_length = 50
patch_size = 5
DG = 16
DL = 32
vocab_size = 256

class ByteEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super(ByteEmbedding, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, sequence_length, embed_dim))

    def forward(self, x):
        x = self.embed(x) + self.pos_embed[:, :x.size(1), :]
        return x

In [None]:
input_sequence = torch.randint(0, vocab_size, (1, sequence_length))
byte_embedding = ByteEmbedding(vocab_size, DG)
embedded_sequence = byte_embedding(input_sequence)

In [None]:
input_sequence

tensor([[ 55,   4,  96,  20,  50,   8,  11, 244,  55, 213, 231, 181, 172, 126,
          89,  49, 170,  78, 119, 214, 224, 218,  65, 197, 212, 134, 200,  90,
         217, 174, 244,  90, 189,   3, 162, 183, 165, 207,  55, 195,  66,  95,
          36,  95,  44, 105, 220,  75, 148, 195]])

In [None]:
input_sequence.shape, embedded_sequence.shape

(torch.Size([1, 50]), torch.Size([1, 50, 16]))

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()
        assert seq_len % self.patch_size == 0, "Sequence length must be divisible by patch size"

        num_patches = seq_len // self.patch_size
        x = x.view(batch_size, num_patches, self.patch_size * embed_dim)

        pad_embed = nn.Parameter(torch.zeros(batch_size, 1, self.patch_size * embed_dim))
        x = torch.cat((pad_embed, x), dim=1)
        return x

In [None]:
patch_embedding = PatchEmbedding(patch_size, DG)
patches = patch_embedding(embedded_sequence)  # (1, 11, 80)

In [None]:
class GlobalTransformer(nn.Module):
    def __init__(self, input_dim, nhead, num_layers):
        super(GlobalTransformer, self).__init__()
        decoder_layer = nn.TransformerDecoderLayer(d_model=patch_size*DG, nhead=nhead, dim_feedforward=256, dropout=0.1)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

    def forward(self, x):
        x = x.transpose(0, 1)
        x = self.transformer(x, x)
        x = x.transpose(0, 1)
        return x

In [None]:
global_transformer = GlobalTransformer(patch_size * DG, nhead=4, num_layers=2)
global_output = global_transformer(patches)  # (1, 11, 80)

In [None]:
global_output.shape

torch.Size([1, 11, 80])

In [None]:
class LocalTransformer(nn.Module):
    def __init__(self, input_dim, nhead, num_layers):
        super(LocalTransformer, self).__init__()
        decoder_layer = nn.TransformerDecoderLayer(d_model=input_dim, nhead=nhead, dim_feedforward=128, dropout=0.1)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

    def forward(self, x):
        x = x.transpose(0, 1)
        x = self.transformer(x, x)
        x = x.transpose(0, 1)
        return x

In [None]:
local_transformer = LocalTransformer(DL, nhead=4, num_layers=2)

In [None]:
global_to_local_proj = nn.Linear(patch_size * DG, DL)
projected_global_output = global_to_local_proj(global_output)  # (1, 11, 32)

In [None]:
projected_global_output.shape

torch.Size([1, 11, 32])

In [None]:
local_outputs = []
for i in range(1, projected_global_output.size(1)):
    patch_input = projected_global_output[:, i, :].unsqueeze(1).repeat(1, patch_size, 1)  # (batch_size, patch_size, DL)
    local_output = local_transformer(patch_input)
    local_outputs.append(local_output)

In [None]:
local_outputs = torch.cat(local_outputs, dim=1)  # (1, 10, DL)
flattened_local_outputs = local_outputs.view(-1, DL)  # (10 * 5, DL)

final_proj = nn.Linear(DL, vocab_size)
logits = final_proj(flattened_local_outputs)  # (50, vocab_size)

probabilities = F.softmax(logits, dim=-1)  # (50, vocab_size)

In [None]:
local_outputs.shape, flattened_local_outputs.shape, logits.shape, probabilities.shape, final_proj

(torch.Size([1, 50, 32]),
 torch.Size([50, 32]),
 torch.Size([50, 256]),
 torch.Size([50, 256]),
 Linear(in_features=32, out_features=256, bias=True))

In [None]:
[I am a] [boy that ]

In [None]:
class ConvolutionalPatchEncoder(nn.Module): # Masked Con
    def __init__(self, embed_dim, num_layers, kernel_size, patch_size):
        super(ConvolutionalPatchEncoder, self).__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(embed_dim, embed_dim, kernel_size, padding=kernel_size//2) for _ in range(num_layers)
        ])
        self.patch_size = patch_size

    def forward(self, x):

        x = x.transpose(1, 2)
        for conv in self.convs:
            x = F.relu(conv(x))
        x = x.transpose(1, 2)

        batch_size, seq_len, embed_dim = x.size()
        num_patches = seq_len // self.patch_size
        x = x.reshape(batch_size, num_patches, self.patch_size * embed_dim)
        return x

In [None]:
conv_patch_encoder = ConvolutionalPatchEncoder(embed_dim=DG, num_layers=2, kernel_size=3, patch_size=5)
embedded_sequence = byte_embedding(input_sequence)
conv_patches = conv_patch_encoder(embedded_sequence)

In [None]:
conv_patches.shape

torch.Size([4, 10, 80])

In [None]:
x = embedded_sequence
x = x.transpose(1, 2)
for conv in convs:
  x = F.relu(conv(x))

In [None]:
x.shape

torch.Size([4, 16, 50])

In [None]:
m = nn.Conv1d(16, 33, 3, padding=1)
input = torch.randn(20, 16, 50)
output = m(input)

In [None]:
output.shape

torch.Size([20, 33, 50])