In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader



In [2]:
with open('input.txt', 'r') as f:
    data = f.read()

In [3]:
print(data[:100])

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [4]:
data = data.lower()
print(len(data))

1115394


In [5]:
chars = list(set(data))
chars.append('M')  # Add the '[MASK]' token to your list of characters
data_size, vocab_size = len(data), len(chars)
char_to_ix = { ch:i for i,ch in enumerate(chars) }
ix_to_char = { i:ch for i,ch in enumerate(chars) }
data = torch.tensor([char_to_ix[ch] for ch in data])

In [6]:
class SimpleDataset(Dataset):
    def __init__(self, data, context_length):
        self.data = data
        self.context_length = context_length

    def __len__(self):
        return len(self.data) // self.context_length

    def __getitem__(self, idx):
        return self.data[idx*self.context_length:(idx+1)*self.context_length]

In [7]:
class SelfAttention(nn.Module):
    def __init__(self, head_size, d_k): # d_k dimension of embed vector
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(d_k, head_size)
        self.keys = nn.Linear(d_k, head_size)
        self.values = nn.Linear(d_k, head_size)

    
    def forward(self, x):
        B, C, d_k = x.shape  # B--> batch size, C --> con_length
        q = self.query(x) #(B, C, H) H--> Head size
        k = self.query(x) #(B, C, H)
        v = self.values(x) #(B, C, H)

        score = q @ k.transpose(-2, -1) * (d_k ** -0.5)  #(B, C, H)
        prob = F.softmax(score, dim=-1)
        out = prob @ v # (B, C, C) @ (B, C, H) = (B, C, H)

        return out


In [8]:
class MultiheadAttention(nn.Module):
    
    def __init__(self, head_size, d_k, n_heads):
        super(MultiheadAttention, self).__init__()
        self.heads = nn.ModuleList(SelfAttention(head_size, d_k) for _ in range(n_heads))
        self.res_fc = nn.Linear(n_heads * head_size, d_k)
    def forward(self, x):
        heads = [h(x) for h in self.heads]
        heads_concat = torch.concat(heads, dim=-1)  #(B, C, n_heads * head_size)
        out = self.res_fc(heads_concat)
        return out

In [9]:
class FC(nn.Module):

    def __init__(self, d_k):
        super(FC, self).__init__()
        self.fc = nn.Linear(d_k, 4 * d_k) # 4 is in the original paper
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(4 * d_k, d_k)

    def forward(self, x):
        out = self.fc2(self.gelu(self.fc(x)))
        return out

In [10]:
class EncoderBlock(nn.Module):

    def __init__(self, d_k, n_heads):
        super(EncoderBlock, self).__init__()
        self.encoder = MultiheadAttention(head_size=d_k // n_heads, d_k=d_k, n_heads=n_heads)
        self.layernorm_pre_encoder = nn.LayerNorm(d_k)
        self.fc = FC(d_k)
        self.layernorm_pre_fc = nn.LayerNorm(d_k)

    def forward(self, x):
        x = x + self.encoder(self.layernorm_pre_encoder(x))    # Adding residual connection
        x = x + self.fc(self.layernorm_pre_fc(x))        # Adding residual connection
        return x

        

In [11]:
class Bert(nn.Module):

    def __init__(self, 
                 vocab_size, 
                 context_length, 
                 d_k, 
                 n_heads,
                 n_layers, 
                 device):

        super(Bert, self).__init__()
        self.device = device
        self.embeddings = nn.Embedding(vocab_size, d_k)
        self.positional_embeddings = nn.Embedding(context_length, d_k)
        self.encoder = nn.Sequential(*[EncoderBlock(d_k, n_heads) for _ in range(n_layers)])
        self.output = nn.Linear(d_k, vocab_size)

    
    def forward(self, x):
        # x --> (B, C)
        B, C = x.shape
        token_embed = self.embeddings(x) #(B, C, d_k)
        pos_embed = self.positional_embeddings(torch.arange(C, device=self.device)) #(C, d_k)
        x = token_embed + pos_embed # (B, C, d_k)
        x = self.encoder(x) 
        logits = self.output(x)

        return logits



In [20]:
embed_size = 16
num_heads = 8
n_layers = 3
num_epochs = 50
learning_rate = 0.001
batch_size = 64
context_length = 100
mask_prob = 0.15
device = "mps"



In [21]:
model = Bert(vocab_size, context_length, embed_size, num_heads, n_layers, device)
model = model.to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=char_to_ix['M'])
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [22]:
dataset = SimpleDataset(data, context_length)
data_loader = DataLoader(dataset, batch_size=batch_size)

