Remember to make a copy of this colab notebook before you start editing cells!

In [1]:
!pip install datasets



In [2]:
!pip install tqdm



In [3]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-02-16 03:23:42--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.2’


2025-02-16 03:23:43 (145 MB/s) - ‘input.txt.2’ saved [1115394/1115394]



In [4]:
# DO NOT MODIFY ANY OF THIS CODE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

In [5]:
# DO NOT MODIFY ANY OF THIS CODE

# Hyperparameters
batch_sz = 16
context_length = 32
max_iterations = 30000
log_interval = 200
init_lr = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_steps = 200
embedding_dim = 64
num_heads = 4
num_blocks = 4
drop_prob = 0.0

# Load and prepare the data
torch.manual_seed(1337)
with open('input.txt', 'r', encoding='utf-8') as file:
    text_data = file.read()

unique_chars = sorted(set(text_data))
vocab_size = len(unique_chars)
char_to_index = {ch: i for i, ch in enumerate(unique_chars)}
index_to_char = {i: ch for i, ch in enumerate(unique_chars)}

def encode_text(s): return [char_to_index[c] for c in s]
def decode_text(l): return ''.join([index_to_char[i] for i in l])

# Split data for training and validation
data_tensor = torch.tensor(encode_text(text_data), dtype=torch.long)
train_size = int(0.9 * len(data_tensor))
train_data, val_data = data_tensor[:train_size], data_tensor[train_size:]

def generate_batch(split):
    data_src = train_data if split == 'train' else val_data
    indices = torch.randint(0, len(data_src) - context_length, (batch_sz,))
    inputs = torch.stack([data_src[i:i + context_length] for i in indices])
    targets = torch.stack([data_src[i + 1:i + context_length + 1] for i in indices])
    return inputs.to(device), targets.to(device)

@torch.no_grad()
def evaluate_loss():
    model.eval()
    losses = {'train': [], 'val': []}
    for split in ['train', 'val']:
        for _ in range(eval_steps):
            batch_x, batch_y = generate_batch(split)
            _, batch_loss = model(batch_x, batch_y)
            losses[split].append(batch_loss.item())
    model.train()
    return {split: torch.tensor(losses[split]).mean().item() for split in losses}

In [14]:
emb_dim // num_heads

NameError: name 'emb_dim' is not defined

In [29]:
# YOU WILL CHANGE CODE IN THIS CELL
# Implement a Transformer model with PyTorch. Fill out the provided skeleton.

class SelfAttention(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        # TODO: initialize key, query, and value as linear layers. Set bias=False
        self.key_proj = nn.Linear(embedding_dim, head_dim, bias = False)
        self.query_proj = nn.Linear(embedding_dim, head_dim, bias = False)
        self.value_proj = nn.Linear(embedding_dim, head_dim, bias = False)
        self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        B, T, C = x.shape
        keys = self.key_proj(x)
        queries = self.query_proj(x)
        values = self.value_proj(x)
        scores = (queries @ keys.transpose(-2, -1)) * (C ** -0.5)
        scores = scores.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
        # TODO: apply softmax and dropout
        attention_weights = F.softmax(scores, dim = -1)
        attention_weights = self.dropout(attention_weights)
        return attention_weights @ values

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, num_heads, head_dim):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttention(head_dim) for _ in range(num_heads)])
        self.output_proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        # TODO: combine multiple attention heads
        x = torch.cat([h(x) for h in self.heads], dim = -1)
        return self.dropout(self.output_proj(x))

class FeedForwardLayer(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(emb_dim, 4 * emb_dim),
            nn.ReLU(),
            nn.Linear(4 * emb_dim, emb_dim),
            nn.Dropout(drop_prob)
        )

    def forward(self, x):
        return self.layers(x)

class TransformerBlock(nn.Module):
    def __init__(self, emb_dim, num_heads):
        super().__init__()
        # TODO: initialize the multihead self attention, feed forward layer, and two layernorms
        self.attention = MultiHeadSelfAttention(num_heads, emb_dim // num_heads)
        self.feed_forward = FeedForwardLayer(emb_dim)
        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)

    def forward(self, x):
        # TODO: implement the forward logic
        # Including the ResNet NN
        x = x + self.attention(self.norm1(x))
        x = x + self.feed_forward(self.norm2(x))
        return x

class TransformerLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_embeddings = nn.Embedding(context_length, embedding_dim)
        self.transformer_blocks = nn.Sequential(*[TransformerBlock(emb_dim = embedding_dim, num_heads=num_heads) for _ in range(num_blocks)])
        self.final_norm = nn.LayerNorm(embedding_dim)
        self.head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embeddings(idx)
        pos_emb = self.position_embeddings(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.final_norm(self.transformer_blocks(x))
        logits = self.head(x)


        if targets is None:
            return logits, None

        # Including channels as outputs
        B, T, C = logits.shape

        logits = logits.view(B * T, C)
        targets = targets.view(B * T)
        loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate_text(self, idx, max_tokens):
        for _ in range(max_tokens):
            idx_cond = idx[:, -context_length:]
            logits, _ = self(idx_cond)
            probs = F.softmax(logits[:, -1, :], dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx

In [30]:
# DO NOT MODIFY ANY OF THIS CODE

# Initialize and train the model
model = TransformerLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=init_lr)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
best_val_loss = float('inf')
no_progress = 0
max_patience = 10

for step in tqdm(range(max_iterations)):
    if step % log_interval == 0 or step == max_iterations - 1:
        current_losses = evaluate_loss()
        print(f"Step {step}: train loss {current_losses['train']:.4f}, val loss {current_losses['val']:.4f}")
        scheduler.step(current_losses['val'])

        if current_losses['val'] < best_val_loss:
            best_val_loss = current_losses['val']
            no_progress = 0
        else:
            no_progress += 1

    batch_x, batch_y = generate_batch('train')
    _, batch_loss = model(batch_x, batch_y)
    optimizer.zero_grad()
    batch_loss.backward()
    optimizer.step()

start_context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_output = decode_text(model.generate_text(start_context, max_tokens=1000)[0].tolist())

with open("output.txt", "w", encoding="utf-8") as out_file:
    out_file.write(generated_output)

  0%|          | 6/30000 [00:03<3:14:28,  2.57it/s] 

Step 0: train loss 4.4243, val loss 4.4220


  1%|          | 208/30000 [00:10<1:12:53,  6.81it/s]

Step 200: train loss 2.5089, val loss 2.5157


  1%|▏         | 408/30000 [00:18<1:07:45,  7.28it/s]

Step 400: train loss 2.3548, val loss 2.3621


  2%|▏         | 607/30000 [00:26<1:25:38,  5.72it/s]

Step 600: train loss 2.2647, val loss 2.2819


  3%|▎         | 806/30000 [00:34<1:06:07,  7.36it/s]

Step 800: train loss 2.1909, val loss 2.2194


  3%|▎         | 1005/30000 [00:42<1:29:41,  5.39it/s]

Step 1000: train loss 2.1274, val loss 2.1583


  4%|▍         | 1207/30000 [00:50<1:16:54,  6.24it/s]

Step 1200: train loss 2.0646, val loss 2.0989


  5%|▍         | 1404/30000 [00:58<1:33:14,  5.11it/s]

Step 1400: train loss 2.0075, val loss 2.0640


  5%|▌         | 1605/30000 [01:06<1:27:31,  5.41it/s]

Step 1600: train loss 1.9768, val loss 2.0287


  6%|▌         | 1804/30000 [01:14<1:43:26,  4.54it/s]

Step 1800: train loss 1.9385, val loss 2.0197


  7%|▋         | 2008/30000 [01:21<1:03:17,  7.37it/s]

Step 2000: train loss 1.8992, val loss 1.9826


  7%|▋         | 2205/30000 [01:29<1:26:52,  5.33it/s]

Step 2200: train loss 1.8723, val loss 1.9746


  8%|▊         | 2404/30000 [01:37<1:42:39,  4.48it/s]

Step 2400: train loss 1.8349, val loss 1.9389


  9%|▊         | 2604/30000 [01:45<1:23:54,  5.44it/s]

Step 2600: train loss 1.8231, val loss 1.9426


  9%|▉         | 2807/30000 [01:53<1:03:23,  7.15it/s]

Step 2800: train loss 1.7962, val loss 1.9190


 10%|█         | 3005/30000 [02:01<1:39:10,  4.54it/s]

Step 3000: train loss 1.7803, val loss 1.9197


 11%|█         | 3204/30000 [02:08<1:22:46,  5.40it/s]

Step 3200: train loss 1.7727, val loss 1.8989


 11%|█▏        | 3405/30000 [02:16<1:25:30,  5.18it/s]

Step 3400: train loss 1.7691, val loss 1.9124


 12%|█▏        | 3605/30000 [02:24<1:11:28,  6.15it/s]

Step 3600: train loss 1.7373, val loss 1.8789


 13%|█▎        | 3804/30000 [02:32<1:20:45,  5.41it/s]

Step 3800: train loss 1.7182, val loss 1.8889


 13%|█▎        | 4004/30000 [02:40<1:38:43,  4.39it/s]

Step 4000: train loss 1.7044, val loss 1.8606


 14%|█▍        | 4203/30000 [02:48<1:25:30,  5.03it/s]

Step 4200: train loss 1.6925, val loss 1.8489


 15%|█▍        | 4405/30000 [02:56<1:19:21,  5.38it/s]

Step 4400: train loss 1.6912, val loss 1.8601


 15%|█▌        | 4608/30000 [03:04<1:10:40,  5.99it/s]

Step 4600: train loss 1.6826, val loss 1.8546


 16%|█▌        | 4806/30000 [03:11<1:00:48,  6.91it/s]

Step 4800: train loss 1.6792, val loss 1.8424


 17%|█▋        | 5007/30000 [03:20<56:49,  7.33it/s]  

Step 5000: train loss 1.6552, val loss 1.8296


 17%|█▋        | 5206/30000 [03:28<1:11:10,  5.81it/s]

Step 5200: train loss 1.6557, val loss 1.8340


 18%|█▊        | 5405/30000 [03:35<1:15:59,  5.39it/s]

Step 5400: train loss 1.6545, val loss 1.8207


 19%|█▊        | 5605/30000 [03:43<1:14:59,  5.42it/s]

Step 5600: train loss 1.6484, val loss 1.8077


 19%|█▉        | 5808/30000 [03:51<1:05:56,  6.11it/s]

Step 5800: train loss 1.6381, val loss 1.8183


 20%|██        | 6007/30000 [03:59<54:34,  7.33it/s]  

Step 6000: train loss 1.6238, val loss 1.8143


 21%|██        | 6205/30000 [04:07<1:12:35,  5.46it/s]

Step 6200: train loss 1.6259, val loss 1.7862


 21%|██▏       | 6407/30000 [04:15<1:03:46,  6.17it/s]

Step 6400: train loss 1.6335, val loss 1.7913


 22%|██▏       | 6606/30000 [04:22<53:24,  7.30it/s]  

Step 6600: train loss 1.6320, val loss 1.7876


 23%|██▎       | 6807/30000 [04:30<53:25,  7.24it/s]  

Step 6800: train loss 1.6109, val loss 1.7769


 23%|██▎       | 7006/30000 [04:38<1:02:10,  6.16it/s]

Step 7000: train loss 1.6188, val loss 1.7735


 24%|██▍       | 7205/30000 [04:46<1:10:33,  5.38it/s]

Step 7200: train loss 1.5998, val loss 1.7709


 25%|██▍       | 7408/30000 [04:54<52:57,  7.11it/s]  

Step 7400: train loss 1.5988, val loss 1.7692


 25%|██▌       | 7607/30000 [05:02<1:04:16,  5.81it/s]

Step 7600: train loss 1.5882, val loss 1.7742


 26%|██▌       | 7806/30000 [05:10<50:42,  7.29it/s]  

Step 7800: train loss 1.5996, val loss 1.7627


 27%|██▋       | 8007/30000 [05:18<51:15,  7.15it/s]  

Step 8000: train loss 1.5882, val loss 1.7708


 27%|██▋       | 8203/30000 [05:26<1:22:37,  4.40it/s]

Step 8200: train loss 1.5875, val loss 1.7624


 28%|██▊       | 8406/30000 [05:34<49:05,  7.33it/s]  

Step 8400: train loss 1.5829, val loss 1.7616


 29%|██▊       | 8604/30000 [05:42<1:08:23,  5.21it/s]

Step 8600: train loss 1.5800, val loss 1.7483


 29%|██▉       | 8804/30000 [05:50<1:00:38,  5.83it/s]

Step 8800: train loss 1.5677, val loss 1.7690


 30%|███       | 9006/30000 [05:57<47:14,  7.41it/s]  

Step 9000: train loss 1.5653, val loss 1.7626


 31%|███       | 9208/30000 [06:05<54:24,  6.37it/s]  

Step 9200: train loss 1.5585, val loss 1.7424


 31%|███▏      | 9402/30000 [06:13<1:08:32,  5.01it/s]

Step 9400: train loss 1.5683, val loss 1.7390


 32%|███▏      | 9604/30000 [06:21<1:02:40,  5.42it/s]

Step 9600: train loss 1.5558, val loss 1.7485


 33%|███▎      | 9806/30000 [06:29<1:00:26,  5.57it/s]

Step 9800: train loss 1.5580, val loss 1.7215


 33%|███▎      | 10005/30000 [06:37<53:31,  6.23it/s]  

Step 10000: train loss 1.5477, val loss 1.7567


 34%|███▍      | 10208/30000 [06:45<45:50,  7.20it/s]  

Step 10200: train loss 1.5563, val loss 1.7405


 35%|███▍      | 10406/30000 [06:53<56:30,  5.78it/s]  

Step 10400: train loss 1.5442, val loss 1.7199


 35%|███▌      | 10605/30000 [07:00<59:37,  5.42it/s]

Step 10600: train loss 1.5405, val loss 1.7383


 36%|███▌      | 10807/30000 [07:08<43:41,  7.32it/s]

Step 10800: train loss 1.5412, val loss 1.7182


 37%|███▋      | 11007/30000 [07:16<55:35,  5.69it/s]  

Step 11000: train loss 1.5501, val loss 1.7274


 37%|███▋      | 11206/30000 [07:24<43:09,  7.26it/s]

Step 11200: train loss 1.5457, val loss 1.7243


 38%|███▊      | 11404/30000 [07:32<56:59,  5.44it/s]

Step 11400: train loss 1.5412, val loss 1.7232


 39%|███▊      | 11608/30000 [07:40<49:05,  6.24it/s]  

Step 11600: train loss 1.5285, val loss 1.7169


 39%|███▉      | 11806/30000 [07:47<40:22,  7.51it/s]

Step 11800: train loss 1.5342, val loss 1.7376


 40%|████      | 12006/30000 [07:55<40:48,  7.35it/s]

Step 12000: train loss 1.5377, val loss 1.7216


 41%|████      | 12205/30000 [08:03<1:04:29,  4.60it/s]

Step 12200: train loss 1.5376, val loss 1.7092


 41%|████▏     | 12405/30000 [08:11<52:43,  5.56it/s]

Step 12400: train loss 1.5173, val loss 1.7055


 42%|████▏     | 12608/30000 [08:19<39:59,  7.25it/s]

Step 12600: train loss 1.5217, val loss 1.6933


 43%|████▎     | 12803/30000 [08:26<58:39,  4.89it/s]

Step 12800: train loss 1.5339, val loss 1.7008


 43%|████▎     | 13004/30000 [08:34<51:44,  5.48it/s]

Step 13000: train loss 1.5264, val loss 1.7186


 44%|████▍     | 13204/30000 [08:42<1:03:33,  4.40it/s]

Step 13200: train loss 1.5145, val loss 1.7230


 45%|████▍     | 13403/30000 [08:49<52:05,  5.31it/s]

Step 13400: train loss 1.5199, val loss 1.7013


 45%|████▌     | 13605/30000 [08:57<50:49,  5.38it/s]

Step 13600: train loss 1.5242, val loss 1.7159


 46%|████▌     | 13807/30000 [09:05<45:42,  5.90it/s]  

Step 13800: train loss 1.5067, val loss 1.6972


 47%|████▋     | 14006/30000 [09:12<35:25,  7.52it/s]

Step 14000: train loss 1.4840, val loss 1.6936


 47%|████▋     | 14207/30000 [09:20<35:09,  7.49it/s]

Step 14200: train loss 1.4787, val loss 1.6685


 48%|████▊     | 14407/30000 [09:28<41:00,  6.34it/s]

Step 14400: train loss 1.4730, val loss 1.6690


 49%|████▊     | 14605/30000 [09:35<46:52,  5.47it/s]

Step 14600: train loss 1.4760, val loss 1.6748


 49%|████▉     | 14805/30000 [09:43<45:46,  5.53it/s]

Step 14800: train loss 1.4757, val loss 1.6623


 50%|█████     | 15005/30000 [09:51<54:27,  4.59it/s]

Step 15000: train loss 1.4663, val loss 1.6578


 51%|█████     | 15208/30000 [09:59<35:16,  6.99it/s]

Step 15200: train loss 1.4700, val loss 1.6599


 51%|█████▏    | 15406/30000 [10:06<33:22,  7.29it/s]

Step 15400: train loss 1.4685, val loss 1.6587


 52%|█████▏    | 15605/30000 [10:14<48:05,  4.99it/s]

Step 15600: train loss 1.4698, val loss 1.6530


 53%|█████▎    | 15807/30000 [10:22<31:56,  7.40it/s]

Step 15800: train loss 1.4639, val loss 1.6626


 53%|█████▎    | 16006/30000 [10:30<40:43,  5.73it/s]  

Step 16000: train loss 1.4648, val loss 1.6547


 54%|█████▍    | 16205/30000 [10:37<43:10,  5.32it/s]

Step 16200: train loss 1.4623, val loss 1.6480


 55%|█████▍    | 16406/30000 [10:45<30:20,  7.47it/s]

Step 16400: train loss 1.4570, val loss 1.6630


 55%|█████▌    | 16605/30000 [10:53<41:03,  5.44it/s]

Step 16600: train loss 1.4625, val loss 1.6454


 56%|█████▌    | 16805/30000 [11:00<39:39,  5.55it/s]

Step 16800: train loss 1.4608, val loss 1.6581


 57%|█████▋    | 17007/30000 [11:08<28:55,  7.49it/s]

Step 17000: train loss 1.4610, val loss 1.6500


 57%|█████▋    | 17206/30000 [11:16<33:47,  6.31it/s]

Step 17200: train loss 1.4599, val loss 1.6626


 58%|█████▊    | 17406/30000 [11:23<28:04,  7.48it/s]

Step 17400: train loss 1.4580, val loss 1.6725


 59%|█████▊    | 17608/30000 [11:31<27:27,  7.52it/s]

Step 17600: train loss 1.4507, val loss 1.6580


 59%|█████▉    | 17808/30000 [11:39<32:19,  6.29it/s]

Step 17800: train loss 1.4566, val loss 1.6444


 60%|██████    | 18007/30000 [11:46<26:33,  7.53it/s]

Step 18000: train loss 1.4654, val loss 1.6647


 61%|██████    | 18205/30000 [11:54<37:24,  5.26it/s]

Step 18200: train loss 1.4589, val loss 1.6560


 61%|██████▏   | 18404/30000 [12:02<38:29,  5.02it/s]

Step 18400: train loss 1.4491, val loss 1.6530


 62%|██████▏   | 18609/30000 [12:10<25:37,  7.41it/s]

Step 18600: train loss 1.4438, val loss 1.6557


 63%|██████▎   | 18806/30000 [12:18<31:24,  5.94it/s]

Step 18800: train loss 1.4496, val loss 1.6565


 63%|██████▎   | 19006/30000 [12:25<24:42,  7.42it/s]

Step 19000: train loss 1.4541, val loss 1.6620


 64%|██████▍   | 19207/30000 [12:33<25:27,  7.07it/s]

Step 19200: train loss 1.4420, val loss 1.6312


 65%|██████▍   | 19404/30000 [12:41<42:37,  4.14it/s]

Step 19400: train loss 1.4298, val loss 1.6362


 65%|██████▌   | 19608/30000 [12:48<23:11,  7.47it/s]

Step 19600: train loss 1.4409, val loss 1.6372


 66%|██████▌   | 19806/30000 [12:56<22:52,  7.43it/s]

Step 19800: train loss 1.4267, val loss 1.6373


 67%|██████▋   | 20005/30000 [13:04<27:48,  5.99it/s]

Step 20000: train loss 1.4338, val loss 1.6415


 67%|██████▋   | 20205/30000 [13:11<29:55,  5.45it/s]

Step 20200: train loss 1.4235, val loss 1.6293


 68%|██████▊   | 20407/30000 [13:19<21:30,  7.44it/s]

Step 20400: train loss 1.4183, val loss 1.6368


 69%|██████▊   | 20605/30000 [13:27<32:33,  4.81it/s]

Step 20600: train loss 1.4286, val loss 1.6321


 69%|██████▉   | 20808/30000 [13:34<20:27,  7.49it/s]

Step 20800: train loss 1.4178, val loss 1.6326


 70%|███████   | 21008/30000 [13:42<23:06,  6.48it/s]

Step 21000: train loss 1.4217, val loss 1.6369


 71%|███████   | 21202/30000 [13:49<27:12,  5.39it/s]

Step 21200: train loss 1.4178, val loss 1.6394


 71%|███████▏  | 21408/30000 [13:57<19:21,  7.40it/s]

Step 21400: train loss 1.4275, val loss 1.6391


 72%|███████▏  | 21607/30000 [14:05<23:33,  5.94it/s]

Step 21600: train loss 1.4114, val loss 1.6289


 73%|███████▎  | 21805/30000 [14:12<24:42,  5.53it/s]

Step 21800: train loss 1.4095, val loss 1.6285


 73%|███████▎  | 22006/30000 [14:20<17:46,  7.49it/s]

Step 22000: train loss 1.4149, val loss 1.6320


 74%|███████▍  | 22205/30000 [14:28<21:43,  5.98it/s]

Step 22200: train loss 1.4106, val loss 1.6267


 75%|███████▍  | 22405/30000 [14:35<22:49,  5.55it/s]

Step 22400: train loss 1.4139, val loss 1.6254


 75%|███████▌  | 22606/30000 [14:43<16:30,  7.47it/s]

Step 22600: train loss 1.4091, val loss 1.6242


 76%|███████▌  | 22805/30000 [14:51<25:46,  4.65it/s]

Step 22800: train loss 1.4152, val loss 1.6277


 77%|███████▋  | 23005/30000 [14:58<21:25,  5.44it/s]

Step 23000: train loss 1.4074, val loss 1.6249


 77%|███████▋  | 23208/30000 [15:06<15:49,  7.15it/s]

Step 23200: train loss 1.4061, val loss 1.6302


 78%|███████▊  | 23406/30000 [15:14<16:43,  6.57it/s]

Step 23400: train loss 1.3998, val loss 1.6277


 79%|███████▊  | 23606/30000 [15:21<14:10,  7.52it/s]

Step 23600: train loss 1.4059, val loss 1.6171


 79%|███████▉  | 23807/30000 [15:29<17:07,  6.03it/s]

Step 23800: train loss 1.3965, val loss 1.6229


 80%|████████  | 24005/30000 [15:37<14:49,  6.74it/s]

Step 24000: train loss 1.4066, val loss 1.6120


 81%|████████  | 24207/30000 [15:45<13:05,  7.37it/s]

Step 24200: train loss 1.3956, val loss 1.6209


 81%|████████▏ | 24406/30000 [15:53<15:51,  5.88it/s]

Step 24400: train loss 1.4010, val loss 1.6362


 82%|████████▏ | 24609/30000 [16:00<11:50,  7.58it/s]

Step 24600: train loss 1.4107, val loss 1.6159


 83%|████████▎ | 24808/30000 [16:08<11:40,  7.41it/s]

Step 24800: train loss 1.3995, val loss 1.6201


 83%|████████▎ | 25007/30000 [16:16<13:24,  6.20it/s]

Step 25000: train loss 1.4085, val loss 1.6126


 84%|████████▍ | 25207/30000 [16:24<10:46,  7.41it/s]

Step 25200: train loss 1.4025, val loss 1.6153


 85%|████████▍ | 25405/30000 [16:32<14:09,  5.41it/s]

Step 25400: train loss 1.3993, val loss 1.6148


 85%|████████▌ | 25605/30000 [16:40<15:57,  4.59it/s]

Step 25600: train loss 1.4015, val loss 1.6123


 86%|████████▌ | 25808/30000 [16:47<09:24,  7.43it/s]

Step 25800: train loss 1.4076, val loss 1.6071


 87%|████████▋ | 26008/30000 [16:55<10:00,  6.64it/s]

Step 26000: train loss 1.3935, val loss 1.6149


 87%|████████▋ | 26208/30000 [17:03<10:01,  6.31it/s]

Step 26200: train loss 1.3984, val loss 1.6121


 88%|████████▊ | 26407/30000 [17:10<07:58,  7.51it/s]

Step 26400: train loss 1.3973, val loss 1.6289


 89%|████████▊ | 26604/30000 [17:19<11:12,  5.05it/s]

Step 26600: train loss 1.3988, val loss 1.6168


 89%|████████▉ | 26803/30000 [17:26<11:10,  4.77it/s]

Step 26800: train loss 1.3973, val loss 1.6113


 90%|█████████ | 27008/30000 [17:34<06:41,  7.46it/s]

Step 27000: train loss 1.4122, val loss 1.6194


 91%|█████████ | 27207/30000 [17:42<07:38,  6.09it/s]

Step 27200: train loss 1.4006, val loss 1.6080


 91%|█████████▏| 27405/30000 [17:50<08:29,  5.09it/s]

Step 27400: train loss 1.3849, val loss 1.6033


 92%|█████████▏| 27604/30000 [17:57<07:20,  5.43it/s]

Step 27600: train loss 1.3953, val loss 1.6122


 93%|█████████▎| 27808/30000 [18:06<06:32,  5.58it/s]

Step 27800: train loss 1.3935, val loss 1.6001


 93%|█████████▎| 28003/30000 [18:13<06:08,  5.43it/s]

Step 28000: train loss 1.3887, val loss 1.6126


 94%|█████████▍| 28207/30000 [18:21<04:03,  7.36it/s]

Step 28200: train loss 1.3991, val loss 1.6207


 95%|█████████▍| 28406/30000 [18:29<04:33,  5.83it/s]

Step 28400: train loss 1.3889, val loss 1.6229


 95%|█████████▌| 28606/30000 [18:37<03:07,  7.42it/s]

Step 28600: train loss 1.3947, val loss 1.6159


 96%|█████████▌| 28808/30000 [18:45<02:41,  7.39it/s]

Step 28800: train loss 1.3940, val loss 1.6213


 97%|█████████▋| 29006/30000 [18:53<02:55,  5.67it/s]

Step 29000: train loss 1.3897, val loss 1.6102


 97%|█████████▋| 29206/30000 [19:00<01:47,  7.40it/s]

Step 29200: train loss 1.3834, val loss 1.6110


 98%|█████████▊| 29405/30000 [19:08<01:24,  7.03it/s]

Step 29400: train loss 1.3933, val loss 1.6164


 99%|█████████▊| 29605/30000 [19:16<01:26,  4.57it/s]

Step 29600: train loss 1.3910, val loss 1.6072


 99%|█████████▉| 29806/30000 [19:24<00:26,  7.25it/s]

Step 29800: train loss 1.3962, val loss 1.6064


100%|██████████| 30000/30000 [19:32<00:00, 25.59it/s]

Step 29999: train loss 1.3967, val loss 1.6167





In [34]:
# DO NOT MODIFY ANY OF THIS CODE

# Generate from the model
with torch.no_grad():
    context = torch.tensor(encode_text("JULIET: "), dtype=torch.long).unsqueeze(0).to(device)
    generated_text = decode_text(model.generate_text(context, max_tokens=200)[0].tolist())
    print(generated_text)

JULIET: my lord, and my angry
To comfort, Julioy pardon.

ISABELLA:
O'erful a death, my lord; here will
Mains; every believe, his lost appely contents
Of her monignat fast the dead me to tell him,
Forthnow th
