# Compatibility Prediction

In [1]:
import os
import sys
import logging
from pathlib import Path
from datasets import concatenate_datasets
from datasets import load_dataset
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainerCallback,
)
from transformers.integrations import MLflowCallback
from types import SimpleNamespace
import transformers
from tqdm import tqdm

os.environ["WANDB_PROJECT"]="compatibility"

print(f"Transformers Version: {transformers.__version__}")
print(f"Torch Version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"{torch.cuda.device_count()} CUDA ({torch.version.cuda}) device available")
else:
    print("No CUDA device available")

  from .autonotebook import tqdm as notebook_tqdm
2025-03-03 07:39:11,248	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION


Transformers Version: 4.47.1
Torch Version: 2.4.0+cu121
4 CUDA (12.1) device available


In [2]:
output_string = 'compatibility_text_based'
grpo_train_file = f"{output_string}_grpo_train_14272.jsonl"
grpo_dev_file = f"{output_string}_grpo_dev_1352.jsonl"
grpo_test_file = f"{output_string}_grpo_test_73213.jsonl"
data_dir = "/home/azureuser/localfiles/data/polyvore_cp"

data_files = {"train": os.path.join(data_dir, grpo_train_file),
              "dev": os.path.join(data_dir, grpo_dev_file),
              }
dataset = load_dataset("json", data_files=data_files)
dataset

DatasetDict({
    train: Dataset({
        features: ['split', 'num_items', 'gt', 'prompt', 'completion', 'prompt_tokens', 'completion_tokens'],
        num_rows: 14272
    })
    dev: Dataset({
        features: ['split', 'num_items', 'gt', 'prompt', 'completion', 'prompt_tokens', 'completion_tokens'],
        num_rows: 1352
    })
})

In [3]:
dataset['train']

Dataset({
    features: ['split', 'num_items', 'gt', 'prompt', 'completion', 'prompt_tokens', 'completion_tokens'],
    num_rows: 14272
})

In [4]:
arg_dict = {'model_name_or_path': "microsoft/Phi-3.5-mini-instruct",
            'max_seq_len': 1100,
            'output_dir': "/home/azureuser/localfiles/models/grpo/",
            'max_prompt_length': 800,
            'max_completion_length': 256,
            'lr': 5e-6,
            'num_epochs': 1,
            'batch_size': 32,
            'grad_accumulation_steps': 4,
            'num_generations': 8,
            'lora_r': 16,
            'lora_alpha': 32,
            'lora_dropout': 0.05,
            'target_modules': ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"],
            'logging_steps': 100,
            'train_log_filename': "/home/azureuser/localfiles/data/polyvore_cp/grpo_training.log"
           }
args = SimpleNamespace(**arg_dict)

In [5]:
logger = logging.getLogger("trainer_logger")
logger.setLevel(logging.INFO)

# File handler for logging
file_handler = logging.FileHandler(args.train_log_filename)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - %(message)s")
file_handler.setFormatter(formatter)

logger.addHandler(file_handler)
logger.propagate = False

class LogMetricsCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        logs = logs or {}
        logger.info(f"Step: {state.global_step}, Metrics: {logs}")

    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        metrics = metrics or {}
        logger.info(f"Step: {state.global_step}, Eval Metrics: {metrics}")


In [6]:
model_kwargs = {
    "use_cache": False,
    "trust_remote_code": False,
    "attn_implementation": "flash_attention_2",
    "torch_dtype": torch.bfloat16,
    "device_map": "auto",
}
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.25it/s]


In [7]:
tokenizer.model_max_length = args.max_seq_len
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.unk_token  # Use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

In [8]:
import re

def format_reward_func(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, content) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]


def reward_fun(completions, gt, **kwargs):
    # check whether the content within <answer> and </answer> matches
    # the ground truth, Note: completions are the predicted completions
    answer_pattern = r"<answer>(.*?)</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    reward_score = []
    for solution_str, label in zip(completion_contents, gt):
        match = re.finditer(answer_pattern, solution_str)
        matches = list(match)
        if matches:
            final_answer = matches[-1].group(1).strip()
        else:
            final_answer = None
        if final_answer is not None and final_answer == label:
            reward_score.append(1)
        else:
            reward_score.append(0)
    return reward_score

