## Load Tokenizer

In [None]:
from regex_model_1 import RegexTokenizer
tokenizer = RegexTokenizer()

In [None]:
tokenizer.load(model_file="/kaggle/input/tokenizer_medical/pytorch/default/1/tokenizer_model.model")

In [None]:
tokenizer.encode("I have fever what is a solution", allowed_special='all')

In [None]:
def get_vocab_size(tokenizer):
    return len(tokenizer.vocab)

In [None]:
get_vocab_size(tokenizer)

In [None]:
# encoded_text_sequence = []
# batch_size = 3_000_000
# with open("/kaggle/input/text-med/text_medical.txt", "r") as f:
    
#     while True:
#         chunk = f.read(batch_size)
#         if not chunk:
#             break 

#         batch_tokens = tokenizer.encode(chunk, allowed_special="all")
#         encoded_text_sequence.extend(batch_tokens)
#         print(f"Processed {len(encoded_text_sequence)} tokens so far")

# print(f"Total Tokens: {len(encoded_text_sequence)}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple
from tqdm import tqdm
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader

In [None]:
# np.save("/kaggle/working/encoded_text.npy", 
#        np.array(encoded_text_sequence, dtype=np.int64))

In [None]:
# data = np.load("/kaggle/input/encoded-npy/encoded_text.npy",
#               mmap_mode='r')

# print("Shape of data:", data.shape)

## Load the formatted dataset for training

In [None]:
conversation = []
with open("/kaggle/input/foramatted/formmated_dataset.jsonl", "r",
         encoding="utf-8") as f:
    for line in f:
        conv = json.loads(line)
        conversation.append(conv["text"])

In [None]:
print(conversation[4])

## Encoding the text using tokenizer

In [None]:
encoded_text = []

for sample in conversation:
    tokens = tokenizer.encode(sample, allowed_special="all")
    encoded_text.extend(tokens)

In [None]:
len(encoded_text)

In [None]:
data = torch.tensor(encoded_text, dtype=torch.long)
len(data)

In [None]:
block_size = 1024
num_blocks = len(data) // block_size
data = data[:num_blocks*block_size].view(-1, block_size)
print("Shape of data", data.shape)

## Creating a Dataset Class for DataLoader

In [None]:
class FineTunedDataset(Dataset):

    def __init__(self, data: torch.Tensor, 
                 padding_token: int, device:str
                ):
        
        self.data = data
        self.padding_token = padding_token
        self.device = device

    def __len__(self):
        return self.data.shape[0]


    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        
        sample = self.data[index]
        x = sample.to(self.device)
        y = sample[1:].to(self.device)
        padding_tensor = torch.tensor([self.padding_token], device=self.device)
        y = torch.cat((y, padding_tensor))

        return x, y

In [None]:
train_split = int(0.95*len(data))
train_data_split = data[:train_split]
val_data_split = data[train_split:]

## Creating dataloaders for training and validation 

In [None]:
batch_size = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
padding_token = tokenizer.special_tokens["<|PAD|>"]


train_dataset = FineTunedDataset(data=train_data_split,
                               padding_token=padding_token,
                                device=device
                                )

train_dataloader = DataLoader(dataset=train_dataset,
                             batch_size=batch_size,
                             shuffle=True
                             )

val_dataset = FineTunedDataset(data=val_data_split,
                               padding_token=padding_token,
                               device=device
                               )

val_dataloader = DataLoader(dataset=val_dataset,
                           batch_size=batch_size,
                           shuffle=False
                           )

## Loading Custom model

In [None]:
from GPTmodel import GPTLanguageModel

block_size= 1024
n_embedding = 384
n_head = 8
n_layer = 6
dropout = 0.2
vocab_size = get_vocab_size(tokenizer)
padding_token = 3077
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GPTLanguageModel(
    vocab_size=vocab_size,
    n_embedding=n_embedding,
    n_head=n_head, block_size=block_size,
    n_layer=n_layer, dropout=dropout, 
    padding_token=padding_token, device=device)


