##Transformer Architecture (based on GPT-OSS) implementation and testing
This notebook was based on task that I have to do for Deep Neural Networks course on University of Warsaw during fall semester of 2025. Because of that not all the code is my, the backbone-structure (based on GPT-OSS 2025) was already implemented. To keep things fair my code in architecture part is comes after ### My Code ### comment. The testing of transformers models was done solely by me. \
\
\
At the first part of this notebook different modules of GPT-OSS-like architecture are implemented. They include:   
- SwiGLU module
- Grouped Query Attention  
- Sliding Window Attention
- Rotational Positionary Embedding
- Mixture of Experts (without shared expert at this point, maybe I will find time to add it later). \



The testing part right now consists of test on three datasets: \
-> CutieSimp model is trained on very simple and easy dataset with simple sentences about two people: 'Cutie' and 'Farfocl' and their relationship. The point here was to get familiar with a very simple text-based example. \
-> TinyStoryTeller model is trained on TinyStories dataset containing short stories. Here tokenization is also implemented. The goal here was to experiment with a bit of a harder task \
-> WieszczMimic model is trained on dataset created from four most important creations of Polish national poet Adam Mickiewicz (*Pan Tadeusz*, *Dziady*, *Konrad Wallenrod* and *Grażyna*). I wanted to do a fun experiment, although results are not really satisfactory

## Implementation of the architecture modules
----------------------------

In [10]:
#imports
import random
import torch
import torch.nn.functional as F
from torch.utils.data import IterableDataset, DataLoader
from typing import List, Optional, Callable
from collections import Counter
from tqdm import tqdm
import functools
import csv
import sys

#configurations
csv.field_size_limit(sys.maxsize);

### SwiGluFeedForward

