In [1]:
from transformer import Transformer
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from collections import Counter
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers

In [2]:
with open('input.txt', 'r') as f:
    text = f.read()

In [3]:
# Initialize and train BPE tokenizer
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
trainer = trainers.BpeTrainer(special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
tokenizer.train(files=["input.txt"], trainer=trainer)
tokenizer.decoder = decoders.WordPiece()

# Tokenize the text using BPE tokenizer
tokenized_text = tokenizer.encode(text).tokens

# Create a token-based vocabulary
token_counts = Counter(tokenized_text)
tokens = sorted(token_counts.keys())
token_to_idx = {token: idx for idx, token in enumerate(tokens)}
idx_to_token = {idx: token for token, idx in token_to_idx.items()}

# Convert the tokenized text into numerical data
data = [token_to_idx[token] for token in tokenized_text if token in token_to_idx]

# Hyperparameters
BATCH_SIZE = 32
BLOCK_SIZE = 128

# Create batches
num_batches = len(data) // (BATCH_SIZE * BLOCK_SIZE)
data = data[:num_batches * BATCH_SIZE * BLOCK_SIZE]
data_batches = torch.tensor(data).view(BATCH_SIZE, -1)

# train/val split
train_batches = int(0.9 * num_batches)
train_data = data_batches[:, :train_batches * BLOCK_SIZE]
val_data = data_batches[:, train_batches * BLOCK_SIZE:]

In [4]:
# Parameters
src_token_size = len(tokens)
tgt_token_size = len(tokens)
d_model = 512
h = 8
d_ff = 2048
num_layers = 6
dropout = 0.2
max_len = len(tokens)

# Create the model
model = Transformer(src_token_size, tgt_token_size, d_model, h, d_ff, num_layers, dropout, max_len)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.device = device

In [6]:
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

NUM_EPOCHS = 20
WARMUP_EPOCHS = 10
INITIAL_LR = 1e-6  

# Warmup function
def warmup_lr_scheduler(epoch, optimizer):
    if epoch < WARMUP_EPOCHS:
        lr = INITIAL_LR + (1e-4 - INITIAL_LR) * (epoch / WARMUP_EPOCHS)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()

for epoch in range(NUM_EPOCHS):
    warmup_lr_scheduler(epoch, optimizer)
    
    model.train()
    total_train_loss = 0
    
    # Compute the number of training batches
    train_batches = (train_data.size(1) - BLOCK_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE
    
    # Training
    train_loop = tqdm(range(0, train_data.size(1) - BLOCK_SIZE, BLOCK_SIZE), total=train_batches, leave=False)
    for i in train_loop:
        inputs = train_data[:, i:i+BLOCK_SIZE].to(device)
        targets = train_data[:, i+1:i+1+BLOCK_SIZE].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs, inputs)
        loss = criterion(outputs.view(-1, tgt_token_size), targets.view(-1))
        total_train_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
        train_loop.set_description(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
        train_loop.set_postfix(train_loss=total_train_loss/(i//BLOCK_SIZE + 1))
    
    # Compute the number of validation batches
    val_batches = (val_data.size(1) - BLOCK_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE
    
    # Validation
    model.eval()
    total_val_loss = 0
    val_loop = tqdm(range(0, val_data.size(1) - BLOCK_SIZE, BLOCK_SIZE), total=val_batches, leave=False)
    with torch.no_grad():
        for i in val_loop:
            inputs = val_data[:, i:i+BLOCK_SIZE].to(device)
            targets = val_data[:, i+1:i+1+BLOCK_SIZE].to(device)
            
            outputs = model(inputs, inputs)
            loss = criterion(outputs.view(-1, tgt_token_size), targets.view(-1))
            total_val_loss += loss.item()
            
            val_loop.set_description(f"Epoch {epoch + 1}/{NUM_EPOCHS} (Validation)")
            val_loop.set_postfix(val_loss=total_val_loss/(i//BLOCK_SIZE + 1))

    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {total_train_loss/train_batches:.4f}, Val Loss: {total_val_loss/val_batches:.4f}")

                                                                                     

Epoch 1/20, Train Loss: 9.4273, Val Loss: 9.1439


                                                                                     

Epoch 2/20, Train Loss: 8.3923, Val Loss: 8.0170


                                                                                     

Epoch 3/20, Train Loss: 7.7498, Val Loss: 7.4128


                                                                                     

Epoch 4/20, Train Loss: 7.1094, Val Loss: 6.8045


                                                                                     

Epoch 5/20, Train Loss: 6.5694, Val Loss: 6.4051


                                                                                     

Epoch 6/20, Train Loss: 6.2423, Val Loss: 6.1540


                                                                                     

Epoch 7/20, Train Loss: 5.9787, Val Loss: 5.9190


                                                                                     

Epoch 8/20, Train Loss: 5.7366, Val Loss: 5.7259


                                                                                     

Epoch 9/20, Train Loss: 5.5345, Val Loss: 5.5903


                                                                                      

Epoch 10/20, Train Loss: 5.3634, Val Loss: 5.4875


                                                                                      

Epoch 11/20, Train Loss: 5.2234, Val Loss: 5.4232


                                                                                      

Epoch 12/20, Train Loss: 5.1075, Val Loss: 5.3758


                                                                                      

Epoch 13/20, Train Loss: 5.0062, Val Loss: 5.3335


                                                                                      

Epoch 14/20, Train Loss: 4.9123, Val Loss: 5.3143


                                                                                      

Epoch 15/20, Train Loss: 4.8303, Val Loss: 5.2886


                                                                                      

Epoch 16/20, Train Loss: 4.7538, Val Loss: 5.2916


                                                                                      

Epoch 17/20, Train Loss: 4.6814, Val Loss: 5.2884


                                                                                      

Epoch 18/20, Train Loss: 4.6090, Val Loss: 5.2736


                                                                                      

Epoch 19/20, Train Loss: 4.5420, Val Loss: 5.2789


                                                                                      

Epoch 20/20, Train Loss: 4.4830, Val Loss: 5.2854




In [14]:
model.eval()

start_text = "We are accounted poor citizens"
start_tokens = tokenizer.encode(start_text).tokens
start_tokens = [token_to_idx[token] for token in start_tokens if token in token_to_idx]
start_tokens = torch.tensor(start_tokens).unsqueeze(0)

generated_text_indices = model.generate(start_tokens.to(device), max_new_tokens=50)

# Convert the generated indices back to tokens
generated_tokens = [idx_to_token[idx] for idx in generated_text_indices[0].cpu().numpy()]

# Convert the tokens into their respective IDs
generated_token_ids = [tokenizer.token_to_id(token) for token in generated_tokens]

# Decode the sequence of IDs
generated_text = tokenizer.decode(generated_token_ids)

print(generated_text)

We are accounted poor citizens by head o Unwieldy passes loins true great all crimes in every deliver in all declining in dark : intercepts in all disorder men in all in all in all alarms of King all Rutland being string in yours pieces short all refuse in all Gave all congealed in in
