In [22]:
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn as nn
from jaxtyping import Float, Int
import requests

# Training Data

In [37]:
def get_gutenberg_book(
	id: int|None = 84,
	data_temp: Path|str = "../../../../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path|str = "../data/gutenberg_data",
	) -> list[list[str]]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

# Model Definition

In [24]:
@dataclass
class Config():
    d_model: int
    d_vocab: int
    d_hidden: int

In [25]:
class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.linear1 = nn.Linear(config.d_model, config.d_hidden)
        self.linear2 = nn.Linear(config.d_hidden, config.d_model)
        

    def forward(self, x: Float[torch.Tensor, "seq_len d_model"]) -> Float[torch.Tensor, "seq_len d_model"]:
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

In [26]:
class AttentionHead(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.W_qk = nn.Linear(config.d_model, config.d_model)
        self.W_vo = nn.Linear(config.d_model, config.d_model)
        self.softmax = nn.Softmax(dim=-1)
        

    def create_mask(self, n_c: int) -> torch.Tensor:
        mask: Float[torch.Tensor, "seq_len seq_len"] = torch.triu(-1 * torch.inf * torch.ones(n_c, n_c), diagonal=1)
        return mask

    def forward(self, x: Float[torch.Tensor, "seq_len d_model"]) -> Float[torch.Tensor, "seq_len d_model"]:
        #create mask, with size n_c x n_c
        mask = self.create_mask(x.shape[0])

        #compute attention scores
        # A = softmax((X @ W_qk @ X^T) + M) @ X @ W_vo
        A = self.softmax((self.W_qk(x)) @ x.transpose(0, -1) + mask) @ self.W_vo(x)
        return A

In [27]:
class TransformerBlock(torch.nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.attention_head = AttentionHead(config)
        self.mlp = MLP(config)

    def forward(self, x: Float[torch.Tensor, "seq_len d_model"]) -> Float[torch.Tensor, "seq_len d_model"]:
        return x + self.attention_head(x) + self.mlp(x)

In [28]:
class Transformer(torch.nn.Module):
    def __init__(self, num_blocks: int, config: Config):
        super().__init__()
        self.config = config
        self.embedding = nn.Linear(config.d_vocab, config.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(num_blocks)])
        

    def forward(self, x: Float[torch.Tensor, "seq_len vocab"]) -> Float[torch.Tensor, "vocab seq_len"]:
        x = self.embedding(x)
        for block in self.blocks:
            x = block.forward(x)
        x = (x @ self.embedding.weight).T
        return x
    

# Tests

In [29]:
# Attention head test
x: Float[torch.Tensor, "seq_len d_model"] = torch.ones(5, 16)
config = Config(d_model=16, d_vocab=1000, d_hidden=64)
attention_head: AttentionHead = AttentionHead(config)
output: Float[torch.Tensor, "seq_len d_model"] = attention_head.forward(x)
print(output.shape)

torch.Size([5, 16])


In [30]:
# Test the whole thing
config = Config(d_model=16, d_vocab=1000, d_hidden=64)
transformer = Transformer(num_blocks=2, config=config)
x = torch.ones(config.d_vocab, dtype=torch.float)
y: Float[torch.Tensor, "vocab seq_len"] = transformer(x)
print(y.shape)
print(y)
print(x)

torch.Size([1000])
tensor([ 1.3376e-02,  9.5112e-02,  1.6399e-02, -8.2028e-02, -8.6612e-03,
        -1.9070e-02,  1.9305e-02, -6.0633e-03,  3.0350e-02,  1.6588e-02,
         7.5139e-02,  6.9059e-03,  1.8125e-02,  1.2030e-02, -2.3526e-02,
         3.0621e-02, -3.6794e-02,  7.1347e-02,  5.2588e-02,  1.6709e-03,
         7.0914e-02,  2.3044e-02,  1.1746e-02,  2.3141e-02, -1.6804e-02,
         1.3218e-03,  5.3172e-02, -5.3854e-02,  5.0984e-02, -2.0084e-02,
        -5.3118e-02, -1.1259e-02, -2.6655e-02,  6.2946e-02,  2.0495e-02,
        -3.1951e-02, -3.3592e-02,  5.1752e-02,  8.0289e-03,  5.1782e-02,
         2.5465e-02,  8.5261e-03, -1.4020e-02,  9.6227e-04,  4.4673e-03,
         2.7667e-02,  1.7083e-02,  4.4350e-02,  5.1626e-03,  3.6788e-02,
         7.3341e-03,  1.4450e-02,  1.1233e-02, -1.2624e-02, -2.0046e-02,
         3.4609e-02, -2.3734e-02,  4.8267e-02,  2.2080e-02, -7.2080e-03,
         4.7682e-02,  5.9896e-02, -3.1297e-02,  1.5619e-02, -2.6683e-02,
         2.4478e-02, -3.7535e-02

# Training Loop

In [45]:
def train_model(
    model: Transformer,
    loss: torch.nn.CrossEntropyLoss = nn.CrossEntropyLoss(),
    lr: Float = 1e-3,
    epochs: Int = 1
    ):
    optimizer: torch.optim.SGD = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    string_data = get_many_books(ids=range(10,15), data_temp="./data/gutenberg_data")
    labels = {}
    int_labels = []
    for i in range(len(string_data)):
        labels[i] = string_data[i]
        int_labels.append(i)
    
    training_data = torch.tensor(int_labels)
    
    for epoch in range(epochs):
        
        optimizer.zero_grad()
        outputs = model(training_data)
    



In [46]:
config: Config = Config(d_model=16, d_vocab=1000, d_hidden=64)
model = Transformer(num_blocks=2, config=config)
train_model(model)

Getting book 10...
	4432261 characters read
Getting book 11...
	148062 characters read
Getting book 12...
	168390 characters read
Getting book 13...
	34579 characters read
Getting book 14...
	1951150 characters read


RuntimeError: mat1 and mat2 must have the same dtype, but got Long and Float