In [1]:
!pip install trl==0.9.6 peft==0.12.0 gdown

Collecting trl==0.9.6
  Downloading trl-0.9.6-py3-none-any.whl.metadata (12 kB)
Collecting peft==0.12.0
  Downloading peft-0.12.0-py3-none-any.whl.metadata (13 kB)
Collecting gdown
  Downloading gdown-5.2.0-py3-none-any.whl.metadata (5.8 kB)
Collecting tyro>=0.5.11 (from trl==0.9.6)
  Downloading tyro-0.8.5-py3-none-any.whl.metadata (8.2 kB)
Collecting docstring-parser>=0.16 (from tyro>=0.5.11->trl==0.9.6)
  Downloading docstring_parser-0.16-py3-none-any.whl.metadata (3.0 kB)
Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl==0.9.6)
  Downloading shtab-1.7.1-py3-none-any.whl.metadata (7.3 kB)
Downloading trl-0.9.6-py3-none-any.whl (245 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m245.8/245.8 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading peft-0.12.0-py3-none-any.whl (296 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m296.4/296.4 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gdown-5.2.0-py3-none-any.whl (18 kB)
Do

In [2]:
%load_ext autoreload
%autoreload 2

In [9]:
%%writefile configs.py

from dataclasses import dataclass, field


@dataclass
class SFTArgs:
    save_folder: str = field(default='sft', metadata={'help': 'Folder to save SFT model.'})


@dataclass
class RewardConfigArgs:
    output_dir: str = field(default='reward', metadata={'help': 'The output directory where the model checkpoints will be written.'})
    per_device_train_batch_size: int = field(default=32, metadata={'help': 'The batch size for training.'})
    gradient_accumulation_steps: int = field(default=1, metadata={'help': 'Number of updates steps to accumulate the gradients for.'})
    learning_rate: float = field(default=1.41e-5, metadata={'help': 'The initial learning rate for AdamW optimizer.'})
    max_steps: int = field(default=4000, metadata={'help': 'The total number of training steps to perform.'})
    logging_steps: int = field(default=100, metadata={'help': 'Number of update steps between two logs.'})
    gradient_checkpointing: bool = field(default=True, metadata={'help': 'If True, use gradient checkpointing.'})
    fp16: bool = field(default=True, metadata={'help': 'Whether to use 16-bit (mixed) precision training.'})
    bf16: bool = field(default=False, metadata={'help': 'Whether to use bf16 16-bit (mixed) precision training.'})
    max_length: int = field(default=512, metadata={'help': 'The maximum length of the sequences in the batch.'})


@dataclass
class PolicyTrainArgs:
    per_device_train_batch_size: int = field(default=64, metadata={'help': 'The batch size for training.'})
    per_device_eval_batch_size: int = field(default=32, metadata={'help': 'The batch size for evaluation.'})
    gradient_accumulation_steps: int = field(default=1, metadata={'help': 'Number of updates steps to accumulate the gradients for.'})
    learning_rate: float = field(default=1.41e-5, metadata={'help': 'The initial learning rate for AdamW optimizer.'})
    max_steps: int = field(default=100, metadata={'help': '(T) Number of train iterations for each policy.'})
    logging_steps: int = field(default=10, metadata={'help': 'Number of update steps between two logs.'})
    gradient_checkpointing: bool = field(default=True, metadata={'help': 'If True, use gradient checkpointing.'})
    fp16: bool = field(default=False, metadata={'help': 'Whether to use fp16 16-bit (mixed) precision training.'})
    bf16: bool = field(default=False, metadata={'help': 'Whether to use bf16 16-bit (mixed) precision training.'})
    warmup_steps: int = field(default=0, metadata={'help': 'Linear warmup over warmup_steps.'})


@dataclass
class LoraArgs:
    rank: int = field(default=32, metadata={'help': 'Lora attention dimension.'})
    lora_alpha: int = field(default=32, metadata={'help': 'The alpha parameter for Lora scaling.'})
    lora_dropout: float = field(default=0.0, metadata={'help': 'The dropout probability for Lora layers.'})


@dataclass
class CheckpointsArgs:
    group_name: str = field(default='warp', metadata={'help': 'Group name for WAND runs. Should be unique every time, otherwise logs to the existing group.'})
    sft_checkpoint: str = field(default='sft', metadata={'help': 'Path to sft model/tokenizer checkpoint.'})
    reward_checkpoint: str = field(default='reward', metadata={'help': 'Path to reward model/tokenizer checkpoint.'})
    save_folder: str = field(default='warp', metadata={'help': 'Folder to save WARP output.'})


@dataclass
class DatasetArgs:
    min_text_length: int = field(default=200, metadata={'help': 'Minimum length of a review from imdb dataset.'})
    min_tokens: int = field(default=5, metadata={'help': 'Minimum number of tokens after prompt truncation.'})
    max_tokens: int = field(default=20, metadata={'help': 'Maximum number of tokens after prompt truncation.'})
    eval_size: int = field(default=100, metadata={'help': 'Evaluation subset size.'})


@dataclass
class GenerationArgs:
    max_new_tokens: int = field(default=64, metadata={'help': 'The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.'})
    do_sample: bool = field(default=True, metadata={'help': 'Whether or not to use sampling ; use greedy decoding otherwise.'})
    top_k: int | None = field(default=None, metadata={'help': 'The number of highest probability vocabulary tokens to keep for top-k-filtering.'})
    top_p: float | None = field(default=1.0, metadata={'help': 'Only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.'})
    temperature: float | None = field(default=1.0, metadata={'help': 'The value used to modulate the next token probabilities.'})
    num_return_sequences: int = field(default=1, metadata={'help': 'The number of independently computed returned sequences for each element in the batch'})


@dataclass
class WARPArgs:
    num_iterations: int = field(default=2, metadata={'help': '(I) Number of WARP iterations.'})
    num_policies: int = field(default=2, metadata={'help': '(M) Number of policies to train in parallel.'})
    kl_coef: float = field(default=0.1, metadata={'help': '(beta) Weight of KL for KL-regularized reward.'})
    ema_rate: float = field(default=0.01, metadata={'help': '(mu) EMA rate for reference policies.'})
    slerp_rate: float = field(default=0.5, metadata={'help': '(lambda) Weight to SLERP trained policies. Ignored when M > 2 (lambda = 1/M).'})
    liti_rate: float = field(default=0.5, metadata={'help': '(eta) LITI rate for initial policy.'})


Overwriting configs.py


In [4]:
%%writefile utils.py

import os
import re
import torch
import wandb
from wandb.sdk.wandb_run import Run


def compute_angle(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor:
    cos = torch.sum(v1 * v2) / torch.norm(v1) / torch.norm(v2)
    return torch.acos(cos)


def get_latest_checkpoint(dir: str | os.PathLike) -> str:
    pattern = r'checkpoint-\d+'
    checkpoints = [file for file in os.scandir(dir) if file.is_dir() and re.fullmatch(pattern, file.name)]
    if not checkpoints:
        return dir

    latest_checkpoint = max(int(checkpoint.name.split('-')[1]) for checkpoint in checkpoints)
    return os.path.join(dir, f'checkpoint-{latest_checkpoint}')


def print_wandb_run(run: Run):
    wandb_path = '/'.join(run.dir.split('/')[:-2])
    group_url = f'{run.get_project_url()}/groups/{run.group}'

    info = [
        f'Currently logged in as: {run.entity}',
        f'Tracking run with wandb version {wandb.__version__}',
        f'Run data is saved locally in {wandb_path}',
        f'View project at {run.get_project_url()}',
        f'View runs at {group_url}'
    ]

    print(*info, sep='\n')


def is_lora_layer(name: str) -> bool:
    return name.find('lora_A') != -1 or name.find('lora_B') != -1


Writing utils.py


In [5]:
%%writefile dataset.py

import typing as tp
import torch
import trl
import datasets
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer


class RewardDatasetItem(tp.TypedDict):
    input_ids_chosen: list[int]
    attention_mask_chosen: list[int]
    input_ids_rejected: list[int]
    attention_mask_rejected: list[int]


class DatasetItem(tp.TypedDict):
    text: str
    query: str
    input_ids: torch.Tensor


class IMDBRewardDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizer, accepted_label: int, split: str = 'train'):
        super().__init__()

        imdb_dataset = datasets.load_dataset('stanfordnlp/imdb', split=split)
        self.tokenizer = tokenizer
        self.chosen_texts = [row['text'] for row in imdb_dataset if row['label'] == accepted_label]
        self.rejected_texts = [row['text'] for row in imdb_dataset if row['label'] != accepted_label]

        self.n_chosen = len(self.chosen_texts)
        self.n_rejected = len(self.rejected_texts)

    def __len__(self):
        return self.n_chosen * self.n_rejected

    def __getitem__(self, index: int) -> RewardDatasetItem:
        chosen = self.tokenizer(self.chosen_texts[index // self.n_rejected], truncation=True)
        rejected = self.tokenizer(self.rejected_texts[index % self.n_rejected], truncation=True)

        return dict(input_ids_chosen=chosen['input_ids'], attention_mask_chosen=chosen['attention_mask'],
                    input_ids_rejected=rejected['input_ids'], attention_mask_rejected=rejected['attention_mask'])


def select_query_and_tokenize(sample: DatasetItem, tokenizer: PreTrainedTokenizer, length_sampler: trl.core.LengthSampler):
    query_ids = tokenizer.encode(sample['text'])[:length_sampler()]
    sample["query"] = tokenizer.decode(query_ids)
    sample["input_ids"] = query_ids
    return sample


def build_imdb_dataset(tokenizer: PreTrainedTokenizer, min_text_length: int = 200, min_tokens: int = 5, max_tokens: int = 15) -> Dataset[DatasetItem]:
    imdb_dataset = datasets.load_dataset('stanfordnlp/imdb')
    length_sampler = trl.core.LengthSampler(min_tokens, max_tokens)

    imdb_dataset = imdb_dataset.filter(lambda row: len(row['text']) > min_text_length, batched=False)
    # Need to have label column to make compute_metrics work
    # imdb_dataset = imdb_dataset.remove_columns(['label'])
    imdb_dataset = imdb_dataset.map(lambda sample: select_query_and_tokenize(sample, tokenizer, length_sampler), batched=False)
    imdb_dataset.set_format(type='torch')

    return imdb_dataset


Writing dataset.py


In [6]:
%%writefile policy_trainer.py

import torch
import torch.optim.swa_utils as swa
import transformers
from collections import defaultdict
from torch.utils.data import Dataset
from configs import WARPArgs


class EMAStepCallback(transformers.TrainerCallback):
    def __init__(self, model: transformers.PreTrainedModel, ema_model: swa.AveragedModel):
        self._model = model
        self._ema_model = ema_model

    def on_step_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
        self._ema_model.update_parameters(self._model)


class PolicyTrainer(transformers.Trainer):
    EPS = 1e-7
    INVALID_LOGPROB = 0.0
    INVALID_REWARD = -1.0

    def __init__(
        self,
        policy: transformers.PreTrainedModel,
        policy_tokenizer: transformers.PreTrainedTokenizer,
        reward_model: transformers.PreTrainedModel,
        reward_tokenizer: transformers.PreTrainedTokenizer,
        generation_config: transformers.GenerationConfig,
        warp_args: WARPArgs,
        ref_policy: transformers.PreTrainedModel | None = None,
        training_args: transformers.TrainingArguments | None = None,
        train_dataset: Dataset | None = None,
        eval_dataset: Dataset | None = None,
        callbacks: list[transformers.TrainerCallback] | None = None,
    ):
        super().__init__(
            model=policy,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=policy_tokenizer,
            callbacks=callbacks,
            compute_metrics=self._compute_metrics,
        )

        self.reward_model = reward_model
        self.reward_tokenizer = reward_tokenizer
        self.generation_config = generation_config
        self.warp_args = warp_args
        self._metrics = defaultdict(list)

        if ref_policy is None:
            self.ref_policy = swa.AveragedModel(self.model, multi_avg_fn=swa.get_ema_multi_avg_fn(1 - warp_args.ema_rate))
            self.add_callback(EMAStepCallback(self.model, self.ref_policy))
        else:
            self.ref_policy = ref_policy

    def compute_loss(self, model: transformers.PreTrainedModel, inputs: transformers.BatchEncoding, return_outputs: bool = False):
        prompt_length = inputs['input_ids'].shape[1]
        gen_tokens, gen_logprobs = self._generate(model, inputs, prompt_length)

        policy_logprobs = self._forward(model, gen_tokens, prompt_length)
        with torch.no_grad():
            ref_logprobs = self._forward(self.ref_policy, gen_tokens, prompt_length)

        kl = torch.sum(gen_logprobs - ref_logprobs, dim=1)
        reward = self._get_reward(gen_tokens)
        rlhf_reward = reward - self.warp_args.kl_coef * kl

        batch_metrics = {
            'batch_kl': kl.mean().item(),
            'batch_reward': reward.mean().item(),
            'batch_rlhf_reward': rlhf_reward.mean().item()
        }
        self.log(batch_metrics)

        if self.generation_config.num_return_sequences > 1:
            loss = self._rloo_loss(rlhf_reward, policy_logprobs.sum(dim=-1))
        else:
            loss = -torch.mean(rlhf_reward * policy_logprobs.sum(dim=-1))

        return (loss, policy_logprobs) if return_outputs else loss

    def _rloo_loss(self, rlhf_reward: torch.Tensor, policy_logprobs: torch.Tensor) -> torch.Tensor:
        k = self.generation_config.num_return_sequences
        rlhf_reward = rlhf_reward.reshape((-1, k))
        policy_logprobs = policy_logprobs.reshape((-1, k))

        baselines = (rlhf_reward.sum(dim=-1, keepdim=True) - rlhf_reward) / (k - 1)
        return -torch.mean((rlhf_reward - baselines) * policy_logprobs)

    def _compute_metrics(self, pred: transformers.EvalPrediction, compute_result: bool = False) -> dict[str, float]:
        prompt_length = pred.inputs['input_ids'].shape[1]
        gen_tokens, gen_logprobs = self._generate(self.model, pred.inputs, prompt_length)

        with torch.no_grad():
            ref_logprobs = self._forward(self.ref_policy, gen_tokens, prompt_length)

        kl = torch.sum(gen_logprobs - ref_logprobs, dim=1)
        reward = self._get_reward(gen_tokens)

        self._metrics['KL'].extend(kl.cpu().tolist())
        self._metrics['Reward'].extend(reward.cpu().tolist())
        batch_metrics = {'KL': kl.mean().item(), 'Reward': reward.mean().item()}

        if not compute_result:
            return batch_metrics

        final_metrics = {name: sum(values) / len(values) for name, values in self._metrics.items()}
        self._metrics = defaultdict(list)
        return  final_metrics

    def _generate(self, model: transformers.PreTrainedModel, inputs: transformers.BatchEncoding, prompt_length: int) -> tuple[torch.Tensor, torch.Tensor]:
        generate_out = model.generate(
            inputs=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            generation_config=self.generation_config,
            pad_token_id=self.tokenizer.pad_token_id,
            use_cache=False,
            return_dict_in_generate=True,
            output_scores=True,
        )

        generated_tokens = generate_out.sequences
        generated_logprobs = self._process_generate_logits(
            torch.stack(generate_out.scores, dim=1),
            generated_tokens,
            prompt_length,
        )

        return generated_tokens, generated_logprobs

    def _process_generate_logits(self, logits: torch.Tensor, gen_tokens: torch.Tensor, prompt_length: int) -> torch.Tensor:
        generated_pad_mask = gen_tokens[:, prompt_length:] == self.tokenizer.pad_token_id
        all_logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
        logprobs = torch.gather(all_logprobs, 2, gen_tokens[:, prompt_length:].unsqueeze(-1)).squeeze(-1)
        logprobs[generated_pad_mask] = self.INVALID_LOGPROB
        return logprobs

    def _forward(self, model: transformers.PreTrainedModel, gen_tokens: torch.Tensor, prompt_length: int) -> torch.Tensor:
        attention_mask = gen_tokens != self.tokenizer.pad_token_id
        position_ids = attention_mask.cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        model_out = model(input_ids=gen_tokens, attention_mask=attention_mask, position_ids=position_ids)
        logprobs = self._process_forward_logits(model_out.logits, gen_tokens, prompt_length)
        return logprobs

    def _process_forward_logits(self, policy_out: torch.Tensor, gen_tokens: torch.Tensor, prompt_length: int) -> torch.Tensor:
        policy_logits = policy_out[:, prompt_length - 1: -1]
        generated_pad_mask = gen_tokens[:, prompt_length:] == self.tokenizer.pad_token_id

        policy_logits /= (self.generation_config.temperature + self.EPS)
        all_logprobs = torch.nn.functional.log_softmax(policy_logits, dim=-1)
        logprobs = torch.gather(all_logprobs, 2, gen_tokens[:, prompt_length:].unsqueeze(-1)).squeeze(-1)
        logprobs[generated_pad_mask] = self.INVALID_LOGPROB
        return logprobs

    @torch.no_grad()
    def _get_reward(self, generated_tokens: torch.Tensor) -> torch.Tensor:
        finished_sequence_mask = generated_tokens[:, -1] == self.tokenizer.pad_token_id
        text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        reward_tokens = self.reward_tokenizer(text=text, truncation=True, padding=True, return_tensors='pt').to(self.reward_model.device)

        logits = self.reward_model(**reward_tokens).logits
        rewards = torch.nn.functional.softmax(logits, dim=-1)[:, 0]
        rewards[~finished_sequence_mask] = self.INVALID_REWARD

        return rewards


Writing policy_trainer.py


In [7]:
%%writefile warp_trainer.py

import os
import gc
import wandb
import torch
import torch.multiprocessing as tmp
import transformers
import peft
import configs, utils, policy_trainer
from dataclasses import asdict
from tqdm.auto import tqdm
from torch.utils.data import Dataset

class AsyncProgressCallback(transformers.TrainerCallback):
    def __init__(self, queue: tmp.Queue):
        self._queue = queue

    def on_step_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
        self._queue.put(None)


class WARPTrainer:
    def __init__(
        self,
        warp_args: configs.WARPArgs,
        generation_args: configs.GenerationArgs,
        training_args: configs.PolicyTrainArgs,
        lora_args: configs.LoraArgs,
        checkpoints_args: configs.CheckpointsArgs,
        train_dataset: Dataset,
        eval_dataset: Dataset | None = None
    ):
        self.checkpoints_args = checkpoints_args
        self.generation_config = transformers.GenerationConfig.from_dict(asdict(generation_args))
        self.warp_args = warp_args
        self.train_args = training_args
        self.lora_config = peft.LoraConfig(
            task_type=peft.TaskType.CAUSAL_LM,
            r=lora_args.rank,
            lora_alpha=lora_args.lora_alpha,
            lora_dropout=lora_args.lora_dropout,
        )

        self.init_policy_checkpoint = checkpoints_args.sft_checkpoint
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset

        manager = tmp.Manager()
        self._progress_queue = manager.Queue()
        self._device_queue = manager.Queue()
        self._print_event = manager.Event()
        self._eval_run_id = None

        device_count = torch.cuda.device_count()
        for device_idx in range(device_count):
            self._device_queue.put(device_idx)

    def train(self):
        steps_per_iteration = self.warp_args.num_policies * self.train_args.max_steps
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            utils.get_latest_checkpoint(self.init_policy_checkpoint),
            padding_side='left'
        )

        for iter_idx in range(self.warp_args.num_iterations):
            with tmp.Pool(self._device_queue.qsize()) as pool:
                for run_idx in range(self.warp_args.num_policies):
                    pool.apply_async(self._train_policy, (iter_idx, run_idx))

                self._print_event.wait()
                with tqdm(total=steps_per_iteration, desc=f'Iteration {iter_idx + 1}/{self.warp_args.num_iterations}') as progress_bar:
                    for _ in range(steps_per_iteration):
                        self._progress_queue.get()
                        progress_bar.update()

                pool.close()
                pool.join()

            slerp_model = self._slerp(iter_idx)
            slerp_path = self._slerp_path(iter_idx)
            slerp_model.save_pretrained(slerp_path)
            tokenizer.save_pretrained(slerp_path)

            init_model = self._litti(slerp_model)
            self.init_policy_checkpoint = self._init_path(iter_idx + 1)
            init_model.save_pretrained(self.init_policy_checkpoint)
            tokenizer.save_pretrained(self.init_policy_checkpoint)

            del slerp_model, init_model
            gc.collect()
            torch.cuda.empty_cache()
            self._evaluate_slerp_model(iter_idx)

    def _train_policy(self, iter_idx: int, run_idx: int):
        device_idx = self._pop_device()
        device = torch.device('cuda')
        train_args = self._get_train_args(iter_idx, run_idx)

        policy, policy_tokenizer = self._load_policy(self.init_policy_checkpoint, device_map=device)
        reward_model, reward_tokenizer = self._load_reward(device_map=device)
        callbacks = [AsyncProgressCallback(self._progress_queue)]

        with wandb.init(group=self.checkpoints_args.group_name, job_type=train_args.run_name, name=train_args.run_name, project='tk-alignment') as run:
            if iter_idx == 0 and run_idx == 0:
                utils.print_wandb_run(run)
                self._print_event.set()

            trainer = policy_trainer.PolicyTrainer(
                policy,
                policy_tokenizer,
                reward_model,
                reward_tokenizer,
                self.generation_config,
                self.warp_args,
                training_args=train_args,
                train_dataset=self.train_dataset,
                eval_dataset=self.eval_dataset,
                callbacks=callbacks
            )

            trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
            trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
            trainer.train()

        self._device_queue.put(device_idx)

    def _litti(self, slerp_model: transformers.PreTrainedModel) -> transformers.PreTrainedModel:
        init_model, _ = self._load_policy(self.init_policy_checkpoint)
        for (name, init_param), slerp_param in zip(init_model.named_parameters(), slerp_model.parameters()):
            if utils.is_lora_layer(name):
                init_param = init_param.detach()
                init_param += self.warp_args.liti_rate * (slerp_param - init_param)
        return init_model

    def _slerp(self, iter_idx: int) -> transformers.PreTrainedModel:
        if self.warp_args.num_policies == 1:
            return self._load_policy(self._policy_path(iter_idx, 0))

        if self.warp_args.num_policies == 2:
            return self._slerp_first_two_policies(iter_idx)

        # merge M > 2 policies
        init_model, _ = self._load_policy(self.init_policy_checkpoint)
        slerp_model = self._slerp_first_two_policies(iter_idx)

        for run_idx in range(2, self.warp_args.num_policies):
            policy, _ = self._load_policy(self._policy_path(iter_idx, run_idx))
            params_iter = zip(init_model.named_parameters(), slerp_model.parameters(), policy.parameters())

            for (name, init_param), slerp_param, policy_param in params_iter:
                if not utils.is_lora_layer(name):
                    continue

                slerp_param = slerp_param.detach()
                task_vector_1 = slerp_param - init_param
                task_vector_2 = policy_param - init_param

                angle = utils.compute_angle(task_vector_1, task_vector_2)
                mult = torch.sin(angle / self.warp_args.num_policies) / torch.sin(angle)

                slerp_param *= mult
                slerp_param += (1 - mult) * init_param + mult * task_vector_2

        return slerp_model

    def _slerp_first_two_policies(self, iter_idx: int) -> transformers.PreTrainedModel:
        slerp_model, _ = self._load_policy(self.init_policy_checkpoint)
        policy_1, _ = self._load_policy(self._policy_path(iter_idx, 0))
        policy_2, _ = self._load_policy(self._policy_path(iter_idx, 1))

        params_iter = zip(slerp_model.named_parameters(), policy_1.parameters(), policy_2.parameters())

        for (name, slerp_param), policy1_param, policy2_param in params_iter:
            if not utils.is_lora_layer(name):
                continue

            slerp_param = slerp_param.detach()
            task_vector_1 = policy1_param - slerp_param
            task_vector_2 = policy2_param - slerp_param

            angle = utils.compute_angle(task_vector_1, task_vector_2)
            slerp_param += torch.sin((1 - self.warp_args.slerp_rate) * angle) / torch.sin(angle) * task_vector_1
            slerp_param += torch.sin(self.warp_args.slerp_rate * angle) / torch.sin(angle) * task_vector_2
    
        return slerp_model

    def _evaluate_slerp_model(self, iter_idx: int):
        device_idx = self._pop_device()
        device = torch.device('cuda')
        group_name = self.checkpoints_args.group_name
        eval_args = self._get_eval_args(iter_idx)

        slerp_model, slerp_tokenizer = self._load_policy(self._slerp_path(iter_idx), is_trainable=False, device_map=device)
        slerp_model.eval()
        sft_model, _ = self._load_policy(self.checkpoints_args.sft_checkpoint, is_trainable=False, device_map=device)
        sft_model.eval()
        reward_model, reward_tokenizer = self._load_reward(device_map=device)

        with wandb.init(group=group_name, job_type=eval_args.run_name, name=eval_args.run_name, id=self._eval_run_id, resume='allow', project='tk-alignment') as run:
            self._eval_run_id = self._eval_run_id or run.id

            trainer = policy_trainer.PolicyTrainer(
                slerp_model,
                slerp_tokenizer,
                reward_model,
                reward_tokenizer,
                self.generation_config,
                self.warp_args,
                sft_model,
                eval_args,
                eval_dataset=self.eval_dataset,
            )

            trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
            trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
            trainer.evaluate()

            self._device_queue.put(device_idx)

    def _pop_device(self) -> int:
        device_idx = self._device_queue.get()
        os.environ['CUDA_VISIBLE_DEVICES'] = str(device_idx)
        return device_idx

    def _get_train_args(self, iter_idx: int, run_idx: int) -> transformers.TrainingArguments:
        output_dir = self._policy_path(iter_idx, run_idx)
        run_name = f'iter_{iter_idx}_run_{run_idx}'
        seed = iter_idx * self.warp_args.num_policies + run_idx

        return transformers.TrainingArguments(
            output_dir=output_dir,
            run_name=run_name,
            seed=seed,
            gradient_checkpointing_kwargs={"use_reentrant": False} if self.train_args.gradient_checkpointing else None,
            num_train_epochs=0,
            **asdict(self.train_args),
        )

    def _get_eval_args(self, iter_idx: int) -> transformers.TrainingArguments:
        output_dir = self._slerp_path(iter_idx)
        return transformers.TrainingArguments(
            output_dir=output_dir,
            run_name='eval_slerp',
            batch_eval_metrics=True,
            include_inputs_for_metrics=True,
            num_train_epochs=0,
            seed=iter_idx,
            **asdict(self.train_args),
        )

    def _load_policy(self, path: str, is_trainable: bool = True, device_map: torch.device | None = None) -> peft.peft_model.PeftModelForCausalLM:
        path = utils.get_latest_checkpoint(path)
        policy = peft.AutoPeftModelForCausalLM.from_pretrained(path, is_trainable=is_trainable, device_map=device_map)
        for module in policy.modules():
            if isinstance(module, torch.nn.Dropout):
                module.p = 0

        policy.config.use_cache = not self.train_args.gradient_checkpointing
        tokenizer = transformers.AutoTokenizer.from_pretrained(path, padding_side='left')
        return policy, tokenizer

    def _load_reward(self, device_map: torch.device | None = None) -> transformers.PreTrainedModel:
        path = utils.get_latest_checkpoint(self.checkpoints_args.reward_checkpoint)
        reward_model = transformers.AutoModelForSequenceClassification.from_pretrained(path, device_map=device_map)
        reward_model.eval()
        tokenizer = transformers.AutoTokenizer.from_pretrained(path, padding_side='left')
        return reward_model, tokenizer

    def _policy_path(self, iter_idx: int, run_idx: int):
        return os.path.join(self.checkpoints_args.save_folder, f'iter_{iter_idx}_run_{run_idx}')

    def _slerp_path(self, iter_idx: int):
        return os.path.join(self.checkpoints_args.save_folder, f'iter_{iter_idx}_slerp')

    def _init_path(self, iter_idx: int):
        return os.path.join(self.checkpoints_args.save_folder, f'iter_{iter_idx}_init')




Writing warp_trainer.py


# Load sft/reward checkpoints

In [8]:
!gdown 1TgG3MllcM-N0BZja8JaeeVQBbVrrMyFE
!unzip checkpoints.zip

Downloading...
From (original): https://drive.google.com/uc?id=1TgG3MllcM-N0BZja8JaeeVQBbVrrMyFE
From (redirected): https://drive.google.com/uc?id=1TgG3MllcM-N0BZja8JaeeVQBbVrrMyFE&confirm=t&uuid=70d4e5aa-924e-414a-92f4-dcccfc69ee79
To: /kaggle/working/checkpoints.zip
100%|████████████████████████████████████████| 723M/723M [00:10<00:00, 66.8MB/s]
Archive:  checkpoints.zip
   creating: checkpoints/
   creating: checkpoints/reward/
  inflating: checkpoints/reward/special_tokens_map.json  
  inflating: checkpoints/reward/training_args.bin  
  inflating: checkpoints/reward/tokenizer.json  
  inflating: checkpoints/reward/rng_state.pth  
  inflating: checkpoints/reward/scheduler.pt  
  inflating: checkpoints/reward/optimizer.pt  
  inflating: checkpoints/reward/config.json  
  inflating: checkpoints/reward/trainer_state.json  
  inflating: checkpoints/reward/vocab.txt  
  inflating: checkpoints/reward/tokenizer_config.json  
  inflating: checkpoints/reward/model.safetensors  
   creating: 

# Train

In [14]:
%%writefile main.py

import os
import transformers
import torch.multiprocessing as tmp
import configs, dataset, warp_trainer, utils
from torch.utils.data import Subset


if __name__ == '__main__':
#     os.environ['WANDB_MODE'] = 'offline'
    os.environ['TOKENIZERS_PARALLELISM'] = 'false'
    os.environ['WANDB_SILENT'] = 'true'
    tmp.set_start_method('spawn')

    warp_args = configs.WARPArgs()
    dataset_args = configs.DatasetArgs()
    generation_args = configs.GenerationArgs(max_new_tokens=128, num_return_sequences=4)
    train_args = configs.PolicyTrainArgs(warmup_steps=20, learning_rate=1e-4, per_device_train_batch_size=16, per_device_eval_batch_size=4)
    lora_args = configs.LoraArgs()
    checkpoints_args = configs.CheckpointsArgs(sft_checkpoint='checkpoints/sft', reward_checkpoint='checkpoints/reward', group_name='warp_test_rloo')

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        utils.get_latest_checkpoint(checkpoints_args.sft_checkpoint)
    )
    imdb = dataset.build_imdb_dataset(
        tokenizer,
        min_text_length=dataset_args.min_text_length,
        min_tokens=dataset_args.min_tokens,
        max_tokens=dataset_args.max_tokens,
    )

    train_dataset = imdb['train']
    test_dataset = Subset(imdb['test'], list(range(dataset_args.eval_size)))

    warp_trainer = warp_trainer.WARPTrainer(
        warp_args,
        generation_args,
        train_args,
        lora_args,
        checkpoints_args,
        train_dataset,
        test_dataset
    )

    warp_trainer.train()


Overwriting main.py


In [11]:
import wandb

wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [15]:
!python -W ignore::UserWarning: main.py

2024-08-05 09:14:35.309843: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-05 09:14:35.309897: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-05 09:14:35.311324: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Map:   0%|                           | 80/24872 [00:00<00:32, 772.05 examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1300 > 1024). Running this sequence through the model will result in indexing errors
Map: 100%|████████████████████████| 24872/24872 [00:30<00:00, 827.67 examples/s]
2024-08-05 09:15:17.

In [None]:
!zip -r rloo_1.zip warp/iter_1_slerp

In [None]:
from IPython.display import FileLink

FileLink('rloo_1.zip')