In [5]:
from dataclasses import dataclass
from pathlib import Path
import torch
import torch.nn as nn
from jaxtyping import Float, Int

In [6]:
def get_gutenberg_book(
	id: int|None = 84,
	data_temp: Path|str = "../../../../data/gutenberg_data",
	remove_gutenberg_meta: bool = True,
) -> str:
	
	data_temp = Path(data_temp)
	data_temp.mkdir(parents=True, exist_ok=True)
	
	url: str = f"https://www.gutenberg.org/cache/epub/{id}/pg{id}.txt"
	data_path: Path = Path(data_temp) / f"{id}.txt"
	data: str
	# read from cache if it exists
	if data_path.exists():
		with open(data_path, 'r', encoding='utf-8') as file:
			data = file.read()
	else:
		# download if it doesn't exist
		response = requests.get(url)
		response.raise_for_status()  # Ensure that the download was successful
		data = response.text

		# save to cache
		with open(data_path, 'w', encoding='utf-8') as file:
			file.write(data)

	# remove header/footer
	if remove_gutenberg_meta:
		data = '***'.join(data.split('***')[2:])
		data = '***'.join(data.split('***')[:-1])
	
	return data

def get_many_books(
		ids: list[int],
		data_temp: Path|str = "../data/gutenberg_data",
	) -> list[str]:
	
	data: list[str] = []
	for id in ids:
		print(f"Getting book {id}...")
		item: str = get_gutenberg_book(id, data_temp)
		print(f"\t{len(item)} characters read")
		data.append(item)
	
	return data

In [7]:
@dataclass
class Config():
    d_model: int
    d_vocab: int
    d_hidden: int
    d_head: int

In [8]:
class MLP(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.linear1 = nn.Linear(config.d_model, config.d_hidden)
        self.linear2 = nn.Linear(config.d_hidden, config.d_model)
        

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        self.W_qk = nn.Linear(config.d_model, config.d_head)
        self.W_vo = nn.Linear(config.d_model, config.d_head)
        self.softmax = nn.Softmax(dim=-1)
        

    def create_mask(self, n_c: int) -> torch.Tensor:
        mask = torch.triu(-1 * torch.inf * torch.ones(n_c, n_c), diagonal=0)
        return mask

    def forward(self, x):
        #create mask, with size n_c x n_c
        mask = self.create_mask(x.shape[1])

        #compute attention scores
        A = self.softmax((self.W_qk(x) @ x.transpose() + mask)) @ self.W_vo(x)

        return A

In [None]:
class TransformerBlock(torch.nn.Module):
    def __init__(self, config: Config):
        super().__init__()

In [None]:
class Transformer(torch.nn.Module):
    def __init__(self, num_blocks: int, config: Config):
        super().__init__()
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(num_blocks)])
        

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x
    