In [40]:
import torch
import torch.nn as nn

In [41]:
# Setting hyperparameters
if torch.backends.mps.is_available():
    device = 'mps'
elif torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
train_split = 0.9
test_split = 1 - train_split
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?

In [42]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for i,s in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [43]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(train_split*len(data))
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [126]:
class PositionalEncoding(nn.Module):
    """
    Positional Encoding 
    PE(pos,2i) =sin(pos/10000^(2i/dmodel))
    PE(pos,2i+1) =cos(pos/10000^(2i/dmodel))
    """
    def __init__(self):
        super().__init__()

    def get_angles(self, pos, i, d_model):
        """
        pos: (seq_length, 1)
        i: (1, d_model)
        d_model: int (dimension of embedding)

        return: (seq_length, d_model)
        """
        power = 2*(i//2)/ torch.tensor(d_model, dtype=torch.float32)
        return pos / (torch.pow(10000, power))

    def forward(self, inputs):
        """
        inputs: (batch_size, seq_length, d_model)
        """
        assert len(inputs.shape) == 3
        seq_length = inputs.shape[-2]
        d_model = inputs.shape[-1]
        angles = self.get_angles(
            torch.arange(seq_length).unsqueeze(1),
            torch.arange(d_model).unsqueeze(0),
            d_model
        )
        
        pe = torch.zeros(seq_length, d_model)
        pe[:, 0::2] = torch.sin(angles[:, 0::2])
        pe[:, 1::2] = torch.cos(angles[:, 1::2])
        pe.unsqueeze(0)
        return inputs + pe
        

In [128]:
"""
Positional encoding test
"""
test_pe_input = torch.tensor([
    [[1,2,3], [2,3,4]], 
    [[3,4,5], [4,5,6]]
]) #batch_size = 2, seq_length = 2, d_model = 3
pos1_i0 = torch.sin(torch.tensor(1/math.pow(10000,0)))
pos1_i1 = torch.cos(torch.tensor(1/math.pow(10000,0)))
pos1_i2 = torch.sin(torch.tensor(1/math.pow(10000,2/float(3))))
expected_pe = torch.tensor([
    [[0, 1, 0], [pos1_i0, pos1_i1, pos1_i2]], 
    [[0, 1, 0], [pos1_i0, pos1_i1, pos1_i2]]
]) + test_pe_input
assert (expected_pe == PositionalEncoding().forward(test_pe_input)).all()

In [296]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, input_shape):
        super().__init__()
        self.n_heads = n_heads
        self.d_model = input_shape[-1]
        self.d_head = self.d_model // self.n_heads
        self.query_lin = nn.Linear(in_features = self.d_model, out_features = self.d_model)
        self.key_lin = nn.Linear(in_features = self.d_model, out_features = self.d_model)
        self.value_lin = nn.Linear(in_features = self.d_model, out_features = self.d_model)
        self.final_lin = nn.Linear(in_features = self.d_model, out_features = self.d_model)
        
        
    def scaled_dot_product_attention(self, query, key, value, mask):
        """
        softmax((QK.T)/sqrt(dk))V
        
        query: (batch_size, num_heads, seq_length, d_k)
        key: (batch_size, num_heads, seq_length, d_k)
        value: (batch_size, num_heads, seq_length, d_v)
        mask: (batch_size, 1, 1, seq_length)
        return: (batch_size, num_heads, seq_length, d_v)
        """
        assert len(query.shape) == len(key.shape) and len(query.shape) == len(value.shape)
        assert key.dtype == torch.float
        
        product = query @ (key.transpose(-1,-2))
        
        dk = torch.tensor(key.shape[-1], dtype = torch.float32)
        sqrt_dk = torch.sqrt(dk)
        scaled_product = product/sqrt_dk
        
        if mask is not None:
            scaled_product += mask * -1e9

        softmax = torch.softmax(scaled_product, dim = -1)
        attention = softmax @ value
        return attention

    def split_to_heads(self, inputs, batch_size):
        """
        input: (batch_size, seq_length, d_model)
        return: (batch_size, n_proj, seq_length, d_model//n_heads)
        """
        proj_inputs = inputs.view(batch_size, -1, self.n_heads, self.d_head)
        return proj_inputs.transpose(1, 2)

    def concat_from_heads(self, inputs, batch_size):
        """
        input: (batch_size, n_proj, seq_length, d_model//n_heads)
        return: (batch_size, seq_length, d_model)
        """
        return inputs.transpose(2,1).view(batch_size, -1, self.d_model)

    def multi_head_attention(self, query, key, value, mask):
        """
        query: (batch_size, seq_length, d_model)
        key: (batch_size, seq_length, d_model)
        value: (batch_size, seq_length, d_model)
        mask: (batch_size, 1, 1, seq_length)
        
        return: (batch_size, seq_length, d_model)
        """
        batch_size = query.shape[0]
        queries = self.query_lin(query)
        keys = self.key_lin(key)
        values = self.value_lin(value)

        queries = self.spli_to_heads(queries, batch_size)
        keys = self.spli_to_heads(keys, batch_size)
        values = self.spli_to_heads(values, batch_size)

        attention = self.scaled_dot_product_attention(queries, keys, values, mask)

        attention = self.concat_from_heads(attention, batch_size)
        outputs = self.final_lin(attention)
        return outputs

In [297]:
"""
Scaled Dot Product Attention
"""
test_scaled_dot_attention = torch.tensor([
    [[1,2,3], [2,3,4]], 
    [[3,4,5], [4,5,6]]
], dtype = torch.float32) #batch_size:2, seq_length: 2, d_model: 3
test_mask = torch.tensor([
    [[0, 0]],
    [[0, 1]]
])# (batch_size, 1, seq_length)
test_product = torch.tensor([[[14., 20.],
         [20., 29.]],
        [[50., 62.],
         [62., 77.]]], dtype = torch.float32)
test_scaled_product = test_product/math.sqrt(3)
test_scaled_product[1, :, 1] += -1e9 #applying mask
expected_attention = torch.softmax(test_scaled_product, dim = -1) @ test_scaled_dot_attention
assert (expected_attention == MultiHeadAttention(5, test_scaled_dot_attention.shape).scaled_dot_product_attention(test_scaled_dot_attention, test_scaled_dot_attention, test_scaled_dot_attention, test_mask)).all()

85