In [None]:
!pip install bitsandbytes -q

In [None]:
%%writefile train.py

import os
import platform
import time
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
from torch.distributed import init_process_group, destroy_process_group
from torch.amp import GradScaler

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score

from transformers import AutoModel, AutoTokenizer, AutoConfig, get_cosine_schedule_with_warmup
from peft import get_peft_model, LoraConfig, TaskType
from transformers import BitsAndBytesConfig

from peft import (
    get_peft_config, 
    get_peft_model, 
    LoraConfig,
    TaskType,
    prepare_model_for_kbit_training
)

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Constants
model_path = '/kaggle/input/qwen-3/transformers/14b/1'
num_folds = 3
num_epochs = 3
batch_size = 3

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.padding_side = 'left'

class MathDataset(Dataset):
    def __init__(self, prompts, targets):
        self.prompts = prompts
        self.targets = targets

    def __getitem__(self, idx):
        return self.prompts[idx], self.targets[idx]

    def __len__(self):
        return len(self.targets)

class Net(nn.Module):
    def __init__(self, model_path, rank):
        super(Net, self).__init__()
        self.config = AutoConfig.from_pretrained(model_path)

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )

        self.backbone = AutoModel.from_pretrained(
            model_path,
            use_cache=False,
            torch_dtype=torch.float16,
            quantization_config=bnb_config,
            device_map=rank
        )

        peft_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION,
            target_modules='all-linear',
            bias='none',
            inference_mode=False,
            r=8,
            lora_alpha=16,
            lora_dropout=0.05
        )

        # self.backbone.gradient_checkpointing_enable()

        # self.backbone = prepare_model_for_kbit_training(self.backbone, use_gradient_checkpointing = True)

        self.backbone = get_peft_model(self.backbone, peft_config)

        
        self.head = nn.Linear(self.config.hidden_size, 8, bias=False)

    def forward(self, x):
        x = self.backbone(**x).last_hidden_state[:, -1, :]
        return self.head(x)

