diff --git a/bats_ai/core/management/commands/load_public_dataset.py b/bats_ai/core/management/commands/load_public_dataset.py index f64ce76d..bc87fbd0 100644 --- a/bats_ai/core/management/commands/load_public_dataset.py +++ b/bats_ai/core/management/commands/load_public_dataset.py @@ -76,6 +76,17 @@ def _get_metadata(filename: str, line: dict[str, str]) -> dict[str, Any]: return metadata +def _try_head_s3_object(s3_client, bucket: str, key: str) -> bool: + try: + s3_client.head_object(Bucket=bucket, Key=key) + return True + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code in ["404", "NoSuchKey"]: + return False + raise + + def _try_start_spectrogram_generation(recording_id: int): metadata_filter = { "type": ProcessingTaskType.SPECTROGRAM_GENERATION.value, @@ -126,6 +137,10 @@ def _ingest_files_from_manifest( _try_start_spectrogram_generation(existing_recording.pk) continue logger.info("Ingesting %s...", s3_key) + object_exists = _try_head_s3_object(s3_client, bucket, s3_key) + if not object_exists: + logger.warning("Could not HEAD object with key %s. Skipping...", s3_key) + continue filename = _create_filename(s3_key) logger.info("Downloading to temporary file %s...", filename) s3_client.download_file(bucket, s3_key, filename) @@ -180,7 +195,9 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument( - "bucket", type=str, help="Name of a public S3 bucket where WAV files are stored" + "bucket", + type=str, + help="Name of a public S3 bucket where WAV files are stored", ) parser.add_argument( "manifest", @@ -194,10 +211,16 @@ def add_arguments(self, parser): help="Username of the owner of the recordings. (Defaults to first superuser)", ) parser.add_argument( - "-p", "--public", action="store_true", help="Make imported recordings public" + "-p", + "--public", + action="store_true", + help="Make imported recordings public", ) parser.add_argument( - "-l", "--limit", type=int, help="Limit the number of WAV files to be imported" + "-l", + "--limit", + type=int, + help="Limit the number of WAV files to be imported", ) parser.add_argument( "--filekey",