In [3]:
from dataclasses import dataclass

In [4]:
def pad(ids, pad_id, max_length):
    if len(ids) > max_length:
        return ids[:max_length]
    return ids + [pad_id] * (max_length - len(ids))


prompt_prefix = ""
prompt_without_output = "<human>:{prompt}\n<bot>:"

@dataclass
class LlamaSFTCollator:
    '''
    由input处理成samples，也就是最终模型的输入
    其中主要处理逻辑在__call__里
    '''
    tokenizer: None  # 分词
    max_seq_length: 1536
    def __call__(self, samples):
        input_ids_list = []
        labels_list = []
        max_length = 0
        for s in samples:
            """
            sample: {
                "task" : str,
                "prompt": [str]
                "output": [str]
                }
            """
            prompt_cnt = min(len(s["prompt"]), len(s["output"]))
            # input_ids = self.tokenizer(prompt_prefix).input_ids
            input_ids = []
            labels_ids = [-100] * len(input_ids)
            for i in range(prompt_cnt):
                prompt_input_ids = self.tokenizer(prompt_without_output.format_map(
                    {"prompt": s["prompt"][i].strip()}), add_special_tokens=False).input_ids
                output_ids = self.tokenizer(s["output"][i].strip(), add_special_tokens=False).input_ids + [self.tokenizer.eos_token_id]
                
                input_ids += prompt_input_ids
                input_ids += output_ids
                
                labels_ids += [-100] * (len(prompt_input_ids)) + output_ids
            
            # input_ids += [self.tokenizer.eos_token_id]
            # labels_ids += [self.tokenizer.eos_token_id]
            max_length = min(max(len(input_ids), max_length), self.max_seq_length)
            input_ids_list.append(input_ids)
            labels_list.append(labels_ids)

        # PAD
        for i in range(len(input_ids_list)):
            labels_list[i] = pad(labels_list[i], -100, max_length)
            input_ids_list[i] = pad(input_ids_list[i], self.tokenizer.eos_token_id, max_length)
        model_inputs = {
            'input_ids': torch.tensor(input_ids_list).clone(),
            'attention_mask': torch.ones((len(input_ids_list), max_length)).clone(),
            "position_ids": torch.arange(0, max_length).unsqueeze(0).expand(len(input_ids_list), max_length).clone(),
            'labels': torch.tensor(labels_list).clone(),
        }
        return model_inputs

In [5]:
import pytorch_lightning as pl

class Llama(pl.LightningModule):

    def __init__(self, args, model,tokenizer):
        super().__init__()
        self.save_hyperparameters(args)
        self.tokenizer = tokenizer
        self.model = model

    def setup(self, stage) -> None:
        
        if stage == 'fit':
            self.total_steps = get_total_steps(self.trainer, self.hparams)
            print('Total steps: {}'.format(self.total_steps))


    def configure_optimizers(self):
        return configure_optimizers(self)

    def forward(self, **batch):
        return self.model(**batch)

    def detokenize(self, token_ids):
        tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
        return self.tokenizer.convert_tokens_to_string(tokens)

    def comput_metrix(self, logits, labels):
        with torch.no_grad():
            y_pred = torch.argmax(logits, dim=-1)
            y_pred = y_pred.view(size=(-1,))
            y_true = labels.view(size=(-1,)).float()
            corr = torch.eq(y_pred, y_true)
            acc = torch.sum(corr.float())/labels.shape[0]
        return acc

    def training_step(self, batch, batch_idx):
        if self.trainer.global_rank == 0:
            global SHOW_DATA
            if not SHOW_DATA:
                SHOW_DATA = True
                print('source: {}'.format(batch['input_ids'][0]))
                print('target: {}'.format(batch['labels'][0]))
                print('source: {}'.format(self.detokenize(batch['input_ids'][0])))
                label_idx = batch['labels'][0] != -100
                print('target: {}'.format(self.detokenize(
                    batch['labels'][0][label_idx])))
                print('mask: {}'.format(batch['attention_mask'][0]))
                print('position_ids: {}'.format(batch['position_ids'][0]))
        output = self(**batch)
        self.log('train/loss', output.loss, sync_dist=True)
        return output.loss

    def validation_step(self, batch, batch_idx):
        output = self(**batch)
        self.log('val_loss', output.loss, sync_dist=True)
        return output.loss

    def predict_step(self, batch, batch_idx):
        # generate data
        generate_kwargs = {
        	"do_sample": True,
        	"top_p": 1.0,   
        	"top_k": 0,
        	"max_length": 256,
        	"repetition_penalty": 1.0,
        	"temperature": 0.8,
        	"pad_token_id": self.tokenizer.eos_token_id,
        	"eos_token_id": self.tokenizer.eos_token_id,
        }
        batch_input_ids = batch['input_ids'].cpu().numpy().tolist()
        print('batch_input_ids:\n', batch_input_ids)
        queries = [self.detokenize(input_ids).split('<bot>:')[0].replace('<s>', '')+'<bot>:' for input_ids in batch_input_ids]
        print('queries:\n', queries)
        # queries = ['怎样给世界一点爱？', '生命的意义是什么？']
        ans = generate(queries=queries,
                tokenizer=self.tokenizer,
                model=self.model,
                device=self.model.device,
                **generate_kwargs)
        print('ans:\n', ans)
        ## end

    def on_load_checkpoint(self, checkpoint) -> None:
        if 'global_samples' in checkpoint:
            self.consumed_samples = checkpoint['global_samples']

