In [1]:


import traceback
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.utils import pad_across_processes, broadcast
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from datasets import load_dataset, load_from_disk, DatasetDict, Dataset, concatenate_datasets
from datetime import datetime,  timedelta
import time
from functools import partial
import json
import os
import random
from src.python_engine import run_python_code
from src.utils import set_seed, floatify, compute_ETA
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup, AdamW, get_constant_schedule_with_warmup
import wandb
import pandas as pd
import shutil
import signal
from contextlib import contextmanager
from torch.utils.data.distributed import DistributedSampler
from accelerate import notebook_launcher
from accelerate.utils import DeepSpeedPlugin



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# data
tqdm = partial(tqdm, ncols=0, leave=False)


TIMEOUT = 10
instruction=None
cot_trigger=None
answer_trigger=None
def setup_cot(src_name):
    assert src_name in ['gsm8k', 'mathqa', 'svamp', 'mathqa-numeric', 'zhouyi']
    global instruction
    global cot_trigger
    global answer_trigger
    # Complete output is in this form: f'{instruction}{question.strip()}{cot_trigger}{answer_cot.strip()}'
    instruction = 'Question:\n'
    cot_trigger = '\nAnswer reasoning:\n'
    answer_trigger = '\n因此，答案是：'
    return 

post_process_final_answer_fn_mapper = {
    'gsm8k': lambda x: float(x.replace(',','').strip()),
    'svamp': lambda x: float(x.replace(',','').strip()),
    'mathqa': lambda x: x.lower().replace('"','').replace("'",'').strip(),
    'mathqa-numeric': lambda x: float(x),
    'zhouyi': lambda x: x.strip(),
}
### the answer_cot is a list of answer_cot
post_process_answer_cot_fn_mapper = {
    ('python', 'gsm8k'): lambda answer_cot: [floatify(res) for res in run_python_code(programs=answer_cot, TIMEOUT=TIMEOUT)],
    ('python', 'svamp'): lambda answer_cot: [floatify(res) for res in run_python_code(programs=answer_cot, TIMEOUT=TIMEOUT)],
    ('python', 'mathqa'): lambda answer_cot: [str(res).lower().replace('"','').replace("'",'').strip() for res in run_python_code(programs=answer_cot, TIMEOUT=TIMEOUT)],
    ('python', 'mathqa-numeric'): lambda answer_cot: [floatify(res) for res in run_python_code(programs=answer_cot, TIMEOUT=TIMEOUT)],
    ('nl', 'gsm8k'): lambda answer_cot: [floatify(res.split(answer_trigger)[-1].strip()) for res in answer_cot],
    ('nl', 'svamp'): lambda answer_cot: [floatify(res.split(answer_trigger)[-1].strip()) for res in answer_cot],
    ('nl', 'mathqa'): lambda answer_cot: [res.split(answer_trigger)[-1].lower().replace('"','').replace("'",'').strip() for res in answer_cot],
    ('nl', 'mathqa-numeric'): lambda answer_cot: [floatify(res.split(answer_trigger)[-1].strip()) for res in answer_cot],
    ('nl', 'zhouyi'): lambda answer_cot: [res.split(answer_trigger)[-1].strip() for res in answer_cot],
}
compare_answer_fn_mapper = {
    'gsm8k': lambda extracted_ans, target_answer: abs(extracted_ans - target_answer) <= 1e-2,
    'svamp': lambda extracted_ans, target_answer: abs(extracted_ans - target_answer) <= 1e-2,
    'mathqa': lambda extracted_ans, target_answer: extracted_ans == target_answer,
    'mathqa-numeric': lambda extracted_ans, target_answer: abs(extracted_ans - target_answer) <= 1e-2,
    'zhouyi': lambda extracted_ans, target_answer: extracted_ans == target_answer,
}


In [3]:

# trainer
# 方式2：使用类装饰器
def with_accelerator(cls):
    def wrapper(acc):
        global accelerator
        accelerator = acc
        return cls()
    return wrapper

