In [1]:
from dataclasses import dataclass
from typing import Dict
import numpy as np
from transformers import Wav2Vec2Processor, Data2VecAudioModel
from transformers.training_args import TrainingArguments
from transformers.utils import logging
from transformers.models.data2vec.configuration_data2vec_audio import Data2VecAudioConfig

from transformers import TrainingArguments
from transformers import Data2VecAudioConfig
from datasets import load_metric
import argparse
from utils import csv2dataset

from transformers import Trainer
from Models import (DataCollatorCTCWithPadding, 
                    Data2VecAudioForCTC,)
import time
import os, json
logger = logging.get_logger(__name__)

class DementiaGRLTrainer(Trainer):  
    def update_log_file(self, log_file=None):
        self.log_file=log_file
    def compute_loss(self, model, inputs, return_outputs=False):
            """
            How the loss is computed by Trainer. By default, all models return the loss in the first element.
            Subclass and override for custom behavior.
            """
            #dementia_labels = inputs.pop("dementia_labels") # pop 出來就會不見?
            
            if self.label_smoother is not None and "labels" in inputs:
                labels = inputs.pop("labels")
            else:
                labels = None
            
            outputs = model(**inputs)
            # Save past state if it exists
            # TODO: this needs to be fixed and made cleaner later.
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index]

            if labels is not None:
                loss = self.label_smoother(outputs, labels)
            else:
                # We don't use .loss here since the model may return tuples instead of ModelOutput.
                loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

            return (loss, outputs) if return_outputs else loss
    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.
        Subclass and override this method to inject custom behavior.
        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = round(self.state.epoch, 2)

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        # 設定log file位置與名稱

        LOG_DIR = './saves/log/'
        if not os.path.exists(LOG_DIR):
            os.makedirs(LOG_DIR)
        # write to txt file
        file_object = open(LOG_DIR + self.log_file, 'a')
        # Append at the end of file
        file_object.write(json.dumps(output) + '\n')
        # Close the file
        file_object.close()

        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
def prepare_dataset(batch):
    audio = batch["array"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
        
    return batch



wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)
    label_ids_asr , label_ids_AD=pred.label_ids

    label_ids_asr[label_ids_asr == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(label_ids_asr, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}



parser = argparse.ArgumentParser()
parser.add_argument('-lam', '--LAMBDA', type=float, default=0.5, help="Lambda for GRL")
parser.add_argument('-st', '--STAGE', type=int, default=1, help="Current training stage")
parser.add_argument('-GRL', '--GRL', action='store_true', default=False, help="True: GRL")
parser.add_argument('-model_in', '--model_in_path', type=str, default="/mnt/Internal/FedASR/weitung/HuggingFace/Pretrain/saves/data2vec-audio-large-960h/final/", help="Where the model is saved")
parser.add_argument('-model_out', '--model_out_path', type=str, default="./saves/data2vec2-base-960h_linear_GRL", help="Where to save the model")
parser.add_argument('-log', '--log_path', type=str, default="data2vec2-base-960h_linear_GRL.txt", help="name for the txt file")
args = parser.parse_args(args=[])
LAMBDA = args.LAMBDA                    # lambda for GRL
REVERSE = args.GRL                      # not used in this version
STAGE = args.STAGE                      # stage 1: train AD classifier; stage 2: train toggling network
model_in_dir = args.model_in_path       # path to load the initial model
model_out_dir = args.model_out_path     # path to store the resulted model
log_file = args.log_path                # path to save log file



# threshold for maskes, not used here
AD_THRES = 0.5
LM_THRES = 0.5

# load model from huggingface hub, here data2vec model
name = "facebook/" + model_in_dir.split("/")[-3]
print("Current model: ", name)

mask_time_prob = 0                                         # change config to avoid training stopping
config = Data2VecAudioConfig.from_pretrained(name, mask_time_prob=mask_time_prob)
model = Data2VecAudioForCTC.from_pretrained(name, config=config,LAMBDA=LAMBDA)
processor = Wav2Vec2Processor.from_pretrained(name)
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

# load train / test data
train_data = csv2dataset(csv_path = "/mnt/Internal/FedASR/Data/ADReSS-IS2020-data/mid_csv/train.csv")
#dev_data = csv2dataset(path = "/mnt/Internal/FedASR/Data/ADReSS-IS2020-data/mid_csv/dev.csv")
test_data = csv2dataset(csv_path = "/mnt/Internal/FedASR/Data/ADReSS-IS2020-data/mid_csv/test.csv")

# map to desired form
train_data = train_data.map(prepare_dataset, num_proc=10)
#dev_data = dev_data.map(prepare_dataset, num_proc=10)
test_data = test_data.map(prepare_dataset, num_proc=10)
    
training_args = TrainingArguments(
    output_dir=model_out_dir,
    group_by_length=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=1,
    evaluation_strategy="steps",
    num_train_epochs=30,                 # finetune & GRL
    fp16=True,
    gradient_checkpointing=True, 
    save_steps=500,
    eval_steps=500,
    logging_steps=100000000,
    learning_rate=1e-5,
    weight_decay=0.005,
    warmup_steps=1000,
    save_total_limit=2,
    # log_level='debug',
    logging_strategy="steps",
    optim="adafactor", #adamw_hf, adamw_torch, adamw_apex_fused, or adafactor.
    #adafactor=False,            # default:false. Whether or not to use transformers.Adafactor optimizer instead of transformers.AdamW
    #fp16_full_eval=True,      # to save memory
    max_grad_norm=0.5
)

trainer = DementiaGRLTrainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data, 
    eval_dataset=test_data,
    tokenizer=processor.feature_extractor,
)
trainer.update_log_file(log_file=log_file)
# trainer.evaluate(eval_dataset=test_data.select(range(2)))
trainer.evaluate(eval_dataset=test_data)
###############################################
# eval_dataset=test_data
# trainer._memory_tracker.start()

# eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
# start_time = time.time()

# eval_loop = trainer.prediction_loop if trainer.args.use_legacy_prediction_loop else trainer.evaluation_loop
# output = eval_loop(
#     eval_dataloader,
#     description="Evaluation",
#     # No point gathering the predictions if there are no metrics, otherwise we defer to
#     # self.args.prediction_loss_only
#     prediction_loss_only=True if trainer.compute_metrics is None else None,
#     ignore_keys=None,
#     metric_key_prefix=None,
# )
###########################################################
# eval_dataset=test_data
# trainer._memory_tracker.start()

# eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
# start_time = time.time()
# eval_dataloader = trainer.get_eval_dataloader(eval_dataset)
# from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled
# from transformers.trainer_utils import (
#     PREFIX_CHECKPOINT_DIR,
#     BestRun,
#     EvalLoopOutput,
#     EvalPrediction,
#     FSDPOption,
#     HPSearchBackend,
#     HubStrategy,
#     IntervalStrategy,
#     PredictionOutput,
#     RemoveColumnsCollator,
#     ShardedDDPOption,
#     TrainerMemoryTracker,
#     TrainOutput,
#     default_compute_objective,
#     default_hp_space,
#     denumpify_detensorize,
#     enable_full_determinism,
#     find_executable_batch_size,
#     get_last_checkpoint,
#     has_length,
#     number_of_arguments,
#     seed_worker,
#     set_seed,
#     speed_metrics,
# )
# from transformers.utils import (
#     CONFIG_NAME,
#     SAFE_WEIGHTS_INDEX_NAME,
#     SAFE_WEIGHTS_NAME,
#     WEIGHTS_INDEX_NAME,
#     WEIGHTS_NAME,
#     can_return_loss,
#     find_labels,
#     get_full_repo_name,
#     is_accelerate_available,
#     is_apex_available,
#     is_datasets_available,
#     is_in_notebook,
#     is_ipex_available,
#     is_safetensors_available,
#     is_sagemaker_dp_enabled,
#     is_sagemaker_mp_enabled,
#     is_torch_compile_available,
#     is_torch_neuroncore_available,
#     is_torch_tpu_available,
#     logging,
#     strtobool,
# )
# from transformers.trainer_pt_utils import (
#     DistributedLengthGroupedSampler,
#     DistributedSamplerWithLoop,
#     DistributedTensorGatherer,
#     IterableDatasetShard,
#     LabelSmoother,
#     LengthGroupedSampler,
#     SequentialDistributedSampler,
#     ShardSampler,
#     distributed_broadcast_scalars,
#     distributed_concat,
#     find_batch_size,
#     get_model_param_count,
#     get_module_class_from_name,
#     get_parameter_names,
#     nested_concat,
#     nested_detach,
#     nested_numpify,
#     nested_truncate,
#     nested_xla_mesh_reduce,
#     reissue_pt_warnings,
# )
# if is_torch_tpu_available(check_device=False):
#     import torch_xla.core.xla_model as xm
#     import torch_xla.debug.metrics as met
#     import torch_xla.distributed.parallel_loader as pl
# import torch
# # ========== Init
# args = trainer.args
# dataloader=eval_dataloader
# description="Evaluation"
# prediction_loss_only=True if trainer.compute_metrics is None else None
# ignore_keys=None
# metric_key_prefix=None
# # ===========================


# prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

# # if eval is called w/o train init deepspeed here
# if args.deepspeed and not trainer.deepspeed:
#     # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
#     # from the checkpoint eventually
#     deepspeed_engine, _, _ = deepspeed_init(
#         trainer, num_training_steps=0, resume_from_checkpoint=None, inference=True
#     )
#     trainer.model = deepspeed_engine.module
#     trainer.model_wrapped = deepspeed_engine
#     trainer.deepspeed = deepspeed_engine

# model = trainer._wrap_model(trainer.model, training=False, dataloader=dataloader)

# # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
# # while ``train`` is running, cast it to the right dtype first and then put on device
# if not trainer.is_in_train:
#     if args.fp16_full_eval:
#         model = model.to(dtype=torch.float16, device=args.device)
#     elif args.bf16_full_eval:
#         model = model.to(dtype=torch.bfloat16, device=args.device)

# batch_size = trainer.args.eval_batch_size

# logger.info(f"***** Running {description} *****")
# if has_length(dataloader):
#     logger.info(f"  Num examples = {trainer.num_examples(dataloader)}")
# else:
#     logger.info("  Num examples: Unknown")
# logger.info(f"  Batch size = {batch_size}")

# model.eval()

# trainer.callback_handler.eval_dataloader = dataloader
# # Do this before wrapping.
# eval_dataset = getattr(dataloader, "dataset", None)

# if is_torch_tpu_available():
#     dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)

# if args.past_index >= 0:
#     trainer._past = None

# # Initialize containers
# # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
# losses_host = None
# preds_host = None
# labels_host = None
# inputs_host = None

# # losses/preds/labels on CPU (final containers)
# all_losses = None
# all_preds = None
# all_labels = None
# all_inputs = None
# # Will be useful when we have an iterable dataset so don't know its length.

# observed_num_examples = 0
# # Main evaluation loop
# for step, inputs in enumerate(dataloader):
#     # Update the observed num examples
#     observed_batch_size = find_batch_size(inputs)
#     if observed_batch_size is not None:
#         observed_num_examples += observed_batch_size
#         # For batch samplers, batch_size is not known by the dataloader in advance.
#         if batch_size is None:
#             batch_size = observed_batch_size

#     # Prediction step
#     loss, logits, labels = trainer.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
#     inputs_decode = trainer._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
#     # print(loss, logits, labels)
#     # print(inputs_decode)
#     if is_torch_tpu_available():
#         xm.mark_step()

#     # Update containers on host
#     if loss is not None:
#         losses = trainer._nested_gather(loss.repeat(batch_size))
#         losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
#     if labels is not None:
#         labels = trainer._pad_across_processes(labels)
#         labels = trainer._nested_gather(labels)
#         labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
#     if inputs_decode is not None:
#         inputs_decode = trainer._pad_across_processes(inputs_decode)
#         inputs_decode = trainer._nested_gather(inputs_decode)
#         inputs_host = (
#             inputs_decode
#             if inputs_host is None
#             else nested_concat(inputs_host, inputs_decode, padding_index=-100)
#         )
#     if logits is not None:
#         logits = trainer._pad_across_processes(logits)
#         logits = trainer._nested_gather(logits)
#         if trainer.preprocess_logits_for_metrics is not None:
#             logits = trainer.preprocess_logits_for_metrics(logits, labels)
#         preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
#     trainer.control = trainer.callback_handler.on_prediction_step(args, trainer.state, trainer.control)

