In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# d_in: input dimension
# d_out: output dimension
# context_size: number of context words
# embedding_dim: dimension of the word embeddings
# n_heads: number of attention heads
# n_layers: number of transformer layers
# d_heads: dimension of each attention head
# d_ff: dimension of the feedforward network

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

def encode(s):
    return [stoi[c] for c in s]

def decode(indices):
    return ''.join([itos[i] for i in indices])

In [4]:
data = torch.tensor(encode(text), dtype=torch.float)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [5]:
class multihead(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads , dropout=0.0):
        super().__init__()

        assert d_out % n_heads == 0, "d_in must be divisible by n_heads"

        self.d_in = d_in
        self.d_out = d_out
        self.n_heads = n_heads
        self.context_size = context_size
        self.d_head = d_out // n_heads
        self.dropout = dropout

        self.q = nn.Linear(d_in, d_out, bias=False)
        self.k = nn.Linear(d_in, d_out, bias=False)
        self.v = nn.Linear(d_in, d_out, bias=False)

        self.out = nn.Linear(d_out, d_out, bias=False)

    def forward(self, x):
        b, num_tokens, d_in = x.size()

        q = self.q(x).view(b, num_tokens, self.n_heads, self.d_head)
        k = self.k(x).view(b, num_tokens, self.n_heads, self.d_head)
        v = self.v(x).view(b, num_tokens, self.n_heads, self.d_head)

        q = q.permute(0, 2, 1, 3)  # (b, n_heads, num_tokens, d_head)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        att = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=self.dropout,
            is_causal=False,
        )

        att = att.permute(0, 2, 1, 3).contiguous()
        att = att.view(b, num_tokens, self.d_out)

        return self.out(att)


In [6]:
class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
        nn.Linear(n_embed, 4*n_embed),
        nn.ReLU(),
        nn.Linear(4*n_embed, n_embed),
        nn.Dropout(dropout)
        )
        # self.layer_norm = nn.LayerNorm(d_out)

    def forward(self, x):
        return self.net(x)

In [7]:
class Block(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads, dropout=0.0):
        super().__init__()
        self.att = multihead(d_in, d_out, context_size, n_heads, dropout)
        self.ff = FeedForward(d_out, dropout)
        self.layer_norm1 = nn.LayerNorm(d_out)
        self.layer_norm2 = nn.LayerNorm(d_out)

    def forward(self, x):
        x = x + self.att(self.layer_norm1(x))
        x = x + self.ff(self.layer_norm2(x))
        return x

In [8]:
## model formation
class model(nn.Module):
    def __init__(self, d_in, d_out, context_size, n_heads, n_layers, dropout=0.0):
        super().__init__()
        self.n_layers = n_layers
        self.context_size = context_size
        self.d_in = d_in
        self.d_out = d_out
        self.n_heads = n_heads

        self.embedding = nn.Embedding(d_in, d_out)
        self.pos_embedding = nn.Embedding(context_size, d_out)


        self.blocks = nn.Sequential(*[
            Block(d_in, d_out, context_size, n_heads, dropout)
            for _ in range(n_layers)
        ])
        self.layer_norm = nn.LayerNorm(d_out)
        self.fc = nn.Linear(d_out, d_in)


    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding(torch.arange(self.context_size).to(x.device))
        x = self.blocks(x)
        x = self.layer_norm(x)
        x = self.fc(x)
        return x

In [9]:
def apply_masking(x_true, t):
    mask_ids = torch.inf
    min_val = 0.001
    max_val = 1.001
    for i in range(x_true.size(0)):
        rand_int = (max_val - min_val) * torch.rand(1) + min_val
        print(rand_int)
        if rand_int <= t:
            x_true[i] = mask_ids
        else:
            x_true[i] = x_true[i]
    return x_true

In [10]:
x_0 = train_data[:5]
print(x_0)
apply_masking(x_0, 0.4)

tensor([ 5., 13., 13.,  1.,  0.])
tensor([0.7954])
tensor([0.7376])
tensor([0.1871])
tensor([0.6954])
tensor([0.9860])


tensor([ 5., 13., inf,  1.,  0.])