In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as f
from transformers import AutoTokenizer
import os
device='cuda' if torch.cuda.is_available() else 'cpu'
checkpoint_dir = "checkpoint_epoch_3"  # change to your folder name
checkpoint = torch.load(os.path.join(checkpoint_dir, "model.pt"), map_location=device)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
BLOCK_SIZE=256
VOCAB_SIZE=len(tokenizer)
BLOCK_SIZE=256
D_MODEL=512
N_HEADS=16
N_LAYERS=16
DFF=4*512
class CausalLM(nn.Module):
    def __init__(self,n_head,d_model,dropout=0.1):
        super().__init__()
        self.d_model=d_model
        self.n_head=n_head
        self.dk=d_model//n_head
        self.qkv=nn.Linear(d_model,3*d_model,bias=False)
        self.out_proj=nn.Linear(d_model,d_model,bias=False)
        self.proj_dropout=nn.Dropout(dropout)
        self.attn_dropout=nn.Dropout(dropout)
    def build_causal_mask(self,T,device):
        mask = torch.tril(torch.ones((T, T), dtype=torch.bool, device=device))
        # we'll use it to set -inf on disallowed positions
        return mask.unsqueeze(0).unsqueeze(0)
    def forward(self,x):
        B,T,D=x.shape
        qkv=self.qkv(x)
        q,k,v=qkv.chunk(3,-1)
        q=q.view(B,T,self.n_head,self.dk).transpose(1,2)
        k=k.view(B,T,self.n_head,self.dk).transpose(1,2)
        v=v.view(B,T,self.n_head,self.dk).transpose(1,2)
        scores=torch.matmul(q,k.transpose(-2,-1))
        scores=scores/(self.dk**0.5)
        self.causal_mask=self.build_causal_mask(T,x.device)
        causal=self.causal_mask[:,:,:T,:T]
        scores = scores.masked_fill(~causal, float("-inf"))
        attn_weights=f.softmax(scores,dim=-1)
        attn_weights=self.attn_dropout(attn_weights)
        context=torch.matmul(attn_weights,v)
        context=context.transpose(1,2).contiguous().view(B,T,D)
        out=self.out_proj(context)
        out=self.proj_dropout(out)
        return out
        
class FeedForward(nn.Module):
    def __init__(self,ff_dim,d_model,dropout=0.1):
        super().__init__()
        self.net=nn.Sequential(nn.Linear(d_model,ff_dim),
                               nn.GELU(),
                               nn.Linear(ff_dim,d_model),
                               nn.Dropout(dropout)
                              )
    def forward(self,x):
        return self.net(x)    
class TransformerBlock(nn.Module):
    def __init__(self,d_model,n_heads,dff,dropout=0.1):
        super().__init__()
        self.ln1=nn.LayerNorm(d_model)
        self.attn=CausalLM(n_heads,d_model)
        self.ln2=nn.LayerNorm(d_model)
        self.ff=FeedForward(dff,d_model,dropout)
    def forward(self,x):
        x=x+self.attn(self.ln1(x))
        x=x+self.ff(self.ln2(x))
        return x
class GPT(nn.Module):
    def __init__(self,vocab_size,block_size,d_model,n_head,n_layers):
        super().__init__()
        self.token_emb=nn.Embedding(vocab_size,d_model)
        self.pos_emb=nn.Embedding(block_size,d_model)
        self.blocks=nn.ModuleList([
            TransformerBlock(d_model,n_head,dff=4*d_model)
        ])
        self.lnf=nn.LayerNorm(d_model)
        self.head=nn.Linear(d_model,vocab_size,bias=False)
        self.block_size=block_size
        self.vocab_size=vocab_size
    def forward(self,idx,targets=None):
        B,T=idx.shape
        token_emb=self.token_emb(idx)
        pos=torch.arange(T,device=idx.device)
        pos_emb=self.pos_emb(pos)
        x=token_emb+pos_emb
        for block in self.blocks:
            x = block(x)                                           # apply transformer block
        x = self.lnf(x)                                           # final norm
        logits = self.head(x) 
        if targets!=None:
            loss = f.cross_entropy(logits.view(-1, self.vocab_size), targets.view(-1))
            return logits, loss
        else:
            return logits,0
model = GPT(VOCAB_SIZE, BLOCK_SIZE, D_MODEL, N_HEADS, N_LAYERS)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
def generate_text(prompt, max_new_tokens=50):
    model.eval()
    with torch.no_grad():
        # Encode prompt → token IDs
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

        for _ in range(max_new_tokens):
            # Feed only recent tokens (truncate to model's max length)
            logits = model(input_ids).logits
            logits = logits[:, -1, :]  # last token’s logits
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat([input_ids, next_id], dim=1)

        # Decode back to text
        return tokenizer.decode(input_ids[0], skip_special_tokens=True)
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI(title="MiniGPT API", version="1.0")
class GenerateRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 50
@app.post("/generate")
async def generate(req: GenerateRequest):
    output = generate_text(req.prompt, req.max_new_tokens)
    return {"prompt": req.prompt, "generated": output}

GPT(
  (token_emb): Embedding(50257, 512)
  (pos_emb): Embedding(256, 512)
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (attn): CausalLM(
        (qkv): Linear(in_features=512, out_features=1536, bias=False)
        (out_proj): Linear(in_features=512, out_features=512, bias=False)
        (proj_dropout): Dropout(p=0.1, inplace=False)
        (attn_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ff): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=512, out_features=2048, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=2048, out_features=512, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (lnf): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=512, out_features=50257, bias=False)
)

In [2]:
!python -m pip install --upgrade pip


Defaulting to user installation because normal site-packages is not writeable
Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
   ---------------------------------------- 0.0/1.8 MB ? eta -:--:--
   ----------------------------------- ---- 1.6/1.8 MB 12.7 MB/s eta 0:00:01
   ---------------------------------------- 1.8/1.8 MB 11.5 MB/s eta 0:00:00
Installing collected packages: pip
Successfully installed pip-25.3




In [3]:
!python -m pip install fastapi uvicorn pydantic


from fastapi import FastAPI
from pydantic import BaseModel



Defaulting to user installation because normal site-packages is not writeable
Collecting fastapi
  Downloading fastapi-0.121.0-py3-none-any.whl.metadata (28 kB)
Collecting uvicorn
  Downloading uvicorn-0.38.0-py3-none-any.whl.metadata (6.8 kB)
Collecting starlette<0.50.0,>=0.40.0 (from fastapi)
  Downloading starlette-0.49.3-py3-none-any.whl.metadata (6.4 kB)
Collecting annotated-doc>=0.0.2 (from fastapi)
  Downloading annotated_doc-0.0.3-py3-none-any.whl.metadata (6.6 kB)
Downloading fastapi-0.121.0-py3-none-any.whl (109 kB)
Downloading starlette-0.49.3-py3-none-any.whl (74 kB)
Downloading uvicorn-0.38.0-py3-none-any.whl (68 kB)
Downloading annotated_doc-0.0.3-py3-none-any.whl (5.5 kB)
Installing collected packages: annotated-doc, uvicorn, starlette, fastapi

   ---------- ----------------------------- 1/4 [uvicorn]
   ---------- ----------------------------- 1/4 [uvicorn]
   ---------- ----------------------------- 1/4 [uvicorn]
   ---------- ----------------------------- 1/4 [uvicor