In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
import os 
import requests
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
class ByteEmbedding(nn.Module):
    def __init__(self,d_model,hash_size):
        super().__init__()
        self.byte_embed = nn.Embedding(256,d_model)
        self.hash_embed = nn.Embedding(hash_size,d_model)

    def forward(self,byte_seq,hash_seq):
        byte_embedding = self.byte_embed(byte_seq)
        hash_embedding = self.hash_embed(hash_seq)
        return byte_embedding + hash_embedding

In [3]:
class FeedForwardLayer(nn.Module):
    def __init__(self,d_model,ff_dim,dropout):
        super().__init__()
        self.layer1 = nn.Linear(d_model,ff_dim)
        self.layer2 = nn.Linear(ff_dim,d_model)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self,x):
        return self.layer2(self.dropout(self.gelu(self.layer1(x))))

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self,d_model,n_heads,ff_dim,dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model,n_heads,dropout=dropout)
        self.ff = FeedForwardLayer(d_model,ff_dim,dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,attn_mask = None):
        attn_out,_ = self.attention(x,x,x,attn_mask=attn_mask) 
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        ff_out = self.ff(x)
        x = x + self.dropout(attn_out)
        return self.norm2(x)

In [5]:
class CrossAttentionBlock(nn.Module):
    def __init__(self,query_dim,key_dim,n_heads,ff_dim,dropout):
        super().__init__()
        self.attention = nn.MultiheadAttention(query_dim,n_heads,dropout=dropout)
        self.norm = nn.LayerNorm(query_dim)
        self.ff = FeedForwardLayer(query_dim,ff_dim,dropout = dropout)
        self.dropout = nn.Dropout(dropout)
        self.query_proj = nn.Linear(query_dim,query_dim)
        self.key_proj = nn.Linear(key_dim,query_dim)
        self.value_proj = nn.Linear(key_dim,query_dim)

    def forward(self,query,key,value):
        query = self.query_proj(query).permute(1,0,2)
        key = self.key_proj(key).permute(1,0,2)
        value = self.value_proj(value).permute(1,0,2)

        attn_out , _ = self.attention(query,key,value)
        attn_out = attn_out.permute(1,0,2)
        query = query.permute(1,0,2)
        query = query + self.dropout(attn_out)
        query = self.norm(query)
        ff_out = self.ff(query)
        return query + self.dropout(ff_out)

In [6]:
class LocalEncoder(nn.Module):
    def __init__(self,byte_dim,n_heads,ff_dim,n_layers,dropout):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(byte_dim,n_heads,ff_dim,dropout) for _ in range(n_layers)])
        self.cross_attn = CrossAttentionBlock(query_dim=byte_dim,key_dim=byte_dim,n_heads=n_heads,ff_dim=ff_dim,dropout=dropout)

    def forward(self,byte_embeddings,patch_embeddings):
        for layer in self.layers:
            byte_embeddings = layer(byte_embeddings)
        patch_embedding = self.cross_attn(patch_embeddings,byte_embeddings,byte_embeddings)
        return patch_embedding
        

In [7]:
class LatentGlobalTransformer(nn.Module):
    def __init__(self,patch_dim,n_heads,ff_dim,n_layers,dropout):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(patch_dim,n_heads,ff_dim,dropout) for _ in range(n_layers)])
    def forward(self,patches,attn_mask=None):
        for layer in self.layers:
            patches = layer(patches,attn_mask=attn_mask)
        return patches

In [8]:
class LocalDecoder(nn.Module):
    def __init__(self,patch_dim,byte_dim,n_heads,ff_dim,n_layers,dropout):
        super().__init__()
        self.layers = nn.ModuleList([TransformerBlock(byte_dim,n_heads,ff_dim,dropout) for _ in range(n_layers)])
        self.cross_attn = CrossAttentionBlock(query_dim=byte_dim,key_dim=patch_dim,n_heads=n_heads,ff_dim=ff_dim,dropout=dropout)
        self.output_proj = nn.Linear(byte_dim,256)

    def forward(self,patch_embedding,byte_embedding):
        byte_embedding = self.cross_attn(byte_embedding,patch_embedding,patch_embedding)
        for layer in self.layers:
            byte_embedding = layer(byte_embedding)
        return self.output_proj(byte_embedding)