In [11]:
class SwiGLUFeedForward(torch.nn.Module):
    def __init__(self, hidden_dim: int, inner_dim: int) -> None:
        """
        Args:
            hidden_dim: Dimension of input and output tensors.
            inner_dim: Dimension of the intermediate (inner) representation.
        """
        super().__init__()

        ### My code ###
        self.projection1 = torch.nn.Linear(hidden_dim, inner_dim)
        self.act = torch.nn.SiLU()
        self.projection2 = torch.nn.Linear(hidden_dim, inner_dim)
        self.deprojection = torch.nn.Linear(inner_dim, hidden_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape [batch_size, seq_len, hidden_dim].

        Returns:
            Output tensor of shape [batch_size, seq_len, hidden_dim].
        """
        assert len(x.shape) == 3, f"Expected 3D tensor, got shape {x.shape}"

        ### My code ###
        x_base = self.projection1(x)
        x_multiplicative = self.projection2(x)
        x_base = self.act(x_base)
        x = x_base*x_multiplicative
        result = self.deprojection(x)

        return result

### Rotary Positional Embedding (RoPE)

In [12]:
class RotaryPositionalEmbedding(torch.nn.Module):
    def __init__(self, head_dim: int, max_seq_len: int = 2048, base: float = 10000.0) -> None:
        """
        Args:
            head_dim: Dimension of each attention head (must be even).
            max_seq_len: Maximum sequence length to precompute embeddings for.
            base: Base for computing rotation frequencies.

        WARNING: YOUR IMPLEMENTATION MUST PRECOMPUTE THE EMBEDDINGS
        """
        super().__init__()
        assert head_dim % 2 == 0, "head_dim must be even for RoPE"

        ### My code ###

        self.base = base
        self.head_dim = head_dim
        self._precompute_cache(max_seq_len)

    def _precompute_cache(self, seq_len: int) -> None:

        ### My code ###
        i_range = torch.arange(0, self.head_dim, 2)
        self.theta = torch.pow(self.base, - i_range / self.head_dim)




    def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape [batch, num_heads, seq_len, head_dim].
            start_pos: Starting position index (for KV-cache during inference).

        Returns:
            Tensor with rotary embedding applied, same shape as input.
        """

        ### My Code ###
        batch, num_heads, seq_len, head_dim = x.shape
        x_reshaped = x.reshape(-1, seq_len, head_dim)  # Combine batch and num_heads
        x_rotated = torch.zeros_like(x_reshaped)
        for pos in range(seq_len):
            m = pos+start_pos
            main_diag = torch.stack((torch.cos(m*self.theta), torch.cos(m*self.theta))).T.flatten()
            off_diag = torch.stack((torch.sin(m*self.theta), 0*torch.ones(head_dim // 2))).T.flatten()
            mask = torch.ones_like(off_diag, dtype=torch.bool)
            mask[-1] = False
            off_diag = off_diag[mask]
            perm = torch.stack((torch.arange(0, head_dim//2), torch.arange(head_dim//2, head_dim)), dim=1).reshape(-1)
            rotmat = torch.diag(main_diag) + torch.diag(-off_diag, diagonal = 1) + torch.diag(off_diag, diagonal = -1)
            rotmat = rotmat[:,perm][perm,:].to(x.device)
            x_rotated[:, pos, :] = torch.matmul(x_reshaped[:, pos, :], rotmat.T)
        x_rotated = x_rotated.reshape(batch, num_heads, seq_len, head_dim)
        return x_rotated



##### TESTS START #####

@torch.no_grad()
def test_rope() -> None:
    """Test RoPE applies correct rotations."""
    head_dim = 4
    max_seq_len = 8
    batch, num_heads, seq_len = 2, 2, 4

    rope = RotaryPositionalEmbedding(head_dim, max_seq_len)
    x = torch.ones(batch, num_heads, seq_len, head_dim)

    result = rope(x)

    expected = torch.tensor(
        [[[[ 1.0000,  1.0000,  1.0000,  1.0000],
          [-0.3012,  0.9900,  1.3818,  1.0099],
          [-1.3254,  0.9798,  0.4932,  1.0198],
          [-1.1311,  0.9696, -0.8489,  1.0295]],

         [[ 1.0000,  1.0000,  1.0000,  1.0000],
          [-0.3012,  0.9900,  1.3818,  1.0099],
          [-1.3254,  0.9798,  0.4932,  1.0198],
          [-1.1311,  0.9696, -0.8489,  1.0295]]],


        [[[ 1.0000,  1.0000,  1.0000,  1.0000],
          [-0.3012,  0.9900,  1.3818,  1.0099],
          [-1.3254,  0.9798,  0.4932,  1.0198],
          [-1.1311,  0.9696, -0.8489,  1.0295]],

         [[ 1.0000,  1.0000,  1.0000,  1.0000],
          [-0.3012,  0.9900,  1.3818,  1.0099],
          [-1.3254,  0.9798,  0.4932,  1.0198],
          [-1.1311,  0.9696, -0.8489,  1.0295]]]]
    )

    assert result.shape == x.shape, f"Shape mismatch: {result.shape} vs {x.shape}"
    assert torch.allclose(result, expected, atol=1e-4), "Error in ROPE"


test_rope()

#####  TESTS END  #####

### Grouped Query Attention (GQA) adn Sliding Window Attention (SWA)


In [13]:
def calculate_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    key_weights: torch.Tensor,
    rope: RotaryPositionalEmbedding,
    scale: float,
    device: torch.device,
    mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """
    Args:
        q: Query tensor of shape [batch, num_heads, seq_len, head_dim].
        k: Key tensor of shape [batch, num_kv_heads, seq_len, head_dim].
        v: Value tensor of shape [batch, num_kv_heads, seq_len, head_dim].
        key_weights: Per-head key weights of shape [num_heads].
        rope: Rotary positional embedding module.
        scale: Scaling factor (typically 1/sqrt(head_dim)).
        device: Device to create the causal mask on.
        mask: Optional attention mask of shape [seq_len, seq_len]. If None, uses causal mask.

    Returns:
        Output tensor of shape [batch, num_heads, seq_len, head_dim].
    """
    ### My Code ###
    batch, num_heads, seq_len, head_dim = q.size()
    num_kv_heads = k.size()[1]
    heads_per_kv = num_heads // num_kv_heads
    k = rope(k, head_dim)
    k = k.repeat_interleave(heads_per_kv, dim=1)    #Needs optimization
    v = v.repeat_interleave(heads_per_kv, dim=1)
    q = rope(q, head_dim)
    k = k * key_weights.view(1, num_heads, 1, 1)
    k_t = k.transpose(-1, -2)
    attention = torch.matmul(q, k_t) * scale
    if mask == None:
      mask = torch.triu(torch.full((seq_len, seq_len), -1e9, device=device, dtype=q.dtype), diagonal=1)
      mask = mask.unsqueeze(0).unsqueeze(0)  # broadcast to [1,1,seq_len,seq_len]
    attention = attention + mask
    attention_weights = F.softmax(attention, dim=-1)
    O = torch.matmul(attention_weights, v)
    output = O
    return output


class GroupedQueryAttention(torch.nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        num_heads: int,
        head_dim: int,
        num_kv_heads: Optional[int] = None
    ) -> None:
        """
        Args:
            hidden_dim: Input/output dimension.
            num_heads: Number of query heads.
            head_dim: Dimension of each head.
            num_kv_heads: Number of key-value heads. If None, defaults to num_heads (standard MHA).
        """
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        self.scale = head_dim ** -0.5

        assert num_heads % self.num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"
        self.num_queries_per_kv = num_heads // self.num_kv_heads

        ### My code ###
        self.key_weights = torch.nn.Parameter(torch.ones(num_heads))
        self.kmat = torch.nn.Linear(self.hidden_dim, self.num_kv_heads*self.head_dim)
        self.vmat = torch.nn.Linear(self.hidden_dim, self.num_kv_heads*self.head_dim)
        self.qmat = torch.nn.Linear(self.hidden_dim, self.num_heads*self.head_dim)
        self.outmat = torch.nn.Linear( self.num_heads * self.head_dim, self.hidden_dim)
        self.rope = RotaryPositionalEmbedding(self.head_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape [batch, seq_len, hidden_dim].

        Returns:
            Output tensor of shape [batch, seq_len, hidden_dim].
        """
        assert len(x.shape) == 3
        batch, seq_len, _ = x.shape

        ### My Code ###
        k = self.kmat(x)
        k = k.reshape(batch, seq_len, self.num_kv_heads, self.head_dim)
        k = k.transpose(-2, -3)

        v = self.vmat(x)
        v = v.reshape(batch, seq_len, self.num_kv_heads, self.head_dim)
        v = v.transpose(-2, -3)

        q = self.qmat(x)
        q = q.reshape(batch, seq_len, self.num_heads, self.head_dim)
        q = q.transpose(-2, -3)

        O = calculate_attention(q, k, v, self.key_weights, self.rope, self.scale, q.device)
        O = O.transpose(-2, -3).reshape(batch, seq_len, self.num_heads * self.head_dim)
        output = self.outmat(O)
        return output

In [14]:
def calculate_sliding_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    key_weights: torch.Tensor,
    rope: RotaryPositionalEmbedding,
    scale: float,
    device: torch.device,
    window_size: int
) -> torch.Tensor:
    """
    Args:
        q: Query tensor of shape [batch, num_heads, seq_len, head_dim].
        k: Key tensor of shape [batch, num_kv_heads, seq_len, head_dim].
        v: Value tensor of shape [batch, num_kv_heads, seq_len, head_dim].
        key_weights: Per-head key weights of shape [num_heads].
        rope: Rotary positional embedding module.
        scale: Scaling factor (typically 1/sqrt(head_dim)).
        device: Device to create the causal mask on.
        window_size: Number of previous tokens each position can attend to.

    Returns:
        Output tensor of shape [batch, num_heads, seq_len, head_dim].
    """
    _, __, seq_len, ___ = q.size()
    mask = torch.triu(torch.full((seq_len, seq_len), -1e9, device=device, dtype=q.dtype), diagonal=1) + torch.tril(torch.full((seq_len, seq_len), -1e9, device=device, dtype=q.dtype), diagonal=-window_size-1)
    mask = mask.unsqueeze(0).unsqueeze(0)  # broadcast to [1,1,seq_len,seq_len]
    output = calculate_attention(q, k, v, key_weights, rope, scale, device, mask = mask)
    return output


class SWAttention(torch.nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        num_heads: int,
        head_dim: int,
        window_size: int
    ) -> None:
        """
        Args:
            hidden_dim: Input/output dimension.
            num_heads: Number of attention heads.
            head_dim: Dimension of each head.
            window_size: Size of the sliding window.
        """
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.window_size = window_size
        self.scale = head_dim ** -0.5

        ### My code ###
        self.key_weights = torch.nn.Parameter(torch.ones(num_heads))
        self.kmat = torch.nn.Linear(self.hidden_dim, self.num_heads*self.head_dim)
        self.vmat = torch.nn.Linear(self.hidden_dim, self.num_heads*self.head_dim)
        self.qmat = torch.nn.Linear(self.hidden_dim, self.num_heads*self.head_dim)
        self.outmat = torch.nn.Linear( self.num_heads * self.head_dim, self.hidden_dim)
        self.rope = RotaryPositionalEmbedding(self.head_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape [batch, seq_len, hidden_dim].

        Returns:
            Output tensor of shape [batch, seq_len, hidden_dim].
        """
        assert len(x.shape) == 3
        batch, seq_len, _ = x.shape

        ### My Code ###
        k = self.kmat(x)
        k = k.reshape(batch, seq_len, self.num_heads, self.head_dim)
        k = k.transpose(-2, -3)  #check dimensions

        v = self.vmat(x)
        v = v.reshape(batch, seq_len, self.num_heads, self.head_dim)
        v = v.transpose(-2, -3)

        q = self.qmat(x)
        q = q.reshape(batch, seq_len, self.num_heads, self.head_dim)
        q = q.transpose(-2, -3)

        O = calculate_sliding_attention(q, k, v, self.key_weights, self.rope, self.scale, q.device, self.window_size)
        O = O.transpose(-2, -3).reshape(batch, seq_len, self.num_heads * self.head_dim)
        output = self.outmat(O)
        return output


##### TESTS START #####

@torch.no_grad()
def test_calculate_sliding_attention() -> None:
    """Test the calculate_sliding_attention function independently of module weights."""
    torch.manual_seed(42)
    batch, seq_len = 2, 4
    num_heads, head_dim = 4, 4
    window_size = 2

    q = torch.randn(batch, num_heads, seq_len, head_dim)
    k = torch.randn(batch, num_heads, seq_len, head_dim)
    v = torch.randn(batch, num_heads, seq_len, head_dim)

    key_weights = torch.randn(num_heads)
    rope = RotaryPositionalEmbedding(head_dim)
    scale = head_dim ** -0.5

    output = calculate_sliding_attention(q, k, v, key_weights, rope, scale, q.device, window_size)

    assert output.shape == (batch, num_heads, seq_len, head_dim), \
        f"Wrong output shape: {output.shape}, expected {(batch, num_heads, seq_len, head_dim)}"

    expected = torch.tensor(
        [[[[-6.8548e-01,  5.6356e-01, -1.5072e+00, -1.6107e+00],
          [-7.5833e-01,  5.5151e-01, -1.3803e+00, -1.3910e+00],
          [-7.5970e-01,  5.4425e-01, -1.3694e+00, -1.4018e+00],
          [-1.0289e+00, -1.5047e-03,  1.3449e-01,  8.8395e-02]],

         [[-1.3793e+00,  6.2580e-01, -2.5850e+00, -2.4000e-02],
          [-1.0804e+00,  2.9937e-01, -1.5638e+00, -4.5193e-03],
          [-3.7572e-01,  9.0874e-01, -9.8827e-01,  3.2158e-01],
          [ 5.5610e-01,  9.3138e-01,  8.2518e-01,  3.6249e-01]],

         [[ 9.7329e-01, -1.0151e+00, -5.4192e-01, -4.4102e-01],
          [ 3.7820e-01, -6.0546e-01, -6.2194e-01, -2.5908e-01],
          [ 7.5603e-01, -4.5413e-01, -2.9462e-01, -6.9975e-02],
          [ 5.6501e-01,  6.4487e-02,  4.0517e-01,  4.1787e-01]],

         [[ 4.0380e-01, -7.1398e-01,  8.3373e-01, -9.5855e-01],
          [ 4.2490e-01,  1.1594e-01, -4.9589e-01, -1.0976e+00],
          [ 3.5349e-01, -4.0529e-01, -6.6044e-01, -1.1089e+00],
          [-2.1912e-01, -6.5963e-01,  1.6555e-01, -1.0503e+00]]],


        [[[ 4.3344e-01, -7.1719e-01,  1.0554e+00, -1.4534e+00],
          [ 4.4607e-01, -2.8344e-01,  6.3300e-01, -8.4259e-01],
          [ 4.2691e-01,  3.2082e-01, -4.8548e-01, -5.2133e-01],
          [ 3.8897e-01,  6.5369e-01, -1.5100e+00, -7.1436e-01]],

         [[ 8.8538e-01,  1.8244e-01,  7.8638e-01, -5.7920e-02],
          [ 7.1542e-01, -2.9334e-01,  1.0705e-01, -3.1831e-04],
          [ 6.7078e-01,  7.1021e-01,  4.8379e-02,  5.4688e-01],
          [ 7.3988e-01,  2.3339e-01, -3.6269e-01,  3.0450e-01]],

         [[-7.9394e-01,  3.7523e-01,  8.7910e-02, -1.2415e+00],
          [-5.1264e-01, -3.4904e-01, -2.9172e-01,  6.7694e-01],
          [ 4.9596e-01,  5.8218e-01, -1.2478e-01,  1.5970e-01],
          [ 6.7310e-02,  1.5020e-01, -2.8622e-01,  1.1465e+00]],

         [[-2.1844e-01,  1.6630e-01,  2.1442e+00,  1.7046e+00],
          [ 9.8019e-02,  4.3332e-01,  8.2743e-01,  1.1330e+00],
          [ 2.0074e-02, -5.5385e-02,  2.2436e-01,  9.4053e-01],
          [-1.2419e-01,  1.2867e-01, -6.4606e-01,  2.5874e-01]]]]
    )

    assert torch.allclose(output, expected, atol=1e-4), \
        f"calculate_sliding_attention output values mismatch"

test_calculate_sliding_attention()

#####  TESTS END  #####

### Mixture of Experts

In [15]:
class Router(torch.nn.Module):
    def __init__(self, hidden_dim: int, num_experts: int, top_k: int = 2) -> None:
        """
        Args:
            hidden_dim: Input dimension.
            num_experts: Total number of experts.
            top_k: Number of experts to activate per token.
        """
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.top_k = top_k

        ### My Code ###
        self.linearity = torch.nn.Linear(hidden_dim, num_experts)


    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input tensor of shape [batch, seq_len, hidden_dim].

        Returns:
            routing_weights: Tensor of shape [batch, seq_len, top_k] with softmax weights.
            expert_indices: Tensor of shape [batch, seq_len, top_k] with selected expert indices.
        """
        assert len(x.shape) == 3

        ### My code ###
        w = self.linearity(x)
        expert_indices = w.argsort(dim=-1, descending=True)[:self.top_k]
        topk_scores = torch.gather(w, -1, expert_indices)  # shape: [batch, seq_len, top_k]
        routing_weights = F.softmax(topk_scores, dim=-1)

        return routing_weights, expert_indices


class MixtureOfExperts(torch.nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        inner_dim: int,
        num_experts: int = 8,
        top_k: int = 2
    ) -> None:
        """
        Args:
            hidden_dim: Input/output dimension.
            inner_dim: Inner dimension of each expert.
            num_experts: Total number of experts.
            top_k: Number of experts to activate per token.
        """
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.top_k = top_k

        ### My code ###
        self.inner_dim = inner_dim
        self.experts = torch.nn.ModuleList([SwiGLUFeedForward(self.hidden_dim, self.inner_dim) for i in range(num_experts)])
        self.router = Router(self.hidden_dim, self.num_experts, self.top_k)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape [batch, seq_len, hidden_dim].

        Returns:
            Output tensor of shape [batch, seq_len, hidden_dim].
        """
        assert len(x.shape) == 3
        batch, seq_len, hidden_dim = x.shape

        ### My code ###
        routing_weights, expert_indices = self.router(x)
        experts_outputs = torch.stack([expert(x) for expert in self.experts], dim=0)
        indices = expert_indices.unsqueeze(-1).expand(-1, -1, -1, hidden_dim)  # [batch, seq_len, top_k, hidden_dim]
        experts_outputs = experts_outputs.permute(1, 2, 0, 3)  # [batch, seq_len, num_experts, hidden_dim]

        # Gather top-k outputs
        topk_outputs = torch.gather(experts_outputs, 2, indices)  # [batch, seq_len, top_k, hidden_dim]
        routing_weights = routing_weights.unsqueeze(-1)  # [batch, seq_len, top_k, 1]
        weighted_outputs = topk_outputs * routing_weights  # [batch, seq_len, top_k, hidden_dim]

        # Sum over top-k
        output = weighted_outputs.sum(dim=2)  # [batch, seq_len, hidden_dim]

        return output

###Transformer Block


In [16]:
class TransformerBlock(torch.nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        ff_dim: int,
        num_heads: int,
        head_dim: int,
        use_sliding_window: bool = False,
        window_size: int = 128,
        use_moe: bool = False,
        num_experts: int = 8,
        top_k: int = 2,
        num_kv_heads: Optional[int] = None
    ) -> None:
        """
        Args:
            hidden_dim: Hidden dimension.
            ff_dim: Feed-forward inner dimension.
            num_heads: Number of attention heads.
            head_dim: Dimension per attention head.
            use_sliding_window: Whether to use sliding window attention.
            window_size: Size of sliding window (if used).
            use_moe: Whether to use Mixture of Experts instead of single FFN.
            num_experts: Number of experts (if MoE).
            top_k: Number of experts per token (if MoE).
            num_kv_heads: Number of KV heads for GQA (None = standard MHA).
        """
        super().__init__()
        self.hidden_dim = hidden_dim
        self.ff_dim = ff_dim
        self.num_heads = num_heads
        self.head_dim = head_dim

        ### My code ###
        if use_sliding_window:
          self.use_sliding_window = use_sliding_window
          self.window_size = window_size
          self.attention = SWAttention(self.hidden_dim, self.num_heads, self.head_dim, self.window_size)
        else:
          self.num_kv_heads = num_kv_heads
          self.attention = GroupedQueryAttention(self.hidden_dim, self.num_heads, self.head_dim, self.num_kv_heads)

        if use_moe:
          self.use_moe = use_moe
          self.num_experts = num_experts
          self.top_k = top_k
          self.forward_layer = MixtureOfExperts(self.hidden_dim, self.ff_dim, self.num_experts, self.top_k)
        else:
          self.forward_layer = SwiGLUFeedForward(self.hidden_dim, self.ff_dim)

        self.norm1 = torch.nn.RMSNorm(hidden_dim)
        self.norm2 = torch.nn.RMSNorm(hidden_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input tensor of shape [batch, seq_len, hidden_dim].

        Returns:
            Output tensor of shape [batch, seq_len, hidden_dim].
        """
        ### My code ###
        r = self.norm1(x)
        r = r + self.attention(r)
        r = self.norm2(r)
        result = r + self.forward_layer(r)

        assert x.shape == result.shape
        return result

### The Transfomer model

In [17]:
class Transformer(torch.nn.Module):
    def __init__(
        self,
        vocab_size: int,
        n_layers: int,
        hidden_dim: int,
        ff_dim: int,
        num_heads: int,
        head_dim: int,
        use_sliding_window_alternating: bool = False,
        window_size: int = 128,
        use_moe: bool = False,
        num_experts: int = 8,
        top_k: int = 2,
        num_kv_heads: Optional[int] = None
    ) -> None:
        """
        Args:
            vocab_size: Size of the vocabulary.
            n_layers: Number of transformer layers.
            hidden_dim: Hidden dimension.
            ff_dim: Feed-forward inner dimension.
            num_heads: Number of attention heads.
            head_dim: Dimension per attention head.
            use_sliding_window_alternating: Use sliding window on every other layer
            window_size: Size of sliding window.
            use_moe: Whether to use Mixture of Experts.
            num_experts: Number of experts (if MoE).
            top_k: Number of experts per token (if MoE).
            num_kv_heads: Number of KV heads for GQA.
        """
        super().__init__()

        self.vocab_size = vocab_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.ff_dim = ff_dim
        self.num_heads = num_heads
        self.head_dim = head_dim

        ### My code ###
        self.embedding = torch.nn.Embedding(self.vocab_size, self.hidden_dim)
        self.output_proj = torch.nn.Linear(self.hidden_dim, self.vocab_size)
        self.final_norm = torch.nn.RMSNorm(hidden_dim)
        if use_sliding_window_alternating:
          dummy = [False, True]
          self.layers = torch.nn.Sequential(*[TransformerBlock(self.hidden_dim, self.ff_dim, self.num_heads, self.head_dim, dummy[i%2], window_size, use_moe, num_experts, top_k, num_kv_heads) for i in range(n_layers)])
        else:
          self.layers = torch.nn.Sequential(*[TransformerBlock(self.hidden_dim, self.ff_dim, self.num_heads, self.head_dim, False, use_moe=use_moe, num_kv_heads = num_kv_heads) for i in range(n_layers)])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Token indices of shape [batch, seq_len].

        Returns:
            Logits of shape [batch, seq_len, vocab_size].
        """
        assert len(x.shape) == 2, f"Expected 2D input, got shape {x.shape}"

        ### My code ###

        r = self.embedding(x)
        r = self.layers(r)
        r = self.final_norm(r)
        logits = self.output_proj(r)
        return logits



## First Dataset - Cutie&Farfocl  ##
----------------------------
Small version of this transformer architecture will learn simple text patterns from 'Cutie&Farfocl' dataset. It consists of repetitive sentences from vocabulary of around 70 words


In [18]:
with open("CutieAndFarfocl.csv", "r", encoding="utf-8") as f:
    reader = csv.reader(f)
    text = next(reader)[0]
print(text[160:206])
print(text[255:308])

Despite Farfocl being big Cutie loves Farfocl 
Cutie is not only really cute but also really smart .


### Data pipeline ###

In [19]:
class TextDataset:
    def __init__(self, text: str):
        # Tokenize by spaces
        self.tokens: List[str] = text.split(" ")
        self.cnt = Counter(self.tokens)

        # Vocabulary
        self.vocab = sorted(set(self.tokens))
        self.vocab_size = len(self.vocab)

        self.stoi = {tok: i for i, tok in enumerate(self.vocab)}
        self.itos = {i: tok for tok, i in self.stoi.items()}

        # Encode entire text as indices
        self.data = torch.tensor(
            [self.stoi[t] for t in self.tokens],
            dtype=torch.long
        )

    def __len__(self):
        return len(self.data)

class RandomBatchGenerator:
    def __init__(self, dataset: TextDataset, seq_len: int, batch_size: int, device="cpu"):
        self.dataset = dataset
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.device = device

    def sample(self):
        data = self.dataset.data
        vocab_size = self.dataset.vocab_size

        max_start = len(data) - self.seq_len - 1
        starts = torch.randint(0, max_start, (self.batch_size,))

        x_idx = torch.stack([
            data[s : s + self.seq_len] for s in starts
        ])

        y = torch.stack([
            data[s + 1 : s + self.seq_len + 1] for s in starts
        ])

        # One-hot embedding
        x = torch.nn.functional.one_hot(
            x_idx, num_classes=vocab_size
        ).float()

        return x.to(self.device), y.to(self.device)


In [20]:
class RandomTextDataset(IterableDataset):
    def __init__(self, data, seq_len):
        """
        data: 1D LongTensor of token IDs
        """
        self.data = data
        self.seq_len = seq_len

    def __iter__(self):
        data = self.data
        L = self.seq_len
        max_start = len(data) - L - 1

        while True:
            s = torch.randint(0, max_start, (1,)).item()

            x = data[s : s + L]           # [seq_len]
            y = data[s + 1 : s + L + 1]   # [seq_len]

            yield x, y



def decode(indices, itos):
    """
    indices: 1D or 2D tensor of token IDs
    itos: dict mapping index -> token (from TextDataset)
    """
    # ensure CPU
    indices = indices.cpu()

    # if 2D batch, pick first sequence
    if indices.dim() == 2:
        indices = indices[0]

    # map each token ID -> string
    decoded = " ".join(itos[i.item()] for i in indices)

    # find all occurrences of '.'
    dots = [i for i, char in enumerate(decoded) if char == '.']

    if not dots:
        return decoded

    # discard everything after last '.'
    last_dot = dots[-1]
    decoded = decoded[:last_dot + 1]

    # add newline after every '.' except the last one
    result = []
    prev = 0
    for dot in dots[:-1]:
        result.append(decoded[prev:dot + 1])
        prev = dot + 1
    result.append(decoded[prev:])

    return "\n".join(result)


def split_data(data, split_ratio=0.9):
    n = int(len(data) * split_ratio)
    return data[:n], data[n:]


In [21]:
dataset = TextDataset(text)
train_data, val_data = split_data(dataset.data)

seq_len = 32
batch_size = 64

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_ds = RandomTextDataset(
    train_data,
    seq_len=seq_len,
)

val_ds = RandomTextDataset(
    val_data,
    seq_len=seq_len,
)

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    num_workers=0
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    num_workers=0
)

In [22]:
print('Vocabulary')
print('Index: word: counts')
for i, item in enumerate(dataset.vocab):
  print(str(i) + ': ' + item + ': ' + str(dataset.cnt[item]))

Vocabulary
Index: word: counts
0: : 641
1: .: 23946
2: 22: 73
3: Cutie: 17931
4: Despite: 630
5: Farfocl: 15008
6: Not: 648
7: On: 4632
8: The: 504
9: Warsaw: 77
10: a: 800
11: about: 126
12: adores: 126
13: also: 7596
14: always: 160
15: and: 3156
16: are: 2484
17: attracted: 73
18: being: 882
19: big: 1407
20: but: 12228
21: cares: 126
22: clumsy: 1406
23: couple: 480
24: crazy: 1406
25: creative: 2122
26: cute: 2122
27: despite: 252
28: dumb: 1406
29: extremely: 4283
30: for: 320
31: freaky: 480
32: funny: 2122
33: hand: 4952
34: his: 219
35: hot: 3688
36: impressive: 1
37: impressively: 3963
38: in: 786
39: inseparable: 320
40: inspiration: 73
41: interesting: 2122
42: is: 27487
43: living: 73
44: love: 626
45: lovely: 2442
46: lovers: 160
47: loves: 126
48: merely: 5323
49: muse: 73
50: not: 12272
51: of: 252
52: often: 160
53: old.: 73
54: only: 7596
55: other: 4632
56: pair: 320
57: petite: 2122
58: pretty: 2123
59: really: 3963
60: sexy: 2282
61: slow: 1406
62: small: 2123
63: 

### Training and Evaluation functions


In [23]:
@torch.no_grad
def eval_acc(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader):
    model.eval()
    sum_acc = 0
    num_examples = 0
    for step, (x, y) in enumerate(dataloader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        model_out = model(x)

        # Transpose model_out to match expected input for cross_entropy: [batch_size, vocab_size, seq_len]
        acc = (torch.argmax(model_out, dim=-1) == y).to(torch.float32).sum()
        sum_acc += acc
        num_examples += model_out.shape[0] * model_out.shape[1]

        if step > 10:
          break

    return sum_acc / num_examples


def eval_fn(step, model, dataloader):
    acc = eval_acc(model, dataloader)
    print(f"{step}: Avg eval accuracy {acc}")


def train(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    dataloader: torch.utils.data.DataLoader,
    eval_fn: functools.partial,
    num_epochs: int,
    scheduler,
):
    model.train()

    for epoch in range(num_epochs):
        if epoch == 0:
            eval_fn(epoch, model)

        model.train()
        total_loss = 0.0
        num_batches = 0
        for i, (x, y) in tqdm(enumerate(dataloader)):
            ### My code ###
            x, y = x.to(DEVICE), y.to(DEVICE)
            model_out = model(x)
            loss = torch.nn.functional.cross_entropy(model_out.transpose(1, 2), y)
            #loss = torch.nn.functional.cross_entropy(model_out.transpose(1, 2), y, ignore_index=tokenizer.pad_token_id)  # <- ignore padding
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            if i > 100:
                break

        eval_fn(epoch, model)

### Model definition ###



In [24]:
CutieSimp = Transformer(
    vocab_size=dataset.vocab_size, n_layers=4, hidden_dim=128, ff_dim=256, num_heads=4, head_dim=32
)

### Training ###

In [25]:
from torch.optim.lr_scheduler import CosineAnnealingLR

CutieSimp.to(DEVICE)
optimizer = torch.optim.AdamW(CutieSimp.parameters(), lr=0.001)

num_epochs  = 12

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=100 * num_epochs
)

train(
    model=CutieSimp,
    optimizer=optimizer,
    dataloader=train_loader,
    eval_fn=functools.partial(
        eval_fn,
        dataloader=val_loader,
    ),
    num_epochs=num_epochs, scheduler=scheduler
)

0: Avg eval accuracy 0.013102213852107525


101it [00:16,  5.96it/s]


0: Avg eval accuracy 0.6468505859375


101it [00:18,  5.51it/s]


1: Avg eval accuracy 0.6560465693473816


101it [00:18,  5.58it/s]


2: Avg eval accuracy 0.6569010615348816


101it [00:17,  5.73it/s]


3: Avg eval accuracy 0.6571859121322632


101it [00:18,  5.54it/s]


4: Avg eval accuracy 0.6582438349723816


101it [00:16,  6.20it/s]


5: Avg eval accuracy 0.6540120840072632


101it [00:14,  6.79it/s]


6: Avg eval accuracy 0.6603597402572632


101it [00:14,  6.75it/s]


7: Avg eval accuracy 0.6533610224723816


101it [00:16,  6.31it/s]


8: Avg eval accuracy 0.6539306640625


101it [00:15,  6.62it/s]


9: Avg eval accuracy 0.6576334834098816


101it [00:15,  6.66it/s]


10: Avg eval accuracy 0.6569010615348816


101it [00:17,  5.69it/s]


11: Avg eval accuracy 0.65576171875


### Sampling functions ###


In [26]:
@torch.no_grad()
def token_choice_greedy(model_logits: torch.Tensor) -> torch.Tensor:
    """
    Select the most likely token (greedy decoding).

    Args:
        model_logits: Logits of shape [batch, seq_len, vocab_size].

    Returns:
        Selected token indices of shape [batch, 1].
    """
    assert len(model_logits.shape) == 3
    return torch.argmax(model_logits[:, -1:, :], dim=-1)


@torch.no_grad()
def generate(
    model: torch.nn.Module,
    input: torch.Tensor,
    gen_length: int,
    token_choice: Callable[[torch.Tensor], torch.Tensor] = token_choice_greedy
) -> torch.Tensor:
    """
    Generate new tokens autoregressively.

    Args:
        model: The transformer model.
        input: Initial token sequence of shape [batch, seq_len].
        gen_length: Number of tokens to generate.
        token_choice: Function to select next token from logits.

    Returns:
        Generated tokens of shape [batch, gen_length] (without the input).
    """
    assert len(input.shape) == 2
    model.eval()

    current_seq = input.to(DEVICE)
    output_tokens = []

    for _ in range(gen_length):
        ### My code ###
        model_out = model(current_seq)
        next_token = token_choice(model_out)
        output_tokens.append(next_token)
        current_seq = torch.cat([current_seq, next_token], dim=-1)

    return torch.cat(output_tokens, dim=-1)

@torch.no_grad()
def get_dist_after_with_temp_and_topp(
    model_logits: torch.Tensor, top_p: float, t: float
) -> torch.Tensor:
    """
    Apply temperature scaling and top-p (nucleus) sampling.

    Args:
        model_logits: Logits of shape [batch, seq_len, vocab_size].
        top_p: Cumulative probability threshold for nucleus sampling.
               0.0 = greedy (only most probable token)
               1.0 = full distribution
        t: Temperature for softmax. Higher = more uniform, lower = more peaked.

    Returns:
        Probability distribution of shape [batch, seq_len, vocab_size] with
        low-probability tokens zeroed out and remaining probabilities rescaled.
    """
    assert len(model_logits.shape) == 3
    probs = torch.nn.functional.softmax(model_logits / t, dim=-1)
    batch_size, seq_len, vocab_size = probs.shape
    probs_flat = probs.view(-1, vocab_size)

    sorted_probs, sorted_indices = torch.sort(probs_flat, descending=True)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    indices_to_remove = cumulative_probs > top_p
    indices_to_remove[..., 1:] = indices_to_remove[..., :-1].clone()
    indices_to_remove[..., 0] = False
    # Set probabilities of removed tokens to 0
    filtered_probs = sorted_probs.clone()
    filtered_probs[indices_to_remove] = 0.0
    # Re-normalize the probabilities of the remaining tokens
    sum_filtered_probs = filtered_probs.sum(dim=-1, keepdim=True)
    normalized_probs = torch.where(sum_filtered_probs == 0, torch.zeros_like(filtered_probs), filtered_probs / sum_filtered_probs)
    # Scatter the re-normalized probabilities back to their original positions
    result_flat = torch.zeros_like(probs_flat)
    result_flat.scatter_(-1, sorted_indices, normalized_probs)
    return result_flat.view(batch_size, seq_len, vocab_size)


@torch.no_grad()
def token_choice_adv(
    model_logits: torch.Tensor, top_p: float, t: float
) -> torch.Tensor:
    """
    Select next token using temperature and top-p sampling.

    Args:
        model_logits: Logits of shape [batch, seq_len, vocab_size].
        top_p: Nucleus sampling threshold.
        t: Temperature.

    Returns:
        Selected token indices of shape [batch, 1].
    """
    # This function should only operate on the last token's logits for next token prediction.
    # Therefore, the slice [:, -1:, :] is appropriate here.
    probs = get_dist_after_with_temp_and_topp(
        model_logits=model_logits[:, -1:, :], top_p=top_p, t=t
    )
    dist = torch.distributions.Categorical(probs=probs)
    return dist.sample()


### Testing ###

In [27]:
#inpucior = input('Provide the index of the first word: ')
#inpucior = torch.tensor([[int(inpucior)]])
inpucior = torch.tensor([[3, 15, 5]])
output = generate(CutieSimp, inpucior, 40, token_choice = functools.partial(token_choice_adv, top_p=0.6, t=1))

decoded_input = decode(inpucior, dataset.itos)
decoded_output = decode(output, dataset.itos)
print('-----------------')
print('Input: ' + decoded_input)
print('-----------------')
print('Output: ' + decoded_output)
print('-----------------')
print('Message together: ')
print('-----------------')
print(decoded_input + ' ' + decoded_output)

-----------------
Input: Cutie and Farfocl
-----------------
Output: are  wholesome and  a lovely couple .
 Cutie is not only very very petite but also extremely smart .
 Cutie is not only really wholesome but also very very pretty .
-----------------
Message together: 
-----------------
Cutie and Farfocl are  wholesome and  a lovely couple .
 Cutie is not only very very petite but also extremely smart .
 Cutie is not only really wholesome but also very very pretty .


## Second Dataset: TinyStories ##
----------------------------
Time for some storytelling! Let's whether I will be able to train a sensible model on colab free-plan GPU. \
Tiny stories is a well-known dataset containing short and simple stories. Language is simple which means that grammar and vocabulary are very basic. Because of that there is a hope that the model will be able to learn it and maybe create at least merely sensible stories

In [28]:
# @title Installments (colab)
!pip install -q datasets transformers sentencepiece

In [29]:
#@title Dataset Loading
from datasets import load_dataset

dataset = load_dataset("roneneldan/TinyStories")

print(dataset)

dataset_small = dataset["train"].shuffle(seed=42).select(range(50000))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})


### Tokenizer definition ###

In [30]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

MAX_LEN = 128 # Define MAX_LEN here

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()

trainer = BpeTrainer(
    vocab_size=8000,
    special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"]
)

tokenizer.train_from_iterator(
    (x["text"] for x in dataset_small),
    trainer=trainer
)

# Configure truncation on the tokenizer object
tokenizer.enable_truncation(max_length=MAX_LEN + 1)
# Configure padding on the tokenizer object
tokenizer.pad_token = "[PAD]"
tokenizer.enable_padding(direction="right", pad_id=tokenizer.token_to_id("[PAD]"), pad_type_id=0, length=MAX_LEN + 1)

def tokenize(example):
    # When batched=True, example["text"] will be a list of strings
    # We use encode_batch for efficiency, truncation and padding are configured globally
    encodings = tokenizer.encode_batch(
        example["text"],
        add_special_tokens=True
    )

    # Prepare lists for x and y sequences
    xs_batch = []
    ys_batch = []

    for encoding in encodings:
        tokens = encoding.ids
        # Since truncation and padding are enabled, `tokens` will be of length MAX_LEN + 1.
        x = tokens[:-1]  # Input sequence of length MAX_LEN
        y = tokens[1:]   # Target sequence of length MAX_LEN

        # Explicitly convert to torch.Tensor
        xs_batch.append(torch.tensor(x, dtype=torch.long))
        ys_batch.append(torch.tensor(y, dtype=torch.long))

    # The map function expects a dictionary of lists for batching
    return {"x": xs_batch, "y": ys_batch}

def collate_fn(batch):
    x = torch.stack([item["x"] for item in batch])
    y = torch.stack([item["y"] for item in batch])
    # Move to device within the collate_fn
    return x.to(DEVICE), y.to(DEVICE)

tokenized_ds = dataset_small.map(
    tokenize,
    batched=True,
    remove_columns=["text"],
    num_proc=2 # Use multiple processes for faster tokenization
)

# Split the tokenized dataset into training and validation sets
split_ds = tokenized_ds.train_test_split(test_size=0.1, seed=42) # 10% for validation

train_ds = split_ds["train"]
val_ds = split_ds["test"]

# Set the format for PyTorch, specifying the columns
train_ds.set_format(type="torch", columns=["x", "y"])
val_ds.set_format(type="torch", columns=["x", "y"])

train_loader = DataLoader(
    train_ds,
    batch_size=32,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_ds,
    batch_size=32,
    shuffle=False,
    collate_fn=collate_fn
)

def train(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    dataloader: torch.utils.data.DataLoader,
    eval_fn: functools.partial,
    num_epochs: int,
    scheduler,
):
    model.train()

    for epoch in range(num_epochs):
        if epoch == 0:
            eval_fn(epoch, model)

        model.train()
        total_loss = 0.0
        num_batches = 0
        for i, (x, y) in tqdm(enumerate(dataloader)):
            ### My code ###
            x, y = x.to(DEVICE), y.to(DEVICE)
            model_out = model(x)
            #loss = torch.nn.functional.cross_entropy(model_out.transpose(1, 2), y)
            loss = torch.nn.functional.cross_entropy(model_out.transpose(1, 2), y, ignore_index=tokenizer.token_to_id("[PAD]"))  # <- ignore padding
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            if i > 100:
                break

        eval_fn(epoch, model)

Map (num_proc=2):   0%|          | 0/50000 [00:00<?, ? examples/s]

In [31]:
#Sanity check
print(tokenizer.decode(tokenized_ds[0]["x"]))


Tim and Mia like to play in the park . They see a big club on the ground . It is brown and long and heavy . " Look , a club !" Tim says . " I can lift it !" He tries to lift the club , but it is too tough . He falls down and drops the club . " Ouch !" he says . " That hurt !" Mia laughs . She is not mean , she just thinks it is funny . " Let me try !" she says . " I can balance it !" She picks up the club and puts it on her head . She walks slowly and carefully . She does not fall down . " Wow


### Model definition ###

In [32]:
TinyStoryTeller = Transformer(
    vocab_size=8000,          # or tokenizer.vocab_size
    n_layers=4,
    hidden_dim=256,
    ff_dim=1024,
    num_heads=4,
    head_dim=64,

    use_sliding_window_alternating=False,
    use_moe=False,
)

### Training ###

In [33]:
from torch.optim.lr_scheduler import CosineAnnealingLR

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TinyStoryTeller.to(DEVICE)
optimizer = torch.optim.AdamW(TinyStoryTeller.parameters(), lr=0.001)

num_epochs  = 16

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=100 * num_epochs
)

train(
    model=TinyStoryTeller,
    optimizer=optimizer,
    dataloader=train_loader,
    eval_fn=functools.partial(
        eval_fn,
        dataloader=val_loader,
    ),
    num_epochs=num_epochs, scheduler=scheduler
)

0: Avg eval accuracy 4.069010537932627e-05


101it [01:20,  1.26it/s]


0: Avg eval accuracy 0.326904296875


101it [01:21,  1.25it/s]


1: Avg eval accuracy 0.3718668818473816


101it [01:15,  1.33it/s]


2: Avg eval accuracy 0.3967488706111908


101it [01:10,  1.44it/s]


3: Avg eval accuracy 0.4120686948299408


101it [01:10,  1.44it/s]


4: Avg eval accuracy 0.4230550229549408


101it [01:14,  1.35it/s]


5: Avg eval accuracy 0.43450927734375


101it [01:21,  1.25it/s]


6: Avg eval accuracy 0.4443562924861908


101it [01:19,  1.26it/s]


7: Avg eval accuracy 0.45098876953125


101it [01:19,  1.27it/s]


8: Avg eval accuracy 0.4582112729549408


101it [01:08,  1.48it/s]


9: Avg eval accuracy 0.4644775390625


101it [01:08,  1.48it/s]


10: Avg eval accuracy 0.4677937924861908


101it [01:07,  1.49it/s]


11: Avg eval accuracy 0.4700520932674408


101it [01:07,  1.50it/s]


12: Avg eval accuracy 0.4732666015625


101it [01:07,  1.50it/s]


13: Avg eval accuracy 0.4757893979549408


101it [01:06,  1.51it/s]


14: Avg eval accuracy 0.4761556088924408


101it [01:06,  1.51it/s]


15: Avg eval accuracy 0.47625732421875


In [34]:
#@title Text generating function
@torch.no_grad()
def generate_text(model, tokenizer, prompt="", max_len=128, temperature=0.2, top_k=50):
    model.eval()

    input_ids = torch.tensor([tokenizer.encode(prompt).ids], dtype=torch.long).to(DEVICE)

    for _ in range(max_len):
        logits = model(input_ids)                 # [1, seq_len, vocab_size]
        next_token_logits = logits[:, -1, :] / temperature

        # Top-k sampling
        topk_probs, topk_indices = torch.topk(next_token_logits, k=top_k, dim=-1)  # [1, k]
        probs = torch.nn.functional.softmax(topk_probs, dim=-1)                     # [1, k]

        # Sample 1 token
        next_token_idx_in_topk = torch.multinomial(probs, num_samples=1)           # [1,1]
        next_token = topk_indices.gather(-1, next_token_idx_in_topk)               # [1,1]
        # Append token
        input_ids = torch.cat([input_ids, next_token], dim=1)                       # [1, seq_len+1]

    return tokenizer.decode(input_ids[0].tolist())

### Prompting and short discussion ###

In [35]:
prompt = "He"
story = generate_text(TinyStoryTeller, tokenizer, prompt=prompt, max_len=24, temperature=1., top_k=50)
print(story)

He and Tom has a big bag . Tom and Tom are very pretty and had a big head . One day , a big


In [36]:
prompt = "He"
story = generate_text(TinyStoryTeller, tokenizer, prompt=prompt, max_len=24, temperature=0.1, top_k=50)
print(story)

He and a hat . She liked to make a big hat . She was very happy . She wanted to make a big cake


Here the results are unfortunately pretty unimpressive. It seems that my TinyStoriesTeller is to small so his stories are quite incoherent. Unfortunately, I don't have computational resources to train a better model. Still the model learns *something* as accuracy is around 48%. \
The stories generated by it doesn't make much sense but one can see some line of reasoning there (especially locally).
As one can see lower temperature seem to be helping (which makes sense as the model probably has a pretty narrow tree of reasoning) \
In the next section I will try to train a slightly bigger model

### Second attempt, a slightly bigger model ###

In [37]:
InfinitesimallyBiggerTinyStoryTeller = Transformer(
    vocab_size=8000,          # or tokenizer.vocab_size
    n_layers=6,
    hidden_dim=256,
    ff_dim=712,
    num_heads=6,
    head_dim=64,

    use_sliding_window_alternating=True,
    window_size=32,
    use_moe=False,
)

In [38]:
#@title Training again...
from torch.optim.lr_scheduler import CosineAnnealingLR

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

InfinitesimallyBiggerTinyStoryTeller.to(DEVICE)
optimizer = torch.optim.AdamW(InfinitesimallyBiggerTinyStoryTeller.parameters(), lr=0.001)

num_epochs  = 16

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=100 * num_epochs
)

train(
    model=InfinitesimallyBiggerTinyStoryTeller,
    optimizer=optimizer,
    dataloader=train_loader,
    eval_fn=functools.partial(
        eval_fn,
        dataloader=val_loader,
    ),
    num_epochs=num_epochs, scheduler=scheduler
)

0: Avg eval accuracy 4.069010537932627e-05


101it [01:54,  1.13s/it]


0: Avg eval accuracy 0.3417765498161316


101it [01:51,  1.10s/it]


1: Avg eval accuracy 0.3836466670036316


101it [01:50,  1.10s/it]


2: Avg eval accuracy 0.4070231318473816


101it [01:51,  1.10s/it]


3: Avg eval accuracy 0.42205810546875


101it [01:51,  1.10s/it]


4: Avg eval accuracy 0.4333903193473816


101it [01:50,  1.09s/it]


5: Avg eval accuracy 0.4451497495174408


101it [01:49,  1.08s/it]


6: Avg eval accuracy 0.453125


101it [01:49,  1.09s/it]


7: Avg eval accuracy 0.4587809443473816


101it [01:50,  1.09s/it]


8: Avg eval accuracy 0.467041015625


101it [01:50,  1.09s/it]


9: Avg eval accuracy 0.47216796875


101it [01:50,  1.09s/it]


10: Avg eval accuracy 0.47723388671875


101it [01:49,  1.08s/it]


11: Avg eval accuracy 0.48223876953125


101it [01:50,  1.09s/it]


12: Avg eval accuracy 0.4851277768611908


101it [01:49,  1.09s/it]


13: Avg eval accuracy 0.4874267578125


101it [01:48,  1.08s/it]


14: Avg eval accuracy 0.4879353940486908


101it [01:47,  1.07s/it]


15: Avg eval accuracy 0.4879150390625


In [39]:
prompt = "He"
story = generate_text(TinyStoryTeller, tokenizer, prompt=prompt, max_len=24, temperature=1., top_k=50)
print(story)

He and a yellow kite . It was loud and pretty and round the hat and a shiny blue hat . It had white hair


The results are still not great...

## Third Dataset: Mickiewicz's poetry ####
Adam Mickiewicz is a national poet of Poland, Lithuania and Belarus. His works were so important for Polish culture that he is known as 'National Prophet'. Because of that I wanted to see what such a small GPT model could learn from his works (spoiler: not much). Fortunately, books of Mickiewicz ('Grażyna', 'Dziady', 'Pan Tadeusz' and 'Konrad Wallenrod') are easy to download in txt format (from wolnelektury.pl for instance).

### Data loading and processing ###

In [40]:
txt_files = ['Pan_Tadeusz.txt', 'Wallenrod.txt', 'grazyna.txt', 'dziady.txt']
texts = []
for f in txt_files:
    with open(f, "r", encoding="utf-8") as file:
        texts.append(file.read())

# Combine into one large string if desired
full_text = "\n".join(texts)

In [41]:
import re
# Replace multiple spaces with single space
clean_text = re.sub(r"[ \t]+", " ", full_text)

# Keep line breaks for poetry, but normalize Windows/Mac line endings
clean_text = clean_text.replace("\r\n", "\n").replace("\r", "\n")

# Remove non-printable characters
clean_text = re.sub(r"[^\x20-\x7EąćęłńóśźżĄĆĘŁŃÓŚŹŻ;:—\n-]", "", clean_text)

# Remove words written in all caps (length >= 3)
clean_text = re.sub(r"\b[A-ZĄĆĘŁŃÓŚŹŻ]{3,}\b", "", clean_text)
clean_text = re.sub(r"[ \t]+", " ", clean_text)
clean_text = re.sub(r"\n\s*\n", "\n", clean_text)  # remove multiple blank lines

In [42]:
max_chars = 512  # approximate max length per sequence
lines = clean_text.split("\n")
chunks = []
current_chunk = ""
for line in lines:
    line = line.strip()
    if not line:  # skip empty lines
        continue

    # if adding the line exceeds max_chars
    if len(current_chunk) + len(line) + 1 > max_chars:
        if current_chunk:           # only append non-empty chunk
            chunks.append(current_chunk)
        current_chunk = line + "\n"  # start new chunk with current line
    else:
        current_chunk += line + "\n"  # add line to current chunk

# append any leftover chunk
if current_chunk:
    chunks.append(current_chunk)

print("Number of chunks:", len(chunks))
print(chunks[1][:200])

Number of chunks: 1320
Iść za wrócone życie podziękować Bogu),
Tak nas powrócisz cudem na Ojczyzny łono.
Tymczasem przenoś moją duszę utęsknioną
Do tych pagórków leśnych, do tych łąk zielonych,
Szeroko nad błękitnym Niemnem


In [43]:
from tokenizers import Tokenizer, models, trainers, pre_tokenizers

# Initialize a Byte-Pair Encoding (BPE) tokenizer
tokenizer = Tokenizer(models.BPE())
#tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()  # tokenizes on whitespace

# Train tokenizer on your text chunks
trainer = trainers.BpeTrainer(
    vocab_size=8000,  # can adjust depending on dataset size
    min_frequency=2,
    special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
)
tokenizer.train_from_iterator(chunks, trainer)

In [44]:
input_ids = []
for chunk in chunks:
    ids = tokenizer.encode(chunk).ids
    input_ids.append(ids)

# Optional: pad sequences to same length for batch training
from torch.nn.utils.rnn import pad_sequence

# Filter out empty tokenized lists before converting to tensors
input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids if ids]
padded_inputs = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.token_to_id("[PAD]"))

In [45]:
from torch.utils.data import TensorDataset, DataLoader, random_split

# Shift inputs by one for causal LM
x = padded_inputs[:, :-1]
y = padded_inputs[:, 1:]

# Create a TensorDataset from x and y for WieszczMimic
full_wieszcz_dataset = TensorDataset(x, y)

val_size = int(0.1 * len(full_wieszcz_dataset))
train_size = len(full_wieszcz_dataset) - val_size

train_subset, val_subset = random_split(full_wieszcz_dataset, [train_size, val_size])

train_loader = DataLoader(train_subset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=8)

### Model definition ###

In [46]:
WieszczMimic = Transformer(
    vocab_size=8000,          # or tokenizer.vocab_size
    n_layers=4,
    hidden_dim=256,
    ff_dim=712,
    num_heads=4,
    head_dim=64,

    use_sliding_window_alternating=False,
    use_moe=False,
)

### Model training ###

In [48]:
from torch.optim.lr_scheduler import CosineAnnealingLR

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

WieszczMimic.to(DEVICE)
optimizer = torch.optim.AdamW(WieszczMimic.parameters(), lr=0.0003)

num_epochs  = 16

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=100 * num_epochs
)

