In [None]:
import numpy as np
import re
import tiktoken
from torch.utils.data import Dataset,DataLoader
import pandas as pd
import torch
import torch.nn as nn
import pymupdf as pmf
from tqdm import tqdm
from functools import partial
import sys
from torch.utils.tensorboard import SummaryWriter
import gpt_download3
from gpt_model import GPT

In [None]:
dataset_path = "Medicina - Grays Anatomy 16th ed.pdf"

In [8]:
doc = pmf.open(dataset_path)
num_pages = doc.page_count
toc = doc.get_toc()

In [9]:
page_ranges = {}
for i in range(1,len(toc)):
  chapter = toc[i][1]
  start_page = toc[i][2]
  end_page = toc[i+1][2]-2 if i < len(toc)-1 else num_pages
  page_ranges[chapter] = (start_page,end_page)
  print(f"Chapter: {chapter}, Start Page: {start_page}, End Page: {end_page}")

Chapter: I. Embryology, Start Page: 6, End Page: 35
Chapter: II. Osteology, Start Page: 37, End Page: 172
Chapter: III. Syndesmology, Start Page: 174, End Page: 241
Chapter: IV. Myology, Start Page: 243, End Page: 331
Chapter: V. Angiology, Start Page: 333, End Page: 360
Chapter: VI. The Arteries, Start Page: 362, End Page: 423
Chapter: VII. The Veins, Start Page: 425, End Page: 449
Chapter: VIII. The Lymphatic System, Start Page: 451, End Page: 471
Chapter: IX. Neurology, Start Page: 473, End Page: 620
Chapter: X. The Organs of the Senses and the Common Integument, Start Page: 622, End Page: 671
Chapter: XI. Splanchnology, Start Page: 673, End Page: 814
Chapter: XII. Surface Anatomy and Surface Markings, Start Page: 816, End Page: 852


In [12]:
#extracting ninth chapter
text = ""
page_range = page_ranges['IX. Neurology']
for i in range(page_range[0],page_range[1]):
  page = doc.load_page(i)
  text += page.get_text('text')

In [13]:
def clean_extracted_text(text):
    text = re.sub(r'\n+', ' ', text)  # Replace multiple newlines with space

    # Fix words split across lines
    text = re.sub(r'(\w+)-\s+(\w+)', r'\1\2', text)

    # Remove extra spaces
    text = re.sub(r'\s{2,}', ' ', text)

    # Trim spaces at the start and end
    text = text.strip()

    return text

In [14]:
cleaned_text = clean_extracted_text(text)

In [15]:
with open('all_chapters.txt','w') as f:
  f.write(cleaned_text)

# **Data Preparation**

In [16]:
class GPTDataset(Dataset):
  def __init__(self,text,tokenizer,max_length,stride):
    token_ids = tokenizer.encode(text)
    self.input_ids = []
    self.output_ids = []

    for i in range(0,len(token_ids)-max_length,stride):
      input_chunk = token_ids[i:i+max_length]
      output_chunk = token_ids[i+1:i+max_length+1]
      self.input_ids.append(torch.tensor(input_chunk))
      self.output_ids.append(torch.tensor(output_chunk))
  def __len__(self):
    return len(self.input_ids)
  def __getitem__(self,idx):
    return self.input_ids[idx],self.output_ids[idx]

In [17]:
with open('all_chapters.txt','r') as f:
    text = f.read()

In [18]:
tokenizer = tiktoken.get_encoding('gpt2')

In [19]:
token_ids = tokenizer.encode(text)

In [20]:
max_token_id = max(token_ids)
total_tokens = len(token_ids)

In [21]:
max_token_id,total_tokens

(50183, 177978)

In [22]:
train_ratio = 0.85
train_size = int(train_ratio*len(text))
train_data = text[:train_size]
val_data = text[train_size:]

In [23]:
def create_dataloader(data,batch_size,max_length,stride,shuffle=True,drop_last=True):
  tokenizer = tiktoken.get_encoding('gpt2')
  dataset = GPTDataset(data,tokenizer,max_length,stride)
  dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=shuffle)
  return dataloader

In [24]:
train_dataloader = create_dataloader(train_data,batch_size=2,max_length=1024,stride=256,shuffle=True,drop_last=True)
val_dataloader = create_dataloader(val_data,batch_size=2,max_length=1024,stride=256,shuffle=False,drop_last=False)

In [25]:
for x,y in train_dataloader:
  print(x.shape,y.shape)
  break

torch.Size([2, 1024]) torch.Size([2, 1024])


In [26]:
for x,y in val_dataloader:
  print(x.shape,y.shape)
  break

torch.Size([2, 1024]) torch.Size([2, 1024])


# **Loading OpenAI GPT2 weights**

In [None]:
settings,params = gpt_download3.download_and_load_gpt2(model_size='124M')

In [28]:
settings

{'n_vocab': 50257, 'n_ctx': 1024, 'n_embd': 768, 'n_head': 12, 'n_layer': 12}

In [48]:
vocab_size = settings['n_vocab']
context_length = settings['n_ctx']
emb_dim = settings['n_embd']
num_heads = settings['n_head']
num_layers = settings['n_layer']

In [49]:
model = GPT(vocab_size=vocab_size,emb_dim=emb_dim,num_heads=num_heads,num_layers=num_layers,context_length=context_length,dropout_rate=0.1)

In [31]:
params['blocks'][0]['mlp'].keys()

dict_keys(['c_fc', 'c_proj'])

