In [None]:
import io
import os
import copy
import time
import torch
import requests

from PIL import Image
import torch.distributed as dist
from transformers import AutoProcessor, AutoModel
from datasets import load_dataset, Dataset
from peft import PeftModel, LoraConfig, get_peft_model
from trl import SFTTrainer



# NOTE: 
# For GPU memory optimization, before starting, go to the Hugging Face cache and set "max_dynamic_tiles" from 12 to 1 in the config.json and preprocessor_config.json file.

## Load Dataset

In [None]:
existing_processed_datasets = False

seed = 7777
test_size = 0.05

newline_between_blocks = True # This for the newline between blocks 

dataset_path = "/home/compu/test_suchae/eagle2-2b-finetuning/multitask_dataset.jsonl"


In [None]:
import random
import numpy as np

def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)

set_seed(seed)

In [None]:
def ealge_format_multiview_data(sample):
    prompt_blocks = sample["prompt_blocks"]
    # Change the value of 'type' from 'image_url' to 'image' in dicts
    for block in prompt_blocks:
        if isinstance(block, dict) and block.get("type") == "image_url":
            block["type"] = "image"
            
    answer = sample["ground_truth_answer"]
    
    return {
        "messages": [
            {
                "role": "user",
                "content": prompt_blocks
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": answer}],
            },
        ],
    }

In [None]:
if existing_processed_datasets:
    train_dataset = Dataset.load_from_disk("./Multiview_processed_datasets/train_dataset")
    eval_dataset = Dataset.load_from_disk("./Multiview_processed_datasets/eval_dataset")

else:
    system_message = """You are a Vision Language Model designed to interpret and reason over multiple related chart images (multi-view).
You will be provided with a set of chart images that together represent different perspectives, time points, or facets of the same data context.
Your task is to analyze all provided images collectively and answer the user's query by integrating information across the views.
Focus on delivering concise, accurate answers (typically a word, number, or short phrase) based on the combined visual information.
Do not provide extra explanation unless specifically requested. Assume the user expects you to synthesize insights from all images."""

    # Load JSONL as Hugging Face dataset
    dataset = load_dataset("json", data_files=dataset_path)

    # Split dataset into train and eval
    dataset = dataset["train"].train_test_split(test_size=test_size, seed=seed)
    train_dataset = dataset["train"] 
    eval_dataset = dataset["test"]
    
    t0 = time.time()
    train_dataset = [ealge_format_multiview_data(sample) for sample in train_dataset]
    t1 = time.time()
    print("time taken (train_dataset to list) : ", t1 - t0)
    train_dataset = Dataset.from_list(train_dataset)
    t2 = time.time()
    print("train_dataset length: ", len(train_dataset))
    print("time taken (train_dataset to Dataset) : ", t2 - t1)

    eval_dataset = Dataset.from_list([ealge_format_multiview_data(sample) for sample in eval_dataset])
    t3 = time.time()
    print("eval_dataset length: ", len(eval_dataset))
    print("time taken (eval_dataset to Dataset) : ", t3 - t2)
    
    train_dataset.save_to_disk("./Multiview_processed_datasets/train_dataset")
    eval_dataset.save_to_disk("./Multiview_processed_datasets/eval_dataset")

In [None]:
model_id = "nvidia/Eagle2-2B"

model = AutoModel.from_pretrained(
    model_id,
    trust_remote_code=True, 
    torch_dtype=torch.bfloat16,
    device_map="cuda" if torch.cuda.is_available() else "cpu"
)

processor = AutoProcessor.from_pretrained(
    "nvidia/Eagle2-2B", 
    trust_remote_code=True, 
    use_fast=True
)
processor.tokenizer.padding_side = "left"

In [None]:
import gc
import time


def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

In [None]:
# Load model and tokenizer
model = AutoModel.from_pretrained(
    model_id,
    trust_remote_code=True, 
    torch_dtype=torch.bfloat16,
    device_map={'': torch.cuda.current_device()} if torch.cuda.is_available() else "cpu"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, use_fast=True)
processor.tokenizer.padding_side = "left"


In [None]:

# Configure LoRA
peft_config = LoraConfig(
    r=32,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
    use_dora=True,
    init_lora_weights="gaussian",
)

# Apply PEFT model adaptation
peft_model = get_peft_model(model, peft_config)

# Print trainable parameters
peft_model.print_trainable_parameters()

