Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems with finetuned model for VQAv2 (ms coco) #260

Closed
25icecreamflavors opened this issue Oct 7, 2022 · 11 comments
Closed

Problems with finetuned model for VQAv2 (ms coco) #260

25icecreamflavors opened this issue Oct 7, 2022 · 11 comments
Assignees

Comments

@25icecreamflavors
Copy link

I am doing inference for the VQA val set manually to get all answers using your demo colab notebook. I used to do everything like you wrote there and so I was using Pre-trained checkpoint (OFA-Large) as it was in tutorial, the quality was around 68% accuracy on a val set. Then I decided to change the model to the Finetuned checkpoint for VQAv2. It works with the same code, however, it's behaviour is strange, inference is very slow: for pretrained model it was 214k answers by 10 hours and now I got only 30k answers by 15 hours on the same Tesla V100. Also the quality is worse for some reason, it's around 60% accuracy now. Some answers are strange, some are completely correct and some are, for example "bedroom bedroom bedroom bedroom bedroom ...", "no no no no no no no..." etc. For some reason model gives very long answers and doesn't stop generating sequence of words.

I am a bit confused, why is it happening, maybe I am doing something wrong. I run model as in this code: https://colab.research.google.com/drive/1lsMsF-Vum3MVyXwSVF5E-Y23rHFvj_3y?usp=sharing

I only change path to finetuned model in this part:

parser = options.get_generation_parser()
input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", "--path=checkpoints/vqa_large_best.pt", "--bpe-dir=utils/BPE"]
args = options.parse_args_and_arch(parser, input_args)
cfg = convert_namespace_to_omegaconf(args)
@25icecreamflavors
Copy link
Author

image
Some examples of another answers:

@yangapku yangapku self-assigned this Oct 8, 2022
@25icecreamflavors
Copy link
Author

Hm, I didn't use train, I thought that I can use model immediately for a prediction. If you are talking about VQA val set, it doesn't really matter, each time I just extract a text like a usual string as in demo notebook and also give model an image like in demo.

@yangapku
Copy link
Member

Hi, the finetuned VQA model cannot directly replace the pretrained checkpoint, since their configs are not compatible. I would recommend to try to use the shell script (evaluate_vqa_beam.sh) to generate your answers.

@yangapku
Copy link
Member

If you would like to use Colab, please consider adapting the code in evaluation.py to the notebook.

@25icecreamflavors
Copy link
Author

@yangapku can I use a finetuned for COCO version of OFA via huggingface interface? Or do I need to use exactly bash script?

@yangapku
Copy link
Member

yangapku commented Nov 7, 2022

Hi, I suggest to first use our provided script for finetuning (train_vqa_distributed.sh) and evaluation (evaluate_vqa_beam.sh), which are both provided for fairseq-based codebase.

@25icecreamflavors
Copy link
Author

@yangapku I tried to use the script, however I get the following error. I don't know what's the problem. I had some problems with a fairseq, it wrote that it can't import something. But I downgraded pip to 21 and installed fairseq 12.0. However, now I get this strange thing ' Key 'train_ans2label_file' not in 'VqaGenConfig'', but I also have trains2anslabel file, don't know how to fix it.

