In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft.tuners.lora import LoraConfig, LoraModel
from peft import get_peft_model
from datasets import load_dataset
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from trl import SFTTrainer

In [4]:
model_checkpoint = "TinyLlama/TinyLlama-1.1B-step-50K-105b"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_checkpoint, device_map="auto", load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.pad_token = tokenizer.eos_token 

In [5]:
lambda_reg = 0.1

In [6]:
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft[:10]")

In [7]:
from jinja2 import Template
tstr = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

In [8]:
def preprocess(item):
    
    #conv = get_conversation_template("vicuna")
    #roles = {"human": 'user', "gpt": 'assistant'}

    add_generation_prompt = True  # You can set this according to your needs

    output = template.render(messages=item["messages"], add_generation_prompt=add_generation_prompt, eos_token='\n')
    
    return {"prompt": item["prompt"], "response": output}
template = Template(tstr)
ds = ds.map(preprocess, remove_columns=["messages", "prompt_id"], batched=False)
dataset = ds

In [9]:
lora_config = LoraConfig(
    r=32,  # rank of LoRA
    lora_alpha=64,  # scaling factor for initialization
    lora_dropout=0.05,
    bias="none",
)
peft_model = get_peft_model(model, lora_config) #LoraModel(model, lora_config, adapter_name="iter0").to(device)

In [10]:
def compute_spin_loss(model_logits_gt, opponent_logits_gt, model_logits_syn, opponent_logits_syn, lambda_reg):
    model_probs_gt = torch.nn.functional.softmax(model_logits_gt, dim=-1)
    model_probs_syn = torch.nn.functional.softmax(model_logits_syn, dim=-1)
    opponent_probs_gt = torch.nn.functional.softmax(opponent_logits_gt, dim=-1)
    opponent_probs_syn = torch.nn.functional.softmax(opponent_logits_syn, dim=-1)

    print(model_probs_gt.shape)

    if model_probs_gt.shape[1] < model_probs_syn.shape[1]:
        model_probs_syn = model_probs_syn[:, :model_probs_gt.shape[1]]

    if model_probs_gt.shape[1] > model_probs_syn.shape[1]:
        model_probs_gt = model_probs_gt[:, :model_probs_syn.shape[1]]
    if opponent_probs_gt.shape[1] < opponent_probs_syn.shape[1]:
        model_probs_syn = model_probs_syn[:, :model_probs_gt.shape[1]]

    if opponent_probs_gt.shape[1] > opponent_probs_syn.shape[1]:
        model_probs_gt = model_probs_gt[:, :model_probs_syn.shape[1]]

    # Calculate losses
    loss_gt = -torch.log(model_probs_gt / opponent_probs_gt)
    loss_syn = -torch.log(model_probs_syn / opponent_probs_syn)

    # Apply the logistic loss to the log odds ratio
    logistic_loss_gt = torch.log(1 + torch.exp(-lambda_reg * loss_gt))
    logistic_loss_syn = torch.log(1 + torch.exp(-lambda_reg * loss_syn))

    # Combine losses for the final spin loss
    spin_loss = logistic_loss_gt.mean(dim=[1,2]) + logistic_loss_syn.mean(dim=[1,2])
    return spin_loss

In [19]:
from tqdm import tqdm
import transformers
from transformers import DataCollatorForSeq2Seq
T = 5  # Set the number of iterations
total_loss = 0

# Disable adapter layers for the opponent model
peft_model.disable_adapter_layers()

synthetic_data = []
#for data in tqdm(dataset):
#    prompt = data['prompt']
#    # Tokenize and generate synthetic data using the opponent model
#    prompt_ids = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).input_ids.to(device)
#    with torch.no_grad():
#       peft_model.eval()  # Set model to evaluation mode
#        synthetic_response_ids = peft_model.generate(prompt_ids, max_length=50)
#        synthetic_data.append(synthetic_response_ids)

# Enable adapter layers for training the main player model
peft_model.enable_adapter_layers()

# Train the main player model using the synthetic data and real responses
peft_model.train()  # Set model to training mode

class SPINTrainer(Trainer):
    def compute_loss(self, peft_model, inputs, return_outputs=False):
        print(inputs)

        exit(0)

        #prompt_ids = inputs["input_ids"]
        prompt = inputs['query']
        
        # Tokenize and generate synthetic data using the opponent model
        prompt_ids = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).input_ids.to(device)
        with torch.no_grad():
            peft_model.eval()  # Set model to evaluation mode
            synthetic_response_ids = peft_model.generate(prompt_ids, max_length=50)
            synthetic_data.append(synthetic_response_ids)

            ground_truth = data['response']
            ground_truth_ids = tokenizer(ground_truth, return_tensors='pt', padding=True, truncation=True).input_ids.to(device)
            synthetic_response_ids = synthetic_data[i]
    
            # Calculate logits for ground truth and synthetic responses
            main_player_logits_gt = peft_model(ground_truth_ids).logits
            main_player_logits_syn = peft_model(synthetic_response_ids).logits
    
            # Get opponent's logits for synthetic responses (as they were generated before enabling LoRA)
            opponent_logits_syn = peft_model(synthetic_response_ids).logits
            
            # Compute the loss (assuming the function is defined above)
            loss = compute_spin_loss(
                main_player_logits_gt, opponent_logits_syn, 
                main_player_logits_syn, opponent_logits_syn, 
                lambda_reg
            )
    
        return (loss, synthetic_response_ids) if return_outputs else loss

    
args = TrainingArguments(remove_unused_columns=False, output_dir="dir", label_names="labels")

def tokenize_func(examples):
    ret =  tokenizer(examples["prompt"], padding=True, truncation=True, max_length=512)  # max_length=512,  padding=True
    ret["custom"] =  tokenizer(examples["response"], padding=True, truncation=True, max_length=512)
    return ret

encoded_dataset = dataset.map(tokenize_func, remove_columns=["prompt", "response"])
print(encoded_dataset)
#print(dataset.columns)
#print(encoded_dataset[0])
#dataset = dataset.remove_columns(dataset["train"].column_names)

#print(encoded_dataset)

trainer = SPINTrainer(
    model=peft_model,
    train_dataset=encoded_dataset,
    tokenizer=tokenizer,
    args = args,
)

trainer.train()


Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Dataset({
    features: ['input_ids', 'attention_mask', 'custom'],
    num_rows: 10
})


TypeError: object of type 'DataCollatorForSeq2Seq' has no len()

In [None]:
# Save the final model parameters

# In[ ]:

final_model_params = peft_model.state_dict()
print("Training complete.")