In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device} device")

#HYPER PARAMETERS
block_size = 32 #context length
batch_size = 128 #mini batch size
epochs = 4
learning_rate = 3e-4

Using cuda device


Functions

In [16]:
def read_file(filePATH):
    with open(filePATH, 'r', encoding="utf-8") as f:
        data = f.read()
    return data
def train_val_split(data, split):
    n = int(split*len(data))
    return data[:n], data[n:]
def get_batch(split, train_data, val_data):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    #print(ix)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    x, y = x.to(device), y.to(device)
    return x, y


#Debugging
def print_progress(epoch, epochs, i, num_batches, loss):
    progress = int((i + 1) / num_batches * 30)  # bar length = 30
    bar = "█" * progress + "-" * (30 - progress)
    print(
        f"Epoch {epoch+1}/{epochs} | [{bar}] {i+1}/{num_batches} "
        f"Loss: {loss:.4f}",
        end="\r",
        flush=True
    )


Implementations

In [17]:
text = read_file("/kaggle/input/wiz-of-oz/wiz_of_oz.txt")
#print(f"Length of dataset in characters: {len(text)}")

chars = sorted(set(text))
vocab_size = len(chars)


In [18]:
string_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([80,  1,  1, 51, 33, 65, 65, 74, 72, 73, 71, 54, 73, 62, 68, 67, 22,  1,
        28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,  1, 47, 33,
        50, 25, 42, 28, 52,  0,  0,  1,  1, 51, 33, 65, 65, 74, 72, 73, 71, 54,
        73, 62, 68, 67, 22,  1, 40, 33, 27, 35, 33, 38, 31,  1, 44, 32, 29,  1,
        40, 42, 33, 38, 27, 29, 43, 43, 11, 52,  0,  0,  0,  0,  0,  1,  1, 28,
        39, 42, 39, 44, 32, 49,  1, 25, 38, 28])


In [19]:
train_data, val_data = train_val_split(data, 0.8)

x, y = get_batch("train", train_data, val_data)
print(f"input {x}")
print(f"target{y}")

input tensor([[64, 58, 69,  ..., 58,  1, 76],
        [58,  1, 73,  ..., 58, 62, 71],
        [ 1, 72, 68,  ..., 73, 61, 58],
        ...,
        [58, 72, 58,  ..., 58, 73, 73],
        [55, 58,  1,  ...,  1, 67, 58],
        [25, 73,  1,  ...,  1, 54,  1]], device='cuda:0')
targettensor([[58, 69, 73,  ...,  1, 76, 62],
        [ 1, 73, 68,  ..., 62, 71,  1],
        [72, 68,  1,  ..., 61, 58,  1],
        ...,
        [72, 58, 66,  ..., 73, 73, 78],
        [58,  1, 54,  ..., 67, 58, 54],
        [73,  1, 73,  ..., 54,  1, 72]], device='cuda:0')


In [35]:
class BiGramLanguageModel(nn.Module):
    def __init__(self, vocab_size, hidden_size=256):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, hidden_size)
        self.fc1 = nn.Linear(hidden_size, vocab_size)
        self.fc2 = nn.Linear(hidden_size, vocab_size)
    
    def forward_pass(self, index, targets=None):
        x = self.token_embedding_table(index)  # B,T,H
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)                     # B,T,C
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    def generate(self, index, max_new_tokens):
        for _ in range (max_new_tokens):
            logits, loss = self.forward_pass(index) #get predictions
            logits = logits[:, -1, :] #Becomes B, C
            probs = F.softmax(logits, dim=-1) #get probabilities
            index_next = torch.multinomial(probs, num_samples=1) #(B,1)
            index = torch.cat((index, index_next), dim=1) #(B, T+1)
        return index


In [27]:
# Wrap model for multi-GPU
model = BiGramLanguageModel(vocab_size)
if torch.cuda.device_count() > 1:  # Kaggle T4 x2 case
    print("✅ Using", torch.cuda.device_count(), "GPUs")
    model = torch.nn.DataParallel(model)

# Send model to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
m = model.to(device)

# Prepare context
context = torch.zeros((1,1), dtype=torch.long, device=device)

# Use .module.generate if DataParallel is active
if isinstance(m, torch.nn.DataParallel):
    generatedChars = decode(m.module.generate(context, max_new_tokens=500)[0].tolist())