2022-11-08 06:48:52 | INFO | torch.distributed.nn.jit.instantiator | Created a temporary directory at /tmp/job-893474/tmpvd9j7xyh
2022-11-08 06:48:52 | INFO | torch.distributed.nn.jit.instantiator | Writing /tmp/job-893474/tmpvd9j7xyh/_remote_module_non_scriptable.py
2022-11-08 06:48:54 | INFO | ofa.evaluate | {'_name': None, 'common': {'_name': None, 'no_progress_bar': False, 'log_interval': 10, 'log_format': 'simple', 'log_file': None, 'aim_repo': None, 'aim_run_hash': None, 'tensorboard_logdir': None, 'wandb_project': None, 'azureml_logging': False, 'seed': 7, 'cpu': False, 'tpu': False, 'bf16': False, 'memory_efficient_bf16': False, 'fp16': True, 'memory_efficient_fp16': False, 'fp16_no_flatten_grads': False, 'fp16_init_scale': 128, 'fp16_scale_window': None, 'fp16_scale_tolerance': 0.0, 'on_cpu_convert_precision': False, 'min_loss_scale': 0.0001, 'threshold_loss_scale': None, 'amp': False, 'amp_batch_retries': 2, 'amp_init_scale': 128, 'amp_scale_window': None, 'user_dir': '../../ofa_module', 'empty_cache_freq': 0, 'all_gather_list_size': 16384, 'model_parallel_size': 1, 'quantization_config_path': None, 'profile': False, 'reset_logging': False, 'suppress_crashes': False, 'use_plasma_view': False, 'plasma_path': '/tmp/plasma'}, 'common_eval': {'_name': None, 'path': '../../checkpoints/vqa_large_best.pt', 'post_process': None, 'quiet': False, 'model_overrides': '{"data":"../../dataset/vqa_data/vqa_.tsv","bpe_dir":"../../utils/BPE","selected_cols":"0,5,2,3,4"}', 'results_path': '../../results/vqa__beam'}, 'distributed_training': {'_name': None, 'distributed_world_size': 1, 'distributed_num_procs': 1, 'distributed_rank': 0, 'distributed_backend': 'nccl', 'distributed_init_method': None, 'distributed_port': -1, 'device_id': 0, 'distributed_no_spawn': False, 'ddp_backend': 'pytorch_ddp', 'ddp_comm_hook': 'none', 'bucket_cap_mb': 25, 'fix_batches_to_gpus': False, 'find_unused_parameters': False, 'gradient_as_bucket_view': False, 'fast_stat_sync': False, 'heartbeat_timeout': -1, 'broadcast_buffers': False, 'slowmo_momentum': None, 'slowmo_base_algorithm': 'localsgd', 'localsgd_frequency': 3, 'nprocs_per_node': 1, 'pipeline_model_parallel': False, 'pipeline_balance': None, 'pipeline_devices': None, 'pipeline_chunks': 0, 'pipeline_encoder_balance': None, 'pipeline_encoder_devices': None, 'pipeline_decoder_balance': None, 'pipeline_decoder_devices': None, 'pipeline_checkpoint': 'never', 'zero_sharding': 'none', 'fp16': True, 'memory_efficient_fp16': False, 'tpu': False, 'no_reshard_after_forward': False, 'fp32_reduce_scatter': False, 'cpu_offload': False, 'use_sharded_state': False, 'not_fsdp_flatten_parameters': False}, 'dataset': {'_name': None, 'num_workers': 0, 'skip_invalid_size_inputs_valid_test': False, 'max_tokens': None, 'batch_size': 16, 'required_batch_size_multiple': 8, 'required_seq_len_multiple': 1, 'dataset_impl': None, 'data_buffer_size': 10, 'train_subset': 'train', 'valid_subset': 'valid', 'combine_valid_subsets': None, 'ignore_unused_valid_subsets': False, 'validate_interval': 1, 'validate_interval_updates': 0, 'validate_after_updates': 0, 'fixed_validation_seed': None, 'disable_validation': False, 'max_tokens_valid': None, 'batch_size_valid': 16, 'max_valid_steps': None, 'curriculum': 0, 'gen_subset': '', 'num_shards': 1, 'shard_id': 0, 'grouped_shuffling': False, 'update_epoch_batch_itr': False, 'update_ordered_indices_seed': False}, 'optimization': {'_name': None, 'max_epoch': 0, 'max_update': 0, 'stop_time_hours': 0.0, 'clip_norm': 0.0, 'sentence_avg': False, 'update_freq': [1], 'lr': [0.25], 'stop_min_lr': -1.0, 'use_bmuf': False, 'skip_remainder_batch': False}, 'checkpoint': {'_name': None, 'save_dir': 'checkpoints', 'restore_file': 'checkpoint_last.pt', 'continue_once': None, 'finetune_from_model': None, 'reset_dataloader': False, 'reset_lr_scheduler': False, 'reset_meters': False, 'reset_optimizer': False, 'optimizer_overrides': '{}', 'save_interval': 1, 'save_interval_updates': 0, 'keep_interval_updates': -1, 'keep_interval_updates_pattern': -1, 'keep_last_epochs': -1, 'keep_best_checkpoints': -1, 'no_save': False, 'no_epoch_checkpoints': False, 'no_last_checkpoints': False, 'no_save_optimizer_state': False, 'best_checkpoint_metric': 'loss', 'maximize_best_checkpoint_metric': False, 'patience': -1, 'checkpoint_suffix': '', 'checkpoint_shard_count': 1, 'load_checkpoint_on_all_dp_ranks': False, 'write_checkpoints_asynchronously': False, 'model_parallel_size': 1}, 'bmuf': {'_name': None, 'block_lr': 1.0, 'block_momentum': 0.875, 'global_sync_iter': 50, 'warmup_iterations': 500, 'use_nbm': False, 'average_sync': False, 'distributed_world_size': 1}, 'generation': {'_name': None, 'beam': 5, 'nbest': 1, 'max_len_a': 0.0, 'max_len_b': 200, 'min_len': 1, 'match_source_len': False, 'unnormalized': True, 'no_early_stop': False, 'no_beamable_mm': False, 'lenpen': 1.0, 'unkpen': 0.0, 'replace_unk': None, 'sacrebleu': False, 'score_reference': False, 'prefix_size': 0, 'no_repeat_ngram_size': 0, 'sampling': False, 'sampling_topk': -1, 'sampling_topp': -1.0, 'constraints': None, 'temperature': 1.0, 'diverse_beam_groups': -1, 'diverse_beam_strength': 0.5, 'diversity_rate': -1.0, 'print_alignment': None, 'print_step': False, 'lm_path': None, 'lm_weight': 0.0, 'iter_decode_eos_penalty': 0.0, 'iter_decode_max_iter': 10, 'iter_decode_force_max_iter': False, 'iter_decode_with_beam': 1, 'iter_decode_with_external_reranker': False, 'retain_iter_history': False, 'retain_dropout': False, 'retain_dropout_modules': None, 'decoding_format': None, 'no_seed_provided': False, 'eos_token': None}, 'eval_lm': {'_name': None, 'output_word_probs': False, 'output_word_stats': False, 'context_window': 0, 'softmax_batch': 9223372036854775807}, 'interactive': {'_name': None, 'buffer_size': 0, 'input': '-'}, 'model': None, 'task': {'_name': 'vqa_gen', 'data': '../../dataset/vqa_data/vqa_.tsv', 'selected_cols': None, 'bpe_dir': None, 'max_source_positions': 1024, 'max_target_positions': 1024, 'max_src_length': 128, 'max_tgt_length': 30, 'code_dict_size': 8192, 'patch_image_size': 480, 'num_bins': 1000, 'imagenet_default_mean_and_std': False, 'constraint_range': None, 'max_object_length': 30, 'ans2label_dict': '{"no": 0, "yes":1}', 'ans2label_file': None, 'add_object': False, 'valid_batch_size': 20, 'prompt_type': None, 'uses_ema': False, 'val_inference_type': 'allcand', 'eval_args': '{"beam":5,"unnormalized":true,"temperature":1.0}'}, 'criterion': {'_name': 'cross_entropy', 'sentence_avg': True}, 'optimizer': None, 'lr_scheduler': {'_name': 'fixed', 'force_anneal': None, 'lr_shrink': 0.1, 'warmup_updates': 0, 'lr': [0.25]}, 'scoring': {'_name': 'bleu', 'pad': 1, 'eos': 2, 'unk': 3}, 'bpe': None, 'tokenizer': None, 'ema': {'_name': None, 'store_ema': False, 'ema_decay': 0.9999, 'ema_start_update': 0, 'ema_seed_model': None, 'ema_update_freq': 1, 'ema_fp32': False}, 'simul_type': None}
2022-11-08 06:48:54 | INFO | ofa.evaluate | loading model(s) from ../../checkpoints/vqa_large_best.pt
Traceback (most recent call last):
  File "/home/aashirnin/content/OFA/run_scripts/vqa/../../evaluate.py", line 156, in <module>
    cli_main()
  File "/home/aashirnin/content/OFA/run_scripts/vqa/../../evaluate.py", line 150, in cli_main
    distributed_utils.call_main(
  File "/home/aashirnin/content/fairseq/fairseq/distributed/utils.py", line 369, in call_main
    main(cfg, **kwargs)
  File "/home/aashirnin/content/OFA/run_scripts/vqa/../../evaluate.py", line 76, in main
    models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
  File "/home/aashirnin/content/OFA/utils/checkpoint_utils.py", line 447, in load_model_ensemble_and_task
    task = tasks.setup_task(cfg.task)
  File "/home/aashirnin/content/fairseq/fairseq/tasks/__init__.py", line 39, in setup_task
    cfg = merge_with_parent(dc(), cfg)
  File "/home/aashirnin/content/fairseq/fairseq/dataclass/utils.py", line 490, in merge_with_parent
    merged_cfg = OmegaConf.merge(dc, cfg)
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/omegaconf.py", line 321, in merge
    target.merge_with(*others[1:])
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/basecontainer.py", line 331, in merge_with
    self._format_and_raise(key=None, value=None, cause=e)
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/base.py", line 95, in _format_and_raise
    format_and_raise(
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/_utils.py", line 629, in format_and_raise
    _raise(ex, cause)
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/_utils.py", line 610, in _raise
    raise ex  # set end OC_CAUSE=1 for full backtrace
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/basecontainer.py", line 329, in merge_with
    self._merge_with(*others)
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/basecontainer.py", line 347, in _merge_with
    BaseContainer._map_merge(self, other)
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/basecontainer.py", line 314, in _map_merge
    dest[key] = src._get_node(key)
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/dictconfig.py", line 258, in __setitem__
    self._format_and_raise(
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/base.py", line 95, in _format_and_raise
    format_and_raise(
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/_utils.py", line 629, in format_and_raise
    _raise(ex, cause)
  File "/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/omegaconf/_utils.py", line 610, in _raise
    raise ex  # set end OC_CAUSE=1 for full backtrace
omegaconf.errors.ConfigKeyError: Key 'train_ans2label_file' not in 'VqaGenConfig'
	full_key: train_ans2label_file
	reference_type=Optional[VqaGenConfig]
	object_type=VqaGenConfig
Fatal error condition occurred in /opt/vcpkg/buildtrees/aws-c-io/src/9e6648842a-364b708815.clean/source/event_loop.c:72: aws_thread_launch(&cleanup_thread, s_event_loop_destroy_async_thread_fn, el_group, &thread_options) == AWS_OP_SUCCESS
Exiting Application
################################################################################
Stack trace:
################################################################################
/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/pyarrow/libarrow.so.900(+0x200af06) [0x2aff6a00af06]
/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/pyarrow/libarrow.so.900(+0x20028e5) [0x2aff6a0028e5]
/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/pyarrow/libarrow.so.900(+0x1f27e09) [0x2aff69f27e09]
/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/pyarrow/libarrow.so.900(+0x200ba3d) [0x2aff6a00ba3d]
/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/pyarrow/libarrow.so.900(+0x1f25948) [0x2aff69f25948]
/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/pyarrow/libarrow.so.900(+0x200ba3d) [0x2aff6a00ba3d]
/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/pyarrow/libarrow.so.900(+0x1ee0b46) [0x2aff69ee0b46]
/home/aashirnin/.conda/envs/myenv/lib/python3.9/site-packages/pyarrow/libarrow.so.900(+0x194546a) [0x2aff6994546a]
/lib64/libc.so.6(+0x39c99) [0x2afee3d09c99]
/lib64/libc.so.6(+0x39ce7) [0x2afee3d09ce7]
/lib64/libc.so.6(__libc_start_main+0xfc) [0x2afee3cf250c]
python3() [0x588eae]
./evaluate_vqa_beam.sh: line 36: 127540 Aborted                 (core dumped) CUDA_VISIBLE_DEVICES=0 python3 ../../evaluate.py ${data} --path=${path} --user-dir=${user_dir} --task=vqa_gen --batch-size=16 --log-format=simple --log-interval=10 --seed=7 --gen-subset=${split} --results-path=${result_path} --fp16 --ema-eval --beam-search-vqa-eval --beam=5 --unnormalized --temperature=1.0 --num-workers=0 --model-overrides="{\"data\":\"${data}\",\"bpe_dir\":\"${bpe_dir}\",\"selected_cols\":\"${selected_cols}\"}"

@yangapku
Copy link
Member

yangapku commented Nov 8, 2022

Hi, please follow our readme ("Installation" section) to install the fairseq we provided in this repo (as well as other dependecies) instead of the official fairseq. If you encounter the problem mentioned in #225 or #217 , please follow the comment in these issues to downgrade the pip.

@25icecreamflavors
Copy link
Author

25icecreamflavors commented Nov 8, 2022

@yangapku yes, I already did it, downgraded pip, installed from the requirements, as you said. There is some other problem with running the script: raise ex # set end OC_CAUSE=1 for full backtrace omegaconf.errors.ConfigKeyError: Key 'train_ans2label_file' not in 'VqaGenConfig.

@yangapku
Copy link
Member

yangapku commented Nov 8, 2022

@25icecreamflavors May I make sure that whether you are using the fairseq in our repo or using the official one?

@25icecreamflavors
Copy link
Author

@yangapku thank you. I reinstalled everything again, everything works ok now.

I referred to the issue #217:

  1. Created conda env with python 3.7.4
  2. Downgraded pip
  3. Installed requirements

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants