In [2]:
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn as nn
from jaxtyping import Float, Int

# Training Data

In [3]:
def get_gutenberg_book(
	id: int|None = 84,
	data_temp: Path|str = "../../../../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path|str = "../data/gutenberg_data",
	) -> list[str]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

# Model Definition

In [11]:
@dataclass
class Config():
    d_model: int
    d_vocab: int
    d_hidden: int

In [12]:
class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.linear1 = nn.Linear(config.d_model, config.d_hidden)
        self.linear2 = nn.Linear(config.d_hidden, config.d_model)
        

    def forward(self, x: Float[torch.Tensor, "batch seq_len d_model"]) -> Float[torch.Tensor, "batch seq_len d_model"]:
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

In [13]:
class AttentionHead(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.W_qk = nn.Linear(config.d_model, config.d_model)
        self.W_vo = nn.Linear(config.d_model, config.d_model)
        self.softmax = nn.Softmax(dim=-1)
        

    def create_mask(self, n_c: int) -> torch.Tensor:
        mask: Float[torch.Tensor, "seq_len seq_len"] = torch.triu(-1 * torch.inf * torch.ones(n_c, n_c), diagonal=1)
        return mask

    def forward(self, x: Float[torch.Tensor, "batch seq_len d_model"]) -> Float[torch.Tensor, "batch seq_len d_model"]:
        #create mask, with size n_c x n_c
        mask = self.create_mask(x.shape[1])

        #compute attention scores
        # A = softmax((X @ W_qk @ X^T) + M) @ X @ W_vo
        A = self.softmax((self.W_qk(x)) @ x.transpose(-1, -2) + mask) @ self.W_vo(x)
        return A

torch.Size([1, 5, 16])


In [15]:
class TransformerBlock(torch.nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.attention_head = AttentionHead(config)
        self.mlp = MLP(config)

    def forward(self, x: Float[torch.Tensor, "batch seq_len d_model"]) -> Float[torch.Tensor, "batch seq_len d_model"]:
        return x + self.attention_head(x) + self.mlp(x)

In [16]:
class Transformer(torch.nn.Module):
    def __init__(self, num_blocks: int, config: Config):
        super().__init__()
        self.config = config
        self.embedding = nn.Linear(config.d_vocab, config.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(num_blocks)])
        

    def forward(self, x: Float[torch.Tensor, "batch seq_len vocab"]) -> Float[torch.Tensor, "batch seq_len vocab"]:
        x = self.embedding(x)
        for block in self.blocks:
            x = block.forward(x)
        x = (x @ self.embedding.weight).transpose(1, 2)
        return x
    

# Tests

In [None]:
# Attention head test
x: Float[torch.Tensor, "batch seq_len d_model"] = torch.ones(1, 5, 16)
config = Config(d_model=16, d_vocab=1000, d_hidden=64)
attention_head: AttentionHead = AttentionHead(config)
output: Float[torch.Tensor, "batch seq_len d_model"] = attention_head.forward(x)
print(output.shape)

In [17]:
# Test the whole thing
config = Config(d_model=16, d_vocab=1000, d_hidden=64)
transformer = Transformer(num_blocks=2, config=config)
x = torch.ones(1, 5, config.d_vocab, dtype=torch.float)
y: Float[torch.Tensor, "batch seq_len d_model"] = transformer(x)
print(y.shape)
print(y)
print(x)

torch.Size([1, 1000, 5])
tensor([[[-0.0002, -0.0002, -0.0002, -0.0002, -0.0002],
         [ 0.0511,  0.0511,  0.0511,  0.0511,  0.0511],
         [-0.0502, -0.0502, -0.0502, -0.0502, -0.0502],
         ...,
         [ 0.0203,  0.0203,  0.0203,  0.0203,  0.0203],
         [-0.1136, -0.1136, -0.1136, -0.1136, -0.1136],
         [ 0.0103,  0.0103,  0.0103,  0.0103,  0.0103]]],
       grad_fn=<TransposeBackward0>)
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])


# Training Loop