In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class TransformerBlockLM(nn.Module):
    class TransformerBlock(nn.Module):
        def __init__(self, head_count, in_size, out_size):
            super().__init__()
            self.comm = TransformerBlockLM.MultiHeadAttention(head_count=head_count,
                                                              in_size=in_size,
                                                              out_size=out_size)
            self.think = TransformerBlockLM.MLP(embed_size=out_size)

        def forward(self, x):
            # Residual connections are implemented around both the attention and the MLP modules
            x = x + self.comm(x)
            x = x + self.think(x)
            return x

    class MLP(nn.Module):
        def __init__(self, embed_size):
            super().__init__()
            self.layerNorm = nn.LayerNorm(embed_size)
            self.mlp = nn.Sequential(nn.Linear(embed_size, embed_size * 4),
                                     nn.ReLU(),
                                     nn.Linear(embed_size * 4, embed_size))
        
        def forward(self, x):
            # Apply layer normalization before the MLP (change implemented here)
            x = self.layerNorm(x)
            return self.mlp(x)

In [None]:
    class MultiHeadAttention(nn.Module):
        def __init__(self, head_count, in_size, out_size):
            super().__init__()
            self.heads = nn.ModuleList([
                TransformerBlockLM.SelfAttentionHead(in_size, out_size // head_count) for _ in range(head_count)
            ])
            self.layerNorm = nn.LayerNorm(out_size)

        def forward(self, x):
            # Apply layer normalization before the multi-head attention computation (consistent with some implementations)
            x = self.layerNorm(x)
            return torch.cat([head(x) for head in self.heads], dim=-1)

In [None]:
    class SelfAttentionHead(nn.Module):
        def __init__(self, in_size, out_size):
            super().__init__()
            self.K = nn.Linear(in_size, out_size, bias=False)
            self.Q = nn.Linear(in_size, out_size, bias=False)
            self.V = nn.Linear(in_size, out_size, bias=False)

        def forward(self, x, return_attention_details=False):
            keys = self.K(x)
            queries = self.Q(x)
            values = self.V(x)
            attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) * (1.0 / (out_size ** 0.5))
            attention_scores = F.softmax(attention_scores, dim=-1)
            context = torch.matmul(attention_scores, values)

            if return_attention_details:
                return context, (queries, keys, values, attention_scores)
            return context

            def visualize_attention(queries, keys, values, attention_scores):
            # Assuming attention_scores shape is (batch_size, num_queries, num_keys)
            attention = attention_scores[0].cpu().detach().numpy()  # Taking the first item in the batch

            plt.figure(figsize=(10, 8))
            plt.matshow(attention, cmap='viridis')
            plt.xlabel('Keys')
            plt.ylabel('Queries')
            plt.title('Attention Scores')
            plt.colorbar()
            plt.show()