In [9]:
class ByteLatentTransformer(nn.Module):
    def __init__(self,byte_dim,patch_dim,vocab_size,n_heads,ff_dim,n_encoder,n_decoder,n_global,dropout=0.1):
        super().__init__()
        self.byte_embed = ByteEmbedding(byte_dim,vocab_size)
        self.local_encoder = LocalEncoder(byte_dim,n_heads,ff_dim,n_layers=n_encoder,dropout=dropout)
        self.global_transformer = LatentGlobalTransformer(patch_dim,n_heads,ff_dim,n_layers=n_global,dropout=dropout)
        self.local_decoder = LocalDecoder(patch_dim,byte_dim,n_heads,ff_dim,n_decoder,dropout=dropout) 
        self.projection = nn.Linear(byte_dim,patch_dim)

    def forward(self,byte_seq,hash_seq,patch_seq):
        byte_embeddings = self.byte_embed(byte_seq,hash_seq)
        if patch_seq is None:
            patch_embeddings = torch.mean(byte_embeddings,dim=1,keepdim=True)
            patch_embeddings = self.local_encoder(byte_embeddings,patch_embeddings)
            patch_embeddings = self.projection(patch_embeddings)
        else:
            patch_embeddings = patch_seq
        patch_embeddings = self.global_transformer(patch_embeddings)
        byte_output = self.local_decoder(patch_embeddings,byte_embeddings)
        return byte_output

In [10]:
print(torch.cuda.is_available())

True


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import os
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt

# -----------------------
# 1. Load and Prepare Data
# -----------------------
wiki = load_dataset("wikitext", "wikitext-2-raw-v1")
# Use a small subset for quicker training; adjust these as needed.
train_size = int(len(wiki["train"]) * 0.1)
val_size = int(len(wiki["validation"]) * 0.1)

# Shuffle and select only the subset
train_subset = wiki["train"].shuffle(seed=42).select(range(train_size))
val_subset = wiki["validation"].shuffle(seed=42).select(range(val_size))

# Use the subset texts instead of the full dataset
train_text = "\n".join(train_subset["text"])
val_text = "\n".join(val_subset["text"])

# Convert text to a tensor of raw byte values using UTF-8 encoding
train_bytes = torch.tensor(list(train_text.encode('utf-8')), dtype=torch.long)
val_bytes = torch.tensor(list(val_text.encode('utf-8')), dtype=torch.long)

# -----------------------
# 2. Create a Dataset Class with Deterministic Hashing
# -----------------------
class WikiTextDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len

    def deterministic_hash(self, input_seq):
        # A simple deterministic hash: multiply by 31, add 17, modulo 256.
        return (input_seq * 31 + 17) % 256

    def __getitem__(self, idx):
        input_seq = self.data[idx:idx + self.seq_len]
        target_seq = self.data[idx + 1:idx + self.seq_len + 1]
        # Compute hash deterministically from the input sequence
        hash_seq = self.deterministic_hash(input_seq)
        return input_seq, hash_seq, target_seq

seq_len = 128
batch_size = 64

train_dataset = WikiTextDataset(train_bytes, seq_len)
val_dataset = WikiTextDataset(val_bytes, seq_len)

# Increase number of workers and enable pin_memory for faster data loading (if using GPU)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

byte_dim = 128
patch_dim = 256
vocab_size = 256  # Bytes: 0-255
n_heads = 8
ff_dim = 1024
n_encoder = 4
n_decoder = 4
n_global = 6
dropout = 0.1
epochs = 100
learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ByteLatentTransformer(
    byte_dim=byte_dim,
    patch_dim=patch_dim,
    vocab_size=vocab_size,
    n_heads=n_heads,
    ff_dim=ff_dim,
    n_encoder=n_encoder,
    n_decoder=n_decoder,
    n_global=n_global,
    dropout=dropout
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)

# -----------------------
# 4. Checkpoint Utilities
# -----------------------
def save_checkpoint(epoch, model, optimizer, train_loss, val_loss, best_val_loss, filepath="checkpoint.pth"):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'best_val_loss': best_val_loss
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved at epoch {epoch}!")

def load_checkpoint(filepath="checkpoint.pth"):
    if os.path.exists(filepath):
        checkpoint = torch.load(filepath, map_location=device)
        print(f"Checkpoint loaded from epoch {checkpoint['epoch']}!")
        return checkpoint
    else:
        print("No checkpoint found!")
        return None

# -----------------------
# 5. Helper Functions: Plot Loss and Generate Text with Temperature Sampling
# -----------------------
def plot_loss(train_losses, val_losses):
    plt.figure(figsize=(10,6))
    plt.plot(range(1, len(train_losses)+1), train_losses, label='Train Loss', marker='o')
    plt.plot(range(1, len(val_losses)+1), val_losses, label='Val Loss', marker='o')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.savefig('loss_plot.png')
    plt.show()