@with_accelerator
class Trainer:
    def __init__(self):
        pass
        
    def prepare_datasets_and_data_loaders(self, args, tokenizer):
        # 确保所有进程同步开始准备数据
        # accelerator.wait_for_everyone()
        
        with accelerator.main_process_first():
            raw_dataset = DatasetDict({
                'train': Dataset.from_list(json.load(open(args['train_file'],'r'))),
                'test': Dataset.from_list(json.load(open(args['test_file'],'r'))),
            })
            accelerator.print('Raw data:', raw_dataset)
            src_name = raw_dataset['train'][0]['item_id'].split('_')[0]  # e.g., gsm8k_0, gsm8k_1, gsm8k_2, ...
            setup_cot(src_name)
            accelerator.print('Using instruction:', instruction)
            accelerator.print('Using cot_trigger:', cot_trigger)
            accelerator.print('Using answer_trigger:', answer_trigger)
            def tokenize_fn(batch, args, tokenizer):
                assert tokenizer.eos_token_id is not None, (tokenizer.eos_token_id, tokenizer.eos_token)
                new_batch = defaultdict(list)
                all_keys = list(batch.keys())
                for item_values in zip(*(batch[k] for k in all_keys)):
                    item = {k: item_values[i] for i, k in enumerate(all_keys)}
                    item_id, question, answer_value, answer_cot = \
                            item['item_id'], \
                            item['question'], \
                            item['answer_value'], \
                            item.get('answer_cot', None), \

                    question = question.strip()
                    if answer_value is not None:
                        answer_value = answer_value.strip()

                    if answer_cot is not None:
                        answer_cot = answer_cot.strip()
                        if args['engine'] == 'nl':
                            answer_cot += f'{answer_trigger}{answer_value}'

                    input = f'{instruction}{question}{cot_trigger}'
                    output = f'{answer_cot}'
                    prefix_text = f'{instruction}{question}{cot_trigger}'

                    input_encode = tokenizer(input, add_special_tokens=False)
                    output_encode = tokenizer(output, add_special_tokens=False)
                    prefix_encode = tokenizer(prefix_text, add_special_tokens=False)

                    input_ids = input_encode['input_ids'] + output_encode['input_ids'] + [tokenizer.eos_token_id]
                    labels = [-100]*len(input_encode['input_ids']) + output_encode['input_ids'] + [tokenizer.eos_token_id]
                    attention_mask = [1]* len(input_ids)
                    prefix = prefix_encode['input_ids']
                    prefix_attention_mask = prefix_encode['attention_mask']

                    # Truncation
                    input_ids_max_length = len(input_ids)
                    # assert input_ids_max_length <= args['max_input_length'], input_ids_max_length
                    input_ids = input_ids[:args['max_input_length']]
                    labels = labels[:args['max_input_length']]
                    attention_mask = attention_mask[:args['max_input_length']]
                    prefix = prefix[:args['max_input_length']]
                    prefix_attention_mask = prefix_attention_mask[:args['max_input_length']]

                    ##
                    new_batch['input_ids'].append(input_ids)
                    new_batch['labels'].append(labels)
                    new_batch['attention_mask'].append(attention_mask)
                    new_batch['prefix'].append(prefix)
                    new_batch['prefix_attention_mask'].append(prefix_attention_mask)
                    ##
                    new_batch['item_id'].append(item_id)
                    new_batch['question'].append(question)
                    new_batch['answer_cot'].append(answer_cot)
                    new_batch['answer_value'].append(answer_value)
                    new_batch['input_ids_max_length'].append(input_ids_max_length)
                
                return new_batch

            tokenized_dataset = DatasetDict({
                mode: dataset.map(
                    tokenize_fn, fn_kwargs={'args': args, 'tokenizer': tokenizer}, batched=True, remove_columns=dataset.column_names, 
                    num_proc=8, load_from_cache_file=False
                ) for mode, dataset in raw_dataset.items()})
            accelerator.print('Processed data:', tokenized_dataset)
            for mode, dataset in tokenized_dataset.items():
                accelerator.print(mode, f'{mode}_input_ids_max_length', max(dataset['input_ids_max_length']))

            if accelerator.is_main_process and args['wandb_log']:
                wandb.config.update({
                    "src_name": src_name,
                    "instruction": instruction,
                    "cot_trigger": cot_trigger,
                    "answer_trigger": answer_trigger,
                    "raw_dataset": str(raw_dataset),
                    "tokenized_dataset": str(tokenized_dataset),
                    "train_input_ids_max_length": max(tokenized_dataset['train']['input_ids_max_length']),
                    "test_input_ids_max_length": max(tokenized_dataset['test']['input_ids_max_length']),
                })

        def collate_fn(batch, args, tokenizer):
            max_input_length = max([len(item['input_ids']) for item in batch])
            max_target_length = max([len(item['labels']) for item in batch])
            max_prefix_length = max([len(item['prefix']) for item in batch])
            input_ids  = []
            attention_mask  = []
            labels, labels_left_padded  = [], []
            prefix_left_padded  = []
            prefix_attention_mask_left_padded  = []
            for item in batch:
                input_ids.append(item['input_ids'] + [tokenizer.pad_token_id]*(max_input_length - len(item['input_ids'])))
                attention_mask.append(item['attention_mask'] + [0]*(max_input_length - len(item['attention_mask'])))
                labels.append(item['labels'] + [-100]*(max_target_length - len(item['labels'])))

                labels_left_padded.append([-100]*(max_target_length - len(item['labels'])) + item['labels'])
                prefix_left_padded.append([tokenizer.pad_token_id]*(max_prefix_length - len(item['prefix'])) + item['prefix'])
                prefix_attention_mask_left_padded.append([0]*(max_prefix_length - len(item['prefix_attention_mask'])) + item['prefix_attention_mask'])
            forward_kwargs = {
                'input_ids': torch.LongTensor(input_ids),
                'attention_mask': torch.BoolTensor(attention_mask),
                'labels': torch.LongTensor(labels)
            }
            generate_prefix_kwargs = {
                'input_ids': torch.LongTensor(prefix_left_padded),
                'attention_mask': torch.BoolTensor(prefix_attention_mask_left_padded),
                'labels': torch.LongTensor(labels_left_padded)
            }
            return {
                'forward_kwargs': forward_kwargs,
                'generate_prefix_kwargs': generate_prefix_kwargs,
            }

        train_dataloader = DataLoader(
            tokenized_dataset['train'], 
            # sampler=train_sampler,  # 使用分布式采样器替代shuffle
            batch_size=args['batch_size'], 
            num_workers=args['num_workers'], 
            pin_memory=True, 
            collate_fn=partial(collate_fn, args=args, tokenizer=tokenizer),
            drop_last=True
        )
                            
        test_dataloader = DataLoader(tokenized_dataset['test'], shuffle=False, batch_size=args['eval_batch_size'], num_workers=args['num_workers'], pin_memory=True, 
                            collate_fn=partial(collate_fn, args=args, tokenizer=tokenizer))
        
        # accelerator.wait_for_everyone()                    
        return (tokenized_dataset['train'], train_dataloader), (tokenized_dataset['test'], test_dataloader)

    def do_checkpoint(self, args, model, tokenizer, save_path, global_step):
        try:
            # 确保进程按顺序打印
            for rank in range(self.accelerator.num_processes):
                if accelerator.process_index == rank:
                    accelerator.print(f"\n{'='*50}")
                    accelerator.print(f"[进程 {rank} 的参数信息]")
                    accelerator.print(f"{'='*50}\n")
                    
                    # 1. 打印总参数量
                    total_params = sum(p.numel() for p in model.parameters())
                    accelerator.print(f"总参数量: {total_params}")
                                        
        except Exception as e:
            accelerator.print(f"[进程 {accelerator.process_index}] 打印参数信息失败: {e}")
            return False
                
        try:
            accelerator.wait_for_everyone()
            accelerator.print(f"Rank {accelerator.process_index} do_checkpoint 同步成功")
        except Exception as e:
            accelerator.print(f"Rank {accelerator.process_index} do_checkpoint 同步失败: {e}")
            return False
        # 保存config tokenizer 
        try:
            if accelerator.is_main_process:
                config = AutoConfig.from_pretrained(args["model_name_or_path"], trust_remote_code=True)
                config.save_pretrained(args["model_dir"])
                tokenizer.save_pretrained(args["model_dir"])
        except:
            accelerator.print(f"Rank {accelerator.process_index} do_checkpoint config tokenizer 保存失败")
            return False
        try:
            if accelerator.is_main_process:
                unwrapped_model = accelerator.unwrap_model(model)
            # 打印state_dict
            accelerator.print(f"{'='*50} unwrapped_model.named_parameters()")
            for name, param in unwrapped_model.named_parameters():
                accelerator.print(f"参数 {name}: shape {param.shape}")
            accelerator.print(f"{'='*50} model.named_parameters()")
            for name, param in model.named_parameters():
                accelerator.print(f"参数 {name}: shape {param.shape}")
            accelerator.print(f"{'='*50}")
            
            accelerator.print(f"Rank {accelerator.process_index} do_checkpoint state_dict 收集成功")
        except:
            accelerator.print(f"Rank {accelerator.process_index} do_checkpoint state_dict 收集失败")
            return False
        try:        
            # if accelerator.is_main_process:
            unwrapped_model.save_pretrained(
                save_path,
                is_main_process=accelerator.is_main_process,
                save_function=accelerator.save,
                # 且分为n个模型文件
                max_shard_size="2GB",  # 添加此参数来控制分片大小
                safe_serialization=True,  # 使用安全的序列化方式
                state_dict = accelerator.get_state_dict(model)
                )
            accelerator.print('save checkpoint success!')
            accelerator.wait_for_everyone()
            return True
        except Exception as e:
            accelerator.print(f"Rank {accelerator.process_index} 保存失败: {e}")
            # accelerator.wait_for_everyone()
            return False

    def train_one_epoch(self, args, model, train_dataset, train_dataloader, optimizer, scheduler, tokenizer,
                        global_step, test_dataset, test_dataloader, 
                        prefix, epoch, best_eval_log_dict, summary_log_dict, most_recent_ckpts_paths):
        
        model_dir = args['model_dir']
        clip_grad_norm = args.get('clip_grad_norm', None)
        evaluating_step_freq = args.get('evaluating_step_freq', None)
        logging_step_freq = args.get('logging_step_freq', None)
        saving_step_freq = args.get('saving_step_freq', None)
        model.train()
        epoch_result_dict = defaultdict(list)
        gradient_accumulation_steps = args['gradient_accumulation_steps']
        optimizer.zero_grad()
        accelerator.print(f"Rank {accelerator.process_index} 进入 train_one_epoch")
        
        # 打印train_dataloader
        accelerator.print(f"Rank {accelerator.process_index} train_dataloader: {train_dataloader}")
        # 确保在开始循环前同步RNG状态
        try:
            accelerator.wait_for_everyone()
            accelerator.print(f"Rank {accelerator.process_index} train_one_epoch 同步成功")
        except Exception as e:
            accelerator.print(f"Rank {accelerator.process_index} train_one_epoch 同步失败: {e}")
            return False
        pbar = tqdm(total=len(train_dataloader), desc=f'Train Loop [{accelerator.process_index}]')
        
        # with tqdm(enumerate(train_dataloader), total=len(train_dataloader), disable=not accelerator.is_main_process, desc='Train Loop') as t:
        for idx, batch in  enumerate(train_dataloader):
            try:
                accelerator.wait_for_everyone()
                accelerator.print(f"Rank {accelerator.process_index} train_one_epoch 第{idx}批次同步成功")
            except Exception as e:
                accelerator.print(f"Rank {accelerator.process_index} train_one_epoch 第{idx}批次同步失败: {e}")
                try:
                    accelerator.wait_for_everyone()
                    accelerator.print(f"Rank {accelerator.process_index} train_one_epoch 第{idx}批次重新同步成功")
                except Exception as e:
                    accelerator.print(f"Rank {accelerator.process_index} train_one_epoch 第{idx}批次重新同步失败: {e}")
            try:
                # 前向传播
                output = model(**batch['forward_kwargs'])
                loss = output[0]
                
                # 缩放损失以适应梯度累积
                loss = loss / gradient_accumulation_steps
                
                # 反向传播
                accelerator.backward(loss)
                
                # 记录指标
                result_dict, extra = {}, None
                
                # 在累积足够步数后更新
                if (idx + 1) % gradient_accumulation_steps == 0:
                    if clip_grad_norm is not None:
                        accelerator.clip_grad_norm_(model.parameters(), clip_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    
                    # 更新全局步数
                    global_step += 1
                    
                    # 记录训练损失
                    epoch_result_dict['loss'].append(loss.item() * gradient_accumulation_steps)
                    for k, v in result_dict.items():
                        epoch_result_dict[k].append(v)

                    # 评估逻辑
                    eval_log_dict = {}
                    is_best = False
                    # if evaluating_step_freq is not None and global_step % evaluating_step_freq == 0:
                    if evaluating_step_freq :
                        evaluate_result_dict = {
                            f'Eval.Gen.{k}': v 
                            for k, v in evaluate_generation(args, model, test_dataset, test_dataloader, tokenizer).items()
                        }
                        eval_log_dict.update(evaluate_result_dict)
                        if eval_log_dict['Eval.Gen.value_accuracy'] > best_eval_log_dict.get('Eval.Gen.value_accuracy_best', 0):
                            is_best = True
                            best_eval_log_dict['Eval.Gen.value_accuracy_best'] = eval_log_dict['Eval.Gen.value_accuracy']
                        if 'Eval.Gen.value_accuracy' not in summary_log_dict:
                            summary_log_dict['Eval.Gen.value_accuracy'] = []
                        summary_log_dict['Eval.Gen.value_accuracy'].append(eval_log_dict['Eval.Gen.value_accuracy'])

                    # 日志记录
                    train_log_dict = {}
                    if logging_step_freq is not None and global_step % logging_step_freq == 0:
                        train_log_dict = {
                            f'T.{k}': sum(v)/len(v) if isinstance(v, list) else v 
                            for k, v in epoch_result_dict.items()
                        }
                    
                    if eval_log_dict or train_log_dict:
                        log_dict = {
                            'lr': scheduler.get_last_lr()[0], 
                            **train_log_dict, 
                            **eval_log_dict, 
                            **best_eval_log_dict
                        }
                        if accelerator.is_main_process and args['wandb_log']:
                            wandb.log(log_dict, step=global_step)
                            log_dict = {'wandb': args['wandb_project'] + '|' + args['wandb_run_name'], **log_dict}
                        log_dict = {k: f'{v:.5g}' if isinstance(v, float) else v for k,v in log_dict.items()}
                        accelerator.print(f"{prefix}[E={epoch}/{args['n_epochs']}, S={global_step}] {log_dict}")

                    # 保持记录数量
                    for k, v in epoch_result_dict.items():
                        if len(v) > 1:
                            epoch_result_dict[k] = v[-1:]
                accelerator.print(f"Rank {accelerator.process_index} 训练批次 {idx} 成功")
            except Exception as e:
                accelerator.print(f"[进程 {accelerator.process_index}] 批次 {idx} 出错: {e}")
                accelerator.print(traceback.format_exc())
                continue
        # Metric summary:
        epoch_result_dict = {k:(sum(v)/len(v) if isinstance(v, list) else v) for k, v in epoch_result_dict.items()}
        # 更新进度条
        if accelerator.is_main_process:
            pbar.update(1)
        pbar.close()
        accelerator.print(f"Rank {accelerator.process_index} 训练epoch {epoch} 结束")
        return epoch_result_dict, global_step, evaluate_result_dict

    def evaluate_generation(self, args, model, dataset, dataloader, tokenizer):
        # return {'value_accuracy': 0}

        model.eval()
        predictions = []
        targets = []
        for idx, batch in tqdm(
            enumerate(dataloader), total=len(dataloader), disable=not accelerator.is_main_process,
            desc='Evaluation Gen Loop'):
            output_ = accelerator.unwrap_model(model).generate(
                **batch['generate_prefix_kwargs'],
                max_length=args['max_gen_length'],
                output_scores=True,
                return_dict_in_generate=True,
                num_beams=1,
                use_cache=True,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            generated_ids = output_.sequences
            generated_ids = pad_across_processes(generated_ids, dim=1, pad_index=tokenizer.pad_token_id, pad_first=True)

            labels = batch['generate_prefix_kwargs']['labels']
            labels = pad_across_processes(labels, dim=1, pad_index=tokenizer.pad_token_id, pad_first=True)
            labels[labels == -100] = tokenizer.pad_token_id

            generated_ids, labels = accelerator.gather(generated_ids), accelerator.gather(labels)

            preds = [
                tokenizer.decode(g.cpu().numpy().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() for g in
                generated_ids]
            predictions.extend(preds)
            target = [
                tokenizer.decode(t.cpu().numpy().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True).strip() for t in
                labels]
            targets.extend(target)

        predictions = predictions[:len(dataset)]
        targets = targets[:len(dataset)]
    
        if accelerator.is_main_process and accelerator.is_local_main_process:
            # 打印输出、目标
            accelerator.print("="*100)
            accelerator.print("="*20,"predictions", "="*20)
            accelerator.print(predictions)
            accelerator.print("="*20,"targets", "="*20)
            accelerator.print(targets)
            accelerator.print("="*100)
            
            results = []
            src_name = dataset[0]['item_id'].split('_')[0]
            for pred, tar, item in zip(predictions, targets, dataset):
                cur_res = {
                    'item_id': item['item_id'],
                    'answer_value': item['answer_value'],
                }
                ## Processing target
                target_cot = tar.strip().split(cot_trigger)[-1].strip()
                target_value = post_process_final_answer_fn_mapper[src_name](cur_res['answer_value'])
                cur_res['target'] = target
                cur_res['target_cot'] = target_cot
                cur_res['target_value'] = target_value
                ## Processing prediction
                prediction_cot = pred.strip().split(cot_trigger)[-1].strip()
                cur_res['prediction'] = pred
                cur_res['prediction_cot'] = prediction_cot
                cur_res['prediction_value'] = None # Tobe filled
                results.append(cur_res)
            print("="*100)
            print("eval results:")
            print(results)
            execute_fn = post_process_answer_cot_fn_mapper[(args['engine'], src_name)]
            corr_value = 0
            for i, prediction_value in enumerate(execute_fn([item['prediction_cot'] for item in results])):
                target_value = results[i]['target_value']
                is_correct = compare_answer_fn_mapper[src_name](prediction_value, target_value) if prediction_value is not None else False
                results[i]['prediction_value'] = prediction_value
                results[i]['is_correct'] = is_correct
                corr_value += is_correct

            res_path = args['model_dir'].rstrip('/')+ '/' + '_res.json'
            with open(res_path, 'w') as f:
                json.dump(results, f, indent=2)

            # if args['wandb_log']:
            #     table = wandb.Table(dataframe=pd.DataFrame(results))
            #     wandb.log({"predictions": table})

            value_accuracy = corr_value / len(results) * 100
            accelerator.print(f"[Eval Info] value_accuracy: {value_accuracy:.5g}%")
            value_accuracy = torch.FloatTensor([value_accuracy]).to(accelerator.device)
        else:
            value_accuracy = torch.FloatTensor([-1.0]).to(accelerator.device)
        value_accuracy = broadcast(value_accuracy).cpu().numpy().tolist()[0]

        # Metric summary:
        model.train()
        return {'value_accuracy': value_accuracy}



In [4]:
# training loop
def training_loop(args):
    accelerator = Accelerator(
        deepspeed_plugin=args['deepspeed_plugin']
    )
    trainer = Trainer(accelerator)
    
    set_seed(args['seed'] + accelerator.process_index)
    if torch.distributed.get_rank() == 0 and args['wandb_log']:
        wandb.init(project=args['wandb_project'], name=args['wandb_run_name'])
        wandb.config.update(args)
        
    tokenizer = AutoTokenizer.from_pretrained(
        args['tokenizer_name_or_path'],
        trust_remote_code=True,
        padding_side='left',  # ChatGLM 使用左侧填充
        eos_token='<|endoftext|>',  # 设置 EOS token
        pad_token='<|endoftext|>',  # 设置 PAD token
    )
    
    # 确保 tokenizer 有必要的特殊 token
    special_tokens_dict = {
        'pad_token': '<|endoftext|>',
        'eos_token': '<|endoftext|>',
        'bos_token': '<|startoftext|>',
    }
    tokenizer.add_special_tokens(special_tokens_dict)

    (train_dataset, train_dataloader), (test_dataset, test_dataloader) = trainer.prepare_datasets_and_data_loaders(args, tokenizer)
    config = AutoConfig.from_pretrained(
        args['model_name_or_path'],
        trust_remote_code=True
    )
    
    # 添加缺失的配置
    config._attn_implementation = "eager"  # 添加注意力实现方式
    
    # 加载模型
    model = AutoModelForCausalLM.from_pretrained(
        args['model_name_or_path'],
        config=config,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )
    
    # 确保模型参数是 bf16 类型
    model = model.bfloat16()
    
    accelerator.print(f'[Vocab size]: {len(tokenizer)}')    
    model.resize_token_embeddings(len(tokenizer))

    if accelerator.is_main_process and args['wandb_log']:
        wandb.run.summary.update({
            'pad_token_id': tokenizer.pad_token_id,
            'eos_token_id': tokenizer.eos_token_id,
            'unk_token_id': tokenizer.unk_token_id,
            'vocab_size': len(tokenizer)
        })

    n_epochs = args['n_epochs']
    num_training_steps = (len(train_dataloader) // accelerator.num_processes * n_epochs) // args['gradient_accumulation_steps']
    warmup_step = args['warmup_step'] if args['warmup_step'] is not None and args['warmup_step'] >= 0 else int(0.1 * num_training_steps)
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "LayerNorm.weight"])],
            "weight_decay": args['weight_decay'],
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in ["bias", "LayerNorm.weight"])],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args['learning_rate'], eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=num_training_steps)
    # scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step)
    accelerator.print(
        f"***** Running training *****\n"
        f"  Num examples = {len(train_dataset)}\n"
        f"  Num Epochs = {n_epochs}\n"
        f"  Instantaneous batch size per device = {args['batch_size']}\n"
        f"  Total train batch size (w. parallel, distributed & accumulation) = {args['batch_size']*accelerator.num_processes*args['gradient_accumulation_steps']}\n"
        f"  Total optimization steps = {num_training_steps}\n"
        f"  Warm up step: {warmup_step}\n"
        f"  Learning rate: {args['learning_rate']}\n"
    )   
    model, optimizer, train_dataloader, test_dataloader = accelerator.prepare(model, optimizer, train_dataloader, test_dataloader)
    
    global_step = 0
    evaluating_epoch_freq = args['evaluating_epoch_freq']
    logging_epoch_freq = args['logging_epoch_freq']
    saving_epoch_freq = args['saving_epoch_freq']
    model_dir=args['model_dir']
    best_eval_log_dict = {}
    summary_log_dict = {}
    os.makedirs(model_dir, exist_ok=True)
    most_recent_ckpts_paths = []
    lowest_loss = None
    with tqdm(range(1, n_epochs+1), total=n_epochs, disable=not accelerator.is_main_process) as t:
        for epoch in t:
            kwargs = {
                'args': args,
                'model': model, 
                'train_dataset': train_dataset, 
                'train_dataloader': train_dataloader, 
                'test_dataset': test_dataset,
                'test_dataloader': test_dataloader,
                'optimizer': optimizer, 
                'scheduler': scheduler,
                'global_step': global_step, 
                'tokenizer': tokenizer,
                'prefix':'', 
                'epoch': epoch,
                'best_eval_log_dict': best_eval_log_dict,
                'summary_log_dict': summary_log_dict,
                'most_recent_ckpts_paths': most_recent_ckpts_paths,
            }
            
            try:
                
                train_epoch_result_dict, global_step, evaluate_result_dict = trainer.train_one_epoch(**kwargs)
                accelerator.print(f"[进程 {accelerator.process_index}] Epoch {epoch} 训练结果: {train_epoch_result_dict}")
 
            except Exception as e:
                accelerator.print(f"[进程 {accelerator.process_index}] Epoch {epoch} 发生错误: {e}")
                break
            
            eval_log_dict = {}
            is_best = False
            
            accelerator.print(f'跳过evaluation')
            # if evaluating_epoch_freq is not None and epoch % evaluating_epoch_freq == 0:
            #     evaluate_result_dict = {f'Eval.Gen.{k}':  v for k, v in evaluate_generation(args, model, test_dataset, test_dataloader, tokenizer).items()}
            #     eval_log_dict.update(evaluate_result_dict)
            #     if eval_log_dict['Eval.Gen.value_accuracy'] > best_eval_log_dict.get('Eval.Gen.value_accuracy_best', 0):
            #         is_best = True
            #         best_eval_log_dict['Eval.Gen.value_accuracy_best'] = eval_log_dict['Eval.Gen.value_accuracy']
            #     if 'Eval.Gen.value_accuracy' not in summary_log_dict:
            #         summary_log_dict['Eval.Gen.value_accuracy'] = []
            #     summary_log_dict['Eval.Gen.value_accuracy'].append(eval_log_dict['Eval.Gen.value_accuracy'])
            if lowest_loss is None:
                lowest_loss = train_epoch_result_dict["loss"]
                is_best = True
            elif train_epoch_result_dict["loss"] < lowest_loss:
                lowest_loss = train_epoch_result_dict["loss"]
                is_best = True
                
            train_log_dict = {}
            if logging_epoch_freq is not None and epoch % logging_epoch_freq == 0:
                train_log_dict = {f'T.{k}': sum(v)/len(v) if isinstance(v, list) else v for k, v in train_epoch_result_dict.items()}

            if eval_log_dict or train_log_dict:
                log_dict = {'lr': scheduler.get_last_lr()[0], **train_log_dict, **eval_log_dict, **best_eval_log_dict}
                if accelerator.is_main_process and args['wandb_log']:
                    wandb.log(log_dict, step=global_step)
                    log_dict = {'wandb': args['wandb_project'] + '|' + args['wandb_run_name'], **log_dict}
                log_dict = {k: f'{v:.5g}' if isinstance(v, float) else v for k,v in log_dict.items()}
                accelerator.print(f"[E={epoch}/{args['n_epochs']}, S={global_step}] {log_dict}")
            
            # if saving_epoch_freq is not None and epoch % saving_epoch_freq == 0:
            if is_best:
                try:
                    accelerator.wait_for_everyone()
                    accelerator.print(f"epoch {epoch} 将开始保存权重，等待所有进程成功")
                except Exception as e:
                    accelerator.print(f"epoch {epoch} 将开始保存权重，等待所有进程失败: {e}")
                    break
                
                save_path = model_dir
                # 如果目录已存在,先清空
                if accelerator.is_main_process and os.path.exists(save_path):
                    accelerator.print(f"目录已存在, 清空目录: {save_path}")
                    # shutil.rmtree(save_path)
                os.makedirs(save_path, exist_ok=True)
                timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
                accelerator.print(f"开始保存新的最佳checkpoint... 时间: {timestamp}")
                
                # accelerator.wait_for_everyone()
                s=trainer.do_checkpoint(args, model, tokenizer, save_path, global_step)
                # accelerator.wait_for_everyone()
                if s:
                    accelerator.print(f"保存checkpoint成功")
                else:
                    accelerator.print(f"保存checkpoint失败")
                    break


