In [None]:
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn as nn
from jaxtyping import Float, Int

# Training Data

In [31]:
def get_gutenberg_book(
	id: int|None = 84,
	data_temp: Path|str = "../../../../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path|str = "../data/gutenberg_data",
	) -> list[str]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

# Model Definition

In [32]:
@dataclass
class Config():
    d_model: int
    d_vocab: int
    d_hidden: int

In [33]:
class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.linear1 = nn.Linear(config.d_model, config.d_hidden)
        self.linear2 = nn.Linear(config.d_hidden, config.d_model)
        

    def forward(self, x: Float[torch.Tensor, "seq_len d_model"]) -> Float[torch.Tensor, "seq_len d_model"]:
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

In [48]:
class AttentionHead(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.W_qk = nn.Linear(config.d_model, config.d_model)
        self.W_vo = nn.Linear(config.d_model, config.d_model)
        self.softmax = nn.Softmax(dim=-1)
        

    def create_mask(self, n_c: int) -> torch.Tensor:
        mask: Float[torch.Tensor, "seq_len seq_len"] = torch.triu(-1 * torch.inf * torch.ones(n_c, n_c), diagonal=1)
        return mask

    def forward(self, x: Float[torch.Tensor, "seq_len d_model"]) -> Float[torch.Tensor, "seq_len d_model"]:
        #create mask, with size n_c x n_c
        mask = self.create_mask(x.shape[0])

        #compute attention scores
        # A = softmax((X @ W_qk @ X^T) + M) @ X @ W_vo
        A = self.softmax((self.W_qk(x)) @ x.transpose(0, -1) + mask) @ self.W_vo(x)
        return A

In [49]:
class TransformerBlock(torch.nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.attention_head = AttentionHead(config)
        self.mlp = MLP(config)

    def forward(self, x: Float[torch.Tensor, "seq_len d_model"]) -> Float[torch.Tensor, "seq_len d_model"]:
        return x + self.attention_head(x) + self.mlp(x)

In [50]:
class Transformer(torch.nn.Module):
    def __init__(self, num_blocks: int, config: Config):
        super().__init__()
        self.config = config
        self.embedding = nn.Linear(config.d_vocab, config.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(num_blocks)])
        

    def forward(self, x: Float[torch.Tensor, "seq_len vocab"]) -> Float[torch.Tensor, "vocab seq_len"]:
        x = self.embedding(x)
        for block in self.blocks:
            x = block.forward(x)
        x = (x @ self.embedding.weight).T
        return x
    

# Tests

In [51]:
# Attention head test
x: Float[torch.Tensor, "seq_len d_model"] = torch.ones(5, 16)
config = Config(d_model=16, d_vocab=1000, d_hidden=64)
attention_head: AttentionHead = AttentionHead(config)
output: Float[torch.Tensor, "seq_len d_model"] = attention_head.forward(x)
print(output.shape)

torch.Size([5, 16])


In [52]:
# Test the whole thing
config = Config(d_model=16, d_vocab=1000, d_hidden=64)
transformer = Transformer(num_blocks=2, config=config)
x = torch.ones(config.d_vocab, dtype=torch.float)
y: Float[torch.Tensor, "vocab seq_len"] = transformer(x)
print(y.shape)
print(y)
print(x)

torch.Size([1000])
tensor([ 6.3441e-03,  8.3942e-02, -3.9893e-02, -6.3548e-03,  3.3729e-02,
        -2.8102e-02,  1.8724e-02,  3.9608e-03, -3.7850e-03, -3.1101e-02,
         8.5074e-03,  6.8095e-02,  7.8028e-02,  5.5352e-02, -5.4220e-02,
        -6.8182e-03,  8.6136e-02,  2.7536e-02,  1.8153e-02,  9.5712e-03,
         2.0403e-02,  8.1410e-02,  5.7799e-02, -4.7979e-02, -8.2013e-02,
        -1.7693e-02,  2.8380e-02, -1.7299e-02,  4.1108e-02,  2.5330e-02,
         3.7915e-02, -1.5364e-02,  6.5139e-02, -1.2433e-02,  3.1411e-02,
         3.7416e-02, -3.6107e-02,  1.8123e-02, -6.5640e-02,  4.5044e-02,
        -1.0404e-02, -2.4668e-02,  7.0007e-02, -4.4044e-02,  6.9703e-02,
         1.7621e-02, -2.2256e-02,  7.2073e-03, -2.0943e-02, -4.6131e-02,
         4.8179e-02,  4.9247e-04, -5.3634e-02, -2.0521e-03,  5.7496e-02,
         1.0118e-03, -1.6561e-02,  1.0509e-02,  4.7976e-03,  7.0115e-02,
         9.9458e-03,  2.7358e-02,  5.3892e-02,  3.8064e-02,  9.3865e-02,
        -3.7413e-02, -4.3165e-02

  x = (x @ self.embedding.weight).T


# Training Loop

In [None]:
loss: torch.nn.MSELoss = nn.MSELoss()
config: Config = Config(d_model=16, d_vocab=1000, d_hidden=64)
model: Transformer = Transformer(num_blocks=2, config=config)
lr: Float = 1e-3
optimizer: torch.optim.Adam = torch.optim.Adam(model.parameters(), lr=lr)