In [1]:
import torch 
from torch import nn
import torch.nn.functional as F

import numpy as np
import math

import os
import requests

# Attention workflow
1. Input 
2. Word embeddings
3. Positional embeddings
4. Concat(Word, Positional)
5. Normalization (optional)
6. Attention


In [2]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Download:
    def __init__(self, path, data_url, create_split: bool = True):
        """
        Initialize the Download class.
        
        :param path: The path where the data will be saved.
        :param data_url: The URL to download the data from.
        :param create_split: Whether to create train/validation split.
        """
        self.path = path
        self.data_url = data_url
        self.create_split = create_split

    def fetch(self):
        """
        Fetch the data from the specified URL and save it to the specified path.
        """
        try:
            # Create the 'data' directory relative to the current directory
            data_dir = os.path.join(os.getcwd(), 'data')
            os.makedirs(data_dir, exist_ok=True)

            # Construct the full file path
            input_file_path = os.path.join(data_dir, self.path)

            # Download the file if it doesn't exist
            if not os.path.exists(input_file_path):
                response = requests.get(self.data_url)
                response.raise_for_status()  # Raise an error for bad status codes
                with open(input_file_path, 'w') as f:
                    f.write(response.text)

            # Read and print the length of the dataset
            with open(input_file_path, 'r') as f:
                self.data = f.read()
            print(f"Length of dataset in characters: {len(self.data):,}")
            self.preprocessing()
        except requests.exceptions.RequestException as e:
            print(f"Error downloading the file: {e}")
        except Exception as e:
            print(f"An error occurred: {e}")
    
    def preprocessing(self):
        """
        Preprocess the data to create character mappings and encoding functions.
        """
        # Get all the unique characters that occur in this text
        chars = sorted(list(set(self.data)))
        vocab_size = len(chars)
        print("All the unique characters:", ''.join(chars))
        print(f"Vocab size: {vocab_size:,}")

        # Create a mapping from characters to integers
        stoi = {ch: i for i, ch in enumerate(chars)}
        itos = {i: ch for i, ch in enumerate(chars)}

        self.encode = lambda s: [stoi[c] for c in s]
        self.decode = lambda l: ''.join([itos[i] for i in l]) 

    def split(self, size=0.9):
        """
        Create train and validation splits from the data.
        
        :param size: The proportion of data to use for training.
        :return: Encoded training and validation data.
        """
        n = len(self.data)
        train_data = self.data[:int(n * size)]
        val_data = self.data[int(n * size):]  

        # Encoding
        train_ids = self.encode(train_data)
        val_ids = self.encode(val_data)
        print(f"Train has {len(train_ids):,} tokens")
        print(f"Val has {len(val_ids):,} tokens")

        return train_ids, val_ids
    
    def get_batch(self, train_ids=None, val_ids=None, split: str = 'train', context_len: int = 1000, batch_size: int = 8, device_type: str = 'mps', device: str = 'mps'):
        """
        Generate batches of data for training or validation.
        
        :param train_ids: Encoded training data.
        :param val_ids: Encoded validation data.
        :param split: Whether to use 'train' or 'val' data.
        :param context_len: Length of the context for each sample.
        :param batch_size: Number of samples per batch.
        :param device_type: Type of device ('cuda' or other).
        :param device: Specific device identifier.
        :return: Batch of input and target data.
        """
        data = train_ids if split == 'train' else val_ids
        print(f"Preparing {split} batch")

        ix = torch.randint(len(data) - context_len, (batch_size,))
        x = torch.stack([torch.tensor(data[i:i + context_len]) for i in ix])
        y = torch.stack([torch.tensor(data[i + 1:i + 1 + context_len]) for i in ix])
        
        if device_type == 'cuda':
            # Pin arrays x, y, which allows us to move them to GPU asynchronously (non_blocking=True)
            x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
        else:
            x, y = x.to(device), y.to(device)
        return x, y

# Example usage:
input_file = 'tiny_shakespeare.txt'
URL = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'

downloader = Download(input_file, URL)
downloader.fetch()
train_ids, val_ids = downloader.split()
x, y = downloader.get_batch(train_ids=train_ids)


