diff --git a/integration_tests/test_model_upload.py b/integration_tests/test_model_upload.py index 53c24d23..2728fe98 100644 --- a/integration_tests/test_model_upload.py +++ b/integration_tests/test_model_upload.py @@ -102,6 +102,22 @@ def test_model_upload_directory(self) -> None: # Create Version model_upload(self.handle, temp_dir, LICENSE_NAME) + def test_model_upload_directory_structure(self) -> None: + nested_dir = Path(self.temp_dir) / "nested" + nested_dir.mkdir() + + with open(Path(self.temp_dir) / "file1.txt", "w") as f: + f.write("dummy content in nested file") + + # Create dummy files in the nested directory + nested_dummy_files = ["nested_model.h5", "nested_config.json", "nested_metadata.json"] + for file in nested_dummy_files: + with open(nested_dir / file, "w") as f: + f.write("dummy content in nested file") + + # Call the model upload function with the base directory + model_upload(self.handle, self.temp_dir, LICENSE_NAME) + def test_model_upload_nested_dir(self) -> None: # Create a nested directory within self.temp_dir nested_dir = Path(self.temp_dir) / "nested" diff --git a/src/kagglehub/gcs_upload.py b/src/kagglehub/gcs_upload.py index c551d350..d3129168 100644 --- a/src/kagglehub/gcs_upload.py +++ b/src/kagglehub/gcs_upload.py @@ -1,13 +1,10 @@ import logging import os -import shutil import time import zipfile from datetime import datetime -from multiprocessing import Pool -from pathlib import Path from tempfile import TemporaryDirectory -from typing import List, Tuple, Union +from typing import Dict, List, Optional, Union import requests from requests.exceptions import ConnectionError, Timeout @@ -25,6 +22,25 @@ REQUEST_TIMEOUT = 600 +class UploadDirectoryInfo: + def __init__( + self, + name: str, + files: Optional[List[str]] = None, + directories: Optional[List["UploadDirectoryInfo"]] = None, + ): + self.name = name + self.files = files if files is not None else [] + self.directories = directories if directories is not None else [] + + def serialize(self) -> Dict: + return { + "name": self.name, + "files": [{"token": file} for file in self.files], + "directories": [directory.serialize() for directory in self.directories], + } + + def parse_datetime_string(string: str) -> Union[datetime, str]: time_formats = ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S.%fZ"] for t in time_formats: @@ -138,51 +154,108 @@ def _upload_blob(file_path: str, model_type: str) -> str: return response["token"] -def zip_file(args: Tuple[Path, Path, Path]) -> int: - file_path, zip_path, source_path_obj = args - arcname = file_path.relative_to(source_path_obj) - size = file_path.stat().st_size - with zipfile.ZipFile(zip_path, "a", zipfile.ZIP_STORED, allowZip64=True) as zipf: - zipf.write(file_path, arcname) - return size - - -def zip_files(source_path_obj: Path, zip_path: Path) -> List[int]: - files = [file for file in source_path_obj.rglob("*") if file.is_file()] - args = [(file, zip_path, source_path_obj) for file in files] - - with Pool() as pool: - sizes = pool.map(zip_file, args) - return sizes - - -def upload_files(source_path: str, model_type: str) -> List[str]: - source_path_obj = Path(source_path) - with TemporaryDirectory() as temp_dir: - temp_dir_path = Path(temp_dir) - total_size = 0 - - if source_path_obj.is_dir(): - for file_path in source_path_obj.rglob("*"): - if file_path.is_file(): - total_size += file_path.stat().st_size - elif source_path_obj.is_file(): - total_size = source_path_obj.stat().st_size - else: - path_error_message = "The source path does not point to a valid file or directory." - raise ValueError(path_error_message) - - with tqdm(total=total_size, desc="Zipping", unit="B", unit_scale=True, unit_divisor=1024) as pbar: - if source_path_obj.is_dir(): - zip_path = temp_dir_path / "archive.zip" - sizes = zip_files(source_path_obj, zip_path) - for size in sizes: - pbar.update(size) - upload_path = str(zip_path) - elif source_path_obj.is_file(): - temp_file_path = temp_dir_path / source_path_obj.name - shutil.copy(source_path_obj, temp_file_path) - pbar.update(temp_file_path.stat().st_size) - upload_path = str(temp_file_path) - - return [token for token in [_upload_blob(upload_path, model_type)] if token] +def upload_files_and_directories( + folder: str, model_type: str, quiet: bool = False # noqa: FBT002, FBT001 +) -> UploadDirectoryInfo: + # Count the total number of files + file_count = 0 + for _, _, files in os.walk(folder): + file_count += len(files) + + if file_count > MAX_FILES_TO_UPLOAD: + if not quiet: + logger.info(f"More than {MAX_FILES_TO_UPLOAD} files detected, creating a zip archive...") + + with TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, TEMP_ARCHIVE_FILE) + with zipfile.ZipFile(zip_path, "w") as zipf: + for root, _, files in os.walk(folder): + for file in files: + file_path = os.path.join(root, file) + zipf.write(file_path, os.path.relpath(file_path, folder)) + + tokens = [ + token + for token in [_upload_file_or_folder(temp_dir, TEMP_ARCHIVE_FILE, model_type, quiet)] + if token is not None + ] + return UploadDirectoryInfo(name="archive", files=tokens) + + root_dict = UploadDirectoryInfo(name="root") + if os.path.isfile(folder): + # Directly upload the file if the path is a file + file_name = os.path.basename(folder) + token = _upload_file_or_folder(os.path.dirname(folder), file_name, model_type, quiet) + if token: + root_dict.files.append(token) + else: + for root, _, files in os.walk(folder): + # Path of the current folder relative to the base folder + path = os.path.relpath(root, folder) + + # Navigate or create the dictionary path to the current folder + current_dict = root_dict + if path != ".": + for part in path.split(os.sep): + # Find or create the subdirectory in the current dictionary + for subdir in current_dict.directories: + if subdir.name == part: + current_dict = subdir + break + else: + # If the directory is not found, create a new one + new_dir = UploadDirectoryInfo(name=part) + current_dict.directories.append(new_dir) + current_dict = new_dir + + # Add file tokens to the current directory in the dictionary + for file in files: + token = _upload_file_or_folder(root, file, model_type, quiet) + if token: + current_dict.files.append(token) + + return root_dict + + +def _upload_file_or_folder( + parent_path: str, + file_or_folder_name: str, + model_type: str, + quiet: bool = False, # noqa: FBT002, FBT001 +) -> Optional[str]: + """ + Uploads a file or each file inside a folder individually from a specified path to a remote service. + Parameters + ========== + parent_path: The parent directory path from where the file or folder is to be uploaded. + file_or_folder_name: The name of the file or folder to be uploaded. + dir_mode: The mode to handle directories. Accepts 'zip', 'tar', or other values for skipping. + model_type: Type of the model that is being uploaded. + quiet: suppress verbose output (default is False) + :return: A token if the upload is successful, or None if the file is skipped or the upload fails. + """ + full_path = os.path.join(parent_path, file_or_folder_name) + if os.path.isfile(full_path): + return _upload_file(file_or_folder_name, full_path, quiet, model_type) + return None + + +def _upload_file(file_name: str, full_path: str, quiet: bool, model_type: str) -> Optional[str]: # noqa: FBT001 + """Helper function to upload a single file + Parameters + ========== + file_name: name of the file to upload + full_path: path to the file to upload + quiet: suppress verbose output + model_type: Type of the model that is being uploaded. + :return: None - upload unsuccessful; instance of UploadFile - upload successful + """ + + if not quiet: + logger.info("Starting upload for file " + file_name) + + content_length = os.path.getsize(full_path) + token = _upload_blob(full_path, model_type) + if not quiet: + logger.info("Upload successful: " + file_name + " (" + File.get_size(content_length) + ")") + return token diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index d5e3d5e9..bfa5dd98 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -2,7 +2,7 @@ from typing import Optional from kagglehub import registry -from kagglehub.gcs_upload import upload_files +from kagglehub.gcs_upload import upload_files_and_directories from kagglehub.handle import parse_model_handle from kagglehub.models_helpers import create_model_if_missing, create_model_instance_or_version @@ -47,7 +47,7 @@ def model_upload( create_model_if_missing(h.owner, h.model) # Upload the model files to GCS - tokens = upload_files(local_model_dir, "model") + tokens = upload_files_and_directories(local_model_dir, "model") # Create a model instance if it doesn't exist, and create a new instance version if an instance exists create_model_instance_or_version(h, tokens, license_name, version_notes) diff --git a/src/kagglehub/models_helpers.py b/src/kagglehub/models_helpers.py index a6b1495b..45efed79 100644 --- a/src/kagglehub/models_helpers.py +++ b/src/kagglehub/models_helpers.py @@ -1,9 +1,10 @@ import logging from http import HTTPStatus -from typing import List, Optional +from typing import Optional -from kagglehub.clients import KaggleApiV1Client -from kagglehub.exceptions import BackendError, KaggleApiHTTPError +from kagglehub.clients import BackendError, KaggleApiV1Client +from kagglehub.exceptions import KaggleApiHTTPError +from kagglehub.gcs_upload import UploadDirectoryInfo from kagglehub.handle import ModelHandle logger = logging.getLogger(__name__) @@ -16,11 +17,15 @@ def _create_model(owner_slug: str, model_slug: str) -> None: logger.info(f"Model '{model_slug}' Created.") -def _create_model_instance(model_handle: ModelHandle, files: List[str], license_name: Optional[str] = None) -> None: +def _create_model_instance( + model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, license_name: Optional[str] = None +) -> None: + serialized_data = files_and_directories.serialize() data = { "instanceSlug": model_handle.variation, "framework": model_handle.framework, - "files": [{"token": file_token} for file_token in files], + "files": [{"token": file_token} for file_token in files_and_directories.files], + "directories": serialized_data["directories"], } if license_name is not None: data["licenseName"] = license_name @@ -30,8 +35,15 @@ def _create_model_instance(model_handle: ModelHandle, files: List[str], license_ logger.info(f"Your model instance has been created.\nFiles are being processed...\nSee at: {model_handle.to_url()}") -def _create_model_instance_version(model_handle: ModelHandle, files: List[str], version_notes: str = "") -> None: - data = {"versionNotes": version_notes, "files": [{"token": file_token} for file_token in files]} +def _create_model_instance_version( + model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, version_notes: str = "" +) -> None: + serialized_data = files_and_directories.serialize() + data = { + "versionNotes": version_notes, + "files": [{"token": file_token} for file_token in files_and_directories.files], + "directories": serialized_data["directories"], + } api_client = KaggleApiV1Client() api_client.post( f"/models/{model_handle.owner}/{model_handle.model}/{model_handle.framework}/{model_handle.variation}/create/version", @@ -43,7 +55,7 @@ def _create_model_instance_version(model_handle: ModelHandle, files: List[str], def create_model_instance_or_version( - model_handle: ModelHandle, files: List[str], license_name: Optional[str], version_notes: str = "" + model_handle: ModelHandle, files: UploadDirectoryInfo, license_name: Optional[str], version_notes: str = "" ) -> None: try: _create_model_instance(model_handle, files, license_name) diff --git a/tests/test_model_upload.py b/tests/test_model_upload.py index a13cc68f..ccfb1988 100644 --- a/tests/test_model_upload.py +++ b/tests/test_model_upload.py @@ -140,7 +140,7 @@ def test_model_upload_instance_with_valid_handle(self) -> None: test_filepath.touch() # Create a temporary file in the temporary directory model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1) - self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) + self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) def test_model_upload_instance_with_nested_directories(self) -> None: # execution path: get_model -> create_model -> get_instance -> create_version @@ -156,7 +156,7 @@ def test_model_upload_instance_with_nested_directories(self) -> None: test_filepath.touch() model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1) - self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) + self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) def test_model_upload_version_with_valid_handle(self) -> None: # execution path: get_model -> get_instance -> create_instance @@ -168,7 +168,7 @@ def test_model_upload_version_with_valid_handle(self) -> None: test_filepath.touch() # Create a temporary file in the temporary directory model_upload("metaresearch/llama-2/pyTorch/7b", temp_dir, APACHE_LICENSE, "model_type") self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1) - self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) + self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) def test_model_upload_with_too_many_files(self) -> None: with create_test_http_server(KaggleAPIHandler): @@ -199,7 +199,7 @@ def test_model_upload_resumable(self) -> None: # Check that GcsAPIHandler received two PUT requests self.assertEqual(GcsAPIHandler.put_requests_count, 2) self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1) - self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) + self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) def test_model_upload_with_none_license(self) -> None: with create_test_http_server(KaggleAPIHandler): @@ -209,7 +209,7 @@ def test_model_upload_with_none_license(self) -> None: test_filepath.touch() # Create a temporary file in the temporary directory model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, None, "model_type") self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1) - self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) + self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) def test_model_upload_without_license(self) -> None: with create_test_http_server(KaggleAPIHandler): @@ -219,7 +219,7 @@ def test_model_upload_without_license(self) -> None: test_filepath.touch() # Create a temporary file in the temporary directory model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, version_notes="model_type") self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1) - self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) + self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) def test_model_upload_with_invalid_license_fails(self) -> None: with create_test_http_server(KaggleAPIHandler): @@ -244,3 +244,29 @@ def test_single_file_upload(self) -> None: self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1) self.assertIn("single_dummy_file.txt", KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES) + + def test_model_upload_with_directory_structure(self) -> None: + with create_test_http_server(KaggleAPIHandler): + with create_test_http_server(GcsAPIHandler, "http://localhost:7778"): + with TemporaryDirectory() as temp_dir: + base_path = Path(temp_dir) + (base_path / "dir1").mkdir() + (base_path / "dir2").mkdir() + + (base_path / "file1.txt").touch() + + (base_path / "dir1" / "file2.txt").touch() + (base_path / "dir1" / "file3.txt").touch() + + (base_path / "dir1" / "subdir1").mkdir() + (base_path / "dir1" / "subdir1" / "file4.txt").touch() + + model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type") + + self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 4) + expected_files = {"file1.txt", "file2.txt", "file3.txt", "file4.txt"} + self.assertTrue(set(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES).issubset(expected_files)) + + # TODO: Add assertions on CreateModelInstanceRequest.Directories and + # CreateModelInstanceRequest.Files to verify the expected structure + # is sent.