In [9]:
training_args = GRPOConfig(
    output_dir=args.output_dir,
    logging_steps=args.logging_steps,
    use_vllm=False,
    learning_rate=5e-6,
    num_train_epochs=args.num_epochs,
    per_device_train_batch_size=args.batch_size,
    gradient_accumulation_steps=args.grad_accumulation_steps,
    num_generations=args.num_generations,
    max_prompt_length=args.max_prompt_length,
    max_completion_length=args.max_completion_length,
    temperature=1,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    bf16=True,
    save_strategy="no",
    torch_empty_cache_steps=None,  # Disable torch empty cache steps
    torch_compile=False,          # Disable torch compilation
    disable_tqdm=True,
    report_to="wandb",
    run_name="phi-3.5-cp-grpo"
)

peft_config = {
            "r": args.lora_r,
            "lora_alpha": args.lora_alpha,
            "lora_dropout": args.lora_dropout,
            "bias": "none",
            "task_type": "CAUSAL_LM",
            "target_modules": args.target_modules,
            "modules_to_save": None,
        }
peft_conf = LoraConfig(**peft_config)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[reward_fun, format_reward_func],
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['dev'],
    peft_config=peft_conf,
    callbacks=[LogMetricsCallback()],
)

In [10]:
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33maeroabir[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


{'loss': 0.0118, 'grad_norm': 0.6621402502059937, 'learning_rate': 4.998082190402241e-06, 'rewards/reward_fun': 0.401015625, 'rewards/format_reward_func': 0.00234375, 'reward': 0.403359375, 'reward_std': 0.27068326631560924, 'completion_length': 255.283515625, 'kl': 0.2959573017813091, 'epoch': 0.11210762331838565}
{'loss': 0.0623, 'grad_norm': 2.532252311706543, 'learning_rate': 4.771484098502683e-06, 'rewards/reward_fun': 0.472421875, 'rewards/format_reward_func': 0.002890625, 'reward': 0.4753125, 'reward_std': 0.2453820329532027, 'completion_length': 253.07875, 'kl': 1.5563380974531174, 'epoch': 0.2242152466367713}
{'loss': 0.0604, 'grad_norm': 6.977310657501221, 'learning_rate': 4.2007736611752195e-06, 'rewards/reward_fun': 0.48046875, 'rewards/format_reward_func': 0.005390625, 'reward': 0.485859375, 'reward_std': 0.2144345358759165, 'completion_length': 249.69171875, 'kl': 1.5102706103026866, 'epoch': 0.336322869955157}
{'loss': 0.052, 'grad_norm': 4.206620693206787, 'learning_rat

TrainOutput(global_step=892, training_loss=0.040821202667304754, metrics={'train_runtime': 153617.6384, 'train_samples_per_second': 0.093, 'train_steps_per_second': 0.006, 'total_flos': 0.0, 'train_loss': 0.040821202667304754})

In [11]:
len(dataset['train'])//args.batch_size

446

In [12]:
trainer.save_state()
trainer.save_model(args.output_dir)
print(f"Saved the trained model in {args.output_dir}")

Saved the trained model in /home/azureuser/localfiles/models/grpo/


In [None]:
from transformers import pipeline

def inference(messages, model, tokenizer, **kwargs):
    # messages = [
    #     {"role": "system", "content": "You are a helpful AI assistant."},
    #     {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
    #     {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
    #     {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
    # ]
    # device = torch.device('cuda:0')
    # model.to(device)
    max_new_tokens = kwargs.get('max_new_tokens', 512)
    temperature = kwargs.get('temperature', 0.0)
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
    )

    generation_args = {
        "max_new_tokens": max_new_tokens,
        "return_full_text": False,
        "temperature": temperature,
        "do_sample": False,
    }
    if type(messages) is not list:
        messages = [messages]
    output = pipe(messages, **generation_args)
    return output[0]['generated_text']


## Load the Test Dataset

In [14]:
sft_test_file = f"{output_string}_sft_test_2894.jsonl"

data_files = {"test": os.path.join(data_dir, sft_test_file)}
test_dataset = load_dataset("json", data_files=data_files)
test_dataset

Generating test split: 2894 examples [00:00, 93532.10 examples/s]


DatasetDict({
    test: Dataset({
        features: ['item_ids_original', 'item_ids_mapped', 'split', 'num_items', 'answer', 'messages', 'num_tokens'],
        num_rows: 2894
    })
})

In [15]:
model.device

device(type='cuda', index=0)

Test one example

In [17]:
example = test_dataset['test'][0]
print(example['messages'][1])
res = inference(example['messages'][0], model, tokenizer)
print(res)

Device set to use cuda:0


{'content': 'compatible', 'role': 'assistant'}




 compatible

The ankle boots, blouse, and slim-fit pants are compatible with each other as they can be combined to create a cohesive and stylish outfit suitable for smart casual or business casual events. The boots provide a sleek and polished look, the blouse adds a touch of sophistication with its tie neckline and long sleeves, and the pants offer a classic and timeless style with their flat front and slim fit. All items are versatile and can be paired together for various occasions, including casual dinners, art gallery openings, or professional meetings. The color palette and design elements complement each other, ensuring a harmonious and elegant ensemble.


In [18]:
example.keys()

dict_keys(['item_ids_original', 'item_ids_mapped', 'split', 'num_items', 'answer', 'messages', 'num_tokens'])

In [19]:
example['answer']

'compatible'

In [21]:
import pandas as pd

num_examples = len(test_dataset['test'])
test_response = []
for index in range(num_examples):
    example = test_dataset['test'][index]
    res = inference(example['messages'][0], model, tokenizer)
    test_response.append(res)
print(f"Evaluated {len(test_response)} test examples")

test_df = {'gt': [ex['answer'] for ex in test_dataset['test']], 'predicted': test_response}
test_df = pd.DataFrame(test_df)
test_df.to_csv(os.path.join(data_dir, 'predicted_results_compatibility_grpo_v2.csv'), index=False)

Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0
Device set to use cuda:0


Evaluated 2894 test examples


In [22]:
test_df = {'gt': [ex['answer'] for ex in test_dataset['test']], 'predicted': test_response}
test_df = pd.DataFrame(test_df)
test_df.to_csv(os.path.join(data_dir, 'predicted_results_compatibility_grpo_v2.csv'), index=False)

In [23]:
test_response[:5]

[' compatible\n\nThe ankle boots, blouse, and slim-fit pants are compatible with each other as they can be combined to create a cohesive and stylish outfit suitable for smart casual or business casual events. The boots provide a sleek and polished look, the blouse adds a touch of sophistication with its tie neckline and long sleeves, and the pants offer a classic and timeless style with their flat front and slim fit. All items are suitable for the fall season and can be paired with each other to elevate an ensemble effortlessly.',
 ' compatible',
 ' compatible\n\nThe reasoning behind this compatibility assessment is that both accessories are necklaces, which means they share a common type of fashion product. While they differ in design, color, and occasion suitability, they are still compatible in the sense that they can both be worn as part of a necklace ensemble. The "other" attribute in the description of the bucket bag does not affect the compatibility assessment between the two ne

In [24]:
df_grpo = pd.read_csv(os.path.join(data_dir, 'predicted_results_compatibility_grpo_v2.csv'))
df_grpo['yhat'] = df_grpo['predicted'].apply(lambda x: x.split()[0])
df_grpo['acc'] = df_grpo.apply(lambda row: row['gt'].lower() == row['yhat'].lower(), axis=1)
print(f"Accuracy: {df_grpo['acc'].mean()*100:.2f}")

Accuracy: 47.03


In [25]:
df_grpo

Unnamed: 0,gt,predicted,yhat,acc
0,compatible,"compatible\n\nThe ankle boots, blouse, and sl...",compatible,True
1,compatible,compatible,compatible,True
2,compatible,compatible\n\nThe reasoning behind this compa...,compatible,True
3,compatible,compatible,compatible,True
4,compatible,compatible\n\nThe ballet shoes and the A-line...,compatible,True
...,...,...,...,...
2889,incompatible,compatible,compatible,False
2890,incompatible,compatible,compatible,False
2891,incompatible,compatible\n\nThe shoulder bag and the day dr...,compatible,False
2892,incompatible,compatible,compatible,False


In [26]:
df_grpo.yhat.value_counts()

yhat
compatible      2780
incompatible     114
Name: count, dtype: int64

In [None]:
df_grpo['gt'].value_counts()

gt
incompatible    1581
compatible      1313
Name: count, dtype: int64

[1;34mwandb[0m: 
[1;34mwandb[0m: 🚀 View run [33mphi-3.5-cp-grpo[0m at: [34mhttps://wandb.ai/aeroabir/compatibility/runs/ns6ytgri[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20250303_073927-ns6ytgri/logs[0m
