In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
import math

Tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-large") # 774M parameters

tokenizer.pad_token = tokenizer.eos_token  # Use the end-of-sequence token as padding token

In [None]:
tokenizer.vocab_size

Encoding (Embedding)

In [4]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model
        
    def forward(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

Positional Encoding

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length, dropout):
        super().__init__()

        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

Attention

In [6]:
class Attention(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def forward(self,
                token_embeddings_q,
                token_embeddings_k,
                token_embeddings_v,
                mask=None):
        
        # [batch, seq_len, d_model]
        q = self.W_q(token_embeddings_q)  
        k = self.W_k(token_embeddings_k)  
        v = self.W_v(token_embeddings_v)  

        # For batched inputs, we need to transpose the last two dimensions
        # This keeps the batch dimension (dim 0) intact!
        sims = torch.matmul(q, k.transpose(-2, -1))  # [batch, seq_len, seq_len]

        scaled_sims = sims / math.sqrt(k.size(-1)) # scaled dot product

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
        
        attention_percents = F.softmax(scaled_sims, dim=-1)  # Apply softmax along the last dimension

        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

Multi-Head Attention

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):

        super().__init__()

        self.heads = nn.ModuleList(
            [Attention(d_model=d_model, dropout=dropout)
            for _ in range(num_heads)]
        )

        print(self.heads)

        # Add projection layer to combine outputs from multiple heads
        self.output_projection = nn.Linear(d_model * num_heads, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, 
                token_embeddings_q,
                token_embeddings_k,
                token_embeddings_v,
                mask=None):
        
        # Concatenate outputs from all attention heads
        multi_head_output = torch.cat(
            [head(token_embeddings_q, token_embeddings_k, token_embeddings_v, mask)
            for head in self.heads],
            dim=-1)
        
        output = self.output_projection(multi_head_output)
        
        return self.dropout(output)

Feed Forward Neural Network (FFNN)

In [8]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()

        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """x is the output of the multi-head attention layer - tensor of shape [batch, seq_len, d_model*num_heads]"""
     
        x = F.relu(self.fc1(x))

        x = self.dropout(x)

        x = self.fc2(x)

        return x

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # Multi-head attention
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Feed-forward network
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff, dropout)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention with residual connection and layer norm
        attn_output = self.attention(x, x, x, mask)
        x = self.norm1(x + attn_output)
        
        # Feed-forward with residual connection and layer norm
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)
        
        return x

In [10]:
class NotebookGPT(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, max_seq_len=1024, dropout=0.1):
        super().__init__()
        
        # Token embedding layer
        self.token_embedding = TokenEmbedding(vocab_size, d_model)
        
        # Positional encoding
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Final layer norm
        self.norm = nn.LayerNorm(d_model)
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, mask=None):
        # Get token embeddings
        x = self.token_embedding(x)
        
        # Add positional encoding
        x = self.positional_encoding(x)
        
        # Apply transformer blocks
        for block in self.transformer_blocks:
            x = block(x, mask)
        
        # Apply final layer norm
        x = self.norm(x)
        
        # Project to vocabulary
        logits = self.output_projection(x)
        
        return logits
    
    def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None):
        """
        Generate text by sampling from the model's distribution.
        
        Args:
            input_ids (Tensor): Starting token indices of shape [batch_size, seq_len]
            max_new_tokens (int): Maximum number of new tokens to generate
            temperature (float): Temperature for sampling (higher = more random)
            top_k (int, optional): If specified, only sample from the top k most likely tokens
            
        Returns:
            Tensor: Generated token indices of shape [batch_size, seq_len + max_new_tokens]
        """
        self.eval()  # Set to evaluation mode
        
        # Create a copy of the input tensor to avoid modifying the original
        generated_ids = input_ids.clone()
        batch_size = generated_ids.size(0)
        
        # Create causal mask for the maximum possible sequence length
        seq_len = generated_ids.size(1)
        max_possible_len = seq_len + max_new_tokens
        causal_mask = torch.triu(
            torch.ones(max_possible_len, max_possible_len), diagonal=1
        ).bool().to(input_ids.device)
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                # If sequence is too long, truncate it to fit model's context window
                if generated_ids.size(1) > 1024:  # Assuming 1024 is the max context length
                    generated_ids = generated_ids[:, -1024:]
                
                # Get the current sequence length
                curr_seq_len = generated_ids.size(1)
                
                # Use the appropriate part of the causal mask
                curr_mask = causal_mask[:curr_seq_len, :curr_seq_len].unsqueeze(0)
                
                # Forward pass to get logits
                logits = self(generated_ids, mask=curr_mask)
                
                # Get the logits for the next token prediction (last token in sequence)
                next_token_logits = logits[:, -1, :] / temperature
                
                # Optional top-k sampling
                if top_k is not None:
                    top_k = min(top_k, next_token_logits.size(-1))
                    # Get the top-k values and indices
                    values, indices = torch.topk(next_token_logits, top_k, dim=-1)
                    # Create a mask for the top-k values
                    mask = torch.zeros_like(next_token_logits).scatter_(1, indices, 1)
                    # Apply the mask and set non-top-k values to -inf
                    next_token_logits = torch.where(mask.bool(), next_token_logits, 
                                                   torch.tensor(-float('inf')).to(next_token_logits.device))
                
                # Apply softmax to get probabilities
                probs = F.softmax(next_token_logits, dim=-1)
                
                # Sample from the distribution
                next_token = torch.multinomial(probs, num_samples=1)
                
                # Append the sampled token to the sequence
                generated_ids = torch.cat([generated_ids, next_token], dim=1)
                
        return generated_ids

