In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# tokenizer/dataloader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

class SimpleTokenizer:
    def __init__(self, max_len, pad_token='<pad>', unk_token='<unk>'):
        self.vocab = {pad_token: 0, unk_token: 1}
        self.max_len = max_len
        self.pad_token = pad_token
        self.unk_token = unk_token

    def build_vocab(self, sentences):
        for sentence in sentences:
            for word in sentence.split():
                if word not in self.vocab:
                    self.vocab[word] = len(self.vocab)

        print(len(self.vocab))

    def encode(self, sentence):
        token_ids = [
            self.vocab.get(word, self.vocab[self.unk_token]) for word in sentence.split()
        ]
        # Truncate if too long
        token_ids = token_ids[:self.max_len]
        # Pad if too short
        token_ids += [self.vocab[self.pad_token]] * (self.max_len - len(token_ids))
        return token_ids

class OpenWebTextSimpleDataset(Dataset):
    def __init__(self, tokenizer, data, num_samples=10000, use_train_split_for_vocab = True):
        self.dataset = load_dataset('text', data_files=data)

        self.tokenizer = tokenizer

        self.train_split = 'train' if use_train_split_for_vocab else list(self.dataset.keys())[0]

        self.data = self.dataset[self.train_split]
        print(self.data[3]['text'])

        # Build the vocabulary from the first 100 sentences of the training split
        self.tokenizer.build_vocab([self.data[i]['text'] for i in range(min(5000, len(self.data)))])
        print(self.tokenizer.encode("Test this string."))


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]['text']
        encoded_text = self.tokenizer.encode(text)
        input_ids = encoded_text[:-1]
        target_ids = encoded_text[1:]
        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)

# Parameters
max_len = 25

# Instantiate the tokenizer
simple_tokenizer = SimpleTokenizer(max_len=max_len)

# Create the dataset
train_dataset = OpenWebTextSimpleDataset(tokenizer=simple_tokenizer, data = splits["train"][:1000], num_samples=10000)
val_dataset = OpenWebTextSimpleDataset(tokenizer=simple_tokenizer, data = splits["validation"][:100], num_samples=10000)




In [None]:
# Create the DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)

print(f"number of batches {len(train_dataloader)}")
# Usage example:
for x, y in train_dataloader:
    print(x, y, x.shape, y.shape)
    break

In [None]:
def scaled_dot_product_attention(query, key, value, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    attn = F.softmax(scores, dim=-1)
    output = torch.matmul(attn, value)
    return output, attn
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.2):
        super().__init__()

        self.d_model = d_model
        self.d_k = d_model // num_heads
        self.h = num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, attn_mask=None):
        bs = q.size(0)

        # Perform linear operation and split into h heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # Transpose to get dimensions bs * h * sl * d_model
        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)

        # Calculate attention using function we will define next
        scores, attn = scaled_dot_product_attention(q, k, v, attn_mask)

        # Concatenate heads and put through final linear layer
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        
        output = self.out(concat)

        return output

