In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
# d_in: input dimension
# d_out: output dimension
# context_size: number of context words
# embedding_dim: dimension of the word embeddings
# n_heads: number of attention heads
# n_layers: number of transformer layers
# d_heads: dimension of each attention head
# d_ff: dimension of the feedforward network

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
batch_size = 16
# block_size = 32
block_size = 512

cuda


In [4]:
with open('the-verdict.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

def encode(s):
    return [stoi[c] for c in s]

def decode(indices):
    return ''.join([itos[i] for i in indices])

In [5]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [6]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [7]:
def apply_masking(x_true, t, mask_ids_tokens):
    x_masked = x_true.clone()

    random_mask = torch.rand(x_true.shape, device=x_true.device) < t

    x_masked[random_mask] = mask_ids_tokens

    return x_masked, random_mask

In [8]:
mask_ids_tokens = vocab_size

x_0 = train_data[:5]
print(x_0)
apply_masking(x_0, 0.4, mask_ids_tokens=mask_ids_tokens)

tensor([21,  1, 20, 13, 16])


(tensor([21,  1, 62, 13, 16]), tensor([False, False,  True, False, False]))

In [9]:
class multihead(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads , dropout=0.0):
        super().__init__()

        assert d_out % n_heads == 0, "d_out must be divisible by n_heads"

        self.d_in = d_in
        self.d_out = d_out
        self.n_heads = n_heads
        self.context_size = context_size
        self.d_head = d_out // n_heads
        self.dropout = dropout

        self.q = nn.Linear(self.d_in, self.d_out, bias=False)
        self.v = nn.Linear(self.d_in, self.d_out, bias=False)
        self.k = nn.Linear(self.d_in, self.d_out, bias=False)

        self.out = nn.Linear(self.d_out, self.d_out, bias=False)

    def forward(self, x):
        b, num_tokens, d_in = x.size()

        q = self.q(x).view(b, num_tokens, self.n_heads, self.d_head)
        k = self.k(x).view(b, num_tokens, self.n_heads, self.d_head)
        v = self.v(x).view(b, num_tokens, self.n_heads, self.d_head)

        q = q.permute(0, 2, 1, 3)  # (b, n_heads, num_tokens, d_head)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        att = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=self.dropout,
            is_causal=False,
        )

        att = att.permute(0, 2, 1, 3).contiguous()
        att = att.view(b, num_tokens, self.d_out)

        return self.out(att)


In [10]:
class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
        nn.Linear(n_embed, 4*n_embed),
        nn.ReLU(),
        nn.Linear(4*n_embed, n_embed),
        nn.Dropout(dropout)
        )
        # self.layer_norm = nn.LayerNorm(d_out)

    def forward(self, x):
        return self.net(x)

In [11]:
class Block(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads, dropout=0.0):
        super().__init__()
        self.att = multihead(d_in, d_out, context_size, n_heads, dropout)
        self.ff = FeedForward(d_out, dropout)
        self.layer_norm1 = nn.LayerNorm(d_out)
        self.layer_norm2 = nn.LayerNorm(d_out)

    def forward(self, x):
        x = x + self.att(self.layer_norm1(x))
        x = x + self.ff(self.layer_norm2(x))
        return x

In [12]:
## model formation
class LLaDA(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads, n_layers, dropout=0.0):
        super().__init__()
        self.n_layers = n_layers
        self.context_size = context_size
        self.d_in = d_in
        self.d_out = d_out
        self.n_heads = n_heads

        self.embedding = nn.Embedding(self.d_in, self.d_out)
        self.pos_embedding = nn.Embedding(self.context_size, self.d_out)


        self.blocks = nn.Sequential(*[
            Block(self.d_out, self.d_out, self.context_size, self.n_heads, dropout)
            for _ in range(self.n_layers)
        ])
        self.layer_norm = nn.LayerNorm(self.d_out)
        self.fc = nn.Linear(self.d_out, self.d_in)


    def forward(self, x):
        b, num_tokens = x.size()

        x = self.embedding(x) + self.pos_embedding(torch.arange(num_tokens).to(x.device))
        x = self.blocks(x)
        x = self.layer_norm(x)
        x = self.fc(x)
        return x

In [13]:
vocab_size

62

In [14]:
model_embedding_dim = 8
res = LLaDA(d_in=vocab_size+1, d_out=model_embedding_dim, context_size=block_size, n_heads=4, n_layers=2)

In [15]:
def compute_loss(x_true):

    t_actual = torch.rand(1, device=x_true.device) * (1.0 - 1e-6) + 1e-6

    x_masked_input, is_masked_mask = apply_masking(x_true, t_actual.item(), mask_ids_tokens)

    logits = res(x_masked_input)

    masked_logits = logits.view(-1, logits.size(-1))[is_masked_mask.view(-1)]
    masked_targets = x_true.view(-1)[is_masked_mask.view(-1)]

    


    loss = F.cross_entropy(masked_logits, masked_targets) 
    loss = loss / t_actual.item()
    return loss

