<a href="https://colab.research.google.com/github/TomJenkin/playground/blob/main/transformers_03.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
import pandas as pd
import datetime
import yfinance as yf
from sklearn import cluster

In [18]:
!pip install -qq torchinfo
from torchinfo import summary

In [19]:
start = datetime.datetime.now()

In [20]:
if False:
    sp500 = yf.Ticker("^GSPC")
    data = sp500.history(period="max")

    cob_start = '20200101'
    cob_end = None
    n_clusters = 20
    ds = data.Close.pct_change().rename(0).loc[slice(cob_start,cob_end)]
    dr = ds.to_frame()
    for n in range(1,5):
        dr[-n] = dr[0].shift(n)
    dr = dr.sort_index(axis=1).dropna()
    model = cluster.KMeans(n_clusters=n_clusters,random_state=0).fit(dr)
    dr = dr.assign(label = model.labels_)
    dm = dr.label.value_counts().reset_index().rename_axis('label2').reset_index()
    dm = dm.assign(label2 = dm.label2+1)
    dmr = dm.set_index('label').label2.to_dict()
    dr = dr.assign(label2 = dr.label.map(dmr))
    labels = dr.label2.to_list()
    print(dr.shape)

    prompts = [
        labels[0:100],
        labels[100:200],
        labels[200:300],
    ]

In [21]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# -----------------------
# Positional Encodings (sin/cos as in the paper)
# -----------------------
def sinusoidal_position_encoding(max_len: int, d_model: int, device=None):
    pe = torch.zeros(max_len, d_model, device=device)
    position = torch.arange(0, max_len, device=device).unsqueeze(1)  # [T, 1]
    div_term = torch.exp(torch.arange(0, d_model, 2, device=device) *
                         (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe  # [T, d_model]

# -----------------------
# Multi-Head Self-Attention (causal)
# -----------------------
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)
        self.attn_drop = nn.Dropout(dropout)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        # x: [B, T, d_model]
        B, T, C = x.shape
        qkv = self.qkv(x)  # [B, T, 3C]
        q, k, v = qkv.chunk(3, dim=-1)

        # reshape -> [B, n_heads, T, d_head]
        def split_heads(t):
            return t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        # scaled dot-product attention
        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)  # [B, h, T, T]
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.attn_drop(attn)
        out = attn @ v  # [B, h, T, d_head]

        # concat heads
        out = out.transpose(1, 2).contiguous().view(B, T, C)  # [B, T, C]
        out = self.proj_drop(self.out(out))
        return out

# -----------------------
# Position-wise FeedForward
# -----------------------
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        return self.fc2(self.drop(F.relu(self.fc1(x))))  # ReLU as in the paper

