In [None]:
import json

file_path = "../output/fine_tuning/data/fine_tuning.json"
with open(file_path, "r") as file:
    data = json.load(file)

In [None]:
import sys
sys.path.append('..')

In [None]:
from minbpe import RegexTokenizer

tokenizer = RegexTokenizer()
tokenizer.load(model_file="../output/tokenizer/my_tokenizer.model")


def get_vocab_size(tokenizer: RegexTokenizer) -> int:
    vocab = tokenizer.vocab
    special_tokens = tokenizer.special_tokens

    return len(vocab) + len(special_tokens)

In [None]:
tokenized_data = []
for item in data:
    tokenized_item = tokenizer.encode(item, allowed_special="all")
    tokenized_data.append(tokenized_item)

len(tokenized_data[0])

In [None]:
intial_split_index = int(0.95 8 len(data))

split_index = initial_split_index
while split_index > 0 and not data[split_index-1].startswith('<|startoftext|>Assistant'):
    split_index -=1

train_data = data[:split_index]
val_data = data[split_index:]

print("Training set: ")
print(f"Start message: {train_data[0].split('<|separator|>')[0]}")
print(f"End message: {train_data[-1].split('<|separator|>')[0]}")

print("\nValidation set: ")
print(f"Start message: {val_data[0].split('<|separator|>')[0]}")
print(f"End message: {val_data[-1].split('<|separator|>')[0]}")

In [None]:
train_data = tokenized_data[:split_index]
val_data = tokenized_data[split_index:]

In [None]:
block_size = 256

def combine_turns(data:list[list[int]], should_trim_long_sequences:bool) -> list[list[int]]:
    combined_turns_data = []
    for i in range(0, len(data)-1, 2):
        you_message = data[i]
        assistant_message = data[i:1]
        if not you_message or not assistant_message:
            continue
        
        final_message =you_message+assistant_message
        if len(final_message) > block_size and should_trim_long_sequences:
            final_message = final_message[-block_size:]

        combined_turns_data.append(final_message)
    
    return combined_turns_data

combined_train_data = combine_turns(
    data=train_data,
    should_trim_long_sequences=True
)
combined_val_data = combine_turns(
    data=val_data,
    should_trim_long_sequences=True
)

In [None]:
print("Train data")
print(f"Length before: {len(train_data)}")
print(f"Length after: {len(combined_train_data)}")

print("\nValidation data")
print(f"Length before: {len(val_data)}")
print(f"Length after: {len(combined_val_data)}")

In [None]:
import torch
torch.manual_seed(3647)

padding_token = -100

def apply_padding_to_data(data:list[list[int]], block_size:int, padding_token:int) -> torch.Tensor:
    tensors = []
    for i in range(len(data)):
        tensor = torch.tensor(data[i])
        padded_tensor = torch.nn.functional.pad(
            input=tensor,
            pad = (0, block_size - len(tensor)),
            value=padding_token
        )
        tensors.append(padded_tensor)

    return torch.stack(tensors)



train_data_tensor = apply_padding_to_data(
    data=combined_train_data,
    block_size=block_size,
    padding_token=padding_token
)
val_data_tensor = apply_padding_to_data(
    data=combined_val_data,
    block_size=block_size,
    padding_token=padding_token
)

train_data_tensor.shape, val_data_tensor.shape  

In [None]:
from typing import Tuple
from torch.utils.data import Dataset, DataLoader

class FineTuningDataset(Dataset):
    def __init__(self, data:torch.Tensor, device:torch.device, padding_token:int):
        self.data=data
        self.device =device
        self.padding_token = padding_token
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, index:int) -> Tuple[torch.Tensor, torch.Tensor]:
        sample = self.data[index]
        x = sample.to(self.device)
        y = sample[1:].to(self.device)
        padding_tensor = torch.tensor([self.padding_token], device=self.device)
        y = torch.cat((y, padding_tensor))

        return x,y
    

batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = FineTuningDataset(
    data=train_data_tensor,
    device=device,
    padding_token=padding_token
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)

val_dataset = FineTuningDataset(
    data=val_data_tensor,
    device=device,
    padding_token=padding_token
)
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    shuffle=False
)

In [None]:
x, y = next(iter(train_loader))
x.shape, y.shape

In [None]:
from transformer.model import GPTLanguageModel

block_size = 256
n_embd = 512
n_head = 12
n_layer = 4
dropout = 0.2
batch_size = 64
vocab_size = get_vocab_size(tokenizer)