In [5]:
# get args from shell script
import re

def parse_shell_script(script_path):
    # 初始化参数和命令字典
    params = {}
    commands = {}
    
    # 读取shell脚本内容
    with open(script_path, 'r') as f:
        lines = f.readlines()
    
    # 解析每一行
    for line in lines:
        line = line.strip()
        
        # 跳过空行和注释行
        if not line or line.startswith('#'):
            continue
            
        # 解析export命令
        if line.startswith('export'):
            match = re.match(r'export\s+(\w+)=(.+)', line)
            if match:
                key, value = match.groups()
                commands[key] = value
            continue
            
        # 解析参数赋值（形如 var=${var:-'default'} 或 var='value'）
        match = re.match(r'(\w+)=\${?([^}]*)}?', line)
        if match:
            key, value = match.groups()
            # 处理带有默认值的情况 ${var:-'default'}
            if ':-' in value:
                value = value.split(':-')[1].strip("'\"")
            params[key] = value
            continue
            
        # 解析直接赋值（形如 var=value）
        match = re.match(r'(\w+)=[\'"]?([^\'"]*)[\'"]?', line)
        if match:
            key, value = match.groups()
            params[key] = value
            
    return params, commands


from dataclasses import dataclass, field, asdict
from transformers import HfArgumentParser

