In [1]:
"""
train_human_gpt_small.py
------------------------
Train a small GPT-style model on your 12 MB human communication dataset.
Works on CPU or GPU.
"""

import os, torch, torch.nn as nn, torch.nn.functional as F
from tqdm import tqdm


In [2]:
pip show torch

Name: torch
Version: 2.5.1+cu121
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: c:\users\aman\desktop\human-communication-gpt-with-1-gb-dataset\env\lib\site-packages
Requires: fsspec, networkx, jinja2, filelock, typing-extensions, sympy
Required-by: torchvision
Note: you may need to restart the kernel to use updated packages.


In [16]:
# ----------------------- Config -----------------------
DATA_PATH = "data/human_chat.txt"
MODEL_PATH = "checkpoints/human_gpt_small.pt"

BLOCK_SIZE = 256    # shorter context (small data)
BATCH_SIZE = 32       # reduce for limited VRAM
N_EMBD = 128
N_HEAD = 4
N_LAYER = 2
DROPOUT = 0.1
N_EMBD = 96       # smaller embedding
LR = 2e-4   
STEPS = 20000          # you can raise to 8000 for better quality

In [17]:
# ----------------------- Load Dataset -----------------------
print("ðŸ“– Loading dataset...")
text = open(DATA_PATH, encoding="utf-8").read()
print(f"âœ… Loaded {len(text):,} characters from {DATA_PATH}")

import tiktoken
print("ðŸ§© Using GPT-2 BPE tokenizer...")
enc = tiktoken.get_encoding("gpt2")

def encode(s): return enc.encode(s)
def decode(l): return enc.decode(l)

text = open(DATA_PATH, encoding="utf-8").read()
data = torch.tensor(encode(text), dtype=torch.long)
vocab_size = enc.n_vocab
print(f"ðŸ§© Vocab size (BPE): {vocab_size}")

n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]
print(f"ðŸ“Š Tokens: train={len(train_data)}, val={len(val_data)}")

ðŸ“– Loading dataset...
âœ… Loaded 12,950,920 characters from data/human_chat.txt
ðŸ§© Using GPT-2 BPE tokenizer...


ðŸ§© Vocab size (BPE): 50257
ðŸ“Š Tokens: train=3238109, val=359790


In [18]:
# ----------------------- Model Definition -----------------------
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, n_embd=N_EMBD, n_head=N_HEAD,
                 n_layer=N_LAYER, block_size=BLOCK_SIZE, dropout=DROPOUT):
        super().__init__()

        # Embeddings
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=n_embd,
                nhead=n_head,
                dim_feedforward=4 * n_embd,
                dropout=dropout,
                batch_first=True
            )
            for _ in range(n_layer)
        ])

        # Final normalization + dropout + output head
        self.ln_f = nn.LayerNorm(n_embd)
        self.dropout = nn.Dropout(0.1)         # âœ… add this line
        self.head = nn.Linear(n_embd, vocab_size)
        self.block_size = block_size

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_emb(idx)
        pos_emb = self.pos_emb(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb

        # pass through each transformer block
        for block in self.blocks:
            x = block(x)

        # final layer norm + dropout
        x = self.ln_f(x)
        x = self.dropout(x)                    # âœ… add this line

        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss


In [19]:
# ----------------------- Training -----------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"ðŸš€ Using device: {device}")
model = MiniGPT(vocab_size).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

def get_batch(split):
    data_split = train_data if split == "train" else val_data
    ix = torch.randint(len(data_split) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data_split[i:i+BLOCK_SIZE] for i in ix])
    y = torch.stack([data_split[i+1:i+BLOCK_SIZE+1] for i in ix])
    return x.to(device), y.to(device)

print("ðŸ§  Starting training...")
for step in tqdm(range(STEPS)):
    xb, yb = get_batch("train")
    _, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if step % 500 == 0 or step == STEPS - 1:
        print(f"Step {step:05d} | Loss: {loss.item():.4f}")


ðŸš€ Using device: cuda
ðŸ§  Starting training...


  0%|          | 3/20000 [00:00<1:06:02,  5.05it/s]

Step 00000 | Loss: 11.0007


  3%|â–Ž         | 503/20000 [00:52<31:47, 10.22it/s]

Step 00500 | Loss: 5.4459


  5%|â–Œ         | 1003/20000 [01:45<31:07, 10.17it/s]

Step 01000 | Loss: 4.6508


  8%|â–Š         | 1503/20000 [02:37<30:14, 10.20it/s]

Step 01500 | Loss: 4.4977


 10%|â–ˆ         | 2003/20000 [03:29<29:25, 10.20it/s]

Step 02000 | Loss: 4.4498


 13%|â–ˆâ–Ž        | 2503/20000 [04:22<28:39, 10.18it/s]

Step 02500 | Loss: 4.3582


 15%|â–ˆâ–Œ        | 3003/20000 [05:14<27:59, 10.12it/s]

Step 03000 | Loss: 4.1279


 18%|â–ˆâ–Š        | 3503/20000 [06:07<27:14, 10.09it/s]

Step 03500 | Loss: 4.3573


 20%|â–ˆâ–ˆ        | 4003/20000 [07:00<27:13,  9.79it/s]

Step 04000 | Loss: 3.9943


 23%|â–ˆâ–ˆâ–Ž       | 4503/20000 [07:53<25:48, 10.01it/s]

Step 04500 | Loss: 3.9672


 25%|â–ˆâ–ˆâ–Œ       | 5003/20000 [08:45<24:47, 10.08it/s]

Step 05000 | Loss: 3.6397


 28%|â–ˆâ–ˆâ–Š       | 5503/20000 [09:38<24:14,  9.97it/s]