In [23]:
# Training loop
for epoch in range(num_epochs):
    for batch in data_loader:
        inputs = batch.clone().to(device)
        targets = batch.clone().to(device)

        # Mask some of the tokens
        mask = torch.rand(inputs.shape) < mask_prob
        inputs[mask] = char_to_ix['M']

        # Forward pass
        outputs = model(inputs)
        loss = loss_fn(outputs.view(-1, vocab_size), targets.view(-1))

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print (f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

  if not is_compiling() and torch.has_cuda and torch.cuda.is_available():


Epoch [1/50], Loss: 0.7160
Epoch [2/50], Loss: 0.5090
Epoch [3/50], Loss: 0.4838
Epoch [4/50], Loss: 0.4926
Epoch [5/50], Loss: 0.4669
Epoch [6/50], Loss: 0.5041
Epoch [7/50], Loss: 0.4544
Epoch [8/50], Loss: 0.4670
Epoch [9/50], Loss: 0.4216
Epoch [10/50], Loss: 0.4245
Epoch [11/50], Loss: 0.4866
Epoch [12/50], Loss: 0.4712
Epoch [13/50], Loss: 0.4501
Epoch [14/50], Loss: 0.4685
Epoch [15/50], Loss: 0.4560
Epoch [16/50], Loss: 0.4832
Epoch [17/50], Loss: 0.4749
Epoch [18/50], Loss: 0.4470
Epoch [19/50], Loss: 0.4954
Epoch [20/50], Loss: 0.4583
Epoch [21/50], Loss: 0.4564
Epoch [22/50], Loss: 0.4805
Epoch [23/50], Loss: 0.4797
Epoch [24/50], Loss: 0.4676
Epoch [25/50], Loss: 0.4274
Epoch [26/50], Loss: 0.4591
Epoch [27/50], Loss: 0.4041
Epoch [28/50], Loss: 0.4806
Epoch [29/50], Loss: 0.4949
Epoch [30/50], Loss: 0.4737
Epoch [31/50], Loss: 0.4287
Epoch [32/50], Loss: 0.5084
Epoch [33/50], Loss: 0.4275
Epoch [34/50], Loss: 0.4475
Epoch [35/50], Loss: 0.4667
Epoch [36/50], Loss: 0.4334
E

In [24]:
# Dumping the model
torch.save(model.state_dict(), 'bert_v1.pth')

In [25]:
model = Bert(vocab_size, context_length, embed_size, num_heads, n_layers, device="cpu")
model.load_state_dict(torch.load("bert_v1.pth"))

<All keys matched successfully>

In [30]:
model_cpu = model.to(device="cpu")
with torch.no_grad():
    # Create input sequence with a masked token
    input_text = "firsM"
    inputs = torch.tensor([char_to_ix[ch] for ch in input_text])
    inputs = inputs.view(1, len(inputs))
    inputs = inputs.cpu()
    # Get model outputs
    outputs = model_cpu(inputs)
    print(outputs.shape)
    
    # Get predicted tokens
    _, predicted = torch.max(outputs, dim=2)
    print(predicted)
    
    # Print predicted sequence
    print(''.join([ix_to_char[ix.item()] for ix in predicted[0]]))



torch.Size([1, 5, 40])
tensor([[33,  0,  5, 28, 35]])
first