# -----------------------
# Transformer Block (Pre-LN for stability)
# -----------------------
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.sa = MultiHeadSelfAttention(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        x = x + self.drop(self.sa(self.ln1(x), attn_mask=attn_mask))
        x = x + self.drop(self.ff(self.ln2(x)))
        return x

# -----------------------
# Simplified Decoder-Only Transformer LM
# -----------------------
class SimpleTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_layers=4, n_heads=4, d_ff=1024,
                 max_len=512, dropout=0.1, pad_id=0):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_len = max_len
        self.pad_id = pad_id

        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding.from_pretrained(
            sinusoidal_position_encoding(max_len, d_model),
            freeze=True
        )
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        # weight tying
        self.head.weight = self.tok_emb.weight

    def _make_attention_mask(self, x):
        """
        Build a combined mask:
        - causal mask to block attending to future positions
        - key padding mask to block attending to pads in K/V
        returns: mask with True where attention is NOT allowed, shape [B, 1, T, T]
        """
        B, T = x.shape
        device = x.device

        # causal: [T, T] True above diagonal (i<j)
        causal = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1)
        causal = causal.unsqueeze(0).unsqueeze(0)  # [1,1,T,T]

        # key padding: mask positions that are pad in keys
        key_pad = (x == self.pad_id).unsqueeze(1).unsqueeze(2)  # [B,1,1,T]

        # broadcast OR
        attn_mask = causal | key_pad  # [B,1,T,T]
        return attn_mask

    def forward(self, x):
        """
        x: LongTensor [B, T] with token ids (may include pad_id).
        Returns logits for all positions: [B, T, vocab_size]
        """
        B, T = x.shape
        if T > self.max_len:
            raise ValueError(f"Sequence length {T} exceeds max_len {self.max_len}")

        device = x.device
        pos = torch.arange(T, device=device)
        h = self.tok_emb(x) * math.sqrt(self.d_model) + self.pos_emb(pos)  # [B,T,C]
        h = self.drop(h)

        attn_mask = self._make_attention_mask(x)  # [B,1,T,T]
        for blk in self.blocks:
            h = blk(h, attn_mask=attn_mask)

        h = self.ln_f(h)
        logits = self.head(h)  # [B,T,V]
        return logits

    @torch.no_grad()
    def predict_next_token(self, x, temperature=1.0, top_k=None):
        """
        Returns next-token logits/probs for each sequence.
        x: LongTensor [B, T]
        """
        logits = self.forward(x)[:, -1, :]  # last position per sequence
        if temperature != 1.0:
            logits = logits / temperature
        if top_k is not None:
            # top-k filtering for sampling; keep it generic
            topk_vals, topk_idx = torch.topk(logits, k=top_k, dim=-1)
            mask = torch.full_like(logits, float('-inf'))
            logits = mask.scatter(1, topk_idx, topk_vals)

        probs = F.softmax(logits, dim=-1)  # [B, V]
        next_token = torch.argmax(probs, dim=-1)  # greedy
        return next_token, probs  # ints [B], floats [B,V]

# -----------------------
# Example usage
# -----------------------
if __name__ == "__main__":

    # batch of sequences (variable lengths, padded with pad_id)
    batch = [
        [101, 42, 77, 88, 5],
        [101, 19, 23],
        [101, 202, 13, 9, 17, 4, 55]
    ]

    batch = [
        [1,1,2]*80,
        [1,1,2]*80,
        [1,1,2]*80,
        [1,1,2]*80,
        [1,1,2]*60,
        [1,1,2]*40,
        [1,1,2]*20,
        [1,1,2]*1,
    ]

    #batch = labels

    vocab = sorted(set.union(*[set(e) for e in batch]))

    # toy config
    assert min(vocab) > 0
    vocab_size = max(vocab) + 1
    # vocab_size = 32000
    pad_id = 0
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = SimpleTransformerLM(
        vocab_size=vocab_size,
        #d_model=256,
        #n_layers=4,
        #n_heads=4,
        #d_ff=1024,
        #max_len=256,
        d_model=256,
        n_layers=4,
        n_heads=4,
        d_ff=1024,
        max_len=2048,
        dropout=0.1,
        pad_id=pad_id
    ).to(device)

    # pad right to same length
    maxT = max(len(s) for s in batch)
    x = torch.full((len(batch), maxT), pad_id, dtype=torch.long)
    for i, s in enumerate(batch):
        x[i, :len(s)] = torch.tensor(s, dtype=torch.long)
    x = x.to(device)

    # ---- quick warm-up training so the model learns 1,1,2 -> 1 ----
    model.train()
    optim = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)

    # teacher-forcing targets: predict next token
    # inp means inputs, targ means target
    inp  = x[:, :-1]          # [B, T-1]
    targ = x[:, 1:]           # [B, T-1]

    steps = 100
    for _ in range(steps):
        logits = model(inp)               # [B, T-1, V]
        loss = loss_fn(logits.reshape(-1, model.vocab_size), targ.reshape(-1))
        optim.zero_grad(set_to_none=True)
        loss.backward()
        optim.step()

    model.eval()
    # ---- end warm-up training ----

    # predict the next token for each sequence
    next_ids, next_probs = model.predict_next_token(x, temperature=1.0)
    print("Next token ids:", next_ids.tolist())
    # next_probs[i] is the full distribution for sequence i


Next token ids: [1, 1, 1, 1, 1, 1, 1, 1]


