In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, num_emb, num_heads=8):
        super().__init__()

        self.D = num_emb
        self.H = num_heads

        self.w_k = nn.Linear(self.D, self.D * self.H)
        self.w_q = nn.Linear(self.D, self.D * self.H)
        self.w_v = nn.Linear(self.D, self.D * self.H)
        self.w_c = nn.Linear(self.D * self.H, self.D)

    def forward(self, x, causal=True):
        B, T, D = x.size()

        k = self.w_k(x).view(B, T, self.H, D) 
        q = self.w_q(x).view(B, T, self.H, D)
        v = self.w_v(x).view(B, T, self.H, D)

        k = k.transpose(1, 2).contiguous().view(B * self.H, T, D)
        q = q.transpose(1, 2).contiguous().view(B * self.H, T, D)
        v = v.transpose(1, 2).contiguous().view(B * self.H, T, D)

        k = k / (D**0.25)
        q = q / (D**0.25)

        kq = torch.bmm(q, k.transpose(1, 2))

        if causal:
            mask = torch.triu_indices(T, T, offset=1)
            kq[..., mask[0], mask[1]] = float('-inf')

        skq = F.softmax(kq, dim=2)

        sa = torch.bmm(skq, v)
        sa = sa.view(B, self.H, T, D)
        sa = sa.transpose(1, 2)
        sa = sa.contiguous().view(B, T, D * self.H)

        out = self.w_c(sa)

        return out


In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, num_emb, num_neurons, num_heads=4):
        super().__init__()

        self.D = num_emb
        self.H = num_heads
        self.neurons = num_neurons

        self.msha = MultiHeadSelfAttention(num_emb=self.D, num_heads=self.H)
        self.layer_norm1 = nn.LayerNorm(self.D)
        self.layer_norm2 = nn.LayerNorm(self.D)

        self.mlp = nn.Sequential(
            nn.Linear(self.D, self.neurons * self.D),
            nn.GELU(),
            nn.Linear(self.neurons * self.D, self.D)
        )

    def forward(self, x, causal=True):
        x_attn = self.msha(x, causal)
        x = self.layer_norm1(x_attn + x)
        x_mlp = self.mlp(x)
        x = self.layer_norm2(x_mlp + x)

        return x


In [2]:
from collections import Counter

# Example poems
poems = [
    "The sun sets in the west.",
    "The moon shines bright at night.",
    "Stars twinkle in the dark sky."
]

# Tokenize the poems
tokens = [word for poem in poems for word in poem.lower().split()]

# Build vocabulary
vocab_counter = Counter(tokens)
vocab = {word: idx for idx, (word, _) in enumerate(vocab_counter.items(), start=1)}

# Add special tokens if needed
vocab['<pad>'] = 0  # Padding token
vocab_size = len(vocab)

print(f"Vocabulary Size: {vocab_size}")


Vocabulary Size: 15


In [3]:
import torch.nn as nn

embedding_dim = 256  # Example embedding dimension

embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

print(f"Embedding Layer: {embedding_layer}")


Embedding Layer: Embedding(15, 256)


In [6]:
class DecoderTransformer(nn.Module):
    def __init__(self, num_emb, emb_dimension, num_neurons, num_heads=4, num_layers=6):
        super().__init__()

        self.D = num_emb
        self.layers = nn.ModuleList([
            TransformerBlock(emb_dimension, num_neurons, num_heads)
            for _ in range(num_layers)
        ])

        self.embedding = nn.Embedding(num_emb, emb_dimension)
        self.positional_encoding = nn.Parameter(torch.zeros(1, 1000, num_emb))

        self.output_layer = nn.Linear(emb_dimension, num_emb)

    def forward(self, x, causal=True):
        B, T = x.size()
        x = self.embedding(x) + self.positional_encoding[:, :T, :]
        for layer in self.layers:
            x = layer(x, causal)
        return self.output_layer(x)


In [None]:
def train_model(model, dataset, epochs=10, batch_size=32, learning_rate=1e-4):
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()

    model.train()

    for epoch in range(epochs):
        for batch in dataloader:
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            outputs = model(input_ids)
            loss = loss_fn(outputs.transpose(1, 2), input_ids)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
