In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DifferentiableSort(nn.Module):
    """
    A simple Differentiable Sort utility.
    Input: raw attention scores (e.g., QK^T)
    Output: a soft permutation matrix
    """
    def __init__(self, temperature=1.0):
        super().__init__()
        self.temperature = temperature

    def forward(self, scores):
        # We need to sort by importance.
        # Let's say we want to sort each row of the scores matrix.
        # A simple way is to use a softmax with a low temperature.
        
        # scores has shape (batch, query_seq_len, key_seq_len)
        # We want to create a permutation for *each* query.
        
        # Sort scores to get the indices. This is non-differentiable.
        # So we create an approximation.
        
        # A simple soft permutation matrix can be a softmax over the scores themselves
        # with a low temperature.
        # The lower the temperature, the closer it is to a hard sort (one-hot vectors).
        
        # This implementation is a simplified approach for demonstration.
        # More advanced methods use Gumbel-Softmax or Sinkhorn networks.
        soft_perm_matrix = F.softmax(scores / self.temperature, dim=-1)
        
        # The key idea here is that the matrix P represents P_ij:
        # P_ij = soft probability that original item j is in the i-th sorted position.
        # This is a bit of a simplification, but conceptually it works.
        
        return soft_perm_matrix

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Assuming DifferentiableSort class from above is defined
#
# A custom Multi-Head Attention module that uses a 
# convolutional sweep over an attention-ordered sequence.
class AttentionWithConvSweep(nn.Module):
    def __init__(self, d_model, num_heads, kernel_size=3, padding='same', temperature=1.0):
        super().__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        
        # Ensure d_model is divisible by num_heads
        assert d_model % num_heads == 0
        self.head_dim = d_model // num_heads

        # Linear layers for Q, K, V projections
        self.query_projection = nn.Linear(d_model, d_model)
        self.key_projection = nn.Linear(d_model, d_model)
        self.value_projection = nn.Linear(d_model, d_model)

        # The Differentiable Sort utility
        self.differentiable_sort = DifferentiableSort(temperature=temperature)

        # 1D Convolution layer
        self.conv1d = nn.Conv1d(
            in_channels=self.head_dim,
            out_channels=self.head_dim,
            kernel_size=kernel_size,
            padding=padding,
            bias=False
        )

        # Final linear layer to project back to d_model
        self.final_projection = nn.Linear(d_model, d_model)


    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # 1. Project input to Q, K, V
        Q = self.query_projection(x)
        K = self.key_projection(x)
        V = self.value_projection(x)

        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 2. Calculate Scaled Dot-Product Attention Scores
        # Q and K have shape (batch, num_heads, seq_len, head_dim)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        # scores now has shape (batch, num_heads, seq_len, seq_len)

        if mask is not None:    
            scores = scores.masked_fill(mask == float('-inf'), float('-inf'))
            
        # 3. Apply Differentiable Sort to get a soft permutation matrix
        # This is where your core idea comes to life.
        # We are generating a permutation matrix for each attention head.
        soft_perm_matrix = self.differentiable_sort(scores)
        # soft_perm_matrix has shape (batch, num_heads, seq_len, seq_len)

        # 4. Perform Soft Re-ordering of Value vectors
        # V has shape (batch, num_heads, seq_len, head_dim)
        V_sorted = torch.matmul(soft_perm_matrix, V)
        # V_sorted now has the same shape, but the vectors are "softly" re-ordered

        # 5. Apply 1D Convolution
        # nn.Conv1d expects input of shape (batch, channels, length)
        # So we need to reshape V_sorted.
        V_conv_input = V_sorted.transpose(1, 2).reshape(
            batch_size * seq_len, self.num_heads, self.head_dim
        )
        
        # Apply the convolution to each head's data
        # Transpose again to get (batch, channels, length) for convolution
        V_conv_input = V_conv_input.transpose(1, 2)

        # ##
        # k = self.conv1d.kernel_size[0]
        # V_conv_input = F.pad(V_conv_input, (k - 1, 0))   # (left, right) padding
        
        # The output of the convolution
        conv_output = self.conv1d(V_conv_input)
        
        # Reshape back to multi-head format
        conv_output = conv_output.transpose(1, 2)
        conv_output = conv_output.view(
            batch_size, seq_len, self.num_heads, self.head_dim
        ).transpose(1, 2)
        
        # 6. Final projection and reshape
        # Concatenate heads
        output = conv_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Final linear layer
        output = self.final_projection(output)

        return output

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        
        self.attention = AttentionWithConvSweep(d_model, num_heads)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, d_model),
            nn.Dropout(dropout)
        )
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Attention and Residual Connection
        attention_output = self.attention(x, mask=mask)
        x = self.norm1(x + self.dropout1(attention_output))
        
        # Feed-Forward and Residual Connection
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_output))
        
        return x

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MiniTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, ff_dim=2048, n_layers=6, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(max_len, d_model)
        self.layers    = nn.ModuleList([
            TransformerBlock(d_model, num_heads, ff_dim)
            for _ in range(n_layers)
        ])
        self.ln = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, x, mask):
        B, T = x.shape
        positions = torch.arange(T, device=x.device).unsqueeze(0)
        h = self.token_emb(x) + self.pos_emb(positions)
        for layer in self.layers:
            h = layer(h, mask=mask)
        h = self.ln(h)
        return self.head(h)