In [None]:
from trl import SFTConfig, SFTTrainer
from eagle2_trl_sft_trainer import Eagle2TRLSFTTrainer
from eagle2_data_collator import Eagle2DataCollator
import wandb

# Initialize wandb with dongguk university team
wandb.init(
    entity="schaeck-dongguk-university",  # Use dongguk university team
    project="eagle2-2b-finetuning"
)

# Configure training arguments
training_args = SFTConfig(
    output_dir="eagle2-2b-trl-sft-Multitask",  # Directory to save the model
    num_train_epochs=5,  # Number of training epochs
    per_device_train_batch_size=1,  # Batch size for training
    per_device_eval_batch_size=1,  # Batch size for evaluation
    gradient_accumulation_steps=64,  # Steps to accumulate gradients
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Options for gradient checkpointing
    label_names=["labels"],
    max_length=None,
    # Optimizer and scheduler settings
    optim="adamw_torch_fused",  # Optimizer type
    lr_scheduler_type="cosine",
    learning_rate=2e-4,  # Learning rate for training
    # Logging and evaluation
    logging_steps=5,  # Steps interval for logging
    eval_steps=60,  # Steps interval for evaluation
    eval_strategy="steps",  # Strategy for evaluation
    save_strategy="steps",  # Strategy for saving the model
    save_steps=180,  # Steps interval for saving
    # Mixed precision and gradient settings
    bf16=True,  # Use bfloat16 precision
    # max_grad_norm=0.3,  # Maximum norm for gradient clipping
    warmup_ratio=0.03,  # Ratio of total steps for warmup
    remove_unused_columns=False,  # Whether to remove unused columns
    # Hub and reporting
    push_to_hub=False,  # Whether to push model to Hugging Face Hub
    use_legacy_prediction_loop=True,
    report_to="wandb",  # Use Weights & Biases for logging
)

processor.tokenizer.pad_token = "<|endoftext|>"
processor.tokenizer.pad_token_id = 151643

eagle2_data_collator = Eagle2DataCollator(processor.tokenizer)


trainer = Eagle2TRLSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=eagle2_data_collator
)

Map:  55%|█████▍    | 3199/5838 [03:16<02:42, 16.26 examples/s]
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.s

KeyboardInterrupt: 

socket.send() raised exception.
socket.send() raised exception.


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.s

: 

In [None]:
# If distributed environment variables are not set, manually configure for single process
if 'RANK' not in os.environ:
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'  # Use any available port
    os.environ['RANK'] = '0'
    os.environ['WORLD_SIZE'] = '1'
    
    # Initialize backend ('nccl' if GPU is available, otherwise 'gloo')
    backend = 'nccl' if torch.cuda.is_available() else 'gloo'
    dist.init_process_group(backend=backend, init_method='env://')

In [None]:
trainer.train()

In [None]:
trainer.save_model(training_args.output_dir)

## Testing

In [None]:
newline_between_blocks = False