In [50]:
model.token_embedding.weight = nn.Parameter(torch.tensor(params['wte']))
model.pos_embedding.weight = nn.Parameter(torch.tensor(params['wpe']))

for block in range(len(params['blocks'])):

  model.trf_blocks[block].ln1.scale = nn.Parameter(torch.tensor(params['blocks'][block]['ln_1']['g']))
  model.trf_blocks[block].ln1.shift = nn.Parameter(torch.tensor(params['blocks'][block]['ln_1']['b']))

  q_w,k_w,v_w = np.split(params['blocks'][block]['attn']['c_attn']['w'],3,axis=-1)
  model.trf_blocks[block].attn.w_q.weight = nn.Parameter(torch.tensor(q_w.T))
  model.trf_blocks[block].attn.w_k.weight = nn.Parameter(torch.tensor(k_w.T))
  model.trf_blocks[block].attn.w_v.weight = nn.Parameter(torch.tensor(v_w.T))

  q_b,k_b,v_b = np.split(params['blocks'][block]['attn']['c_attn']['b'],3,axis=-1)
  model.trf_blocks[block].attn.w_q.bias = nn.Parameter(torch.tensor(q_b))
  model.trf_blocks[block].attn.w_k.bias = nn.Parameter(torch.tensor(k_b))
  model.trf_blocks[block].attn.w_v.bias = nn.Parameter(torch.tensor(v_b))

  model.trf_blocks[block].attn.out_proj.weight = nn.Parameter(torch.tensor(params['blocks'][block]['attn']['c_proj']['w'].T))
  model.trf_blocks[block].attn.out_proj.bias = nn.Parameter(torch.tensor(params['blocks'][block]['attn']['c_proj']['b']))

  model.trf_blocks[block].ln2.scale = nn.Parameter(torch.tensor(params['blocks'][block]['ln_2']['g']))
  model.trf_blocks[block].ln2.shift = nn.Parameter(torch.tensor(params['blocks'][block]['ln_2']['b']))

  model.trf_blocks[block].fcn.layers[0].weight = nn.Parameter(torch.tensor(params['blocks'][block]['mlp']['c_fc']['w'].T))
  model.trf_blocks[block].fcn.layers[0].bias = nn.Parameter(torch.tensor(params['blocks'][block]['mlp']['c_fc']['b']))

  model.trf_blocks[block].fcn.layers[2].weight = nn.Parameter(torch.tensor(params['blocks'][block]['mlp']['c_proj']['w'].T))
  model.trf_blocks[block].fcn.layers[2].bias = nn.Parameter(torch.tensor(params['blocks'][block]['mlp']['c_proj']['b']))

model.final_norm.scale = nn.Parameter(torch.tensor(params['g']))
model.final_norm.shift = nn.Parameter(torch.tensor(params['b']))
model.output.weight = nn.Parameter(torch.tensor(params['wte']))


# **Training**

In [36]:
num_epochs = 60
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.CrossEntropyLoss()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001,weight_decay=0.1)

In [37]:
writer = SummaryWriter()

In [None]:
train_losses = []
val_losses = []
for epoch in range(num_epochs):
  model.train()
  total_losses = 0
  for data in tqdm(train_dataloader):
    optimizer.zero_grad()
    x,y = data
    x = x.to(device)
    y = y.to(device)
    logits = model(x)
    loss = criterion(logits.view(-1,logits.shape[-1]),y.view(-1))
    loss.backward()
    optimizer.step()
    total_losses += loss.item()
    writer.add_scalar("Loss/train", total_losses/len(train_dataloader), epoch)
  train_losses.append(total_losses/len(train_dataloader))
  print(f"Epoch {epoch+1} : Training_losses : {total_losses/len(train_dataloader)}")

  with torch.no_grad():
      model.eval()
      total_val_losses = 0
      for val_data in val_dataloader:
        x,y = val_data
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        loss = criterion(logits.view(-1,logits.shape[-1]),y.view(-1))
        total_val_losses += loss.item()
        writer.add_scalar("Loss/val", total_val_losses/len(val_dataloader), epoch)
      print(f"Epoch {epoch+1} : Validation losses :{total_val_losses/len(val_dataloader)}")
  val_losses.append(total_val_losses/len(val_dataloader))
writer.flush()

In [39]:
torch.save(model.state_dict(),'LLM_model.pth')

In [None]:
model.state_dict(torch.load("LLM_model.pth",weights_only = True))

# **Prediction**

In [51]:
def predict(model,input_text,max_new_tokens,context_size):
    for _ in range(max_new_tokens):
        model_input = input_text[:,-context_size:]
        with torch.no_grad():
            logits = model(model_input)
        last_token = logits[:,-1,:]
        probs = torch.softmax(last_token,dim = -1)
        predicted_token = torch.argmax(probs,dim = -1,keepdim = True)
        # print(input_text.shape,predicted_token.shape)
        input_text = torch.cat([input_text,predicted_token],dim = 1)
    return input_text

In [55]:
input_text = 'the nerves of brain'
input_text_tokenized = tokenizer.encode(input_text)
input_text_encoded = torch.tensor(input_text_tokenized).unsqueeze(0)

In [58]:
model.eval()
context_size = 1024
prediction = predict(model,input_text_encoded,40,context_size)
prediction = prediction.squeeze(0)
decoded_text = tokenizer.decode(prediction.tolist())
print(decoded_text.replace("\n", " "))

the nerves of brain cells, and the brain's ability to process information.  The researchers found that the brain's ability to process information is also affected by the amount of information it receives.  "The brain