def generate_text_for_sample(model, input_seq, device, length, temperature=1.0):
    model.eval()
    # Compute deterministic hash for the input sequence
    hash_seq = (input_seq * 31 + 17) % 256
    generated_text = "".join([chr(x.item()) for x in input_seq[0]])
    for _ in range(length):
        with torch.no_grad():
            output = model(input_seq, hash_seq, None)
            logits = output[0, -1] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_byte = torch.multinomial(probs, num_samples=1).item()
            generated_text += chr(next_byte)
            input_seq = torch.cat([input_seq[:, 1:], torch.tensor([[next_byte]], device=device)], dim=1)
            hash_seq = (input_seq * 31 + 17) % 256
    return generated_text

# -----------------------
# 6. Training Loop with Mixed Precision, Early Stopping, Checkpointing, and Logging
# -----------------------
train_losses = []
val_losses = []
early_stop_patience = 5
epochs_no_improve = 0

def train_model():
    global train_losses, val_losses, epochs_no_improve
    best_val_loss = float('inf')
    start_epoch = 0
    checkpoint = load_checkpoint()
    if checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        train_losses = checkpoint['train_loss']
        val_losses = checkpoint['val_loss']
        best_val_loss = checkpoint['best_val_loss']
        start_epoch = checkpoint['epoch']
    
    scaler = torch.amp.GradScaler()
    
    for epoch in range(start_epoch, epochs):
        model.train()
        tot_loss = 0
        print(f"Epoch [{epoch+1}/{epochs}]")
        for i, (byte_seq, hash_seq, target_seq) in enumerate(tqdm(train_loader, desc="Training")):
            byte_seq = byte_seq.to(device, non_blocking=True)
            hash_seq = hash_seq.to(device, non_blocking=True)
            target_seq = target_seq.to(device, non_blocking=True)
            
            optimizer.zero_grad()
            with torch.amp.autocast(device_type="cuda"):
                outputs = model(byte_seq, hash_seq, None)
                if i == 0:
                    print("Logits stats: mean: {:.4f}, std: {:.4f}, min: {:.4f}, max: {:.4f}".format(
                        outputs.mean().item(), outputs.std().item(), outputs.min().item(), outputs.max().item()))
                outputs = outputs.view(-1, vocab_size)
                target_seq = target_seq.view(-1)
                loss = criterion(outputs, target_seq)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            tot_loss += loss.item()
        
        avg_train_loss = tot_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        model.eval()
        tot_val_loss = 0
        with torch.no_grad():
            for byte_seq, hash_seq, target_seq in tqdm(val_loader, desc="Validation"):
                byte_seq = byte_seq.to(device, non_blocking=True)
                hash_seq = hash_seq.to(device, non_blocking=True)
                target_seq = target_seq.to(device, non_blocking=True)
                with torch.amp.autocast(device_type="cuda"):
                    outputs = model(byte_seq, hash_seq, None)
                    outputs = outputs.view(-1, vocab_size)
                    target_seq = target_seq.view(-1)
                    l = criterion(outputs, target_seq)
                tot_val_loss += l.item()
        avg_val_loss = tot_val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), "BLT_wikitext.pth")
            print("Best model saved!")
        else:
            epochs_no_improve += 1
            print(f"No improvement in validation loss for {epochs_no_improve} epochs.")
            if epochs_no_improve >= early_stop_patience:
                print("Early stopping triggered.")
                break

        save_checkpoint(epoch+1, model, optimizer, train_losses, val_losses, best_val_loss)
        print("Sample Generated Text:")
        sample_prompt = "The "  # A common prompt in Wiki text
        sample_input = torch.tensor([ord(c) for c in sample_prompt], dtype=torch.long).unsqueeze(0).to(device)
        print(generate_text_for_sample(model, sample_input, device, length=seq_len, temperature=0.8))
    
    plot_loss(train_losses, val_losses)

# Start training
train_model()


  checkpoint = torch.load(filepath, map_location=device)


Checkpoint loaded from epoch 40!
Epoch [41/100]


Training:   0%|                               | 1/17428 [00:00<31:45,  9.15it/s]

Logits stats: mean: -9.6875, std: 6.9922, min: -25.2188, max: 13.4062


Training: 100%|███████████████████████████| 17428/17428 [09:53<00:00, 29.35it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.65it/s]