optimizer = optim.AdamW(res.parameters(), lr=1e-3)

In [16]:
epochs = 200
train_loss = []
res.train()
res.to(device)
# Training loop
for epoch in range(epochs):

    xb, _ = get_batch('train')
    loss = compute_loss(xb)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_loss.append(loss.item())
    if(epoch%100==0):
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")


Epoch 1/200, Loss: 5.4556
Epoch 101/200, Loss: 42.8259


In [17]:
res = LLaDA(d_in=vocab_size+1, d_out=model_embedding_dim, context_size=block_size, n_heads=4, n_layers=2)

In [18]:
@torch.no_grad()
def generate(model, start_prompt_ids, ans_length, sampling_steps=5, device=device):
    model.to(device)
    model.eval()


    if start_prompt_ids.dim() == 1:
        start_prompt_ids = start_prompt_ids.unsqueeze(0)
    start_prompt_ids = start_prompt_ids.to(device)
    prompt_length = start_prompt_ids.size(1)

    current_answer_tokens = torch.full(
        (1, ans_length),
        fill_value=mask_ids_tokens,
        dtype=torch.long,
        device=device
    )

    t_initial = 1.0

    for step_idx in range(sampling_steps):
        xt_full_sequence = torch.cat((start_prompt_ids, current_answer_tokens), dim=1)
        
        if xt_full_sequence.size(1) > model.context_size:
            raise ValueError(
                f"Input sequence length ({xt_full_sequence.size(1)}) exceeds model's context_size ({model.context_size}). "
                f"Prompt length: {prompt_length}, Answer length: {ans_length}."
            )

        logits_full_sequence = model(xt_full_sequence)
        
        logits_answer_part = logits_full_sequence[:, prompt_length:, :]

        r0_predicted_answer = torch.argmax(logits_answer_part, dim=-1)

        probabilities = F.softmax(logits_answer_part.to(torch.float64), dim=-1)
        confidence_scores_answer = torch.gather(
            probabilities, 
            -1, 
            r0_predicted_answer.unsqueeze(-1)
        ).squeeze(-1).to(torch.float32)

        # Remasking (Low-Confidence Strategy)
        t_current_for_remasking = t_initial - (step_idx / sampling_steps)
        t_next_for_remasking = t_initial - ((step_idx + 1) / sampling_steps)
        
        if t_current_for_remasking <= 1e-6:
            num_tokens_to_remask_for_next_step = 0
        else:
            num_tokens_to_remask_for_next_step = int(ans_length * (t_next_for_remasking / t_current_for_remasking))
        
        num_tokens_to_remask_for_next_step = max(0, min(ans_length, num_tokens_to_remask_for_next_step))

        next_answer_tokens = r0_predicted_answer.clone()

        if num_tokens_to_remask_for_next_step > 0:
            is_currently_masked_in_answer = (current_answer_tokens == mask_ids_tokens)

            confidence_for_remasking_selection = torch.where(
                is_currently_masked_in_answer,
                confidence_scores_answer,
                torch.tensor(float('inf'), device=device)
            )
            
            sorted_confidence_indices = confidence_for_remasking_selection.argsort(dim=-1, descending=False)
            
            remask_indices_in_answer = sorted_confidence_indices[:, :num_tokens_to_remask_for_next_step]

            # Apply remasking
            next_answer_tokens.scatter_(1, remask_indices_in_answer, mask_ids_tokens)
        
        previously_unmasked_positions = (current_answer_tokens != mask_ids_tokens)
        next_answer_tokens[previously_unmasked_positions] = r0_predicted_answer[previously_unmasked_positions]


        current_answer_tokens = next_answer_tokens

        if not torch.any(current_answer_tokens == mask_ids_tokens):
            break
            
    final_response_ids = current_answer_tokens.squeeze(0)
            
    return final_response_ids

In [19]:
start_prompt = "how are you?"
start_prompt = encode(start_prompt)
start_prompt = torch.tensor(start_prompt)

In [20]:
start_prompt

tensor([43, 50, 58,  1, 36, 53, 40,  1, 60, 50, 56, 12])

In [21]:
result = generate(res, start_prompt_ids=start_prompt, ans_length=128, sampling_steps=20)

In [22]:
result = result.detach().cpu()


In [23]:
decoded_ids_list = [token_id.item() for token_id in result]

decoded_text = decode(decoded_ids_list)
print(decoded_text)

pA ppAvp zppv N NspaE AEecsNEAsoEc HvscPv_ ccv Fs_ sPErptE Dc AOaopAp-)!sa E OpHcEpAsccpppvpseAcv toyo Op pEHcc vOpEc!)csAfs cp 