model = model.to(device)
model

In [None]:
print(sum(p.numel() for p in model.parameters())/1e6,"M parameters")

In [None]:
# def get_batch(split: str, split_index: int,
#              block_size: int, device:str,
#              data):

#     if split == "train":
#         start_index = 0
#         end_index = split_index
        
#     else:
#         start_index = split_index
#         end_index = len(data)

#     available_blocks = (end_index - start_index - 1) // block_size
#     block_indices = torch.randint(0, available_blocks, (batch_size,))

#     x_batch, y_batch = [], []
#     for i in block_indices:
#         block_start = start_index + (i * block_size)
#         x_batch.append(data[block_start:block_start+block_size])
#         y_batch.append(data[block_start+1:block_start+block_size+1])

#     x_batch = np.array(x_batch)
#     y_batch = np.array(y_batch)

#     x_batch = torch.tensor(x_batch, dtype=torch.long).to(device)
#     y_batch = torch.tensor(y_batch, dtype=torch.long).to(device)

#     return x_batch, y_batch

## Estimate loss function
* to estimate the loss of training and validation splits
  

In [None]:
@torch.inference_mode()
def estimate_loss(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader) -> Dict[str, float]:

  total_loss = {}
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  model.eval()
    
  for split, loader in [('train', train_loader), ('val', val_loader)]:
    losses = torch.zeros(len(loader))

    for i, (x,y) in enumerate(loader):
      with torch.inference_mode():
       _, loss = model(x, y)
      #  print(_.shape)
      losses[i] = loss.item()
    total_loss[split] = losses.mean().item()

  model.train()
  return total_loss

### Saving checkpoints

In [None]:
def save_checkpoint(model: GPTLanguageModel,
                   optimizer: torch.optim.Optimizer,
                   epoch: int, loss: float,
                   file_path: str,global_step: int,
                   scheduler) -> None:

    checkpoint = {
        "model_state_dict": model.state_dict(),
        "epoch": epoch,
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "loss": loss, 
        "global_step": global_step
    }

    torch.save(checkpoint, file_path)

In [None]:
checkpoint = torch.load("model_checkpoint",
                       map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])

## Trainining Loop

In [None]:

from transformers import get_cosine_schedule_with_warmup

max_iters = 10
eval_interval = len(train_dataloader) // 15
learning_rate = 1e-5
warmup_steps=100
clip_grad_norm = 1.0
global_step = 0
start_epoch = 0
gradient_accumulation_step = 8
weight_decay = 0.01
total_steps = len(train_dataloader) * max_iters // gradient_accumulation_step



optimizer = torch.optim.AdamW(params=model.parameters(), 
                             lr=learning_rate, 
                              weight_decay=0.01
                             )

if checkpoint:
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        
    scheduler = get_cosine_schedule_with_warmup(optimizer,
                                            num_warmup_steps=warmup_steps,
                                            num_training_steps=total_steps
                                            )
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

    global_step = checkpoint.get("global_step", 0)
    start_epoch = checkpoint.get("epoch", 0) + 1
    
train_loss = []
val_loss = []
total_lrs = []
trained_loss = []

for epoch in range(start_epoch, max_iters):
    model.train()
    train_loss_backprop = 0.0
    epoch_lrs = []
    
    for batch_idx, (x_batch, y_batch) in tqdm(
        iterable=enumerate(train_dataloader),desc='Training on batches',
        total=len(train_dataloader)):
        
        if global_step % eval_interval == 0 or batch_idx == len(train_dataloader) - 1:
            losses = estimate_loss(model=model,
                             train_loader=train_dataloader,
                             val_loader=val_dataloader)
            
            print(f"Iteration: {epoch}/step {batch_idx} |"
                  f"Train Loss: {losses['train']:.4f} |"
                  f"Validation Loss: {losses['val']:.4f}")
            
            train_loss.append(losses['train'])
            val_loss.append(losses['val'])

        logits, loss = model(x_batch, y_batch)
        loss.backward()

        if batch_idx % gradient_accumulation_step == 0:
            
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), 
                                           max_norm=clip_grad_norm)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
            
        global_step += 1
        
        train_loss_backprop += loss.item()
        trained_loss.append(loss.item())

        current_lr = scheduler.get_last_lr()[0]
        epoch_lrs.append(current_lr)
        
    avg_epoch_loss = train_loss_backprop / (batch_idx+1)      
    print(f'\nEpoch {epoch} average train loss:{avg_epoch_loss:.4f}\n')

 
    save_checkpoint(model=model,optimizer=optimizer,
                  epoch=epoch,loss=avg_epoch_loss,
                  file_path=f"/kaggle/working/pretrained_checkponint_{epoch}.pth",
                   global_step=global_step, scheduler=scheduler)

    total_lrs.extend(epoch_lrs)