Epoch [41/100], Train Loss: 2.2772, Validation Loss: 2.2742
Best model saved!
Checkpoint saved at epoch 41!
Sample Generated Text:
The Chiffomum mm the the hanlitt To Tued . .. s . . Intanson d wodowat we . tort oo alllentititititrilices womenthe t ata aieaional 
Epoch [42/100]


Training:   0%|                               | 1/17428 [00:00<30:20,  9.58it/s]

Logits stats: mean: -8.1016, std: 5.8359, min: -23.7656, max: 12.8359


Training: 100%|███████████████████████████| 17428/17428 [09:51<00:00, 29.46it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.83it/s]


Epoch [42/100], Train Loss: 2.2548, Validation Loss: 2.2623
Best model saved!
Checkpoint saved at epoch 42!
Sample Generated Text:
The The Thererrerarererererelll rirrerr . . . . w . . . . . aryaravercon intons sh hivivaravian ona ooooos ass , , , , , , , , s , ,
Epoch [43/100]


Training:   0%|                               | 1/17428 [00:00<29:54,  9.71it/s]

Logits stats: mean: -8.9766, std: 6.5781, min: -26.0938, max: 12.7578


Training: 100%|███████████████████████████| 17428/17428 [09:50<00:00, 29.52it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.68it/s]


Epoch [43/100], Train Loss: 2.2441, Validation Loss: 2.2546
Best model saved!
Checkpoint saved at epoch 43!
Sample Generated Text:
The â f S n hind inirincer ar llllilalie ithecall il . a . . . g . .. . .. . . . . .. .. . .. .. . . ilalllilll S offiffaimiaserer
Epoch [44/100]


Training:   0%|                               | 1/17428 [00:00<30:04,  9.66it/s]

Logits stats: mean: -9.5156, std: 7.0742, min: -27.8281, max: 13.8359


Training: 100%|███████████████████████████| 17428/17428 [09:50<00:00, 29.51it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.65it/s]


Epoch [44/100], Train Loss: 2.2363, Validation Loss: 2.2488
Best model saved!
Checkpoint saved at epoch 44!
Sample Generated Text:
The Th arprlartiorourk w .. ... " " ' " um ..... .. ... V. . ... d . hith akeatty . ranerankand an atyanten t . ai alil . ne , rr 2 
Epoch [45/100]


Training:   0%|                               | 1/17428 [00:00<30:28,  9.53it/s]

Logits stats: mean: -9.6094, std: 7.1680, min: -27.5469, max: 13.7344


Training: 100%|███████████████████████████| 17428/17428 [09:56<00:00, 29.22it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.35it/s]


Epoch [45/100], Train Loss: 2.2301, Validation Loss: 2.2441
Best model saved!
Checkpoint saved at epoch 45!
Sample Generated Text:
The Thel e is his his hos s he A A As intitit @ o a a I arear IIrrer r rorid is liler Fro , , , , , , Hie . . oo rr . ne ito rolonan
Epoch [46/100]


Training:   0%|                               | 1/17428 [00:00<29:51,  9.73it/s]

Logits stats: mean: -10.0156, std: 7.4414, min: -28.6094, max: 13.4609


Training: 100%|███████████████████████████| 17428/17428 [13:58<00:00, 20.79it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 90.24it/s]


Epoch [46/100], Train Loss: 2.2250, Validation Loss: 2.2409
Best model saved!
Checkpoint saved at epoch 46!
Sample Generated Text:
The ve D Dy Dy Dy Dat Damemmommm a is was wepphend an an anntrnorro Ar Arerr rrrerarerererrext e ppphelale aveven : f . . wawecicurr
Epoch [47/100]


Training:   0%|                               | 1/17428 [00:00<30:02,  9.67it/s]

Logits stats: mean: -10.1719, std: 7.6367, min: -29.9062, max: 14.7578


Training: 100%|███████████████████████████| 17428/17428 [09:48<00:00, 29.59it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.78it/s]


Epoch [47/100], Train Loss: 2.2205, Validation Loss: 2.2382
Best model saved!
Checkpoint saved at epoch 47!
Sample Generated Text:
The ... fotulolearenanionanan Arerorertarorory . ve a a a  ny uno . .. .......... . a . . ... ! cod ted tiditigitigoginon al . . . .
Epoch [48/100]


Training:   0%|                               | 1/17428 [00:00<31:03,  9.35it/s]

Logits stats: mean: -10.4219, std: 7.7500, min: -30.2969, max: 14.1328


