In [None]:
%pip install datasets

In [None]:
from transformers import LlamaConfig, LlamaForCausalLM
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
config = LlamaConfig(
    vocab_size=32000,
    hidden_size=128,
    intermediate_size=512,
    num_attention_heads=4,
    num_hidden_layers=4,
    max_position_embeddings=256,
)

model = LlamaForCausalLM(config)
print(model.num_parameters())

9241728


In [None]:
dataset = load_dataset("tiny_shakespeare")["train"]["text"][0]
chars = sorted(list(set(dataset)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}


class TSDataset(Dataset):
    def __init__(self, text, seq_len=32):
        self.data = torch.tensor([char_to_idx[c] for c in text], dtype=torch.long)
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        input_ids = self.data[idx : idx + self.seq_len]
        labels = self.data[idx + 1 : idx + self.seq_len + 1]
        return input_ids, labels


toy_dataset = TSDataset(dataset)
dataloader = DataLoader(toy_dataset, batch_size=16, shuffle=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/6.10k [00:00<?, ?B/s]

tiny_shakespeare.py:   0%|          | 0.00/3.73k [00:00<?, ?B/s]

The repository for tiny_shakespeare contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/tiny_shakespeare.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/1.12M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1 [00:00<?, ? examples/s]

In [None]:
for input_ids, labels in dataloader:
    output = model(input_ids, labels=labels)
    print(f"Loss: {output['loss'].item()}")
    break

Loss: 10.396821975708008


In [None]:
class ResLinear(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, ):
        super(ResLinear, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, x: torch.Tensor):
        return F.silu(x + self.linear(x))

class MemoryAsLayer(nn.Module):
    def __init__(
        self,
        layer_sizes: int,
        persistent_memory_dim:int,
    ) -> None:
        super(MemoryAsLayer, self).__init__()
        self.persistent_memory_dim = persistent_memory_dim
        self.layer_sizes = torch.Tensor(layer_sizes)
        self.layer_sizes += persistent_memory_dim

        self.lmm = nn.Sequential(
            *[
                ResLinear(in_dim, out_dim)
                for in_dim, out_dim in list(
                    zip(self.layer_sizes[:-1], self.layer_sizes[1:])
                )
            ]
        )

    def forward(self, x: torch.Tensor, persistent_memory: torch.Tensor) -> torch.Tensor:
        x = x.view(-1)
        persistent_memory = persistent_memory.view(-1)
        x_with_memory = torch.concat((persistent_memory, x))
        return self.lmm(x_with_memory)

In [None]:
target_layer = model.model.layers[-2] # second-to-last layer
lmm_layer_sizes = torch.tile(torch.Tensor([config.hidden_size]), dims=(3,)).int()
lmm = MemoryAsLayer(lmm_layer_sizes, persistent_memory_dim=16)

In [None]:
x = torch.rand((config.hidden_size,))
persistent_memory = torch.rand((lmm.persistent_memory_dim,))
lmm(x, persistent_memory)

tensor([ 7.2560e-01,  1.9435e-01,  2.3824e-01, -2.0753e-01, -3.9599e-04,
         2.4625e-01,  2.1237e-01,  2.8203e-01,  1.0609e-01,  3.8266e-01,
        -1.5553e-01, -2.9027e-02,  4.5945e-01,  3.3439e-01,  6.1954e-01,
         5.6758e-01,  2.7851e-01,  2.6466e-01,  3.5024e-02,  1.0212e-01,
        -2.6987e-02,  8.4172e-02, -1.1277e-02,  1.0216e-01,  3.1037e-01,
         3.9719e-02,  1.3532e+00,  6.4529e-01, -2.8735e-02, -1.5486e-01,
         2.2567e-01,  1.6975e-01,  3.4338e-02, -8.1769e-03, -6.3081e-02,
         5.7953e-01, -5.6104e-02,  1.2886e-01,  2.6639e-01,  2.4766e-01,
         1.1953e+00,  8.2119e-01,  7.4811e-01,  2.5095e-02,  5.6811e-01,
         5.2204e-01,  2.3884e-01, -1.0107e-01,  4.6331e-01,  7.1855e-02,
         6.2105e-01,  3.0060e-01,  2.1166e-02,  6.7429e-02,  3.6395e-01,
         8.6550e-01,  7.5062e-01,  1.2350e+00,  2.5634e-02,  1.4467e-01,
         7.1290e-01,  2.5460e-01,  4.2784e-02,  9.4088e-03,  1.4937e-01,
        -1.2231e-01,  2.4447e-01,  1.9967e-01,  5.9