In [None]:
# Initialize the model
model = NotebookGPT(
    vocab_size=tokenizer.vocab_size,
    d_model=512,
    num_heads=8,
    num_layers=4,
    d_ff=2048,
    dropout=0.1
)

In [12]:
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [13]:
def train_model(model, tokenizer, train_data, epochs=3, lr=5e-5, batch_size=4):
    """
    Simple training loop for the NotebookGPT model.
    
    Args:
        model: The NotebookGPT model
        tokenizer: The tokenizer
        train_data: List of text samples for training
        epochs: Number of training epochs
        lr: Learning rate
        batch_size: Batch size
    """
    # Set model to training mode
    model.train()
    
    # Adjust batch size if it's larger than the dataset
    batch_size = min(batch_size, len(train_data))
    
    # Define optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    # Define loss function
    loss_fn = nn.CrossEntropyLoss()
    
    # Training loop
    for epoch in range(epochs):
        total_loss = 0
        batch_count = 0
        
        # Process data in batches
        for i in range(0, len(train_data), batch_size):
            batch_texts = train_data[i:i+batch_size]
            
            # Tokenize batch
            inputs = tokenizer(batch_texts, padding=True, truncation=True, 
                              return_tensors="pt", max_length=512)
            input_ids = inputs.input_ids.to(device)
            
            # Create targets (shifted input_ids)
            targets = input_ids.clone()
            
            # Create causal mask
            seq_len = input_ids.size(1)
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len), diagonal=1
            ).bool().to(device)
            
            # Forward pass
            logits = model(input_ids, mask=causal_mask)
            
            # Reshape for loss calculation
            logits = logits[:, :-1, :].contiguous().view(-1, tokenizer.vocab_size)
            targets = targets[:, 1:].contiguous().view(-1)
            
            # Calculate loss
            loss = loss_fn(logits, targets)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            batch_count += 1
            
        # Print epoch statistics
        avg_loss = total_loss / max(1, batch_count)  # Avoid division by zero
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")
    
    return model

In [None]:
# Load training data from file
with open(r'data\training-data.txt', 'r', encoding='utf-8') as f:
    train_data = f.readlines()
    
# Remove any empty lines and strip whitespace
train_data = [line.strip() for line in train_data if line.strip()]

train_data

In [None]:
# Train the model
model = train_model(model, tokenizer, train_data)

Prompt

In [16]:
prompt = "The story of the Karamozovs"

In [17]:
# Tokenize input (text to token)
prompt_input_tokens = tokenizer.encode(prompt, return_tensors="pt")

In [18]:
input_ids = prompt_input_tokens.to(device)

In [19]:
# Generate text
generated_ids = model.generate(
    input_ids=input_ids,
    max_new_tokens=100,
    temperature=0.1,
    top_k=40
)

In [None]:
generated_ids # the sequence of tokens the model generated

In [None]:
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print("Generated text:")
print(generated_text)

In [None]:
def count_parameters(model):
    """
    Count the total number of trainable parameters in the model.
    
    Args:
        model: PyTorch model
        
    Returns:
        int: Total number of trainable parameters
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Count and display the number of parameters
total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params:,}")

# You can also print a more detailed breakdown by layer
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.numel():,} parameters")