In [22]:
next_ids, next_probs = model.predict_next_token(x, temperature=1.0)
print("Next token ids:", next_ids.tolist())

Next token ids: [1, 1, 1, 1, 1, 1, 1, 1]


In [23]:
pd.DataFrame(next_probs.cpu()).style.format('{:,.2%}')

Unnamed: 0,0,1,2
0,0.00%,100.00%,0.00%
1,0.00%,100.00%,0.00%
2,0.00%,100.00%,0.00%
3,0.00%,100.00%,0.00%
4,0.00%,98.08%,1.92%
5,0.00%,98.07%,1.93%
6,0.00%,98.07%,1.93%
7,0.00%,97.95%,2.05%


In [24]:
end = datetime.datetime.now()
run_time = (end-start)
print(run_time)

0:02:12.123449


In [25]:
if False:
    for name, module in model.named_modules():
        print(name, "->", module.__class__.__name__)

In [26]:
if False:
    print(model)

In [28]:
summary(
    model,
    input_size=(2, 16),    # (batch_size=2, seq_len=16)
    dtypes=[torch.long],   # important! embeddings expect LongTensor
    device="cpu",          # or "cuda" if you’re on GPU
    mode="eval"
)

Layer (type:depth-idx)                        Output Shape              Param #
SimpleTransformerLM                           [2, 16, 3]                --
├─Embedding: 1-1                              [2, 16, 256]              768
├─Embedding: 1-2                              [16, 256]                 (524,288)
├─Dropout: 1-3                                [2, 16, 256]              --
├─ModuleList: 1-4                             --                        --
│    └─TransformerBlock: 2-1                  [2, 16, 256]              --
│    │    └─LayerNorm: 3-1                    [2, 16, 256]              512
│    │    └─MultiHeadSelfAttention: 3-2       [2, 16, 256]              263,168
│    │    └─Dropout: 3-3                      [2, 16, 256]              --
│    │    └─LayerNorm: 3-4                    [2, 16, 256]              512
│    │    └─FeedForward: 3-5                  [2, 16, 256]              525,568
│    │    └─Dropout: 3-6                      [2, 16, 256]              --


In [34]:
vocab_size = 2
B, T = 2, 16   # batch size 2, sequence length 16
dummy_input = torch.randint(1, vocab_size, (B, T), dtype=torch.long).to(device)

# print summary
summary(model, input_data=dummy_input, mode="eval")

Layer (type:depth-idx)                        Output Shape              Param #
SimpleTransformerLM                           [2, 16, 3]                --
├─Embedding: 1-1                              [2, 16, 256]              768
├─Embedding: 1-2                              [16, 256]                 (524,288)
├─Dropout: 1-3                                [2, 16, 256]              --
├─ModuleList: 1-4                             --                        --
│    └─TransformerBlock: 2-1                  [2, 16, 256]              --
│    │    └─LayerNorm: 3-1                    [2, 16, 256]              512
│    │    └─MultiHeadSelfAttention: 3-2       [2, 16, 256]              263,168
│    │    └─Dropout: 3-3                      [2, 16, 256]              --
│    │    └─LayerNorm: 3-4                    [2, 16, 256]              512
│    │    └─FeedForward: 3-5                  [2, 16, 256]              525,568
│    │    └─Dropout: 3-6                      [2, 16, 256]              --


In [None]:
summary(model)

In [None]:
summary(model,input_size=(240,48))

In [None]:
input_size = torch.tensor([12, 24], dtype=torch.long)
summary(model,input_size=input_size)

In [None]:
summary(model,input_size=(int(240),int(48)))

In [None]:
import torch
import torch.nn as nn
from torchinfo import summary

d_model = 512
model2 = nn.Transformer(
    d_model=d_model, nhead=8,
    num_encoder_layers=6, num_decoder_layers=6,
    dim_feedforward=2048
)

S, T, N = 10, 9, 32
src = torch.randn(S, N, d_model)  # floats are correct here
tgt = torch.randn(T, N, d_model)

summary(model2, input_data=(src, tgt), mode="eval")