def ddp_setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    if platform.system() == 'Windows':
        os.environ['USE_LIBUV'] = '0'
        init_process_group(backend='gloo', rank=rank, world_size=world_size)
    else:
        init_process_group(backend='nccl', rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def get_optimizer(model, learning_rate=0.0001, diff_lr=0.00001, weight_decay=0.01):

	no_decay = ['bias', 'LayerNorm.weight']
	differential_layers = ['backbone']

	optimizer = torch.optim.AdamW(
			[
				{
					"params": [
						param
						for name, param in model.named_parameters()
						if (not any(layer in name for layer in differential_layers))
						and (not any(nd in name for nd in no_decay))
					],
					"lr": learning_rate,
					"weight_decay": weight_decay,
				},
				{
					"params": [
						param
						for name, param in model.named_parameters()
						if (not any(layer in name for layer in differential_layers))
						and (any(nd in name for nd in no_decay))
					],
					"lr": learning_rate,
					"weight_decay": 0,
				},
				{
					"params": [
						param
						for name, param in model.named_parameters()
						if (any(layer in name for layer in differential_layers))
						and (not any(nd in name for nd in no_decay))
					],
					"lr": diff_lr,
					"weight_decay": weight_decay,
				},
				{
					"params": [
						param
						for name, param in model.named_parameters()
						if (any(layer in name for layer in differential_layers))
						and (any(nd in name for nd in no_decay))
					],
					"lr": diff_lr,
					"weight_decay": 0,
				},
			],
			lr=learning_rate,
			weight_decay=weight_decay,
	)

	return optimizer

def train_model(rank, world_size, num_epochs, fold, train_index, val_index, all_prompts, all_targets):
    ddp_setup(rank, world_size)

    train_prompts = [all_prompts[i] for i in train_index]
    val_prompts = [all_prompts[i] for i in val_index]
    train_targets = [all_targets[i] for i in train_index]
    val_targets = [all_targets[i] for i in val_index]

    class_weights = 1 / (np.unique(train_targets, return_counts=True)[1] / len(train_targets))
    class_weights = torch.tensor(class_weights, dtype=torch.half)

    train_dataset = MathDataset(train_prompts, train_targets)
    val_dataset = MathDataset(val_prompts, val_targets)

    train_sampler = DistributedSampler(train_dataset)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=True, shuffle=False, drop_last=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

    model = Net(model_path, rank).to(rank)
    model = DDP(model, device_ids=[rank])

    optimizer = get_optimizer(model, learning_rate=2e-4, diff_lr=2e-4, weight_decay=0.01)
        
    scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer,
                                                num_warmup_steps=0, 
                                                num_training_steps=len(train_loader) * num_epochs)
    scaler = GradScaler()

    best_f1 = 0.0  # Track best F1
    for epoch in range(num_epochs):
        train_loader.sampler.set_epoch(epoch)
        model.train()

        for batch_prompts, batch_targets in tqdm(train_loader):
            max_len = max(len(x) for x in tokenizer(batch_prompts).input_ids)

            if max_len > 300:
                encodings = tokenizer(batch_prompts,
                  return_tensors='pt', 
                  padding='max_length', 
                  truncation=True,
                  max_length=300).to(rank)
            else:
                encodings = tokenizer(batch_prompts,
                  return_tensors='pt', 
                  padding='longest').to(rank)            
            
            batch_targets = batch_targets.long().to(rank)

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                logits = model(encodings)
                loss = F.cross_entropy(logits, batch_targets, weight=class_weights.to(rank))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()

        # Validation
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch_prompts, batch_targets in tqdm(val_loader, total=len(val_loader)):
                max_len = max(len(x) for x in tokenizer(batch_prompts).input_ids)

                if max_len > 300:
                    encodings = tokenizer(batch_prompts,
                      return_tensors='pt', 
                      padding='max_length', 
                      truncation=True,
                      max_length=300).to(rank)
                else:
                    encodings = tokenizer(batch_prompts,
                      return_tensors='pt', 
                      padding='longest').to(rank)
                    
                with torch.autocast(device_type='cuda', dtype=torch.float16):
                    
                    logits = model(encodings)
                    preds = torch.argmax(logits, dim=1).cpu().tolist()
    
                    all_preds.extend(preds)
                    all_labels.extend(batch_targets)

        f1 = f1_score(all_labels, all_preds, average='micro')
        print(f'[GPU {rank}] Fold {fold+1} | Epoch {epoch+1}/{num_epochs} | Val F1-micro: {f1:.4f}')
    
        if rank == 0 and f1 > best_f1:
            best_f1 = f1
            model.eval()
            model.module.backbone.save_pretrained(f'backbone_fold_{fold}_best')
            torch.save(model.module.head.state_dict(), f'head_fold_{fold}_best.pt')
            
    destroy_process_group()

def run_ddp(rank, world_size, num_epochs, splits, fold, all_prompts, all_targets):
    train_index, val_index = splits[fold]
    train_model(rank, world_size, num_epochs, fold, train_index, val_index, all_prompts, all_targets)

if __name__ == '__main__':
    print("PyTorch version:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    print("Number of GPUs available:", torch.cuda.device_count())
    
    torch.manual_seed(1)
    
    df = pd.read_csv('/kaggle/input/classification-of-math-problems-by-kasut-academy/train.csv')
    df.columns = ['problem', 'target']

    prompts = [
        f"""'<|im_start|>user
Your task is to classify each Math problem into one of these eight topics using a machine learning or NLP-based approach.
0: Algebra
1: Geometry and Trigonometry
2: Calculus and Analysis
3: Probability and Statistics
4: Number Theory
5: Combinatorics and Discrete Math
6: Linear Algebra
7: Abstract Algebra and Topology

Your answer should be an integer that assigns the most appropriate topic category to the given Math problem based on its content and required reasoning.

Math Problem: {p.strip()}

Answer: """
        for p in df['problem']
    ]

    targets = df['target'].tolist()

    skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)
    splits = list(skf.split(prompts, targets))

    world_size = torch.cuda.device_count()

    for fold in range(num_folds):
        mp.spawn(run_ddp, args=(world_size, num_epochs, splits, fold, prompts, targets), nprocs=world_size)

In [None]:
!python train.py