Training: 100%|███████████████████████████| 17428/17428 [09:48<00:00, 29.62it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 90.09it/s]


Epoch [48/100], Train Loss: 2.2165, Validation Loss: 2.2364
Best model saved!
Checkpoint saved at epoch 48!
Sample Generated Text:
The e e t it itistis isisitee t e t t , , , , , , , , , , , , , , , , , , , , , , , , , , , t , , , , , , Gan annanntanitenonenseres
Epoch [49/100]


Training:   0%|                               | 1/17428 [00:00<30:07,  9.64it/s]

Logits stats: mean: -10.4844, std: 7.8438, min: -29.8750, max: 15.6250


Training: 100%|███████████████████████████| 17428/17428 [09:48<00:00, 29.59it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 90.32it/s]


Epoch [49/100], Train Loss: 2.2129, Validation Loss: 2.2335
Best model saved!
Checkpoint saved at epoch 49!
Sample Generated Text:
The Jaracivalisooof ffe fffoffffforyoeriss seessstand andalatentrenerenerurintratrt mammmm 2 2 2 2 2 2 2 2 2 2 in ion ion ion onecon
Epoch [50/100]


Training:   0%|                               | 1/17428 [00:00<29:52,  9.72it/s]

Logits stats: mean: -10.7422, std: 7.9336, min: -31.3281, max: 16.2031


Training: 100%|███████████████████████████| 17428/17428 [09:48<00:00, 29.61it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 90.34it/s]


Epoch [50/100], Train Loss: 2.2096, Validation Loss: 2.2326
Best model saved!
Checkpoint saved at epoch 50!
Sample Generated Text:
The Tagange in it is a asasalal a a a avedededed ar . .. Hotontal Sisi s isindsunun in Mant ontonononttint at t t . 2 is ic a cas s 
Epoch [51/100]


Training:   0%|                               | 1/17428 [00:00<30:09,  9.63it/s]

Logits stats: mean: -10.7344, std: 7.9922, min: -31.5312, max: 15.5312


Training: 100%|███████████████████████████| 17428/17428 [09:54<00:00, 29.33it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 88.78it/s]


Epoch [51/100], Train Loss: 2.2065, Validation Loss: 2.2327
No improvement in validation loss for 1 epochs.
Checkpoint saved at epoch 51!
Sample Generated Text:
The Tenianalalusastanstrtst , , , , , , , , , , , , , , , , , , , , , , , , , 4 , , , , , , , , , , , , , , , , , , , , , , , , , Ar
Epoch [52/100]


Training:   0%|                               | 1/17428 [00:00<30:38,  9.48it/s]

Logits stats: mean: -10.8984, std: 8.0547, min: -32.3750, max: 14.8516


Training: 100%|███████████████████████████| 17428/17428 [09:48<00:00, 29.60it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.96it/s]


Epoch [52/100], Train Loss: 2.2036, Validation Loss: 2.2310
Best model saved!
Checkpoint saved at epoch 52!
Sample Generated Text:
The thend 1 ime Gun an aay Tays anend 1 is Assenininianinintent , , , , , , , , , , , , , , , , , La Le Lit Lilsiol rrllillellllllll
Epoch [53/100]


Training:   0%|                               | 1/17428 [00:00<30:34,  9.50it/s]

Logits stats: mean: -11.1406, std: 8.1484, min: -31.3281, max: 15.5234


Training: 100%|███████████████████████████| 17428/17428 [09:48<00:00, 29.60it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 90.15it/s]


Epoch [53/100], Train Loss: 2.2008, Validation Loss: 2.2304
Best model saved!
Checkpoint saved at epoch 53!
Sample Generated Text:
The Theranconcon an anon Feantiton a isossthesh as t ttt o hoon one on ion iomalal d 7 7 vedieiortin int 3th al alllllillilllllallll
Epoch [54/100]


Training:   0%|                               | 1/17428 [00:00<31:22,  9.26it/s]

Logits stats: mean: -11.2734, std: 8.2734, min: -33.8125, max: 15.0391


Training: 100%|███████████████████████████| 17428/17428 [09:49<00:00, 29.58it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 90.14it/s]


Epoch [54/100], Train Loss: 2.1981, Validation Loss: 2.2302
Best model saved!
Checkpoint saved at epoch 54!
Sample Generated Text:
The Tas s aseappoppland and appppppppppppplas sesererbeverererss ffipepentte oteseogatenturrry r a . fendfrdrefe Sos Scrd ct ctiel i
Epoch [55/100]