Length of dataset in characters: 1,115,394
All the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocab size: 65
Train has 1,003,854 tokens
Val has 111,540 tokens
Preparing train batch


In [3]:
x

tensor([[45, 46, 39,  ..., 50,  8,  0],
        [43, 39, 60,  ..., 57,  1, 59],
        [47, 60, 43,  ..., 52, 53,  1],
        ...,
        [58, 39, 58,  ..., 57,  1, 41],
        [63,  1, 39,  ..., 50, 39, 41],
        [58, 46, 43,  ..., 43,  8,  0]], device='mps:0')

In [4]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size=50276, n_embd=768, block_size=1024, device='mps'):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, n_embd).to(device)
        self.wpe = nn.Embedding(block_size, n_embd).to(device)

    def forward(self, x):
        b, t = x.size()
        pos = torch.arange(0, t, dtype=torch.long, device=x.device)
        T = self.wte(x)
        P = self.wpe(pos)[None, :, :].expand(b, t, -1)
        o = T + P
        return o

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd=768, n_head=12, dropout=0.1, block_size=1024, device='mps'):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.dropout = dropout
        self.block_size = block_size

        self.qkv_proj = nn.Linear(n_embd, 3 * n_embd).to(device)
        self.c_proj = nn.Linear(n_embd, n_embd).to(device)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.resid_dropout(self.c_proj(y))
        return y

class TransformerModel(nn.Module):
    def __init__(self, vocab_size=50276, n_embd=768, n_head=12, dropout=0.1, block_size=1024, device='mps'):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, n_embd, block_size, device)
        self.attention = CausalSelfAttention(n_embd, n_head, dropout, block_size, device)

    def forward(self, x):
        x = self.embeddings(x)
        x = self.attention(x)
        return x

# Example usage
# E = Embeddings(vocab_size=50276, n_embd=768, block_size=1024, device='mps')
# A = CausalSelfAttention(n_embd=768, n_head=12, dropout=0.1, block_size=1024, device='mps')
# A(E(x))

model = TransformerModel()
logits = model(x)

In [5]:
# class Embeddings(nn.Module):
#     def __init__(self, vocab_size=50276, n_embd=768, block_size=1024, device='mps'):
#         super().__init__()
#         self.wte = nn.Embedding(vocab_size, n_embd).to(device)
#         self.wpe = nn.Embedding(block_size, n_embd).to(device)

#     def forward(self, x):
#         b, t = x.size()
#         pos = torch.arange(0, t, dtype=torch.long, device=x.device)
#         T = self.wte(x)
#         P = self.wpe(pos)[None, :, :].expand(b, t, -1)
#         return T + P

class FlashAttention(nn.Module):
    def __init__(self, n_embd=768, n_head=12, dropout=0.1, block_size=1024, device='mps'):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.dropout = dropout
        self.block_size = block_size

        self.qkv_proj = nn.Linear(n_embd, 3 * n_embd).to(device)
        self.c_proj = nn.Linear(n_embd, n_embd).to(device)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Flash Attention computation
        q = q / math.sqrt(C // self.n_head)  # Scaling
        attn_weights = torch.einsum('bhqd, bhkd -> bhqk', q, k)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)
        
        y = torch.einsum('bhqk, bhvd -> bhqd', attn_weights, v)
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.resid_dropout(self.c_proj(y))
        return y

class TransformerModel(nn.Module):
    def __init__(self, vocab_size=50276, n_embd=768, n_head=12, dropout=0.1, block_size=1024, device='mps'):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, n_embd, block_size, device)
        self.attention = FlashAttention(n_embd, n_head, dropout, block_size, device)

    def forward(self, x):
        x = self.embeddings(x)
        x = self.attention(x)
        return x

# Example usage
model = TransformerModel()
output = model(x)

print(output.size())


torch.Size([8, 1000, 768])


In [6]:
# class Embeddings(nn.Module):
#     def __init__(self, vocab_size=50276, n_embd=768, block_size=1024, device='mps'):
#         super().__init__()
#         self.wte = nn.Embedding(vocab_size, n_embd).to(device)
#         self.wpe = nn.Embedding(block_size, n_embd).to(device)

