In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Set device based on availability
device = 'mps' if torch.cuda.is_available() else 'cpu'
print(device)

# Hyperparameters
block_size = 8
batch_size = 4
max_steps = 1000
learning_rate = 3e-4
eval_interval = 250
num_eval_steps = 250

cpu


In [6]:
with open('wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text_data = f.read()

# Process characters
unique_chars = sorted(set(text_data))
vocab_size = len(unique_chars)

    

In [8]:
char_to_idx = {ch: i for i, ch in enumerate(unique_chars)}
idx_to_char = {i: ch for i, ch in enumerate(unique_chars)}

# Encode and decode functions
encode = lambda s: [char_to_idx[c] for c in s]
decode = lambda l: ''.join([idx_to_char[i] for i in l])

# Convert text to tensor data
data_tensor = torch.tensor(encode(text_data), dtype=torch.long)

    

In [10]:
train_split = int(0.8 * len(data_tensor))
train_data = data_tensor[:train_split]
val_data = data_tensor[train_split:]

def get_batch(data_type):
    data = train_data if data_type == 'train' else val_data
    indices = torch.randint(len(data) - block_size, (batch_size,))
    x_batch = torch.stack([data[i:i + block_size] for i in indices])
    y_batch = torch.stack([data[i + 1:i + block_size + 1] for i in indices])
    return x_batch.to(device), y_batch.to(device)

x_batch, y_batch = get_batch('train')
print('Input batch:')
print(x_batch)
print('Target batch:')
print(y_batch)



Input batch:
tensor([[74,  1, 54, 60, 54, 62, 67,  9],
        [62, 73,  0, 62, 72,  1, 73, 61],
        [73, 62, 79, 58, 67, 72, 11,  3],
        [68, 59,  1, 54,  1, 60, 71, 58]])
Target batch:
tensor([[ 1, 54, 60, 54, 62, 67,  9,  1],
        [73,  0, 62, 72,  1, 73, 61, 58],
        [62, 79, 58, 67, 72, 11,  3,  0],
        [59,  1, 54,  1, 60, 71, 58, 54]])


In [12]:
@torch.no_grad()
def compute_loss():
    losses = {}
    model.eval()
    for phase in ['train', 'val']:
        batch_losses = torch.zeros(num_eval_steps)
        for i in range(num_eval_steps):
            x_batch, y_batch = get_batch(phase)
            _, loss = model(x_batch, y_batch)
            batch_losses[i] = loss.item()
        losses[phase] = batch_losses.mean()
    model.train()
    return losses

In [14]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, input_idx, target_idx=None):
        logits = self.embedding_table(input_idx)
        loss = None
        if target_idx is not None:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            target_idx = target_idx.view(B * T)
            loss = F.cross_entropy(logits, target_idx)
        return logits, loss
    
    def generate(self, input_idx, num_tokens):
        for _ in range(num_tokens):
            logits, _ = self(input_idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            input_idx = torch.cat((input_idx, next_idx), dim=1)
        return input_idx
    

In [16]:
model = BigramLanguageModel(vocab_size)
model = model.to(device)

# Generate initial context
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_text = decode(model.generate(context, num_tokens=500)[0].tolist())
print(generated_text)

# Setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training loop
for step in range(max_steps):
    if step % eval_interval == 0:
        current_losses = compute_loss()
        print(f"Step: {step}, Train Loss: {current_losses['train']:.3f}, Val Loss: {current_losses['val']:.3f}")

    x_batch, y_batch = get_batch('train')
    logits, loss = model(x_batch, y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())




5E&mu*lPAT-coZO79,2bR[ ssG6L2DeLV,Il?8U?u2dT!Juk4nZJ4.G&W-UE4!TS1XxoT),OZWp5.[*AODLAQvF&-I:3iS6:8sXI*FQgrxZ
iVbGK-&yeh.xbdd2Xbwa1":Qy!:f [Im)h6rkTSz('qSsq kdij]hJ730 *n_w(sUnJG?Z?us;xoCKm!IWFer1Ul sZKW&FACyHt:dIjt8t-&7IzI]:)p.WOwrQ]YIT?I8yZTSsYT!S,v-U_LDSkGoHeh7'
(LRUhyvL2DWP*LJurBB,3[kj8sUowZ*'i"YvXC4yKRC4yW.OmrQ_Ku-N :6g2ANgoOyJMQUt:OEJ2_aYi:)?bOb.Ti]0j-*SBd4gof"v,6n;;:OTMqWre557
o*ic_8dP_f*;Ij9oRfn)y0yguev8U'N
oQN):QUANUDS.TSsx,jd2d'C5?l_yyN8N'N
X'gYP;CibW"9rzDE[kQ[dIdnQr1w?i&!ux;c36CowujOQ[Z
Step: 0, Train Loss: 4.963, Val Loss: 4.980
Step: 250, Train Loss: 4.879, Val Loss: 4.919
Step: 500, Train Loss: 4.799, Val Loss: 4.848
Step: 750, Train Loss: 4.764, Val Loss: 4.774
4.746335983276367


In [17]:

# Generate final text
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_text = decode(model.generate(context, num_tokens=500)[0].tolist())
print(generated_text)


(:! GJ219TDwW(YU.C0X'Ijs&J2d"(q&anqteItrimO7:ZqWL,6,-U vd "x2,j1,(cv;*BY()]C4U_J0E&SAF&8jeR7wlj_D9uID k0EpeEqSFtrSIjWj1uzvO;;gyqqhh*F
LVt:L2ks:twmC8D)Rwpeo.gPWcOr1UA oCDwa98zAnRyN)zayxx-EJASQV-DivWJ0Nh!XYIj'[h*bB'
(m)uspmbUIj-WHqu :.gNj.f7EJWK.fnMX,bdo1a
V[v O B9xY G-XKf :EkPcrqcB]ZK_wIS S"50O:!Sz2TtkLs:OQTP?I*Kl_
wZUSIW:) Y65o!CBlo6"Qw?uw
(kbXOvLpAl"?K6PnUDQXat,mT&Ma*"99uAwrrt
:Lc*6T!*xn5CjkPguZy
nwr&"y[!cwI;X']*]OVY;4P;cl;f
ogD&W5xYoO(LGUIEn98lxTSw7W.8dgomda)JQcC:Pd"LdNLV6;S
;CTfJyHofr1"GJsH&