Training:   0%|                               | 1/17428 [00:00<31:00,  9.37it/s]

Logits stats: mean: -11.3125, std: 8.3516, min: -32.4375, max: 15.5156


Training: 100%|███████████████████████████| 17428/17428 [09:49<00:00, 29.56it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.71it/s]


Epoch [55/100], Train Loss: 2.1955, Validation Loss: 2.2297
Best model saved!
Checkpoint saved at epoch 55!
Sample Generated Text:
The Tund and an a a isissuss  s  ss as on to t o to ove , , , , , , , , , , , hoone anamime . . at avivovinvintstat a l gengagagagag
Epoch [56/100]


Training:   0%|                               | 1/17428 [00:00<31:17,  9.28it/s]

Logits stats: mean: -11.3438, std: 8.3359, min: -33.8125, max: 17.1406


Training: 100%|███████████████████████████| 17428/17428 [12:23<00:00, 23.43it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:43<00:00, 42.12it/s]


Epoch [56/100], Train Loss: 2.1929, Validation Loss: 2.2301
No improvement in validation loss for 1 epochs.
Checkpoint saved at epoch 56!
Sample Generated Text:
The TBTis TuTuTuTTecatict ito t t t it , , , , , , , , , , , , , , , , , , , , , , , hoochonstt atr atrrarith the ....... we awawami
Epoch [57/100]


Training:   0%|                               | 1/17428 [00:00<42:49,  6.78it/s]

Logits stats: mean: -11.5156, std: 8.4375, min: -35.1875, max: 16.3125


Training: 100%|███████████████████████████| 17428/17428 [11:17<00:00, 25.72it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 88.29it/s]


Epoch [57/100], Train Loss: 2.1904, Validation Loss: 2.2292
Best model saved!
Checkpoint saved at epoch 57!
Sample Generated Text:
The Thershosholullllillinginin int at  atsisthishieshisuniananenenananann 2 2 22 2 wan a a a a ad s a heathers ssiloll tlllllllollll
Epoch [58/100]


Training:   0%|                               | 1/17428 [00:00<31:06,  9.33it/s]

Logits stats: mean: -11.5703, std: 8.4922, min: -32.6562, max: 19.2031


Training: 100%|███████████████████████████| 17428/17428 [10:12<00:00, 28.46it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.81it/s]


Epoch [58/100], Train Loss: 2.1881, Validation Loss: 2.2299
No improvement in validation loss for 1 epochs.
Checkpoint saved at epoch 58!
Sample Generated Text:
The Thennncosoblory any " " " " " " " " " " sitat ' ' ' ' ' ' a a at at " " " " " " " " " " " " " " " " " " " " " " " " " Wel Wilill
Epoch [59/100]


Training:   0%|                               | 1/17428 [00:00<30:28,  9.53it/s]

Logits stats: mean: -11.5234, std: 8.4688, min: -32.9688, max: 17.4219


Training: 100%|███████████████████████████| 17428/17428 [09:49<00:00, 29.55it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.22it/s]


Epoch [59/100], Train Loss: 2.1858, Validation Loss: 2.2299
No improvement in validation loss for 2 epochs.
Checkpoint saved at epoch 59!
Sample Generated Text:
The TT Ta Tas sasisissssys : : : : : : : : : : : : : : : : : : : : : : : : : : : : : : : : : : : : : , : Lilerrerer rnd and and and 
Epoch [60/100]


Training:   0%|                               | 1/17428 [00:00<32:55,  8.82it/s]

Logits stats: mean: -11.7656, std: 8.6016, min: -33.6875, max: 16.1875


Training: 100%|███████████████████████████| 17428/17428 [10:13<00:00, 28.42it/s]
Validation: 100%|███████████████████████████| 1843/1843 [00:20<00:00, 89.83it/s]


Epoch [60/100], Train Loss: 2.1835, Validation Loss: 2.2300
No improvement in validation loss for 3 epochs.
Checkpoint saved at epoch 60!
Sample Generated Text:
The The Thelin ien iondiong angannd III IIIII II I I I I w I I II II I I I I I Ithata al all annnnennntrntrrerroroly onthen . ft ati
Epoch [61/100]


Training:   0%|                               | 1/17428 [00:00<31:02,  9.36it/s]

Logits stats: mean: -11.6875, std: 8.5469, min: -34.7188, max: 18.0938


Training:  69%|██████████████████▌        | 11981/17428 [07:08<03:13, 28.09it/s]