In [5]:
from datasets import load_dataset
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")  # or another LM tokenizer
tokenizer.pad_token = tokenizer.eos_token
dataset = load_dataset("allenai/c4", "en", streaming=True, split="train")

max_len = 512

def encode(ex):
    ids = tokenizer(ex["text"], truncation=True, max_length=max_len,
                    padding="max_length", return_tensors="pt").input_ids[0]
    return {"input_ids": ids}

tokenized = dataset.map(encode, remove_columns=dataset.column_names)


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
from torch.utils.data import DataLoader

    # def collate(batch):
    #     x = torch.stack([b["input_ids"] for b in batch])
    #     return x, x  # input and target are the same for next-token LM

def collate(batch, pad_token_id=0):
    # stack into (B, T)
    x = torch.stack([b["input_ids"] for b in batch])
    # input is everything except final token
    input_ids  = x[:, :-1]
    # target is everything except first token
    target_ids = x[:, 1:]
    return input_ids, target_ids

loader = DataLoader(tokenized, batch_size=8, collate_fn=collate)


In [7]:
def causal_mask_inputs(x):
    # x: (batch, seq_len, dim)
    B, T, D = x.shape
    # create upper-triangular mask: (T, T)
    mask = torch.tril(torch.ones(T, T, device=x.device))
    # expand to batch and feature dims
    mask = mask.unsqueeze(0).unsqueeze(-1)          # (1, T, T, 1)
    x = x.unsqueeze(1)                              # (B, 1, T, D)
    # broadcast so position i only keeps past j ≤ i
    x_masked = (x * mask).sum(dim=2)                # (B, T, D)
    return x_masked

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = len(tokenizer)
model = MiniTransformerLM(vocab_size).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)

accumulation_step = 64 / 8

optim.zero_grad()
for step, (x, y) in enumerate(loader):
    x, y = x.to(device), y.to(device)
    B, T = x.shape

    attn_mask = torch.triu(torch.ones(T, T, device=device) * float('-inf'), diagonal=1)

    logits = model(x, attn_mask)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y.view(-1))

    
    loss.backward()

    if (step + 1) % accumulation_step:
        optim.step()
        optim.zero_grad()
    
    if step % 100 == 0:
        print(f"step {step} | loss {loss.item():.4f}")
        torch.save(model.state_dict(), 'llm.pt')


step 0 | loss 10.9290
step 100 | loss 3.1340
step 200 | loss 5.6683
step 300 | loss 4.1739
step 400 | loss 4.3506
step 500 | loss 4.7017
step 600 | loss 2.9628
step 700 | loss 3.1012
step 800 | loss 2.6092
step 900 | loss 2.2340
step 1000 | loss 4.6275
step 1100 | loss 4.1496
step 1200 | loss 1.7028
step 1300 | loss 3.3404
step 1400 | loss 3.5240
step 1500 | loss 4.5412
step 1600 | loss 4.4101
step 1700 | loss 3.7683
step 1800 | loss 4.1143
step 1900 | loss 4.3877
step 2000 | loss 1.4420
step 2100 | loss 4.2939
step 2200 | loss 4.3631
step 2300 | loss 6.1179
step 2400 | loss 4.5910
step 2500 | loss 4.2241
step 2600 | loss 3.8288
step 2700 | loss 4.3593
step 2800 | loss 2.7222
step 2900 | loss 2.5809
step 3000 | loss 3.3163
step 3100 | loss 3.5158
step 3200 | loss 3.6414
step 3300 | loss 3.1643
step 3400 | loss 3.2544
step 3500 | loss 3.0553
step 3600 | loss 3.3687
step 3700 | loss 2.6648
step 3800 | loss 3.9459
step 3900 | loss 3.5089
step 4000 | loss 2.7996
step 4100 | loss 3.0660
ste

KeyboardInterrupt: 

In [None]:
import torch

# Assume 'tokenizer' is your tokenizer with encode/decode methods
# vocab_size should match the tokenizer
vocab_size = len(tokenizer)
model = MiniTransformerLM(vocab_size)
state_dict = torch.load("llm.pt")
model.load_state_dict(state_dict)
model.eval()  # set to evaluation mode

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

def generate_text(model, tokenizer, prompt, max_new_tokens=50, temperature=1.0):
    # Encode prompt
    input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)  # shape: [1, T]
    
    for _ in range(max_new_tokens):
        # causal mask: prevent attending to future tokens
        T = input_ids.shape[1]
        mask = torch.tril(torch.ones(T, T, device=device)).unsqueeze(0)  # shape [1, T, T]
        
        # Get logits
        logits = model(input_ids, mask=mask)  # shape: [1, T, vocab_size]
        next_token_logits = logits[0, -1, :] / temperature
        
        # Sample next token
        probs = torch.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).unsqueeze(0)  # shape [1, 1]
        
        # Append to input_ids
        input_ids = torch.cat([input_ids, next_token], dim=1)
    
    # Decode tokens
    return tokenizer.decode(input_ids[0].tolist())

# Example usage
prompt = "There is a wallet on a"
generated_text = generate_text(model, tokenizer, prompt, max_new_tokens=50)
print(generated_text)


There is a wallet on a flying to practice intends to travel alert when in Edinburgh. The taxi types are typically you can fly from £10 to take you to UK retail but there are extra baggage and entry you will insist will also begin your holidays with the route along.
There