Step 05500 | Loss: 3.1963


 30%|â–ˆâ–ˆâ–ˆ       | 6003/20000 [10:31<23:08, 10.08it/s]

Step 06000 | Loss: 2.1225


 33%|â–ˆâ–ˆâ–ˆâ–Ž      | 6503/20000 [11:24<22:07, 10.17it/s]

Step 06500 | Loss: 1.7585


 35%|â–ˆâ–ˆâ–ˆâ–Œ      | 7003/20000 [12:16<21:20, 10.15it/s]

Step 07000 | Loss: 1.1834


 38%|â–ˆâ–ˆâ–ˆâ–Š      | 7503/20000 [13:09<20:42, 10.05it/s]

Step 07500 | Loss: 0.9839


 40%|â–ˆâ–ˆâ–ˆâ–ˆ      | 8003/20000 [14:02<19:29, 10.26it/s]

Step 08000 | Loss: 0.7547


 43%|â–ˆâ–ˆâ–ˆâ–ˆâ–Ž     | 8503/20000 [14:54<18:56, 10.12it/s]

Step 08500 | Loss: 0.5548


 45%|â–ˆâ–ˆâ–ˆâ–ˆâ–Œ     | 9003/20000 [15:47<18:04, 10.14it/s]

Step 09000 | Loss: 0.4668


 48%|â–ˆâ–ˆâ–ˆâ–ˆâ–Š     | 9503/20000 [16:40<17:13, 10.15it/s]

Step 09500 | Loss: 0.4241


 50%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆ     | 10003/20000 [17:33<16:28, 10.11it/s]

Step 10000 | Loss: 0.2782


 53%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Ž    | 10503/20000 [18:26<16:04,  9.85it/s]

Step 10500 | Loss: 0.2173


 55%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ    | 11003/20000 [19:19<14:49, 10.12it/s]

Step 11000 | Loss: 0.2138


 58%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Š    | 11503/20000 [20:12<14:08, 10.01it/s]

Step 11500 | Loss: 0.1477


 60%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ    | 12003/20000 [21:05<13:08, 10.15it/s]

Step 12000 | Loss: 0.1324


 63%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Ž   | 12503/20000 [21:58<12:20, 10.13it/s]

Step 12500 | Loss: 0.1274


 65%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ   | 13003/20000 [22:50<11:32, 10.11it/s]

Step 13000 | Loss: 0.0960


 68%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Š   | 13503/20000 [23:42<10:29, 10.32it/s]

Step 13500 | Loss: 0.0833


 70%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ   | 14003/20000 [24:34<09:41, 10.32it/s]

Step 14000 | Loss: 0.0798


 73%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Ž  | 14503/20000 [25:25<08:53, 10.30it/s]

Step 14500 | Loss: 0.0628


 75%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ  | 15003/20000 [26:17<08:05, 10.30it/s]

Step 15000 | Loss: 0.0628


 78%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Š  | 15503/20000 [27:09<07:17, 10.29it/s]

Step 15500 | Loss: 0.0556


 80%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ  | 16003/20000 [28:01<06:27, 10.33it/s]

Step 16000 | Loss: 0.0571


 83%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Ž | 16503/20000 [28:53<05:44, 10.15it/s]

Step 16500 | Loss: 0.0416


 85%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ | 17003/20000 [29:46<04:55, 10.14it/s]

Step 17000 | Loss: 0.0451


 88%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Š | 17503/20000 [30:38<04:06, 10.14it/s]

Step 17500 | Loss: 0.0407


 90%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ | 18003/20000 [31:30<03:16, 10.15it/s]

Step 18000 | Loss: 0.0311


 93%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Ž| 18503/20000 [32:23<02:27, 10.18it/s]

Step 18500 | Loss: 0.0282


 95%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Œ| 19003/20000 [33:15<01:38, 10.13it/s]

Step 19000 | Loss: 0.0315


 98%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Š| 19503/20000 [34:08<00:49, 10.14it/s]

Step 19500 | Loss: 0.0355


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 20000/20000 [35:00<00:00,  9.52it/s]

Step 19999 | Loss: 0.0380





In [22]:
# ==========================
# ðŸ’¬ TEXT GENERATION
# ==========================
def generate(model, start="User: Hello!\nBot:", max_new_tokens=150, temperature=0.9):
    model.eval()
    idx = torch.tensor([encode(start)], dtype=torch.long).to(device)
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.block_size:]
        logits, _ = model(idx_cond)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, next_id), dim=1)
    return decode(idx[0].tolist())

In [23]:

# ==========================
# ðŸ¤– TEST CHAT GENERATION
# ==========================
prompt = """
User1: You never reply to my messages.
User2: I was busy with work.
Please help me communicate better.
Bot:
"""

print("\n================== GENERATED RESPONSE ==================\n")
print(generate(model, start=prompt, temperature=0.7, max_new_tokens=150))
print("\n========================================================")




User1: You never reply to my messages.
User2: I was busy with work.
Please help me communicate better.
Bot:

 never crashed help me. I to catch my car. Somh po. Generally crashed me true
 SC. Ifh po If slamming sometime
I gave po If me  po If po If If po po. If Always help If my plea shown. peaceful!! better spill to my work busy wanted work my preferences me1est. po true! If If embarrassed po trust catch trust tastes never busy! Cheap po rocks If po say po true If If true embarrassed po po repeat If there properties freeze coaching. If proceed po owes me out. po If trust true po If embarrassed callates tastes true po If rent gre true.. If If helping flooded coaching. If true! If proceed rocked shower po po. Senior cardio prank hack true DOES slamming