NONE_INT = -100 
NONE_STR = 'None'
# 首先定义Arguments类（与你的代码中一致）
@dataclass
class Arguments:
    model_name_or_path: str
    tokenizer_name_or_path: str
    model_dir: str
    train_file: str 
    test_file: str
    batch_size: int = field(default=4)
    eval_batch_size: int = field(default=8)
    n_epochs: int = field(default=40)
    num_workers: int = field(default=8)
    learning_rate: float = field(default=2e-5)
    weight_decay: float = field(default=1e-6)
    warmup_step: int = field(default=0)
    clip_grad_norm: float = field(default=1)
    evaluating_epoch_freq: int = field(default=1)
    logging_epoch_freq: int = field(default=1)
    saving_epoch_freq: int = field(default=1000)
    evaluating_step_freq: int = field(default=NONE_INT)
    logging_step_freq: int = field(default=NONE_INT)
    saving_step_freq: int = field(default=NONE_INT)
    seed: int = field(default=42)
    max_input_length: int = field(default=700)
    max_gen_length: int = field(default=512)
    gradient_accumulation_steps: int = field(default=1)
    keep_num_ckpt: int = field(default=1)
    wandb_log: bool = field(default=False)
    wandb_project: str = field(default='tmp_anvfupsadfn')
    wandb_run_name: str = field(default='default_run_name')
    engine: str = field(default='nl')

