In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"  # Synchronous CUDA errors
os.environ['TORCH_USE_CUDA_DSA'] = "1"   # Device-side assertions
device = 'cuda:0'  # Explicitly use first GPU
print(f"Using {device} device")
# HYPER PARAMETERS
block_size = 32
batch_size = 64
epochs = 10
learning_rate = 3e-4
hidden_size = 128
dropout = 0.2
n_layer = 4
n_head = 4

Using cuda:0 device


In [2]:
torch.cuda.empty_cache()
torch.cuda.synchronize()  # Ensure all CUDA operations are complete

Functions

In [3]:
def read_file(filePATH):
    with open(filePATH, 'r', encoding="utf-8") as f:
        data = f.read()
    return data
def train_val_split(data, split):
    n = int(split*len(data))
    return data[:n], data[n:]

def get_batch(split, train_data, val_data):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix]).long()
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]).long()
    
    x, y = x.to(device), y.to(device)
    return x, y


#Debugging
def print_progress(epoch, epochs, i, num_batches, loss):
    progress = int((i + 1) / num_batches * 30)  # bar length = 30
    bar = "█" * progress + "-" * (30 - progress)
    print(
        f"Epoch {epoch+1}/{epochs} | [{bar}] {i+1}/{num_batches} "
        f"Loss: {loss:.4f}",
        end="\r",
        flush=True
    )


Implementations

In [4]:
text = read_file("/kaggle/input/wiz-of-oz/wiz_of_oz.txt")
#print(f"Length of dataset in characters: {len(text)}")

chars = sorted(set(text))
vocab_size = len(chars)
print(vocab_size)

81


In [5]:
string_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([80,  1,  1, 51, 33, 65, 65, 74, 72, 73, 71, 54, 73, 62, 68, 67, 22,  1,
        28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,  1, 47, 33,
        50, 25, 42, 28, 52,  0,  0,  1,  1, 51, 33, 65, 65, 74, 72, 73, 71, 54,
        73, 62, 68, 67, 22,  1, 40, 33, 27, 35, 33, 38, 31,  1, 44, 32, 29,  1,
        40, 42, 33, 38, 27, 29, 43, 43, 11, 52,  0,  0,  0,  0,  0,  1,  1, 28,
        39, 42, 39, 44, 32, 49,  1, 25, 38, 28])


In [9]:
train_data, val_data = train_val_split(data, 0.8)
x, y = get_batch("train", train_data, val_data)


In [7]:
print("Max value in train_data:", train_data.max().item())
print("Vocab size:", vocab_size)


Max value in train_data: 80
Vocab size: 81


In [10]:
class Head(nn.Module):
    def __init__(self, hidden_size, head_size):
        super().__init__()
        self.key = nn.Linear(hidden_size, head_size, bias=False)
        self.query = nn.Linear(hidden_size, head_size, bias=False)
        self.value = nn.Linear(hidden_size, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size, dtype=torch.bool)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x) #(B,T,head_size)
        q = self.query(x) #(B,T,head_size)
        v = self.value(x) #(B,T,head_size)
                                                                    #we square root this to prevent large dot product values
        attn_weights = q @ k.transpose(-2, -1) * k.shape[-1]** -0.5 #(B,T,T)
        attn_weights = attn_weights.masked_fill(~self.tril[:T, :T], float('-inf'))

        attn_weights = F.softmax(attn_weights, dim=-1) #(B,T,T)
        attn_weights = self.dropout(attn_weights)

        attn_output = attn_weights @ v #(B,T,head_size)
        return attn_output
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, n_head):
        super().__init__()
        assert hidden_size % n_head == 0
        self.head_size = hidden_size // n_head
        self.n_head = n_head
        self.heads = nn.ModuleList([Head(hidden_size,self.head_size) for _ in range(n_head)])
        self.proj = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        # Process each head in parallel
        head_outputs = [head(x) for head in self.heads]
        attn_output = torch.cat(head_outputs, dim=-1)  # (B, T, hidden_size)
        attn_output = self.proj(attn_output)
        attn_output = self.dropout(attn_output)
        return attn_output