In [None]:
def generate_text_from_sample(model, processor, sample, max_new_tokens=1024, device="cuda", debug=False):
    """
    Generate text from a sample using the Eagle2 model.
    Uses the same preprocessing logic as the trainer for consistency.
    """
    import copy
    import io
    import requests
    from PIL import Image
    
    # Set model to evaluation mode
    model.eval()
    
    if debug:
        print("=" * 50)
        print("DEBUG: Sample structure:")
        print(f"Sample keys: {list(sample.keys())}")
        print(f"Messages length: {len(sample.get('messages', []))}")
        
        if 'messages' in sample and len(sample['messages']) > 0:
            user_content = sample['messages'][0].get('content', [])
            print(f"User content type: {type(user_content)}")
            print(f"User content length: {len(user_content) if isinstance(user_content, list) else 'Not a list'}")
            
            for i, block in enumerate(user_content):
                print(f"  Block {i}: {type(block)} - {block}")
        
        if 'messages' in sample and len(sample['messages']) > 1:
            assistant_content = sample['messages'][1].get('content', [])
            print(f"Assistant content: {assistant_content}")
        print("=" * 50)
    
    # Use the same preprocessing logic as the trainer
    sample_copy = copy.deepcopy(sample)
    has_any_image = False
    
    # Normalize messages reference
    messages = sample_copy["messages"]

    # Collect and normalize all images across all messages/contents
    pil_images = []
    image_inputs = None
    
    # Remove null values from content items
    for message in messages:
        if "content" in message:
            for content_item in message["content"]:
                # Remove keys with null values
                keys_to_remove = [key for key, value in content_item.items() if value is None]
                for key in keys_to_remove:
                    del content_item[key]
    
    # Walk through all content items to collect images
    for message in messages:
        if "content" not in message:
            continue
        for content_item in message["content"]:
            if content_item.get("type") == "image" and ("image" in content_item or "image_url" in content_item):
                raw_image_data = content_item["image"] if "image" in content_item else content_item["image_url"]

                pil_image = None
                if isinstance(raw_image_data, Image.Image):
                    pil_image = raw_image_data
                elif isinstance(raw_image_data, dict) and 'bytes' in raw_image_data and raw_image_data['bytes'] is not None:
                    image_bytes = raw_image_data['bytes']
                    pil_image = Image.open(io.BytesIO(image_bytes))
                elif isinstance(raw_image_data, dict) and 'url' in raw_image_data and raw_image_data['url'] is not None:
                    image_url = raw_image_data['url']

                    # Check if the URL is a web address
                    if image_url.startswith('http://') or image_url.startswith('https://'):
                        try:
                            response = requests.get(image_url)
                            # Check if status code is 200 (success)
                            if response.status_code == 200:
                                pil_image = Image.open(io.BytesIO(response.content))
                            else:
                                print(f"Failed to load web image. Status code: {response.status_code}")
                        except requests.exceptions.RequestException as e:
                            print(f"Error occurred during web image request: {e}")
                    # Check if the URL is a local file path
                    elif os.path.exists(image_url):
                        try:
                            pil_image = Image.open(image_url)
                        except FileNotFoundError:
                            print(f"Error: File does not exist at path {image_url}.")
                        except Exception as e:
                            print(f"Error: Problem occurred while opening local image: {e}")
                else:
                    # Unsupported type; skip this content item
                    pil_image = None

                if pil_image is not None:
                    has_any_image = True
                    pil_image = pil_image.convert("RGB")
                    # Update the in-memory structure to hold the PIL image
                    content_item["image"] = pil_image
                    pil_images.append(pil_image)
            else: # text
                if newline_between_blocks:
                    content_item["text"] = "\n" + content_item["text"] + " "

    if has_any_image:
        # Let the processor build image_inputs (handles multi-view)
        image_inputs, video_inputs = processor.process_vision_info(messages)
        if debug:
            print(f"Image inputs: {len(image_inputs) if image_inputs else 0} images")
            print(f"Video inputs: {len(video_inputs) if video_inputs else 0} videos")
    else:
        image_inputs, video_inputs = None, None
        if debug:
            print("No images found in sample")

    # Build prompt-only messages for generation
    prompt_messages = copy.deepcopy(messages)
    if len(prompt_messages) > 0 and prompt_messages[-1].get("role") == "assistant":
        prompt_messages = prompt_messages[:-1]

    # print("prompt_messages:", prompt_messages)

    # Generate text using prompt messages
    text_input = [processor.apply_chat_template(
        prompt_messages, tokenize=False, add_generation_prompt=True
    )]
    
    if debug:
        print(f"Generated text input: {text_input[0]}")
        
    # Prepare the inputs for the model
    model_inputs = processor(
        text=text_input,
        images=image_inputs,
        videos=video_inputs,
        return_tensors="pt",
        padding=True,
    ).to(device)
    
    if debug:
        print(f"Model inputs keys: {list(model_inputs.keys())}")
        print(f"Input IDs shape: {model_inputs['input_ids'].shape}")
        if 'pixel_values' in model_inputs:
            print(f"Pixel values shape: {model_inputs['pixel_values'].shape}")
    
    model = model.to(device)

    # Generate text with the model (using torch.no_grad for memory efficiency)
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs, 
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id
        )

    # Decode the output text
    output_text = processor.batch_decode(
        generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )

    if debug:
        print(f"Generated output: {output_text[0]}")
        print("=" * 50)

    return output_text[0]  # Return the first decoded output text

In [None]:
clear_memory()

