In [None]:
import logging
import os
from pathlib import Path
from datetime import timedelta
import torch
import numpy as np
import datasets
from trl import SFTTrainer, SFTConfig
from transformers import EarlyStoppingCallback, set_seed
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import AdamW
# from lm_experiments_tools.dataset_preprocessing import load_and_preprocess_task, combine_datasets
# from lm_experiments_tools.instruction_utils import mask_non_completion, mask_non_completion_multi

from torch.nn.utils.rnn import pad_sequence

import accelerate
from peft import get_peft_model, LoraConfig, TaskType

# import transformers  # noqa: E402
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser  # noqa: E402

# from lm_experiments_tools.utils import get_cls_by_name, get_optimizer, prepare_run  # noqa: E402
from lm_experiments_tools.utils import get_cls_by_name

from utils.reasoning import make_segment, split_cot

logger_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(format=logger_fmt, level=logging.INFO)
logger = logging.getLogger('')


# if CUDA_VISIBLE_DEVICES is not set make all gpus visible
if os.environ.get('CUDA_VISIBLE_DEVICES', None) is None:
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in range(torch.cuda.device_count())])

logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
# first call to torch.cuda.device_count() sets visible gpus, following calls will not change the result
logger.info(f"CUDA DEVICE COUNT: {torch.cuda.device_count()}")


    # set current working dir
    args.working_dir = str(Path(args.working_dir).expanduser().absolute())
    os.chdir(args.working_dir)
    set_seed(args.seed)

    # workaround with setting bigger tiomeout for NCCL (useful for big dataset, to avoid timeout at tokenization)
    timeout = timedelta(seconds=20 * 1800)
    accelerator = accelerate.Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
                                         kwargs_handlers=[accelerate.InitProcessGroupKwargs(timeout=timeout)])
    from accelerate.logging import get_logger
    logger = get_logger('')

    logger.info(f'num processes: {accelerator.num_processes}')
    logger.info(f'mixed precision: {accelerator.mixed_precision}')

    if args.output_dir is None:
        logger.warning('output_dir is not set: config, logs and checkpoints will not be saved.')

    # ============================
    # === Prepare tokenizer and datasets
    # ============================
    if not args.from_pretrained:
        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
    else:
        tokenizer = AutoTokenizer.from_pretrained(args.from_pretrained)
    if args.tokenizer_for_chat_template is not None:
        it_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_for_chat_template, trust_remote_code=True)
        tokenizer.chat_template = it_tokenizer.chat_template
    if args.padding_side is not None:
        tokenizer.padding_side = args.padding_side
    # Prepare datasets
    logger.info(f'preparing dataset for {args.task_name}')

    if args.dataset_name is not None:
        hf_dataset = datasets.load_dataset(args.dataset_name)
        train_dataset = hf_dataset["train"]
        valid_dataset = hf_dataset["valid"]
        if "test" in hf_dataset:
            test_dataset = hf_dataset["test"]
        else:
            test_dataset = None
    else:
        dataset_path = os.path.join(args.dataset_dir, args.task_name)

        train_dataset = datasets.load_from_disk(os.path.join(dataset_path, "train"))
        valid_dataset = datasets.load_from_disk(os.path.join(dataset_path, "valid"))
        if os.path.exists(os.path.join(dataset_path, "test")):
            test_dataset = datasets.load_from_disk(os.path.join(dataset_path, "test"))
        else:
            test_dataset = datasets.load_from_disk(os.path.join(dataset_path, "valid"))

    if args.max_cot_steps is not None:
        train_dataset = train_dataset.filter(lambda x: x['cot_len'] <= args.max_cot_steps)
        valid_dataset = valid_dataset.filter(lambda x: x['cot_len'] <= args.max_cot_steps)
        test_dataset = test_dataset.filter(lambda x: x['cot_len'] <= args.max_cot_steps)
        logger.info(f"Filtered ds sizes: {len(train_dataset), len(valid_dataset), len(test_dataset)}")
    if 'gsm8k' in args.task_name:
        delim = ">> <<"
    elif 'multiplication' in args.task_name:
        delim = ' + '
    else:
        raise NotImplementedError(f"Unknown task name {args.task_name}")

    id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
    think = tokenizer.encode('????')
    bos = tokenizer.encode('////')
    ans = tokenizer.encode('!!!!')
    eos = [tokenizer.eos_token_id]
    if 'gsm8k' in args.task_name:
        delim = ">> <<"
    elif 'multiplication' in args.task_name:
        delim = ' + '
    else:
        raise NotImplementedError(f"Unknown task name {args.task_name}")


    # ============================
    # === Prepare data collator ===
    # ============================

    def collate_fn(batch):
        # first, we segment each sample into task, cot steps and labels
        segments_batch = []
        for sample in batch:
            task, lab, cot = sample['task'], sample['labels'], sample['cot']
            task_tokens = tokenizer.encode(task, add_special_tokens=False)
            labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
            if getattr(args, 'use_cot', False):
                cot_segments = split_cot(cot, by=delim)
            else:
                cot_segments = [cot]
            cot_segment_tokens = tokenizer.batch_encode_plus(cot_segments, add_special_tokens=False)['input_ids']

            segments = []
            segments.append(make_segment(bos + task_tokens + think, loss=False))
            for segment in cot_segment_tokens[:-1]:
                segments.append(make_segment(bos + segment + think, loss=True))
            segments.append(make_segment(bos + cot_segment_tokens[-1] + ans, loss=True))

            segments.append(make_segment(bos + labels_tokens + eos, loss=True))
            segments_batch.append(segments)

        # if some samples have less segments than others, we pad them with empty segments
        num_segments = max(len(segments) for segments in segments_batch)
        for segments in segments_batch:
            if len(segments) < num_segments:
                segments.extend([make_segment(eos, loss=False)] * (num_segments - len(segments)))

        # prepare segments for the whole batch
        batch_segments = []
        for i in range(num_segments):
            input_ids = [s[i]['input_ids'] for s in segments_batch]
            attention_mask = [s[i]['attention_mask'] for s in segments_batch]
            labels = [s[i]['labels'] for s in segments_batch]
            labels_mask = [s[i]['labels_mask'] for s in segments_batch]

            input_ids = pad_sequence(input_ids, batch_first=True, padding_value=id_pad_value)
            attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
            labels = pad_sequence(labels, batch_first=True, padding_value=-100)
            labels_mask = pad_sequence(labels_mask, batch_first=True, padding_value=False)

            batch_segment = {'input_ids': input_ids,
                             'attention_mask': attention_mask,
                             'labels_mask': labels_mask,
                             'labels': labels
                             }
            batch_segments.append(batch_segment)
        full_labels = torch.cat([s['labels'] for s in batch_segments], dim=1)
        return {"segments": batch_segments, 'labels': full_labels}


    # ============================
    # === Define model ==========
    # ============================
    # TODO: move model building to separate function
    model_cls = get_cls_by_name(args.model_cls)
    logger.info(f'Using model class: {model_cls}')

    if args.use_adapter:
        model_cfg = AutoConfig.from_pretrained(args.from_pretrained)

        model_cfg.use_parallel_adapter = args.use_adapter
        model_cfg.parallel_adapter_mode = 'ffn'
        model_cfg.adapter_bottleneck_dim = args.adapter_bottleneck_dim
        model_cfg.adapter_dropout = args.adapter_dropout
        model_cfg.adapter_scale = args.adapter_scale

        model = model_cls(config=model_cfg)

        logger.info(f'Loading pretrained model: {args.from_pretrained}')
        base_model = model_cls.from_pretrained(args.from_pretrained, use_safetensors=False)

        model.load_state_dict(base_model.state_dict(), strict=False)
        del base_model
        logger.info('Added adapters')
    else:
        # TODO: fix if for Qwen and Llama
        if not args.from_pretrained:
            model_cfg = AutoConfig.from_pretrained(args.model_cfg)
            model = model_cls.from_config(model_cfg)
        else:
            logger.info(f'Loading pretrained model: {args.from_pretrained}')
            if "Qwen" in args.from_pretrained or "Llama" in args.from_pretrained:
                model = model_cls.from_pretrained(args.from_pretrained,
                                                  attn_implementation="flash_attention_2",
                                                  torch_dtype=torch.bfloat16,
                                                  trust_remote_code=True)
            else:
                model = model_cls.from_pretrained(args.from_pretrained)

    # add LoRA adapters
    if args.use_lora:
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=args.lora_attn_dim,
            lora_alpha=args.lora_attn_alpha,
            lora_dropout=args.lora_dropout
            )
        model = get_peft_model(model, peft_config)
        logger.info('Added LoRA, trainable parameters with LoRA only:')
        model.print_trainable_parameters()
    
    # load cpt of backbone model
    if args.backbone_cpt:
        if 'bin' in args.backbone_cpt:
            backbone_cpt = args.backbone_cpt
        else:
            backbone_cpt = os.path.join(args.backbone_cpt, "model_best.pth")
        cpt = torch.load(backbone_cpt, map_location='cpu')
        model.load_state_dict(cpt['model_state_dict'], strict=True)
        logger.info(f'Loaded baseline state dict from: {args.backbone_cpt}')
    
    # Pass memory settings to pretrained model
    if args.num_mem_tokens is not None:
        memory_cell_cls = get_cls_by_name(args.memory_cell_cls)
        recurrent_wrapper_cls = get_cls_by_name(args.recurrent_wrapper_cls)
        logger.info(f'Wrapping in: {memory_cell_cls} and {recurrent_wrapper_cls}')
        mem_cell_args = dict(
            base_model=model,
            num_mem_tokens=args.num_mem_tokens,
        )
        # additional parameters for ARMT model
        if args.d_mem is not None:
            mem_cell_args['d_mem'] = args.d_mem
            mem_cell_args['wrap_pos'] = args.wrap_pos
            mem_cell_args['correction'] = not args.no_correction
            # mem_cell_args['use_lora'] = args.use_lora
        if args.layers_attr is not None:
            mem_cell_args['layers_attr'] = args.layers_attr
        if args.attend_to_previous_input:
            mem_cell_args['attend_to_previous_input'] = args.attend_to_previous_input
        
        cell = memory_cell_cls(**mem_cell_args)
        model = recurrent_wrapper_cls(
            cell,
            segment_size=args.segment_size,
            max_n_segments=args.max_n_segments,
            vary_n_segments=args.vary_n_segments,
            k2=args.k2,
            attend_to_previous_input=args.attend_to_previous_input,
            return_all_logits=False,
            answer_loss_weight=args.answer_loss_weight
        )
        
        # load cpt of rmt
        if args.model_cpt:
            if "safetensors" in args.model_cpt:
                print(model)
                from safetensors.torch import load_model
                load_model(model, args.model_cpt, device="cuda:0")
            else:
                if ".bin" in args.model_cpt:
                    model_cpt = args.model_cpt
                elif "model_best" in os.listdir(args.model_cpt):
                    model_cpt = os.path.join(args.model_cpt, "model_best", "pytorch_model.bin")
                else:
                    dir_files = os.listdir(args.model_cpt)
                    checkpoint_dir = [el for el in dir_files if "checkpoint-" in el][0]
                    model_cpt = os.path.join(args.model_cpt, checkpoint_dir, "pytorch_model.bin")
                cpt = torch.load(model_cpt, map_location='cpu')
                model.load_state_dict(cpt, strict=False)
            logger.info(f'Loaded RMT state dict from: {args.model_cpt}')
            logger.info(f'Trainable parameters: {[n for n, p in model.named_parameters() if p.requires_grad]}')
    
    if args.add_lora_to_armt:
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=args.lora_attn_dim,
            lora_alpha=args.lora_attn_alpha,
            lora_dropout=args.lora_dropout
            )
        # add LoRA only to the inner model
        model.memory_cell.model = get_peft_model(model.memory_cell.model, peft_config)
        logger.info('Added LoRA, trainable parameters with LoRA only:')
        model.memory_cell.model.print_trainable_parameters()
        # print(model)
    
    if args.freeze_model_weights:
        for n, p in model.named_parameters():
            if 'memory' not in n and 'lora' not in n and 'adapter' not in n:
                p.requires_grad = False
            else:
                p.requires_grad = True
        logger.info('Frozen model weights')
        logger.info(f'Remaining parameters: {[n for n, p in model.named_parameters() if p.requires_grad]}')
    
    if args.tune_only_memory:
        for n, p in model.named_parameters():
            if 'memory_cell.memory' not in n:
                p.requires_grad = False
            else:
                p.requires_grad = True
        logger.info('Frozen model weights')
        logger.info(f'Remaining parameters: {[n for n, p in model.named_parameters() if p.requires_grad]}')
    
    if args.tune_only_armt:
        for n, p in model.named_parameters():
            if 'memory_cell.memory' not in n and 'W_mq' not in n \
                    and 'W_mk' not in n and 'W_mv' not in n and 'W_mb' not in n:
                p.requires_grad = False
            else:
                p.requires_grad = True
        logger.info('Frozen model weights')
        logger.info(f'Remaining parameters: {[n for n, p in model.named_parameters() if p.requires_grad]}')

    # fix the not-contiguous error
    def make_contiguous(module):
        with torch.no_grad():
            for param in module.parameters():
                param.set_(param.contiguous())
    make_contiguous(model)
    
    # ============================
    # === Preparing HF trainer ===
    # ============================
    training_args_dict = {key: value for key, value in vars(args).items() if hasattr(SFTConfig('.'), key)}

    training_args_dict['remove_unused_columns'] = False
    training_args_dict['save_safetensors'] = False
    training_args_dict['label_names'] = ['labels']
    training_args_dict['eval_strategy'] = 'steps'
    per_device_eval_batch_size = training_args_dict.get('per_device_train_batch_size') // 8
    training_args_dict['per_device_eval_batch_size'] = max(per_device_eval_batch_size, 1)
    training_args_dict['eval_accumulation_steps'] = 16
    if args.d_mem is None:
        # for now, gradient checkpointing doesn't supported for ARMT
        training_args_dict['gradient_checkpointing'] = True
        training_args_dict['gradient_checkpointing_kwargs'] = {'use_reentrant': False}
    training_args_dict['log_level'] = 'debug'
    training_args_dict['load_best_model_at_end'] = args.early_stopping_patience != -1

    training_args_dict['dataset_kwargs'] = {"skip_prepare_dataset": True}

    if args.num_mem_tokens is not None:
        # fix max_seq_length warning
        training_args_dict["max_seq_length"] = args.segment_size
    training_args = SFTConfig(**training_args_dict)

    def compute_accuracy(eval_pred):
        preds = eval_pred.predictions.argmax(axis=-1)[:, :-1]
        labels = eval_pred.label_ids[:, 1:]

        labels_masks = labels > 0
        preds_full = [p[m] for p, m in zip(preds, labels_masks)]
        labels_full = [lab[m] for lab, m in zip(labels, labels_masks)]

        special_tokens = {ans[0], bos[0]}
        acc_cot, acc_ans = [], []
        for lab_tokens, pred_tokens in zip(labels_full, preds_full):
            ans_start_index = max(i for i, x in enumerate(lab_tokens) if x == ans[0])

            pred_cot_tokens = pred_tokens[:ans_start_index].tolist()
            lab_cot_tokens = lab_tokens[:ans_start_index].tolist()

            cot_correct = [p == l for p, l in zip(pred_cot_tokens, lab_cot_tokens) if l not in special_tokens]
            acc_cot.append(all(cot_correct))

            pred_ans_tokens = pred_tokens[ans_start_index:].tolist()
            lab_ans_tokens = lab_tokens[ans_start_index:].tolist()

            ans_correct = [p == l for p, l in zip(pred_ans_tokens, lab_ans_tokens) if l not in special_tokens]
            acc_ans.append(all(ans_correct))

        return {'accuracy_cot': np.mean(acc_cot), 'accuracy_ans': np.mean(acc_ans)}

    def lr_lambda(current_step):
        if current_step < training_args.warmup_steps:
            return current_step / training_args.warmup_steps
        if args.lr_scheduler_type == "linear":
            decay_factor = (training_args.max_steps - current_step) / (training_args.max_steps - training_args.warmup_steps)
            return max(args.min_lr / training_args.learning_rate, decay_factor)
        elif args.lr_scheduler_type == "constant":
            return 1.0
        else:
            raise ValueError("Unsupported lr_scheduler_type")

    optimizer = AdamW(model.parameters(), lr=training_args.learning_rate)
    scheduler = LambdaLR(optimizer, lr_lambda)

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=valid_dataset,
        processing_class=tokenizer,
        data_collator=collate_fn,
        compute_metrics=compute_accuracy,
        optimizers=(optimizer, scheduler)
    )
    logger.info(f"Trainer Gradient Checkpointing Enabled: {trainer.args.gradient_checkpointing}")
    if args.early_stopping_patience != -1:
        early_stopping = EarlyStoppingCallback(
            early_stopping_patience=args.early_stopping_patience
        )
        trainer.add_callback(early_stopping)
    start_metrics = trainer.evaluate()
    logger.info(f"Metrics of initial model: {start_metrics}")
    if not args.validate_only:
        trainer.train(resume_from_checkpoint=args.checkpoint)
