In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

from numpy.ma.core import indices
from torch.ao.nn.quantized import ReLU6
from torch.nn import ReLU


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()

        self.dimension_model = d_model
        self.num_heads = num_heads
        self.dimension_head = d_model // num_heads

        self.weight_query = nn.Linear(d_model, d_model)
        self.weight_key = nn.Linear(d_model, d_model)
        self.weight_value = nn.Linear(d_model, d_model)
        self.weight_output = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, query, key, value, mask = None):
        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.dimension_head)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
        attention_probs = torch.softmax(attention_scores, dim = -1)
        output = torch.matmul(attention_probs, value)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.dimension_head).transpose(1,2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1,2).contiguous().view(batch_size, seq_length, self.dimension_model)

    def forward(self, query, key, value, mask = None):
        query = self.split_heads(self.weight_query(query))
        key = self.split_heads(self.weight_key(key))
        value = self.split_heads(self.weight_value(value))

        attention_output = self.scaled_dot_product_attention(query, key, value, mask)
        output = self.weight_output(self.combine_heads(attention_output))
        return output

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
       super().__init__()

       pe = torch.zeros(max_seq_length, d_model)
       position = torch.arange(0, max_seq_length, dtype = torch.float).unsqueeze(1)
       div_term = torch.exp(torch.arange(0, d_model, 2).float()* -(math.log(10000.0)/d_model))

       pe[:, ::2] = torch.sin(position*div_term)
       pe[:, 1::2] = torch.cos(position * div_term)

       self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_out, enc_out, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask


    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)

        return output


    def generate(self, src, start_token, max_length, temperature=1.0, top_k=None):
        self.eval()
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)

        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        generated = torch.tensor([[start_token]], dtype=torch.long)

        for _ in range(max_length - 1):
            tgt_mask = (generated != 0).unsqueeze(1).unsqueeze(2)
            seq_length = generated.size(1)
            nopeak_mask = (1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()
            tgt_mask = tgt_mask & nopeak_mask

            tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(generated)))
            dec_output = tgt_embedded
            for dec_layer in self.decoder_layers:
                dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

            logits = self.fc(dec_output[:, -1, :]) / temperature
            if top_k is not None:
                top_k = min(top_k, logits.size(-1))
                values, indices = torch.topk(logits, top_k)
                logits = torch.full_like(logits, float('-inf')).scatter(-1, indices, values)

            probs = nn.functional.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            generated = torch.cat([generated, next_token], dim=1)

            if len(generated[0]) == max_length:
                break
        return generated


with open('/content/input.txt', 'r', encoding='utf-8') as file:
    text = file.read()

chars = sorted(list(set(text)))
v_size = len(chars)

src_vocab_size = v_size
tgt_vocab_size = v_size
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length =100
dropout = 0.1

string_to_int = {c: i for i, c in enumerate(chars)}
int_to_string = {i: c for i, c in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda i: ''.join(int_to_string[c.item()] for c in i)

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
context_with_given_input = torch.tensor([encode('hello')], dtype = torch.long)
src = torch.tensor([[1]])
generated = transformer.generate(
    src = context_with_given_input,
    start_token = 2,
    max_length = max_seq_length,
    temperature = 0.7,
    top_k = None
)
print(decode(generated[0]))

src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0065, betas=(0.9, 0.98), eps=1e-9)
transformer.train()

for epoch in range(15):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:,1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch : {epoch+1}, Loss : {loss.item()}")

generated = transformer.generate(
    src=src,
    start_token =2,
    max_length= max_seq_length,
    temperature = 0.7,
    top_k = None
)
print(generated)

!‘9m#.ïFé-pvB[JJ2,[k/?Uï,ñ$ñzp!wx4AFHcQ‘D7i-XPr#!_‘$KP•-j/Wp.ZNX/[U4XXP%r!ñ
3Mu!Cq(p-QNWPIo2%-NH"zwK
Epoch : 1, Loss : 4.6961989402771
Epoch : 2, Loss : 5.4317946434021
Epoch : 3, Loss : 5.201897621154785
Epoch : 4, Loss : 5.288149833679199
Epoch : 5, Loss : 5.189873695373535
Epoch : 6, Loss : 4.911618709564209
Epoch : 7, Loss : 4.793488502502441
Epoch : 8, Loss : 4.772161483764648
Epoch : 9, Loss : 4.815664768218994
Epoch : 10, Loss : 4.785419940948486
Epoch : 11, Loss : 4.7387824058532715
Epoch : 12, Loss : 4.690345287322998
Epoch : 13, Loss : 4.673628807067871
Epoch : 14, Loss : 4.6582489013671875
Epoch : 15, Loss : 4.633567810058594
tensor([[ 2, 86, 37, 45,  8, 70, 83, 41, 26, 83, 81, 42, 79, 41, 93, 45, 78, 81,
         42, 19, 19, 45, 43, 80, 19, 79, 56, 18,  5, 57, 56, 85, 17,  8, 64, 18,
         45, 30,  7, 54, 10, 15, 12, 74, 56, 47, 76, 61, 82, 52, 16, 90, 16, 91,
         57, 44, 61, 22, 21, 43,  7,  1,  9, 83, 89, 15, 77, 74, 14, 66, 42, 90,
         43, 19, 15, 90, 73, 65

In [2]:
decoded_output = decode(generated[0])
print(decoded_output)

!ïIQ'mzM:zxNvM•QuxN33QOw3v]2$_]é1'g2QB&Z)/,q]SsdyX0’0“_Pd65O& (z‘/tq.iN’O3/’ph(—h—A“.1•;q]i—K]1ï•c—i
