In [2]:
!ls -al /content/drive/MyDrive/Simple_RLHF/model/actor


total 5139760
-rw------- 1 root root        749 Mar  9 01:55 config.json
-rw------- 1 root root        132 Mar  9 01:55 generation_config.json
-rw------- 1 root root 4994509120 Mar  9 01:55 model-00001-of-00002.safetensors
-rw------- 1 root root  268568856 Mar  9 01:56 model-00002-of-00002.safetensors
-rw------- 1 root root      33832 Mar  9 01:56 model.safetensors.index.json


In [3]:
import torch
import sys
sys.path.append('/content/drive/MyDrive/Simple_RLHF')

from google.colab import drive

drive.mount('/content/drive')
from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, attention_mask = tokenizer.encode('how are you', max_length=4)

input_ids, attention_mask, tokenizer.decode(input_ids)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


(tensor([   0, 9178,   32,    2]), tensor([1, 1, 1, 1]), '<s>how are</s>')

In [4]:
!pip install datasets
from datasets import load_dataset

dataset = load_dataset('json', data_files='/content/drive/MyDrive/Simple_RLHF/dataset/train.json', split='train')

#2,4,4切分,取第1部分
dataset = dataset.select(range(15000, 15050))


def f(data):
    #区分两种生成结果
    chosen = data['prompt'] + data['chosen'].swapcase()
    rejected = data['prompt'] + data['chosen']

    chosen_input_ids, chosen_attention_mask = tokenizer.encode(chosen)
    rejected_input_ids, rejected_attention_mask = tokenizer.encode(rejected)

    return {
        'chosen_input_ids': chosen_input_ids,
        'chosen_attention_mask': chosen_attention_mask,
        'rejected_input_ids': rejected_input_ids,
        'rejected_attention_mask': rejected_attention_mask
    }


dataset = dataset.map(f)
dataset.set_format('torch')


def f(data):
    chosen_input_ids = [i['chosen_input_ids'] for i in data]
    chosen_attention_mask = [i['chosen_attention_mask'] for i in data]
    rejected_input_ids = [i['rejected_input_ids'] for i in data]
    rejected_attention_mask = [i['rejected_attention_mask'] for i in data]

    input_ids = torch.stack(chosen_input_ids + rejected_input_ids, dim=0)
    attention_mask = torch.stack(chosen_attention_mask +
                                 rejected_attention_mask,
                                 dim=0)

    return {'input_ids': input_ids, 'attention_mask': attention_mask}


loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=f,
                                     batch_size=1,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader))



Map:   0%|          | 0/50 [00:00<?, ? examples/s]

(50,
 {'input_ids': tensor([[    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1]]),
  'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]])})

In [5]:
from costom_lora import count_params


class CriticModel(torch.nn.Module):

    def __init__(self):
        super().__init__()

        from transformers import AutoModel
        self.rwtransformer = AutoModel.from_pretrained('facebook/opt-350m',
                                                       dropout=0.0)

        self.v_head = torch.nn.Linear(512, 1, bias=False)

    def forward(self, input_ids, attention_mask):
        value = self.rwtransformer(
            input_ids=input_ids,
            attention_mask=attention_mask).last_hidden_state
        value = self.v_head(value).squeeze(-1)

        loss_sum = 0.0
        value_chosen_sum = 0.0
        value_rejected_sum = 0.0
        for input_ids_chosen, input_ids_rejected, value_chosen, value_rejected in zip(
                input_ids[:1], input_ids[1:], value[:1], value[1:]):

            #找出每条回答中的起止索引
            start = (
                input_ids_chosen == input_ids_rejected).tolist().index(False)

            end_chosen = input_ids_chosen.tolist().index(
                tokenizer.eos_token_id) + 1
            end_rejected = input_ids_rejected.tolist().index(
                tokenizer.eos_token_id) + 1
            end = max(end_chosen, end_rejected)

            value_chosen = value_chosen[start:end]
            value_rejected = value_rejected[start:end]

            loss = value_chosen - value_rejected
            loss = -torch.nn.functional.logsigmoid(loss).mean()

            loss_sum += loss
            value_chosen_sum += value_chosen.mean().item()
            value_rejected_sum += value_rejected.mean().item()

        return loss_sum / 1, value_chosen_sum, value_rejected_sum


model_critic = CriticModel()

count_params(model_critic)

{'count_require': 3.31196928, 'count_all': 3.31196928, 'ratio': 1.0}


In [6]:
from transformers import get_scheduler
from accelerate import Accelerator


def f():
    params_decay = []
    params = []
    for name, param in model_critic.named_parameters():
        if 'bias' in name or 'norm.weight' in name:
            params.append(param)
            continue
        params_decay.append(param)

    return [{
        'params': params_decay,
        'weight_decay': 0.1
    }, {
        'params': params,
        'weight_decay': 0.0
    }]


optimizer = torch.optim.Adam(f(), lr=5e-5, betas=(0.9, 0.95))

scheduler = get_scheduler(name='cosine',
                          optimizer=optimizer,
                          num_warmup_steps=0,
                          num_training_steps=500)

accelerator = Accelerator(gradient_accumulation_steps=16,
                          mixed_precision='fp16')

model_critic, loader, optimizer, scheduler = accelerator.prepare(
    model_critic, loader, optimizer, scheduler)

model_critic.train()

CriticModel(
  (rwtransformer): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 512, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
      (project_out): Linear(in_features=1024, out_features=512, bias=False)
      (project_in): Linear(in_features=512, out_features=1024, bias=False)
      (layers): ModuleList(
        (0-23): 24 x OPTDecoderLayer(
          (self_attn): OPTSdpaAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_fea

In [7]:

for i, data in enumerate(loader):
    print(i)
    with accelerator.accumulate(model_critic):
        loss, value_chosen_sum, value_rejected_sum = model_critic(**data)
        accelerator.backward(loss)

        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model_critic.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    if (i + 1) % 100 == 0:
        lr = optimizer.param_groups[0]['lr']
        print(i, len(loader), loss.item(), lr, value_chosen_sum,
              value_rejected_sum)

    if i == 2000:
        break

torch.save(model_critic.to('cpu'), '/content/drive/MyDrive/Simple_RLHF/model/critic')

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


In [14]:
!ls -al /content/drive/MyDrive/Simple_RLHF/model/

total 1293911
drwx------ 2 root root       4096 Mar  9 01:55 actor
-rw------- 1 root root 1324960543 Mar  9 13:16 critic