# Model

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):
        super(Encoder, self).__init__()
        self.batch_size = batch_size
        self.enc_units = enc_units
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, enc_units, batch_first=True)

    def forward(self, x, hidden = None):
        x = self.embedding(x)
        if hidden == None:
            hidden = self.initialize_hidden_state()
        output, (hidden, cell) = self.lstm(x, hidden)
        return output, (hidden, cell)

    def initialize_hidden_state(self):
        return (torch.zeros(1, self.batch_size, self.enc_units).to(device),
                torch.zeros(1, self.batch_size, self.enc_units).to(device))

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, dec_units, batch_size):
        super(Decoder, self).__init__()
        self.batch_size = batch_size
        self.dec_units = dec_units
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim + dec_units, dec_units, batch_first=True)
        self.fc = nn.Linear(dec_units, vocab_size)
        self.attention = MultiHeadAttention(dec_units, num_heads=2)

    def forward(self, x, hidden, enc_output):
        hidden_state, cell_state = hidden
        hidden_state = hidden_state[-1].unsqueeze(0)  # shape: [1, batch, features]

        # print("hidden_state shape:", hidden_state.shape)
        # print("enc_output shape:", enc_output.shape)

        enc_output = enc_output.transpose(0, 1)

        x = self.embedding(x)

        # print("x shape:", x.shape)

        # Attention
        context_vector, attention_weights = self.attention(hidden_state, enc_output, enc_output)

        # print("context_vector shape:", context_vector.shape)
        # print("attention_weights shape:", attention_weights.shape)

        # Adjust context_vector to have the same sequence length as x
        context_vector = context_vector.transpose(0, 1)  # Transpose context_vector to shape [batch, 1, features]

        x = x.unsqueeze(1)  # Add a dimension to x, x shape: [batch, 1, embedding_dim]

        # print("context_vector shape:", context_vector.shape)
        # print("x shape:", x.shape)


        x = torch.cat((context_vector, x), dim=-1)

        # print("x shape:", x.shape)


        output, (hidden, cell) = self.lstm(x, (hidden_state, torch.zeros_like(hidden_state)))
        output = output.reshape(-1, output.shape[2])
        x = self.fc(output)

        return x, (hidden, cell), attention_weights


In [None]:
class LSTM_model(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target, teacher_force_ratio=0.5):
        batch_size = source.shape[0]
        target_len = target.shape[1]
        target_vocab_size = self.decoder.fc.out_features

        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(device)
        encoder_outputs, hidden = self.encoder(source)

        # Decoder's first input is the start-of-sequence token
        x = target[:, 0]

        for t in range(1, target_len):
            # Note: We are using the last hidden state of the encoder as the initial hidden state for the decoder
            output, hidden, _ = self.decoder(x, hidden, encoder_outputs)
            outputs[:, t] = output

            # Decide whether to use teacher forcing
            best_guess = output.argmax(-1)
            x = target[:, t] if random.random() < teacher_force_ratio else best_guess

        return outputs

In [None]:
config = {
    "beam_width" : 2,
    "lr"         : 5e-4,
    "epochs"     : 15,
    "batch_size" : 64,
    "dropout": 0.2,
    "input_dim": len(simple_tokenizer.vocab),
    "output_dim": len(simple_tokenizer.vocab),
    "emb_dim": 1024,
    "enc_hid_dim": 1024,
    "dec_hid_dim": 1024
}

num_layers = 2
dropout = 0.2
heads = 8

encoder = Encoder(config['input_dim'], config["emb_dim"], config['enc_hid_dim'], config['batch_size'])
decoder = Decoder(config['output_dim'], config["emb_dim"], config['dec_hid_dim'], config['batch_size'])
model = LSTM_model(encoder, decoder).to(device)



# Training

In [None]:
def train(model, dataloader, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0

    for i, (src, trg) in enumerate(dataloader):

        src = src.to(device)
        trg = trg.to(device)

        with torch.cuda.amp.autocast():
            optimizer.zero_grad()


            output = model(src, trg)

            # trg shape: (batch_size, trg_len)
            # output shape: (batch_size, trg_len, output_dim)

            output_dim = output.shape[-1]

            output = output[:, 1:].reshape(-1, output_dim)
            trg = trg[:, 1:].reshape(-1)

            loss = criterion(output, trg)
            loss.backward()

        # Clip the gradients to prevent them from exploding (a common issue in RNNs)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(dataloader)


In [None]:
def validate(model, dataloader, criterion):
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for i, (src, trg) in enumerate(dataloader):

            src = src.to(device)
            trg = trg.to(device)

            output = model(src, trg, 0) # turn off teacher forcing
            # trg = [trg len, batch size]
            # output = [trg len, batch size, output dim]

            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

    return epoch_loss / len(dataloader)


In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters())

In [None]:
from tqdm.auto import tqdm

N_EPOCHS = 10
CLIP = 1

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_dataloader, optimizer, criterion, CLIP)
    valid_loss = validate(model, val_dataloader, criterion)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-model.pt')

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f}')