#     # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
#     if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
#         if losses_host is not None:
#             losses = nested_numpify(losses_host)
#             all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
#         if preds_host is not None:
#             logits = nested_numpify(preds_host)
#             all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
#         if inputs_host is not None:
#             inputs_decode = nested_numpify(inputs_host)
#             all_inputs = (
#                 inputs_decode
#                 if all_inputs is None
#                 else nested_concat(all_inputs, inputs_decode, padding_index=-100)
#             )
#         if labels_host is not None:
#             labels = nested_numpify(labels_host)
#             all_labels = (
#                 labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
#             )

#         # Set back to None to begin a new accumulation
#         losses_host, preds_host, inputs_host, labels_host = None, None, None, None

# if args.past_index and hasattr(trainer, "_past"):
#     # Clean the state at the end of the evaluation loop
#     delattr(trainer, "_past")

# # Gather all remaining tensors and put them back on the CPU
# if losses_host is not None:
#     losses = nested_numpify(losses_host)
#     all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
# if preds_host is not None:
#     logits = nested_numpify(preds_host)
#     all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
# if inputs_host is not None:
#     inputs_decode = nested_numpify(inputs_host)
#     all_inputs = (
#         inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
#     )
# if labels_host is not None:
#     labels = nested_numpify(labels_host)
#     all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

# # Number of samples
# if has_length(eval_dataset):
#     num_samples = len(eval_dataset)
# # The instance check is weird and does not actually check for the type, but whether the dataset has the right
# # methods. Therefore we need to make sure it also has the attribute.
# elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
#     num_samples = eval_dataset.num_examples
# else:
#     if has_length(dataloader):
#         num_samples = trainer.num_examples(dataloader)
#     else:  # both len(dataloader.dataset) and len(dataloader) fail
#         num_samples = observed_num_examples
# if num_samples == 0 and observed_num_examples > 0:
#     num_samples = observed_num_examples

# # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
# # samplers has been rounded to a multiple of batch_size, so we truncate.
# if all_losses is not None:
#     all_losses = all_losses[:num_samples]
# if all_preds is not None:
#     all_preds = nested_truncate(all_preds, num_samples)
# if all_labels is not None:
#     all_labels = nested_truncate(all_labels, num_samples)
# if all_inputs is not None:
#     all_inputs = nested_truncate(all_inputs, num_samples)

# # Metrics!
# if trainer.compute_metrics is not None and all_preds is not None and all_labels is not None:
#     if args.include_inputs_for_metrics:
#         metrics = trainer.compute_metrics(
#             EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
#         )
#     else:
#         metrics = trainer.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
# else:
#     metrics = {}


aaa=ccc
trainer.train() #"./saves/data2vec-audio-large-960h_GRL/checkpoint-56000/"
trainer.update_log_file(log_file=log_file)
trainer.save_model(model_out_dir + "/final")








2023-04-18 15:38:55.116250: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Current model:  facebook/data2vec-audio-large-960h


Some weights of the model checkpoint at facebook/data2vec-audio-large-960h were not used when initializing Data2VecAudioForCTC: ['data2vec_audio.masked_spec_embed']
- This IS expected if you are initializing Data2VecAudioForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Data2VecAudioForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Data2VecAudioForCTC were not initialized from the model checkpoint at facebook/data2vec-audio-large-960h and are newly initialized: ['dementia_head.weight', 'dementia_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


lambda =  tensor(0.5000)


Loading cached processed dataset at /home/FedASR/dacs/centralized/dataset/train/cache-684be29c00c58e99_*_of_00010.arrow
Loading cached processed dataset at /home/FedASR/dacs/centralized/dataset/test/cache-bbb0167bf98fd901_*_of_00010.arrow


Load data from local...
Load data from local...


    There is an imbalance between your GPUs. You may want to exclude GPU 4 which
    has less than 75% of the memory or cores of GPU 0. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.
  "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "


NameError: name 'ccc' is not defined

In [None]:
trainer.compute_metrics