In [None]:
pip install --upgrade transformers huggingface_hub --q

In [None]:
import os 
import json 
from accelerate import Accelerator
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification
from peft import LoraConfig, TaskType, get_peft_model

In [None]:
import torch

In [None]:
import torch

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(model_name,attn_implementation="sdpa",dtype=torch.float16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.padding_side = 'left'

In [None]:
from peft import PeftModel

In [None]:
import torch.optim as optim

In [None]:
with open("/kaggle/input/mathdataset/math_dataset.json", 'r') as f:
    dataset = json.load(f)

In [None]:
from datasets import Dataset
train_data = []
for item in dataset['train']:
    train_data.append({
        'question': item['question'],
        'answer': item['answer']
    })

train_dataset = Dataset.from_list(train_data)

In [None]:
def collate_fn(batch):
    text = [b['question']+b['answer'] for b in batch]
    
    
    tokenized = tokenizer(text, 
                          truncation=True,
                          max_length=128, padding=True,
                          return_tensors='pt')

    return {
        'input_ids': tokenized['input_ids'],
        'attention_mask': tokenized['attention_mask'],
        'labels': tokenized['input_ids'].clone()
    }

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

In [None]:
accelerator = Accelerator(gradient_accumulation_steps=4,
                         mixed_precision='fp16')

In [None]:
model = PeftModel.from_pretrained(
    base_model,
    '/kaggle/input/loraadapters/pytorch/default/1',
    trainable=True
)

In [None]:
for n,p in model.named_parameters():
    if 'lora' in n:
        p.requires_grad=True

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
model, optimizer, dataloader = accelerator.prepare(model, optimizer, train_dataloader)

In [None]:
from transformers import PreTrainedModel

In [None]:
def sft_trainer(model: PeftModel|PreTrainedModel , 
                dataloader: DataLoader, 
                optimizer: optim.Optimizer, 
                num_epochs: int=10):
    model.train()
    best_loss = float('inf')
    for epoch in range(num_epochs):
        total_loss=0
        progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}')
        for i, batch in enumerate(progress_bar):
            with accelerator.accumulate(model):
                optimizer.zero_grad()
                output = model(input_ids=batch['input_ids'],
                              attention_mask=batch['attention_mask'],
                              labels=batch['labels'])
                loss = output.loss
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                loss_val = loss.detach().item()
                total_loss += loss_val
                progress_bar.set_postfix({'loss': loss_val})
                del output, loss
        avg_loss = total_loss/len(dataloader)
        print(f"[+] Epoch: {epoch+1} completed, Avg Loss: {avg_loss:.4f}")
        try:        
            if avg_loss < best_loss:
                best_loss = avg_loss
                accelerator.save_state("/kaggle/working/checkpoints/best_model")
                print(f"[*] New best model saved! Loss: {best_loss:.4f}")
        except Exception as e:
            print(f"[!] Failed to save: {e}")
    return accelerator.unwrap_model(model)

In [None]:
sft_model = sft_trainer(
    model,
    dataloader, 
    optimizer,
    num_epochs=50
)

In [None]:
sft_model.save_pretrained('/kaggle/working/best_model')

In [None]:
inputt = tokenizer('What is 1+2? ', return_tensors='pt')
out= sft_model.generate(**inputt,max_new_tokens=2)
print(tokenizer.decode(out[0], skip_special_token=True))