import os
import json
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset
from trl import SFTTrainer, SFTConfig
import torch

# Load the tokenizer and model
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32  # Use torch.float16 if training with fp16
)
model.to("cuda")  # Move model to GPU explicitly
tokenizer.pad_token = tokenizer.eos_token

# Inspect parameter names to determine layer structure
print("Inspecting model layers to identify transformer blocks...")
sample_param_names = [name for name, _ in list(model.named_parameters())[:50]]
for name in sample_param_names:
    print(name)

# Based on known TinyLlama structure (same as LLaMA), layers are under "model.layers"
layer_prefix = "model.layers"

# Identify layer indices
layer_nums = sorted(set(
    int(name.split(f"{layer_prefix}.")[1].split(".")[0])
    for name, _ in model.named_parameters()
    if layer_prefix in name and name.split(f"{layer_prefix}.")[1].split(".")[0].isdigit()
))

total_layers = len(layer_nums)
layers_to_freeze = int(0.8 * total_layers)

print(f"\nTotal transformer layers: {total_layers}")
print(f"Freezing the bottom {layers_to_freeze} layers...")

# Freeze parameters accordingly
for name, param in model.named_parameters():
    if "embed_tokens" in name or "embed_positions" in name:
        param.requires_grad = False
    elif layer_prefix in name:
        layer_num = int(name.split(f"{layer_prefix}.")[1].split(".")[0])
        if layer_num < layers_to_freeze:
            param.requires_grad = False
        else:
            param.requires_grad = True
            print(f"Keeping trainable: {name}")
    elif "lm_head" in name:
        param.requires_grad = True
        print(f"Keeping trainable: {name}")
    else:
        param.requires_grad = False

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTrainable parameters: {trainable_params:,}")
print(f"Total parameters: {total_params:,}")
print(f"Percentage trainable: {100 * trainable_params / total_params:.2f}%")

# Load and preprocess the dataset
def load_dataset(jsonl_file):
    with open(jsonl_file, "r", encoding="utf-8") as f:
        data = [json.loads(line) for line in f]

    system_prompt = "Summarize the following legal text."
    texts = []
    for item in data:
        text = f"""### Instruction: {system_prompt}

### Input:
{item['judgement'].strip()[:10000]}

### Response:
{item['summary'].strip()}
""".strip()
        texts.append(text)

    dataset = Dataset.from_dict({"text": texts})
    return dataset

# Load dataset
train_file = "full_summaries.jsonl"
train_dataset = load_dataset(train_file)

# Set up training parameters
train_params = SFTConfig(
    output_dir="../results_partial_model",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    optim="adamw_torch",
    save_steps=50,
    logging_steps=50,
    learning_rate=1e-4,
    weight_decay=0.001,
    fp16=False,  # Set True if you want mixed precision
    bf16=False,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    report_to="tensorboard",
    dataset_text_field="text",
    max_seq_length=2048,
    ddp_find_unused_parameters=False
)

# Initialize Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    processing_class=tokenizer,
    args=train_params
)

# Train the model
print("Starting training (~20% of layers + lm_head)...")
start_time = time.time()
trainer.train()
training_time = time.time() - start_time

print(f"Training completed in {training_time:.2f} seconds")

# Save the model
print("Saving the model...")
model.save_pretrained("../partial_model_output")
tokenizer.save_pretrained("../partial_model_output")
print("Model saved at '../partial_model_output'")

# Save training info
with open("../partial_model_output/training_info.json", "w") as f:
    json.dump({
        "training_time_seconds": training_time,
        "trainable_params": trainable_params,
        "total_params": total_params,
        "percentage_trainable": 100 * trainable_params / total_params
    }, f, indent=2)