# Define a wrapper for eval_fn that matches the signature expected by the `train` function
def _eval_fn_for_train_call(epoch, model_instance):
    # The eval_fn from pLlSYPoogUBq expects (step, model, dataloader)
    # The train function from It-u7eb08UFq calls eval_fn(epoch, model)
    # This wrapper aligns those two calls.
    acc = eval_fn(
        epoch,
        model_instance,
        val_loader
    )
    print(f"{epoch}: Avg eval accuracy {acc}")

train(
    model=WieszczMimic,
    optimizer=optimizer,
    dataloader=train_loader,
    eval_fn=_eval_fn_for_train_call, # Pass the wrapper function
    num_epochs=num_epochs, scheduler=scheduler
)

0: Avg eval accuracy 0.0
0: Avg eval accuracy None


101it [01:25,  1.18it/s]


0: Avg eval accuracy 0.004179526586085558
0: Avg eval accuracy None


101it [01:24,  1.19it/s]


1: Avg eval accuracy 0.0030221191700547934
1: Avg eval accuracy None


101it [01:25,  1.18it/s]


2: Avg eval accuracy 0.005401234608143568
2: Avg eval accuracy None


101it [01:25,  1.18it/s]


3: Avg eval accuracy 0.009645061567425728
3: Avg eval accuracy None