In [6]:
max_seq_length=1024

In [7]:
from transformers import LlamaTokenizer
cache_dir = '/root/autodl-tmp/ziya'
model_name = 'IDEA-CCNL/Ziya-LLaMA-13B-v1.1'

tokenizer = LlamaTokenizer.from_pretrained(model_name,cache_dir = cache_dir, use_fast=False)
collate_fn = LlamaSFTCollator(
        tokenizer=tokenizer,
        max_seq_length=max_seq_length,
    )

In [9]:
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(model_name,cache_dir = cache_dir)

Downloading shards:   0%|          | 0/28 [00:00<?, ?it/s]

Downloading (…)l-00009-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00010-of-00028.bin:   0%|          | 0.00/986M [00:00<?, ?B/s]

Downloading (…)l-00011-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00012-of-00028.bin:   0%|          | 0.00/986M [00:00<?, ?B/s]

Downloading (…)l-00013-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00014-of-00028.bin:   0%|          | 0.00/986M [00:00<?, ?B/s]

Downloading (…)l-00015-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00016-of-00028.bin:   0%|          | 0.00/986M [00:00<?, ?B/s]

Downloading (…)l-00017-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00018-of-00028.bin:   0%|          | 0.00/986M [00:00<?, ?B/s]

Downloading (…)l-00019-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00020-of-00028.bin:   0%|          | 0.00/986M [00:00<?, ?B/s]

Downloading (…)l-00021-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00022-of-00028.bin:   0%|          | 0.00/986M [00:00<?, ?B/s]

Downloading (…)l-00023-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00024-of-00028.bin:   0%|          | 0.00/986M [00:00<?, ?B/s]

Downloading (…)l-00025-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00026-of-00028.bin:   0%|          | 0.00/986M [00:00<?, ?B/s]

Downloading (…)l-00027-of-00028.bin:   0%|          | 0.00/918M [00:00<?, ?B/s]

Downloading (…)l-00028-of-00028.bin:   0%|          | 0.00/545M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/28 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [1]:
from fengshen.data.universal_datamodule import UniversalDataModule
data_module = UniversalDataModule(tokenizer=tokenizer, args=args, collate_fn=collate_fn)
print('data load complete')
model = Llama(args,model, tokenizer=tokenizer)
print('model load complete')
print(model)

ModuleNotFoundError: No module named 'fengshen.models.megatron'

In [15]:
import fengshen
print(dir(fengshen.data))

['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__']