In [None]:
     def __init__(self, batch_size=4, input_length=8, embed_size=16, sa_head_size=8, sa_multihead_count=4, pos_embed=True):
            super().__init__()
            self.blocks = None
            self.ffn = None
            self.sa_heads = None
            # sa_head_size head_size of self-attention module
            self.sa_head_size = sa_head_size
            self.sa_multihead_count = sa_multihead_count

            self.val_data = None
            self.train_data = None
            self.val_text = None
            self.train_text = None
            self.K = None
            self.linear_sahead_to_vocab = None
            self.vocab = None
            self.token_embeddings_table = None
            self.vocab_size = None
            self.encoder = None
            self.decoder = None
            self.vocab_size: int
            self.is_pos_emb = pos_embed
            self.include_mlp = include_mlp
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
            # input_length = how many consecutive tokens/chars in one input
            self.input_length = input_length
            # batch_size = how many inputs are going to be processed in-parallel (on GPU)
            self.batch_size = batch_size
            # embed_size = embedding size
            self.embed_size = embed_size

            self.lm_head = None
            self.position_embeddings_table = None

        def forward(self, in_ids, target=None):
            in_ids_emb = self.token_embeddings_table(in_ids[:, -self.input_length:])
            if self.position_embeddings_table:
                # Positional encoding added here
                positions = torch.arange(in_ids[:, -self.input_length:].shape[1], device=self.device).expand(len(in_ids), -1)
                pos_emb = self.position_embeddings_table(positions)
                in_ids_emb += pos_emb

            for block in self.blocks:
                in_ids_emb = block(in_ids_emb)

            logits = self.decoder(in_ids_emb)  # Output passed through the decoder

            if target is None:
                return logits, None

            batch_size, input_length, vocab_size = logits.shape
            logits_flattened = logits.view(batch_size * input_length, vocab_size)
            targets_flattened = target.view(-1)
            ce_loss = F.cross_entropy(logits_flattened, targets_flattened)
            return logits, ce_loss

        def fit(self, data, input_length, batch_size, train_iters, eval_iters, lr):
            self.train()
            optimizer = torch.optim.Adam(self.parameters(), lr=lr)
            for i in range(train_iters):
                inputs, targets = self.get_batch(data, input_length, batch_size)
                optimizer.zero_grad()
                logits, loss = self(inputs, targets)
                loss.backward()
                optimizer.step()
                if i % eval_iters == 0:
                    print(f"Iteration {i}, Loss: {loss.item()}")

        def generate(self, start_text, max_length):
            self.eval()
            tokens = torch.tensor(self.encoder(start_text), dtype=torch.long).unsqueeze(0)
            for _ in range(max_length):
                logits, _ = self(tokens)
                next_token_logits = logits[0, -1, :]
                next_token = torch.multinomial(F.softmax(next_token_logits, dim=0), 1)
                tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=1)
            return self.decoder(tokens.squeeze().tolist())

        def prep(self, text):
            n = len(text)
            self.train_text = text[:int(n * 0.9)]
            self.val_text = text[int(n * 0.9):]

            self.vocab = sorted(list(set(self.train_text)))
            self.vocab_size = len(self.vocab)
            c2i = {c: i for i, c in enumerate(self.vocab)}
            self.encoder = lambda s: [c2i.get(c, 0) for c in s]  # Encoder with handling for unknown characters

            self.train_data = torch.tensor(self.encoder(self.train_text), dtype=torch.long)
            self.val_data = torch.tensor(self.encoder(self.val_text), dtype=torch.long)

            self.token_embeddings_table = nn.Embedding(self.vocab_size, self.embed_size)

            if self.position_embeddings_table is not None:
                self.position_embeddings_table = nn.Embedding(self.input_length, self.embed_size)

            self.blocks = nn.Sequential(
                TransformerBlockLM.TransformerBlock(head_count=self.sa_multihead_count,
                                                    in_size=self.embed_size,
                                                    out_size=self.sa_head_size),
                TransformerBlockLM.TransformerBlock(head_count=self.sa_multihead_count,
                                                    in_size=self.embed_size,
                                                    out_size=self.sa_head_size),
                TransformerBlockLM.TransformerBlock(head_count=self.sa_multihead_count,
                                                    in_size=self.embed_size,
                                                    out_size=self.sa_head_size),
                TransformerBlockLM.TransformerBlock(head_count=self.sa_multihead_count,
                                                    in_size=self.embed_size,
                                                    out_size=self.sa_head_size),
                TransformerBlockLM.TransformerBlock(head_count=self.sa_multihead_count,
                                                    in_size=self.embed_size,
                                                    out_size=self.sa_head_size),
                TransformerBlockLM.TransformerBlock(head_count=self.sa_multihead_count,
                                                    in_size=self.embed_size,
                                                    out_size=self.sa_head_size),
            )

            self.linear_sahead_to_vocab = nn.Linear(self.sa_head_size, self.vocab_size)


        def eval_loss(self, data, input_length, batch_size, eval_iters):
            self.eval()
            total_loss = 0.0
            with torch.no_grad():
                for _ in range(eval_iters):
                    inputs, targets = self.get_batch(data, input_length, batch_size)
                    inputs, targets = inputs.to(self.device), targets.to(self.device)
                    _, loss = self(inputs, targets)
                    total_loss += loss.item()
            return total_loss / eval_iters


        def get_batch(self, data, input_length, batch_size):
            start_indices = torch.randint(0, len(data) - input_length - 1, (batch_size,))
            inputs = torch.stack([data[i:i+input_length] for i in start_indices])
            targets = torch.stack([data[i+1:i+input_length+1] for i in start_indices])
            return inputs, targets

In [None]:
with open('config.json', 'r') as config_file:
    config = json.load(config_file)

with open('./emily_dickinson.txt', 'r') as f:
    text = f.read()

In [None]:
model = TransformerBlockLM(
    batch_size=config['batch_size'],
    input_length=config['input_length'],
    embed_size=config['embed_size'],
    sa_head_size=config['sa_head_size'],
    sa_multihead_count=config['sa_multihead_count'],
    pos_embed=config['pos_embed']
)
model = model.to(model.device)
model.prep(text)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
print(f'params {sum([np.prod(p.size()) for p in model_parameters])}')

# Convert entire text to tensor for training
data = torch.tensor(model.encoder(text), dtype=torch.long)
input_batch, output_batch = model.get_batch(data, model.input_length, model.batch_size)
_, _ = model(input_batch.to(model.device), output_batch.to(model.device))
model.fit(data, model.input_length, model.batch_size, config['train_iters'], config['eval_iters'], config['learning_rate'])


generated_text = model.generate(text, 50)
print(generated_text)