else:
    generatedChars = decode(m.generate(context, max_new_tokens=500)[0].tolist())

print(generatedChars)


✅ Using 2 GPUs

mYc3(A,Jz0ZYS21RThTCN*L?uD2Td'oGyuX(*H40G?qD?_lGvSS2Rd7r][7;HyWL.K8Mrr-W4T.ATSAXL-sMY.ADX.ES2cgNt!MuXCuS)P0s65:VKG8Rajq'HkCK8Vs?vonYLFem_2R﻿﻿W﻿i'iQ7[LMu-5mM1]nvtA-(n[5;Ty﻿H1FRLa2T.?.fakaJN:jh4kCz't(Nno]cVn3_p
Y.(ge_t_pn-s-hZ3TS;zM :8(﻿
YV
cDuqp﻿Ri2?,Z9
[Exss-p?;IpqYMT﻿Fw8P6?:vou[EeA;k8V
Lk&& '5﻿TgciQT[hv8?vItCy.e; sP(";Pd4zhRB7EubKKLKv? gVALBDU!mVD4ZGRHbiWZQps.o&lUirSg8zio
GK&pNNyF6_'xvdY﻿5_4Pq2-X7,;yGTCagcO﻿zh?pXY1
(2ReaGC[JGjx_ld8hDuy!q.?JG[ewW9r03I*6Agn::R!5H4.7NHWuXvCg3Xogal1kX)"8'pKpHPX2b[T


In [23]:
def evaluate(val_data, model, batch_size, train_data):
    model.eval()
    losses = []
    with torch.no_grad():
        num_batches = len(val_data) // batch_size
        for _ in range(num_batches):
            xb, yb = get_batch("val", train_data, val_data)

            # handle DataParallel
            if isinstance(model, torch.nn.DataParallel):
                _, loss = model.module.forward_pass(xb, yb)
            else:
                _, loss = model.forward_pass(xb, yb)

            losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)



def train_BLM(epochs, model, train_data, val_data, batch_size, learning_rate, clip_grad=False, max_norm=1.0):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)

    for epoch in range(epochs):
        num_batches = len(train_data) // batch_size
        epoch_loss = 0.0

        for i in range(num_batches):
            xb, yb = get_batch("train", train_data, val_data)

            # Handle DataParallel
            if isinstance(model, torch.nn.DataParallel):
                logits, loss = model.module.forward_pass(xb, yb)
            else:
                logits, loss = model.forward_pass(xb, yb)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()

            if clip_grad:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)

            optimizer.step()
            epoch_loss += loss.item()
            print_progress(epoch, epochs, i, num_batches, loss.item())

        avg_loss = epoch_loss / num_batches
        val_loss = evaluate(val_data, model, batch_size, train_data)
        scheduler.step()

        print(f"\nEpoch {epoch+1}/{epochs} finished. "
              f"Avg Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f}")

        checkpoint = {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "train_loss": avg_loss,
            "val_loss": val_loss,
        }
        torch.save(checkpoint, f"checkpoint_epoch{epoch+1}.pt")


def load_checkpoint(model, optimizer, scheduler, checkpoint_path, device="cpu"):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    start_epoch = checkpoint["epoch"]  # resume from the next epoch
    train_loss = checkpoint.get("train_loss", None)
    val_loss = checkpoint.get("val_loss", None)

    print(f"Loaded checkpoint from epoch {start_epoch}")
    return model, optimizer, scheduler, start_epoch, train_loss, val_loss



In [24]:
train_BLM(100, model, train_data, val_data, batch_size, learning_rate)

Epoch 1/100 | [██████████████████████████████] 1452/1452 Loss: 4.4603
Epoch 1/100 finished. Avg Train Loss: 4.7735 | Val Loss: 4.4639
Epoch 2/100 | [██████████████████████████████] 1452/1452 Loss: 3.9391
Epoch 2/100 finished. Avg Train Loss: 4.1960 | Val Loss: 3.9571
Epoch 3/100 | [████████████████████████------] 1163/1452 Loss: 3.6197

KeyboardInterrupt: 

In [28]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
checkpoint_path = "/kaggle/working/checkpoint_epoch88.pt"  # replace with your file
model, optimizer, scheduler, start_epoch, train_loss, val_loss = load_checkpoint(
    model, optimizer, scheduler, checkpoint_path, device=device
)

Loaded checkpoint from epoch 88


In [None]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
generatedChars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generatedChars)


isha utt tecerdou etheanspereallondre s yon,"CERaty-y ssed as anoussur ino thesmagabe tolagr Wine try rar wied "I atest's as hale."APrughy ntered ry g Maggrus  ampenghe 2Gat yem
HE ghe
e an alatherdaplapl Ningg  Ozarsot,"Ozassilid copasetwicherkenghe




Theace kss
hong re thelss ut m, maroon fue pithe matag o whan y. olott andand, ieke wighepe tliscka ma  medrt the be  sth
be. onde foy, win waved d tinthe aroud orecor blly le,"I want Jils, witheslve man w; geme Win chock With crer rrears,"I in 


In [30]:
# Prepare context
context = torch.zeros((1,1), dtype=torch.long, device=device)

# Use .module.generate if DataParallel is active
if isinstance(m, torch.nn.DataParallel):
    generatedChars = decode(m.module.generate(context, max_new_tokens=500)[0].tolist())
else:
    generatedChars = decode(m.generate(context, max_new_tokens=500)[0].tolist())

print(generatedChars)


Ozaritheregh ingheerdskea tus then ra t tticeren s p wen a ernoouthomy. lemextoulle rkeres, as id ha

e I pashery sck ans fof sthe " y antrl tt eeg bud shed o
"Ifor nd Ray ofond THEug, d Winturis thor angrdid.
Dowidin refourot, fo theppee e M he beaged heden ovan horin.
"BOF tes, t her

m bed w,
ithooond iled ch,


hevoongr stouly t run h harinch sist dnot

wak wad


"An bier t Do em our eas asheeall ano Prfleckastha g fillimie, "thand Jubofr e Budleds he leimoref Ozablill.
"IBe Gafiche ed d mir


In [31]:
train_BLM(100, model, train_data, val_data, batch_size, learning_rate)

Epoch 1/100 | [██████████████████████████████] 1452/1452 Loss: 2.4029
Epoch 1/100 finished. Avg Train Loss: 2.4213 | Val Loss: 2.4697
Epoch 2/100 | [██████████████████████████████] 1452/1452 Loss: 2.3905
Epoch 2/100 finished. Avg Train Loss: 2.4218 | Val Loss: 2.4741
Epoch 3/100 | [██████████████████████████████] 1452/1452 Loss: 2.3830
Epoch 3/100 finished. Avg Train Loss: 2.4213 | Val Loss: 2.4713
Epoch 4/100 | [██████████████████████████████] 1452/1452 Loss: 2.4200
Epoch 4/100 finished. Avg Train Loss: 2.4206 | Val Loss: 2.4712
Epoch 5/100 | [██████████████████████████████] 1452/1452 Loss: 2.4333
Epoch 5/100 finished. Avg Train Loss: 2.4208 | Val Loss: 2.4751
Epoch 6/100 | [██████████████████████████████] 1452/1452 Loss: 2.4040
Epoch 6/100 finished. Avg Train Loss: 2.4205 | Val Loss: 2.4735
Epoch 7/100 | [██████████████████████████████] 1452/1452 Loss: 2.4227
Epoch 7/100 finished. Avg Train Loss: 2.4210 | Val Loss: 2.4745
Epoch 8/100 | [██████████████████████████████] 1452/1452 Loss:

In [34]:
# Prepare context
context = torch.zeros((1,1), dtype=torch.long, device=device)

# Use .module.generate if DataParallel is active
if isinstance(m, torch.nn.DataParallel):
    generatedChars = decode(m.module.generate(context, max_new_tokens=500)[0].tolist())
else:
    generatedChars = decode(m.generate(context, max_new_tokens=500)[0].tolist())

print(generatedChars)


CASheekeatos.

ar I'verepralorear wemucorecr.

y. ay w fanow bed w heth hawo ome be fored "braft. heat s-chabe carches, whim tinersooawswhof skety lea p wabliz it be be esine chess THed boflinawesmeasite yin athe
ally bucarem thed t rves awid fo warey

"
lot m.
uro myo beas."ale gsklay," ie thewin as ca she plin t to the bbluth, itint
grs t

s ugoufas rrslller t o aulerizad oupll waboulld arn iofofed fle, id s
thagairs aweermappllem, sur the s fotca o narlaimizagr hs se coocablan sived wifothed 