class FeedForward(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.GELU(),
            nn.Linear(4 * hidden_size, hidden_size),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.net(x)
class Block(nn.Module):
    def __init__(self, hidden_size, n_head):
        super().__init__()
        head_size = hidden_size // n_head
        self.attn = MultiHeadAttention(hidden_size, n_head)
        self.ffwd = FeedForward(hidden_size)
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
    def forward(self, x):
        x_ln1 = self.ln1(x)
        attn_output = self.attn(x_ln1)
        x = x + attn_output  # Residual connection
        x_ln2 = self.ln2(x)
        ffwd_output = self.ffwd(x_ln2)
        x = x + ffwd_output  # Residual connection
        return x


class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, hidden_size) #Token embeddings
        self.positional_embedding_table = nn.Embedding(block_size, hidden_size) #Positional embeddings
        self.blocks = nn.Sequential(*[Block(hidden_size, n_head=n_head) for _ in range(n_layer)]) #Stack of transformer blocks
        self.ln_f = nn.LayerNorm(hidden_size) #Final layer norm
        self.lm_head = nn.Linear(hidden_size, vocab_size) #Language model head

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, index, targets=None):
        index = torch.clamp(index, 0, vocab_size-1)

        B, T = index.shape
        tok_emb = self.token_embedding_table(index)  # B,T,C
        pos_emb = self.positional_embedding_table(torch.arange(T, device=index.device))  # T,C
        x = tok_emb + pos_emb  # B,T,C
        x = self.blocks(x)  # B,T,C
        x = self.ln_f(x)  # B,T,C
        logits = self.lm_head(x) # B,T,vocab_size
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    @torch.no_grad()
    def generate(self, index, max_new_tokens):
        for _ in range (max_new_tokens):
            index = index[:, -block_size:]
            logits, loss = self.forward(index) #get predictions
            logits = logits[:, -1, :] #Becomes B, C
            probs = F.softmax(logits, dim=-1) #get probabilities
            index_next = torch.multinomial(probs, num_samples=1) #(B,1)
            index_next = torch.clamp(index_next, 0, self.token_embedding_table.num_embeddings - 1) # Clamp to valid range to prevent CUDA assert
            index = torch.cat((index, index_next), dim=1) #(B, T+1)
        return index


In [11]:
!nvidia-smi

Sun Sep 21 07:57:46 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla P100-PCIE-16GB           Off |   00000000:00:04.0 Off |                    0 |
| N/A   40C    P0             30W /  250W |     517MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [12]:
model = GPTLanguageModel(vocab_size, hidden_size, dropout)

# 2️⃣ Test CPU forward first
dummy_index = torch.randint(0, vocab_size, (1, 10))
logits, loss = model(dummy_index)

# 3️⃣ Move model to GPU
model = model.to(device)

# 4️⃣ Wrap in DataParallel
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs")
    model = torch.nn.DataParallel(model)

# 5️⃣ Prepare GPU input
context = torch.zeros((1,1), dtype=torch.long, device=device)

# 6️⃣ Generate text safely
if isinstance(model, torch.nn.DataParallel):
    generated_indices = model.module.generate(context, max_new_tokens=500)
else:
    generated_indices = model.generate(context, max_new_tokens=500)
generatedChars = decode(model.generate(context, max_new_tokens=500)[0].tolist())

In [13]:
print(generatedChars)

5'UCYuBJ9'vQwO﻿Ez-TIN:*&!qNs;x3 H


In [None]:
model = GPTLanguageModel(vocab_size, hidden_size, dropout).to(device)
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs")
    model = torch.nn.DataParallel(model)


In [14]:
def evaluate(val_data, model, batch_size, train_data):
    model.eval()
    losses = []
    with torch.no_grad():
        num_batches = len(val_data) // batch_size
        for _ in range(num_batches):
            xb, yb = get_batch("val", train_data, val_data)

            # handle DataParallel
            if isinstance(model, torch.nn.DataParallel):
                _, loss = model.module.forward(xb, yb)
            else:
                _, loss = model.forward(xb, yb)

            losses.append(loss.item())
    model.train()
    return sum(losses) / len(losses)



def train_BLM(epochs, model, train_data, val_data, batch_size, learning_rate, clip_grad=False, max_norm=1.0):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
    best_val_loss = float('inf')
    
    for epoch in range(epochs):
        num_batches = len(train_data) // batch_size
        epoch_loss = 0.0
        for i in range(num_batches):
            xb, yb = get_batch("train", train_data, val_data)

            # Handle DataParallel
            if isinstance(model, torch.nn.DataParallel):
                logits, loss = model.module.forward(xb, yb)
            else:
                logits, loss = model.forward(xb, yb)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()

            if clip_grad:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)

            optimizer.step()
            epoch_loss += loss.item()
            print_progress(epoch, epochs, i, num_batches, loss.item())

        avg_loss = epoch_loss / num_batches
        val_loss = evaluate(val_data, model, batch_size, train_data)
        scheduler.step()

        print(f"\nEpoch {epoch+1}/{epochs} finished. "
              f"Avg Train Loss: {avg_loss:.4f} | Val Loss: {val_loss:.4f}")

        checkpoint = {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "train_loss": avg_loss,
            "val_loss": val_loss,
        }
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(checkpoint, "best_GPT_checkpoint.pt")
        torch.save(checkpoint, f"GPT_checkpoint_epoch{epoch+1}.pt")


def load_checkpoint(model, optimizer, scheduler, checkpoint_path, device="cpu"):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    start_epoch = checkpoint["epoch"]  # resume from the next epoch
    train_loss = checkpoint.get("train_loss", None)
    val_loss = checkpoint.get("val_loss", None)

    print(f"Loaded checkpoint from epoch {start_epoch}")
    return model, optimizer, scheduler, start_epoch, train_loss, val_loss



