In [None]:
%%capture
!pip install torchinfo

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

In [None]:
class MemoryModule(nn.Module):
    def __init__(self, input_dim, max_len):
        super().__init__()
        self.memory_net = nn.Sequential(
            nn.Linear(input_dim, input_dim * 2),
            nn.SiLU(),
            nn.Linear(input_dim * 2, input_dim),
            nn.SiLU()
        )

        self.key_layer = nn.Linear(input_dim, input_dim)
        self.value_layer = nn.Linear(input_dim, input_dim)
        self.query_layer = nn.Linear(input_dim, input_dim)
        self.gate_net = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.SiLU(),
            Flatten(),
            nn.Linear((2 * max_len) * input_dim // 2, 3),
            nn.Sigmoid()
        )

        self.theta = 0
        self.alpha = 0
        self.eta = 0
        self.loss = 0

        self.register_full_backward_hook(self.update_memory())

    def forward(self, x):
        v = self.value_layer(x)
        k = self.key_layer(x)
        q = self.query_layer(x)

        gates = self.gate_net(x) # (B, 3)
        gates = gates.mean(dim=1) # (3)
        self.theta, self.alpha, self.eta = gates[0], gates[1], gates[2]

        self.loss = F.mse_loss(self.memory_net(k), v)
        self.scale_gradients(self.eta)

        return x

    def update_memory(self):
        self.scale_gradients(self.theta)

        with torch.no_grad():
            for param in self.memory_net.parameters():
                param.data = param.data * (1 - self.alpha) + param.grad.data

    def scale_gradients(self, scale):
        for param in self.parameters():
            if param.grad is not None:
                param.grad.data *= scale

In [None]:
class Flatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.flatten(start_dim=1)


class PositionalEncoding(nn.Module):
    def __init__(self, input_dim, max_len=5000):
        super().__init__()
        self.encoding = torch.zeros(max_len, input_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, input_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / input_dim))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)  # Add batch dimension

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.encoding[:, :seq_len, :].to(x.device)


class AttentionHead(nn.Module):
    def __init__(self, input_dim, head_size, masked):
        super().__init__()
        self.q = nn.Linear(input_dim, head_size)
        self.k = nn.Linear(input_dim, head_size)
        self.v = nn.Linear(input_dim, head_size)
        self.masked = masked

    def forward(self, q_input, kv_input, mask=None):
        q = self.q(q_input)
        k = self.k(kv_input)
        v = self.v(kv_input)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (q.shape[-1] ** 0.5)  # Scaled dot-product

        if self.masked:
            mask = mask or self.generate_mask(q_input.size(1))  # Default mask for causal attention
            attn_scores = attn_scores.masked_fill(~mask, float('-inf'))

        attn_weights = F.softmax(attn_scores, dim=-1)
        return torch.matmul(attn_weights, v)

    @staticmethod
    def generate_mask(seq_len):
        return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu"))


class MultiHeadedAttention(nn.Module):
    def __init__(self, input_dim, num_heads, masked, dropout):
        super().__init__()
        self.head_size = input_dim // num_heads
        assert input_dim % num_heads == 0, "Input dimension must be divisible by the number of heads."

        self.heads = nn.ModuleList([
            AttentionHead(input_dim=input_dim, head_size=self.head_size, masked=masked) for _ in range(num_heads)
        ])
        self.output_linear = nn.Linear(input_dim, input_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q_input, kv_input, mask=None):
        head_outputs = torch.cat([
            head(q_input, kv_input, mask) for head in self.heads
        ], dim=-1)
        return self.dropout(self.output_linear(head_outputs))