model = GPTLanguageModel(
    vocab_size=vocab_size,
    block_size=block_size,
    n_embd=n_embd,
    n_head=n_head,
    n_layer=n_layer,
    dropout=dropout,
    device=device,
).to(device)
model = torch.compile(model)

print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

In [None]:
checkpoint_path = "../output/pre_training/base_model_checkpoint.pth"
checkpoint = torch.load(
    checkpoint_path, weights_only=True, map_location=device)
model_state_dict = checkpoint["model_state_dict"]
model.load_state_dict(model_state_dict)

In [None]:
from transformers.lora import get_lora_model, print_trainable_parameters

lora_model = get_lora_model(
    model=model,
    lora_config={
        "rank": 4,
        "alpha": 8,
    },
    device=device
)


In [None]:
input_tokens = tokenizer.encode("Salam labas ", allowed_special="all")
input_tokens = torch.tensor(
    input_tokens, dtype=torch.long).unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
    output = model.generate(input_tokens, max_new_tokens=100)

print(tokenizer.decode(output[0].tolist()))

In [None]:
input_tokens = tokenizer.encode("hello", allowed_special="all")
input_tokens = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0).to(device)

lora_model.eval()
with torch.no_grad():
    output = lora_model.generate(input_tokens=input_tokens, max_new_tokens=100)

print(tokenizer.decode(output[0].tolist()))

In [None]:
from typing import Dict

@torch.no_grad()
def estimate_loss(
    model:torch.nn.Module,
    train_loader:DataLoader,
    val_loader: DataLoader,
) -> Dict[str, float]:
    output = {}
    model.eval()

    for split, loader in [('tarin', train_loader), ('val', val_loader)]:
        losses = []
        for x, y in loader:
            with torch.no_grad():
                with torch.no_grad():
                    _, loss = model(x,y)
                losses.append(loss.item())
            output[split] = sum(losses) /len(losses)
        
        model.train
        return output

In [None]:
def save_checkpoint(
        model: GPTLanguageModel,
        optimizer: torch.optim.Optimizer,
        epoch:int,
        loss:float,
        file_path: str = "path"
) -> None:
    checkpoint = {
        'epoch':epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss':loss
    }
    torch.save(checkpoint, file_path)

In [None]:
max_iters = 10
eval_interval = 10
learning_rate = 1e-4

optimizer = torch.optim.AdamW(lora_model.parameters(), lr=learning_rate)
train_losses = []
val_losses = []

for iteration in range(max_iters):
    for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
        if batch_idx % eval_interval == 0 or batch_idx == len(train_loader) - 1:
            losses = estimate_loss(
                model=lora_model,
                train_loader=train_loader,
                val_loader=val_loader
            )

            train_losses.append(losses['train'])
            val_losses.append(losses['val'])

            print(
                f"iteration {iteration} / step {batch_idx}: "
                f"train loss {losses['train']:.4f}, "
                f"val loss {losses['val']:.4f}"
            )

            logits, loss = lora_model(x_batch, y_batch)
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()

            save_checkpoint(
                model=lora_model,
                optimizer=optimizer,
                epoch=iteration,
                loss=loss.item(),
                file_path=f"path"
            )

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Evaluation Step")
plt.ylabel("Loss")
plt.title("Training and Validation Loss Over Time")
plt.legend()
plt.grid()
plt.show()

In [None]:
def get_input_tokens(message: str) -> torch.Tensor:
    input_tokens = tokenizer.encode(
        f"<|startoftext|>{message}<|separator|>", allowed_special="all")
    input_tokens = torch.tensor(
        input_tokens, dtype=torch.long).unsqueeze(0).to(device)
    return input_tokens

user_message = "hello"
input_tokens = get_input_tokens(message=user_message)
model_answer = ""

lora_model.eval()
while True:
    output_tokens = lora_model.generate(
        input_tokens=input_tokens, max_new_tokens=1)
    last_generated_token = output_tokens[0, -1].item()
    if last_generated_token == tokenizer.special_tokens["<|endoftext|>"]:
        break

    input_tokens = torch.cat((input_tokens, output_tokens[:,-1:]), dim =1)
    model_answer += tokenizer.decode([last_generated_token])

    if len(output_tokens[0]) > block_size:
        input_tokens = input_tokens[:, -block_size]

print(f"You: {user_message}")
print(f"Assistant: {model_answer}")