In [None]:
train_BLM(100, model, train_data, val_data, batch_size, learning_rate)

Epoch 1/100 | [████████----------------------] 810/2905 Loss: 1.9202

In [28]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
checkpoint_path = "/kaggle/working/checkpoint_epoch88.pt"  # replace with your file
model, optimizer, scheduler, start_epoch, train_loss, val_loss = load_checkpoint(
    model, optimizer, scheduler, checkpoint_path, device=device
)

Loaded checkpoint from epoch 88


In [22]:
model = GPTLanguageModel(vocab_size, hidden_size, dropout)

# 2️⃣ Test CPU forward first
dummy_index = torch.randint(0, vocab_size, (1, 10))
logits, loss = model(dummy_index)

# 3️⃣ Move model to GPU
model = model.to(device)

# 4️⃣ Wrap in DataParallel
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs")
    model = torch.nn.DataParallel(model)

# 5️⃣ Prepare GPU input
context = torch.zeros((1,1), dtype=torch.long, device=device)

# 6️⃣ Generate text safely
generatedChars = decode(model.generate(context, max_new_tokens=500)[0].tolist())
print(generatedChars)

JsKS3*d*xdnqI"g'VfM;J7Js'Uz;.y6,_


In [30]:
# Prepare context
context = torch.zeros((1,1), dtype=torch.long, device=device)

# Use .module.generate if DataParallel is active
if isinstance(m, torch.nn.DataParallel):
    generatedChars = decode(m.module.generate(context, max_new_tokens=500)[0].tolist())
else:
    generatedChars = decode(m.generate(context, max_new_tokens=500)[0].tolist())

print(generatedChars)


Ozaritheregh ingheerdskea tus then ra t tticeren s p wen a ernoouthomy. lemextoulle rkeres, as id ha

e I pashery sck ans fof sthe " y antrl tt eeg bud shed o
"Ifor nd Ray ofond THEug, d Winturis thor angrdid.
Dowidin refourot, fo theppee e M he beaged heden ovan horin.
"BOF tes, t her

m bed w,
ithooond iled ch,


hevoongr stouly t run h harinch sist dnot

wak wad


"An bier t Do em our eas asheeall ano Prfleckastha g fillimie, "thand Jubofr e Budleds he leimoref Ozablill.
"IBe Gafiche ed d mir


In [31]:
train_BLM(100, model, train_data, val_data, batch_size, learning_rate)

Epoch 1/100 | [██████████████████████████████] 1452/1452 Loss: 2.4029
Epoch 1/100 finished. Avg Train Loss: 2.4213 | Val Loss: 2.4697
Epoch 2/100 | [██████████████████████████████] 1452/1452 Loss: 2.3905
Epoch 2/100 finished. Avg Train Loss: 2.4218 | Val Loss: 2.4741
Epoch 3/100 | [██████████████████████████████] 1452/1452 Loss: 2.3830
Epoch 3/100 finished. Avg Train Loss: 2.4213 | Val Loss: 2.4713
Epoch 4/100 | [██████████████████████████████] 1452/1452 Loss: 2.4200
Epoch 4/100 finished. Avg Train Loss: 2.4206 | Val Loss: 2.4712
Epoch 5/100 | [██████████████████████████████] 1452/1452 Loss: 2.4333
Epoch 5/100 finished. Avg Train Loss: 2.4208 | Val Loss: 2.4751
Epoch 6/100 | [██████████████████████████████] 1452/1452 Loss: 2.4040
Epoch 6/100 finished. Avg Train Loss: 2.4205 | Val Loss: 2.4735
Epoch 7/100 | [██████████████████████████████] 1452/1452 Loss: 2.4227
Epoch 7/100 finished. Avg Train Loss: 2.4210 | Val Loss: 2.4745
Epoch 8/100 | [██████████████████████████████] 1452/1452 Loss:

In [34]:
# Prepare context
context = torch.zeros((1,1), dtype=torch.long, device=device)

# Use .module.generate if DataParallel is active
if isinstance(m, torch.nn.DataParallel):
    generatedChars = decode(m.module.generate(context, max_new_tokens=500)[0].tolist())
else:
    generatedChars = decode(m.generate(context, max_new_tokens=500)[0].tolist())

print(generatedChars)


CASheekeatos.

ar I'verepralorear wemucorecr.

y. ay w fanow bed w heth hawo ome be fored "braft. heat s-chabe carches, whim tinersooawswhof skety lea p wabliz it be be esine chess THed boflinawesmeasite yin athe
ally bucarem thed t rves awid fo warey

"
lot m.
uro myo beas."ale gsklay," ie thewin as ca she plin t to the bbluth, itint
grs t

s ugoufas rrslller t o aulerizad oupll waboulld arn iofofed fle, id s
thagairs aweermappllem, sur the s fotca o narlaimizagr hs se coocablan sived wifothed 