class FeedForwardLayer(nn.Module):
    def __init__(self, input_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 4 * input_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(4 * input_dim, input_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class AddNorm(nn.Module):
    def __init__(self, input_dim, sublayer):
        super().__init__()
        self.norm = nn.LayerNorm(input_dim)
        self.sublayer = sublayer

    def forward(self, x, **kwargs):
        return self.norm(x + self.sublayer(x, **kwargs))


class TransformerLayer(nn.Module):
    def __init__(self, input_dim, num_heads, dropout, masked):
        super().__init__()
        self.self_attention = AddNorm(
            input_dim,
            MultiHeadedAttention(input_dim=input_dim, num_heads=num_heads, masked=masked, dropout=dropout)
        )
        self.feed_forward = AddNorm(
            input_dim, FeedForwardLayer(input_dim=input_dim, dropout=dropout)
        )

    def forward(self, x, mask=None):
        x = self.self_attention(x, kv_input=x, mask=mask)
        return self.feed_forward(x)

class TitanLayer(nn.Module):
    def __init__(self, input_dim, max_len, dropout, masked):
        super().__init__()
        self.neural_memory = AddNorm(
            input_dim,
            MemoryModule(input_dim, max_len)
        )
        self.feed_forward = AddNorm(
            input_dim, FeedForwardLayer(input_dim=input_dim, dropout=dropout)
        )

    def forward(self, x, mask=None):
        x = self.neural_memory(x)
        return self.feed_forward(x)


class TransformerModel(nn.Module):
    def __init__(self, input_dim, vocab_size, num_heads, num_layers, dropout, max_len=5000):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, input_dim)
        self.positional_encoding = PositionalEncoding(input_dim, max_len)
        self.layers = nn.ModuleList([
            TransformerLayer(input_dim=input_dim, num_heads=num_heads, dropout=dropout, masked=True)
            for _ in range(num_layers)
        ])
        self.output_layer = nn.Linear(input_dim, vocab_size)

    def forward(self, x, input_ids, mask=None):
        x = self.embedding(input_ids)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        return self.output_layer(x)


class TitanModel(nn.Module):
    def __init__(self, input_dim, vocab_size, num_layers, dropout, max_len=128):
        super().__init__()
        self.max_len = max_len
        self.long_memory = nn.Parameter(torch.randn(1, self.max_len, input_dim))
        self.layers = nn.ModuleList([
            TitanLayer(input_dim=input_dim, max_len=max_len, dropout=dropout, masked=True)
            for _ in range(num_layers)
        ])

    def forward(self, x, mask=None):
        x = torch.cat([self.long_memory.repeat_interleave(x.shape[0], dim=0), x], dim=-2)
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
titanModel = TitanModel(312, 30522, 4, 0.1)
summary(titanModel)

Layer (type:depth-idx)                        Param #
TitanModel                                    39,936
├─ModuleList: 1-1                             --
│    └─TitanLayer: 2-1                        --
│    │    └─AddNorm: 3-1                      852,543
│    │    └─AddNorm: 3-2                      780,936
│    └─TitanLayer: 2-2                        --
│    │    └─AddNorm: 3-3                      852,543
│    │    └─AddNorm: 3-4                      780,936
│    └─TitanLayer: 2-3                        --
│    │    └─AddNorm: 3-5                      852,543
│    │    └─AddNorm: 3-6                      780,936
│    └─TitanLayer: 2-4                        --
│    │    └─AddNorm: 3-7                      852,543
│    │    └─AddNorm: 3-8                      780,936
Total params: 6,573,852
Trainable params: 6,573,852
Non-trainable params: 0

In [None]:
tens = torch.randn((32, 128, 312))

In [None]:
titanModel(tens)

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
from transformers import AutoTokenizer, AutoModel

text =''

tokenizer = AutoTokenizer.from_pretrained("huawei-noah/TinyBERT_General_4L_312D")
model = AutoModel.from_pretrained("huawei-noah/TinyBERT_General_4L_312D")

inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    embeddings = model.embeddings(input_ids=inputs["input_ids"])

In [None]:
class Trainer

In [None]:
len(tokenizer.vocab)

30522

In [None]:
tokenizer.decode(inputs["input_ids"][0])

'[CLS] [SEP]'