In [None]:
# Load base model and processor
model = AutoModel.from_pretrained("nvidia/Eagle2-2B", trust_remote_code=True, torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained("nvidia/Eagle2-2B", trust_remote_code=True, use_fast=True)
processor.tokenizer.padding_side = "left"
processor.tokenizer.pad_token = "<|endoftext|>"
processor.tokenizer.pad_token_id = 151643

# Load the trained PEFT adapter
adapter_path = "/home/compu/test_suchae/eagle2-2b-finetuning/eagle2-2b-trl-sft-Multitask/checkpoint-150"
model = PeftModel.from_pretrained(model, adapter_path)

In [None]:
train_dataset[62]

In [None]:
from collections import defaultdict

all_incorrect_counts = defaultdict(int)
overall_accuracies = []
task_accuracies = defaultdict(list)
all_incorrect_details = [] # 모든 run의 오답 정보를 여기에 저장합니다.

num_runs = 10

for run_idx in range(num_runs):
    print(f"--- Starting run {run_idx + 1}/{num_runs} ---")
    
    task_correct = defaultdict(int)
    task_total = defaultdict(int)
    
    overall_correct = 0
    overall_total = len(train_dataset)

    for sample_idx in range(overall_total):
        processed_sample = train_dataset[sample_idx]
        output = generate_text_from_sample(model, processor, processed_sample, debug=False)
        
        gt = dataset['train'][sample_idx]['ground_truth_answer']
        task = dataset['train'][sample_idx]['task'] if 'task' in dataset['train'][sample_idx] else "unknown"

        is_correct = (output.strip() == gt.strip())
        if is_correct:
            overall_correct += 1
            task_correct[task] += 1
        else:
            all_incorrect_counts[sample_idx] += 1
            # 틀린 경우 상세 정보를 저장합니다.
            incorrect_entry = {
                "run_idx": run_idx + 1,
                "sample_idx": sample_idx,
                "task": task,
                "ground_truth": gt,
                "model_output": output
            }
            all_incorrect_details.append(incorrect_entry)

        task_total[task] += 1

    accuracy = overall_correct / overall_total if overall_total > 0 else 0.0
    overall_accuracies.append(accuracy)
    
    print(f"Run {run_idx + 1} Overall Accuracy: {accuracy:.4f} ({overall_correct}/{overall_total})")

    for task in sorted(task_total.keys()):
        task_acc = task_correct[task] / task_total[task] if task_total[task] > 0 else 0.0
        task_accuracies[task].append(task_acc)
        print(f"  {task}: {task_acc:.4f} ({task_correct[task]}/{task_total[task]})")

print("\n" + "="*50)
print("             Final Evaluation Results")
print("="*50)

avg_overall_accuracy = sum(overall_accuracies) / num_runs
print(f"\nAverage Overall Accuracy over {num_runs} runs: {avg_overall_accuracy:.4f}")

print("\nAverage Accuracy per task over all runs:")
for task in sorted(task_accuracies.keys()):
    avg_task_acc = sum(task_accuracies[task]) / num_runs
    print(f"  {task}: {avg_task_acc:.4f}")

print("\n" + "-"*50)
print("      Distribution of Incorrect Predictions")
print("-"*50)
if all_incorrect_counts:
    sorted_incorrect_items = sorted(all_incorrect_counts.items(), key=lambda item: item[1], reverse=True)
    
    print("Sample Index: Number of times it was incorrectly predicted (out of 10 runs)")
    for sample_idx, count in sorted_incorrect_items:
        print(f"  {sample_idx}: {count}")
else:
    print("No samples were incorrectly predicted across any of the runs.")


In [None]:

print("\n" + "="*50)
print("     Summary of Incorrect Predictions by Index")
print("="*50)

# Create a dictionary to hold the summarized incorrect prediction data
summary = defaultdict(lambda: {'count': 0, 'incorrect_outputs': defaultdict(int)})

# Iterate through the list of all incorrect predictions and aggregate the data
for entry in all_incorrect_details:
    sample_idx = entry['sample_idx']
    model_output = entry['model_output']
    
    summary[sample_idx]['count'] += 1
    summary[sample_idx]['incorrect_outputs'][model_output] += 1

if not summary:
    print("No incorrect predictions were found across all runs.")
else:
    # Sort the summary by sample index for cleaner output
    sorted_indices = sorted(summary.keys())
    
    for sample_idx in sorted_indices:
        data = summary[sample_idx]
        total_count = data['count']
        incorrect_outputs = data['incorrect_outputs']
        
        # Get the ground truth answer for the sample
        ground_truth = dataset['train'][sample_idx]['ground_truth_answer']
        
        print(f"Index {sample_idx} ({total_count}번 틀림):")
        print(f"  - 정답: '{ground_truth}'")
        
        # Sort the incorrect outputs by count in descending order
        sorted_outputs = sorted(incorrect_outputs.items(), key=lambda item: item[1], reverse=True)
        
        for output, count in sorted_outputs:
            print(f"  - 오답: '{output}' ({count}번)")
        
        print("-" * 20)