# 解析Shell脚本并获取参数字典
template_path = '/home/wangxinrong/workspace/reft/divination/mwp_ReFT/exps/paper_exps/SFT/_template.sh'
zhouyi_path = '/home/wangxinrong/workspace/reft/divination/mwp_ReFT/exps/paper_exps/SFT/zhouyi_sft.sh'

# 获取合并后的参数字典
template_params, template_commands = parse_shell_script(template_path)
zhouyi_params, zhouyi_commands = parse_shell_script(zhouyi_path)
shell_params = {**template_params, **zhouyi_params}

# 类型转换函数
def convert_type(value, target_type):
    if target_type == bool:
        return value.lower() == 'true'
    if target_type == int:
        try:
            return int(value)
        except ValueError:
            if value == 'None' or value == '-100':
                return None
            raise
    if target_type == float:
        try:
            return float(value)
        except ValueError:
            if value == 'None':
                return None
            raise
    return value

# 创建符合Arguments类型的字典
args = {}
for field_name, field_def in Arguments.__dataclass_fields__.items():
    if field_name in shell_params:
        args[field_name] = convert_type(shell_params[field_name], field_def.type)
    else:
        # 使用默认值
        args[field_name] = field_def.default

# 处理特殊值（如果有的话）
for k, v in args.items():
    if v in [NONE_INT, 'None']:
        args[k] = None