101it [01:25,  1.18it/s]


4: Avg eval accuracy 0.015046295709908009
4: Avg eval accuracy None


101it [01:24,  1.19it/s]


5: Avg eval accuracy 0.01845421828329563
5: Avg eval accuracy None


101it [01:23,  1.21it/s]


6: Avg eval accuracy 0.022633744403719902
6: Avg eval accuracy None


101it [01:24,  1.19it/s]


7: Avg eval accuracy 0.023148147389292717
7: Avg eval accuracy None


101it [01:25,  1.18it/s]


8: Avg eval accuracy 0.023083847016096115
8: Avg eval accuracy None


101it [01:28,  1.14it/s]


9: Avg eval accuracy 0.023726850748062134
9: Avg eval accuracy None


101it [01:30,  1.11it/s]


10: Avg eval accuracy 0.024884259328246117
10: Avg eval accuracy None


101it [01:30,  1.11it/s]


11: Avg eval accuracy 0.025527263060212135
11: Avg eval accuracy None


101it [01:30,  1.12it/s]


12: Avg eval accuracy 0.023791151121258736
12: Avg eval accuracy None


101it [01:30,  1.11it/s]


13: Avg eval accuracy 0.02359825000166893
13: Avg eval accuracy None


101it [01:29,  1.13it/s]