#     def forward(self, x):
#         b, t = x.size()
#         pos = torch.arange(0, t, dtype=torch.long, device=x.device)
#         T = self.wte(x)
#         P = self.wpe(pos)[None, :, :].expand(b, t, -1)
#         return T + P

class SparseAttention(nn.Module):
    def __init__(self, n_embd=768, n_head=12, dropout=0.1, block_size=1024, sparsity_pattern=None, device='mps'):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.dropout = dropout
        self.block_size = block_size
        self.sparsity_pattern = sparsity_pattern

        self.qkv_proj = nn.Linear(n_embd, 3 * n_embd).to(device)
        self.c_proj = nn.Linear(n_embd, n_embd).to(device)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Sparse Attention computation
        attn_weights = torch.einsum('bhqd, bhkd -> bhqk', q, k) / math.sqrt(C // self.n_head)
        
        if self.sparsity_pattern is not None:
            attn_weights = attn_weights.masked_fill(self.sparsity_pattern == 0, -1e9)
        
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)
        
        y = torch.einsum('bhqk, bhvd -> bhqd', attn_weights, v)
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.resid_dropout(self.c_proj(y))
        return y

class TransformerModel(nn.Module):
    def __init__(self, vocab_size=50276, n_embd=768, n_head=12, dropout=0.1, block_size=1024, sparsity_pattern=None, device='mps'):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, n_embd, block_size, device)
        self.attention = SparseAttention(n_embd, n_head, dropout, block_size, sparsity_pattern, device)

    def forward(self, x):
        x = self.embeddings(x)
        x = self.attention(x)
        return x

# Example usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define a sparsity pattern (example: attend to only a subset of tokens)
sparsity_pattern = torch.ones(1, 12, 1024, 1024, device=device)  # Modify this pattern as needed
# For example, you can create a block sparsity pattern
for i in range(0, 1024, 32):
    for j in range(0, 1024, 32):
        if (i // 32 + j // 32) % 2 == 0:
            sparsity_pattern[:, :, i:i+32, j:j+32] = 0

x = torch.randint(1, 50276, (1, 1024), device=device)
model = TransformerModel(sparsity_pattern=sparsity_pattern, device=device).to(device)
output = model(x)

print(output.size())


torch.Size([1, 1024, 768])


In [7]:
class LocalAttention(nn.Module):
    def __init__(self, n_embd=768, n_head=12, dropout=0.1, window_size=128, device='mps'):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.dropout = dropout
        self.window_size = window_size

        self.qkv_proj = nn.Linear(n_embd, 3 * n_embd).to(device)
        self.c_proj = nn.Linear(n_embd, n_embd).to(device)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        output = []
        for i in range(T):
            start = max(0, i - self.window_size)
            end = min(T, i + self.window_size)
            q_i = q[:, :, i, :]
            k_i = k[:, :, start:end, :]
            v_i = v[:, :, start:end, :]

            attn_weights = torch.einsum('bhd,bhjd->bhj', q_i, k_i) / math.sqrt(C // self.n_head)
            attn_weights = F.softmax(attn_weights, dim=-1)
            attn_weights = self.attn_dropout(attn_weights)

            context = torch.einsum('bhj,bhjd->bhd', attn_weights, v_i)
            output.append(context)

        output = torch.stack(output, dim=2).transpose(1, 2).contiguous().view(B, T, C)
        output = self.resid_dropout(self.c_proj(output))
        return output

class TransformerModel(nn.Module):
    def __init__(self, vocab_size=50276, n_embd=768, n_head=12, dropout=0.1, block_size=1024, window_size=128, device='mps'):
        super().__init__()
        self.embeddings = Embeddings(vocab_size, n_embd, block_size, device)
        self.attention = LocalAttention(n_embd, n_head, dropout, window_size, device)

    def forward(self, x):
        x = self.embeddings(x)
        x = self.attention(x)
        return x

# Example usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# x = torch.randint(1, 50276, (1, 1024), device=device)
model = TransformerModel()
output = model(x)

print(output.size())


torch.Size([1, 1024, 768])