In [6]:
from accelerate.commands.config import load_config_from_file
from accelerate.utils import DeepSpeedPlugin
from accelerate import Accelerator
from accelerate import notebook_launcher

# 使用完整路径
config_path = "/home/hanxianlin/workspace/reft/divination/mwp_ReFT/default_config_deepspeed_ga2.yaml"
config = load_config_from_file(config_path)
deepspeed_config = config.deepspeed_config

# 创建DeepSpeedPlugin实例
deepspeed_plugin = DeepSpeedPlugin(
    zero_stage=deepspeed_config['zero_stage'],
    gradient_accumulation_steps=deepspeed_config['gradient_accumulation_steps'],
    gradient_clipping=deepspeed_config['gradient_clipping'],
    offload_optimizer_device=deepspeed_config['offload_optimizer_device'],
    offload_param_device=deepspeed_config['offload_param_device'],
)

args['deepspeed_plugin'] = deepspeed_plugin
# 使用notebook_launcher，只传入基本参数
notebook_launcher(
    training_loop, 
    [args], 
    num_processes=config.num_processes,
    mixed_precision=config.mixed_precision
)

Launching training on 8 GPUs.


Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
[rank6]:[W109 16:16:34.216926984 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 6]  using GPU 6 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
[rank1]:[W109 16:16:40.750114209 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.


Raw data: DatasetDict({
    train: Dataset({
        features: ['item_id', 'question', 'answer_cot', 'answer_value'],
        num_rows: 16
    })
    test: Dataset({
        features: ['item_id', 'question', 'answer_cot', 'answer_value'],
        num_rows: 1
    })
})
Using instruction: Question:

Using cot_trigger: 
Answer reasoning:

Using answer_trigger: 
因此，答案是：


[rank7]:[W109 16:16:41.764693545 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 7]  using GPU 7 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
[rank2]:[W109 16:16:45.607773104 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 2]  using GPU 2 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device,or call init_process_group() with a device_id.
[rank3]:[W109 16:16:46.294356444 ProcessGroupNCCL.cpp:4115] [PG ID 0 PG GUID 0 Rank 3]  using GPU 3 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in b

ChildFailedError: 
============================================================
training_loop FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-01-09_16:26:34
  host      : HS-DSS8440-009
  rank      : 6 (local_rank: 6)
  exitcode  : 1 (pid: 3894189)
  error_file: /tmp/torchelastic_u0zkmm7w/none_5u7n5yel/attempt_0/6/error.json
  traceback : Traceback (most recent call last):
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
      return f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
    File "/tmp/ipykernel_3893964/3522858831.py", line 29, in training_loop
      (train_dataset, train_dataloader), (test_dataset, test_dataloader) = trainer.prepare_datasets_and_data_loaders(args, tokenizer)
                                                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/tmp/ipykernel_3893964/2319000155.py", line 19, in prepare_datasets_and_data_loaders
      with accelerator.main_process_first():
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/contextlib.py", line 137, in __enter__
      return next(self.gen)
             ^^^^^^^^^^^^^^
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/accelerate/accelerator.py", line 920, in main_process_first
      with self.state.main_process_first():
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/contextlib.py", line 137, in __enter__
      return next(self.gen)
             ^^^^^^^^^^^^^^
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/accelerate/state.py", line 1074, in main_process_first
      with PartialState().main_process_first():
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/contextlib.py", line 137, in __enter__
      return next(self.gen)
             ^^^^^^^^^^^^^^
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/accelerate/state.py", line 496, in main_process_first
      yield from self._goes_first(self.is_main_process)
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/accelerate/state.py", line 381, in _goes_first
      self.wait_for_everyone()
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/accelerate/state.py", line 375, in wait_for_everyone
      torch.distributed.barrier()
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
      return func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
    File "/home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 4159, in barrier
      work = group.barrier(opts=opts)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  torch.distributed.DistBackendError: [6] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: wait timeout after 600000ms, keys: //worker/attempt_0/default_pg/0//cuda//0
  Exception raised from doWait at /opt/conda/conda-bld/pytorch_1728945377988/work/torch/csrc/distributed/c10d/TCPStore.cpp:600 (most recent call first):
  frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7fb8d6975446 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libc10.so)
  frame #1: <unknown function> + 0x1327393 (0x7fb91b3c8393 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #2: c10d::TCPStore::doGet(std::string const&) + 0x2a (0x7fb91fdd6bca in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #3: c10d::TCPStore::get(std::string const&) + 0x7a (0x7fb91fdd7a3a in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #4: c10d::PrefixStore::get(std::string const&) + 0x31 (0x7fb91fd87dc1 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #5: c10d::PrefixStore::get(std::string const&) + 0x31 (0x7fb91fd87dc1 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #6: c10d::PrefixStore::get(std::string const&) + 0x31 (0x7fb91fd87dc1 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #7: c10d::PrefixStore::get(std::string const&) + 0x31 (0x7fb91fd87dc1 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #8: c10d::ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId*, bool, std::string const&, int) + 0xaf (0x7fb8d7c2fbff in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
  frame #9: c10d::ProcessGroupNCCL::getNCCLComm(std::string const&, c10::Device&, c10d::OpType, int, bool) + 0xfbd (0x7fb8d7c3bb9d in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
  frame #10: <unknown function> + 0x123a33e (0x7fb8d7c4433e in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
  frame #11: c10d::ProcessGroupNCCL::allreduce_impl(at::Tensor&, c10d::AllreduceOptions const&) + 0x12c (0x7fb8d7c4590c in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
  frame #12: c10d::ProcessGroupNCCL::barrier(c10d::BarrierOptions const&) + 0x476 (0x7fb8d7c531e6 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cuda.so)
  frame #13: <unknown function> + 0x5cd95f2 (0x7fb91fd7a5f2 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #14: <unknown function> + 0x5ce3df5 (0x7fb91fd84df5 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #15: <unknown function> + 0x52fd9bb (0x7fb91f39e9bb in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #16: <unknown function> + 0x52fb249 (0x7fb91f39c249 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #17: <unknown function> + 0x17d7c38 (0x7fb91b878c38 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #18: <unknown function> + 0x5cedc74 (0x7fb91fd8ec74 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #19: <unknown function> + 0x5ceea05 (0x7fb91fd8fa05 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_cpu.so)
  frame #20: <unknown function> + 0xdfe698 (0x7fb930f13698 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_python.so)
  frame #21: <unknown function> + 0x4cc1e3 (0x7fb9305e11e3 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/site-packages/torch/lib/libtorch_python.so)
  frame #22: <unknown function> + 0x224588 (0x564e62709588 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #23: _PyObject_MakeTpCall + 0x2bb (0x564e626e975b in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #24: <unknown function> + 0x1126a1 (0x564e625f76a1 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #25: <unknown function> + 0x27089d (0x564e6275589d in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #26: <unknown function> + 0x113768 (0x564e625f8768 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #27: <unknown function> + 0x251adc (0x564e62736adc in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #28: <unknown function> + 0x114b20 (0x564e625f9b20 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #29: <unknown function> + 0x27089d (0x564e6275589d in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #30: <unknown function> + 0x113768 (0x564e625f8768 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #31: <unknown function> + 0x251adc (0x564e62736adc in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #32: <unknown function> + 0x114b20 (0x564e625f9b20 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #33: <unknown function> + 0x27089d (0x564e6275589d in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #34: <unknown function> + 0x113768 (0x564e625f8768 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #35: <unknown function> + 0x251adc (0x564e62736adc in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #36: <unknown function> + 0x114b20 (0x564e625f9b20 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #37: _PyObject_FastCallDictTstate + 0x1ee (0x564e626ec2fe in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #38: <unknown function> + 0x23229c (0x564e6271729c in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #39: _PyObject_MakeTpCall + 0x274 (0x564e626e9714 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #40: <unknown function> + 0x1126a1 (0x564e625f76a1 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #41: _PyObject_FastCallDictTstate + 0x1ee (0x564e626ec2fe in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #42: _PyObject_Call_Prepend + 0x69 (0x564e627176b9 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #43: <unknown function> + 0x30364b (0x564e627e864b in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #44: _PyObject_Call + 0xb5 (0x564e6271a135 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #45: <unknown function> + 0x113339 (0x564e625f8339 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #46: PyEval_EvalCode + 0xa1 (0x564e6279f741 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #47: <unknown function> + 0x2d5ece (0x564e627baece in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #48: <unknown function> + 0x112f8e (0x564e625f7f8e in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #49: <unknown function> + 0x2d099f (0x564e627b599f in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #50: <unknown function> + 0x2d1c57 (0x564e627b6c57 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #51: <unknown function> + 0x113e38 (0x564e625f8e38 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #52: <unknown function> + 0x251adc (0x564e62736adc in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #53: <unknown function> + 0x2515be (0x564e627365be in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #54: _PyObject_Call + 0x12b (0x564e6271a1ab in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #55: <unknown function> + 0x113339 (0x564e625f8339 in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #56: <unknown function> + 0x2d099f (0x564e627b599f in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #57: <unknown function> + 0x8274 (0x7fb9573b2274 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
  frame #58: <unknown function> + 0x8a63 (0x7fb9573b2a63 in /home/wangxinrong/miniconda3/envs/cuda-12.2/lib/python3.12/lib-dynload/_asyncio.cpython-312-x86_64-linux-gnu.so)
  frame #59: <unknown function> + 0x222fbc (0x564e62707fbc in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #60: <unknown function> + 0x34db0c (0x564e62832b0c in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #61: <unknown function> + 0x1c402e (0x564e626a902e in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  frame #62: <unknown function> + 0x21940b (0x564e626fe40b in /home/wangxinrong/miniconda3/envs/cuda-12.2/bin/python)
  . This may indicate a possible application crash on rank 0 or a network set up issue.
  
============================================================