14: Avg eval accuracy 0.02462705597281456
14: Avg eval accuracy None


101it [01:30,  1.11it/s]


15: Avg eval accuracy 0.02391975186765194
15: Avg eval accuracy None


101it [01:34,  1.07it/s]


16: Avg eval accuracy 0.02289094589650631
16: Avg eval accuracy None


101it [01:29,  1.13it/s]


17: Avg eval accuracy 0.02327674813568592
17: Avg eval accuracy None


101it [01:27,  1.15it/s]


18: Avg eval accuracy 0.023019546642899513
18: Avg eval accuracy None


101it [01:27,  1.15it/s]


19: Avg eval accuracy 0.02289094589650631
19: Avg eval accuracy None


101it [01:27,  1.16it/s]


20: Avg eval accuracy 0.02289094589650631
20: Avg eval accuracy None


101it [01:26,  1.17it/s]


21: Avg eval accuracy 0.02289094589650631
21: Avg eval accuracy None


101it [01:26,  1.17it/s]


22: Avg eval accuracy 0.022826645523309708
22: Avg eval accuracy None


101it [01:25,  1.18it/s]


23: Avg eval accuracy 0.022826645523309708
23: Avg eval accuracy None


### Prompting and testing

In [None]:
prompts = ["Litwo!", '\nGdzie', '\nI']
for prompt in prompts:
  print('-------------------------------------')
  story = generate_text(WieszczMimic, tokenizer, prompt=prompt, max_len=24, temperature=1., top_k=50)
  print(story)
  print('-------------------------------------')