if global_step % gradient_accumulation_step != 0:
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad(set_to_none=True)


## Visualizations

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(8,6))
plt.plot(train_loss, label="Train Loss", marker="o")
plt.plot(val_loss, label="Validation Loss", marker="x")
plt.title("Training Loss x Validation Loss")
plt.xlabel("Intervals")
plt.legend()
plt.ylabel("Loss")
plt.grid(True)
plt.show()
plt.savefig(f"loss_plot_pretrained_{epoch}.png")

In [None]:
plt.figure(figsize=(8,6))
plt.plot(total_lrs, label="Learning rate", marker="o")
plt.title("LR Curve")
plt.xlabel("Steps")
plt.legend()
plt.ylabel("lr")
plt.grid(True)
plt.show()
plt.savefig(f"lr_plot_pretrained_{epoch}.png")

## Inference Phase

In [None]:
prompt = "A 33-year-old woman is brought to the emergency department 15 minutes after being stabbed in the chest with a screwdriver. Given her vital signs of pulse 110min, respirations 22min, and blood pressure 9065 mm Hg, along with the presence of a 5-cm deep stab wound at the upper border of the 8th rib in the left midaxillary line, which anatomical structure in her chest is most likely to be injured?"
input_ids = tokenizer.encode(prompt, allowed_special='all')
input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)

model.eval()
with torch.inference_mode():
    output = model.generate(input_ids, 200,
                           block_size, 0.7,
                           top_k=50, top_p=None)

print(tokenizer.decode(output[0].tolist()))

In [None]:
def get_input_tokens(turns: list[dict]) -> list[int]:

    formatted_input = ""
    for turn in turns:
        role = turn['role']
        content = turn['content']
        formatted_input += f"<|startoftext|><|User|>{content}"

    formatted_input += f"|Assistant|>"
    
    input_tokens = tokenizer.encode(formatted_input, allowed_special='all')
    input_tokens = torch.tensor(input_tokens, dtype=torch.long)
    input_tokens = input_tokens.unsqueeze(0).to(device)
    return input_tokens


def generate_message(input_tokens: list[int]):
    model_answer = ""
    
    model.eval()
    while True:
    
        try:
            output_tokens = model.generate(
                input_tokens=input_tokens,max_new_tokens=1,
                block_size=1024, top_k=50, top_p=None,
                temperature=0.7
            )
    
            last_generated_tokens = output_tokens[0, -1].item()
            
            if last_generated_tokens == tokenizer.special_tokens['<|endoftext|>']:
                break
    
    
            input_tokens = torch.cat((input_tokens, output_tokens[:, -1:]), dim=1)
            model_answer += tokenizer.decode([last_generated_tokens])
    
        except Exception:
            continue

    return model_answer

In [None]:
user_msg = "What is cause of urine loss?"
turns = [{
    "role": 'user',
    "content": user_msg
}]

input_tokens = get_input_tokens(turns)
model_answer = generate_message(input_tokens)

turns.append({
    "role": 'assistant',
    "content": model_answer
})

In [None]:
for turn in turns:
    role = turn['role']
    if role == 'user':
        print("User:", turn['content'] + "\n")

    elif role == 'assistant':
        print("Assistant:", turn['content'])