Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7961cc4
hi
mohami2000 Apr 19, 2024
2a37e7e
hi
mohami2000 Apr 19, 2024
67f34e4
Merge branch 'remove_zipping' of https://github.com/Kaggle/kagglehub …
mohami2000 Apr 24, 2024
9d814a9
Merge branch 'remove_zipping' of https://github.com/Kaggle/kagglehub …
mohami2000 Apr 24, 2024
39fed73
Merge branch 'remove_zipping' of https://github.com/Kaggle/kagglehub …
mohami2000 Apr 24, 2024
79719f6
remove
mohami2000 Apr 24, 2024
94ea4a3
r
mohami2000 Apr 24, 2024
be6406b
r
mohami2000 Apr 24, 2024
4801a74
r
mohami2000 Apr 24, 2024
35c9c5c
r
mohami2000 Apr 24, 2024
be68ec9
r
mohami2000 Apr 24, 2024
a1e4c34
r
mohami2000 Apr 24, 2024
520cad4
r
mohami2000 Apr 24, 2024
c607c64
r
mohami2000 Apr 24, 2024
b7b31ad
r
mohami2000 Apr 24, 2024
777de49
r
mohami2000 Apr 24, 2024
4085044
r
mohami2000 Apr 24, 2024
9b7050f
r
mohami2000 Apr 24, 2024
823eda3
r
mohami2000 Apr 24, 2024
05deed5
r
mohami2000 Apr 24, 2024
b724c7d
r
mohami2000 Apr 24, 2024
afd1095
r
mohami2000 Apr 24, 2024
afdccf3
Merge branch 'main' into remove_zipping
mohami2000 Apr 24, 2024
6d37941
r
mohami2000 Apr 24, 2024
f68fa08
Merge branch 'remove_zipping' of https://github.com/Kaggle/kagglehub …
mohami2000 Apr 24, 2024
ff1f594
r
mohami2000 Apr 24, 2024
008e55a
r
mohami2000 Apr 24, 2024
eaf095c
r
mohami2000 Apr 24, 2024
4dc0f49
r
mohami2000 Apr 24, 2024
d40061c
r
mohami2000 Apr 25, 2024
6d896f0
r
mohami2000 Apr 25, 2024
7736410
r
mohami2000 Apr 25, 2024
cb54f48
r
mohami2000 Apr 25, 2024
e9aff93
ir
mohami2000 Apr 25, 2024
490813f
r
mohami2000 Apr 25, 2024
1afe0fc
r
mohami2000 Apr 25, 2024
b2eb30f
r
mohami2000 Apr 25, 2024
500cd3b
r
mohami2000 Apr 25, 2024
ba688d1
r
mohami2000 Apr 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions integration_tests/test_model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,22 @@ def test_model_upload_directory(self) -> None:
# Create Version
model_upload(self.handle, temp_dir, LICENSE_NAME)

def test_model_upload_directory_structure(self) -> None:
nested_dir = Path(self.temp_dir) / "nested"
nested_dir.mkdir()

with open(Path(self.temp_dir) / "file1.txt", "w") as f:
f.write("dummy content in nested file")

# Create dummy files in the nested directory
nested_dummy_files = ["nested_model.h5", "nested_config.json", "nested_metadata.json"]
for file in nested_dummy_files:
with open(nested_dir / file, "w") as f:
f.write("dummy content in nested file")

# Call the model upload function with the base directory
model_upload(self.handle, self.temp_dir, LICENSE_NAME)

def test_model_upload_nested_dir(self) -> None:
# Create a nested directory within self.temp_dir
nested_dir = Path(self.temp_dir) / "nested"
Expand Down
177 changes: 125 additions & 52 deletions src/kagglehub/gcs_upload.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import logging
import os
import shutil
import time
import zipfile
from datetime import datetime
from multiprocessing import Pool
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Tuple, Union
from typing import Dict, List, Optional, Union

import requests
from requests.exceptions import ConnectionError, Timeout
Expand All @@ -25,6 +22,25 @@
REQUEST_TIMEOUT = 600


class UploadDirectoryInfo:
def __init__(
self,
name: str,
files: Optional[List[str]] = None,
directories: Optional[List["UploadDirectoryInfo"]] = None,
):
self.name = name
self.files = files if files is not None else []
self.directories = directories if directories is not None else []

def serialize(self) -> Dict:
return {
"name": self.name,
"files": [{"token": file} for file in self.files],
"directories": [directory.serialize() for directory in self.directories],
}


def parse_datetime_string(string: str) -> Union[datetime, str]:
time_formats = ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S.%fZ"]
for t in time_formats:
Expand Down Expand Up @@ -138,51 +154,108 @@ def _upload_blob(file_path: str, model_type: str) -> str:
return response["token"]


def zip_file(args: Tuple[Path, Path, Path]) -> int:
file_path, zip_path, source_path_obj = args
arcname = file_path.relative_to(source_path_obj)
size = file_path.stat().st_size
with zipfile.ZipFile(zip_path, "a", zipfile.ZIP_STORED, allowZip64=True) as zipf:
zipf.write(file_path, arcname)
return size


def zip_files(source_path_obj: Path, zip_path: Path) -> List[int]:
files = [file for file in source_path_obj.rglob("*") if file.is_file()]
args = [(file, zip_path, source_path_obj) for file in files]

with Pool() as pool:
sizes = pool.map(zip_file, args)
return sizes


def upload_files(source_path: str, model_type: str) -> List[str]:
source_path_obj = Path(source_path)
with TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)
total_size = 0

if source_path_obj.is_dir():
for file_path in source_path_obj.rglob("*"):
if file_path.is_file():
total_size += file_path.stat().st_size
elif source_path_obj.is_file():
total_size = source_path_obj.stat().st_size
else:
path_error_message = "The source path does not point to a valid file or directory."
raise ValueError(path_error_message)

with tqdm(total=total_size, desc="Zipping", unit="B", unit_scale=True, unit_divisor=1024) as pbar:
if source_path_obj.is_dir():
zip_path = temp_dir_path / "archive.zip"
sizes = zip_files(source_path_obj, zip_path)
for size in sizes:
pbar.update(size)
upload_path = str(zip_path)
elif source_path_obj.is_file():
temp_file_path = temp_dir_path / source_path_obj.name
shutil.copy(source_path_obj, temp_file_path)
pbar.update(temp_file_path.stat().st_size)
upload_path = str(temp_file_path)

return [token for token in [_upload_blob(upload_path, model_type)] if token]
def upload_files_and_directories(
folder: str, model_type: str, quiet: bool = False # noqa: FBT002, FBT001
) -> UploadDirectoryInfo:
# Count the total number of files
file_count = 0
for _, _, files in os.walk(folder):
file_count += len(files)

if file_count > MAX_FILES_TO_UPLOAD:
if not quiet:
logger.info(f"More than {MAX_FILES_TO_UPLOAD} files detected, creating a zip archive...")

with TemporaryDirectory() as temp_dir:
zip_path = os.path.join(temp_dir, TEMP_ARCHIVE_FILE)
with zipfile.ZipFile(zip_path, "w") as zipf:
for root, _, files in os.walk(folder):
for file in files:
file_path = os.path.join(root, file)
zipf.write(file_path, os.path.relpath(file_path, folder))

tokens = [
token
for token in [_upload_file_or_folder(temp_dir, TEMP_ARCHIVE_FILE, model_type, quiet)]
if token is not None
]
return UploadDirectoryInfo(name="archive", files=tokens)

root_dict = UploadDirectoryInfo(name="root")
if os.path.isfile(folder):
# Directly upload the file if the path is a file
file_name = os.path.basename(folder)
token = _upload_file_or_folder(os.path.dirname(folder), file_name, model_type, quiet)
if token:
root_dict.files.append(token)
else:
for root, _, files in os.walk(folder):
# Path of the current folder relative to the base folder
path = os.path.relpath(root, folder)

# Navigate or create the dictionary path to the current folder
current_dict = root_dict
if path != ".":
for part in path.split(os.sep):
# Find or create the subdirectory in the current dictionary
for subdir in current_dict.directories:
if subdir.name == part:
current_dict = subdir
break
else:
# If the directory is not found, create a new one
new_dir = UploadDirectoryInfo(name=part)
current_dict.directories.append(new_dir)
current_dict = new_dir

# Add file tokens to the current directory in the dictionary
for file in files:
token = _upload_file_or_folder(root, file, model_type, quiet)
if token:
current_dict.files.append(token)

return root_dict


def _upload_file_or_folder(
parent_path: str,
file_or_folder_name: str,
model_type: str,
quiet: bool = False, # noqa: FBT002, FBT001
) -> Optional[str]:
"""
Uploads a file or each file inside a folder individually from a specified path to a remote service.
Parameters
==========
parent_path: The parent directory path from where the file or folder is to be uploaded.
file_or_folder_name: The name of the file or folder to be uploaded.
dir_mode: The mode to handle directories. Accepts 'zip', 'tar', or other values for skipping.
model_type: Type of the model that is being uploaded.
quiet: suppress verbose output (default is False)
:return: A token if the upload is successful, or None if the file is skipped or the upload fails.
"""
full_path = os.path.join(parent_path, file_or_folder_name)
if os.path.isfile(full_path):
return _upload_file(file_or_folder_name, full_path, quiet, model_type)
return None


def _upload_file(file_name: str, full_path: str, quiet: bool, model_type: str) -> Optional[str]: # noqa: FBT001
"""Helper function to upload a single file
Parameters
==========
file_name: name of the file to upload
full_path: path to the file to upload
quiet: suppress verbose output
model_type: Type of the model that is being uploaded.
:return: None - upload unsuccessful; instance of UploadFile - upload successful
"""

if not quiet:
logger.info("Starting upload for file " + file_name)

content_length = os.path.getsize(full_path)
token = _upload_blob(full_path, model_type)
if not quiet:
logger.info("Upload successful: " + file_name + " (" + File.get_size(content_length) + ")")
return token
4 changes: 2 additions & 2 deletions src/kagglehub/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

from kagglehub import registry
from kagglehub.gcs_upload import upload_files
from kagglehub.gcs_upload import upload_files_and_directories
from kagglehub.handle import parse_model_handle
from kagglehub.models_helpers import create_model_if_missing, create_model_instance_or_version

Expand Down Expand Up @@ -47,7 +47,7 @@ def model_upload(
create_model_if_missing(h.owner, h.model)

# Upload the model files to GCS
tokens = upload_files(local_model_dir, "model")
tokens = upload_files_and_directories(local_model_dir, "model")

# Create a model instance if it doesn't exist, and create a new instance version if an instance exists
create_model_instance_or_version(h, tokens, license_name, version_notes)
28 changes: 20 additions & 8 deletions src/kagglehub/models_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from http import HTTPStatus
from typing import List, Optional
from typing import Optional

from kagglehub.clients import KaggleApiV1Client
from kagglehub.exceptions import BackendError, KaggleApiHTTPError
from kagglehub.clients import BackendError, KaggleApiV1Client
from kagglehub.exceptions import KaggleApiHTTPError
from kagglehub.gcs_upload import UploadDirectoryInfo
from kagglehub.handle import ModelHandle

logger = logging.getLogger(__name__)
Expand All @@ -16,11 +17,15 @@ def _create_model(owner_slug: str, model_slug: str) -> None:
logger.info(f"Model '{model_slug}' Created.")


def _create_model_instance(model_handle: ModelHandle, files: List[str], license_name: Optional[str] = None) -> None:
def _create_model_instance(
model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, license_name: Optional[str] = None
) -> None:
serialized_data = files_and_directories.serialize()
data = {
"instanceSlug": model_handle.variation,
"framework": model_handle.framework,
"files": [{"token": file_token} for file_token in files],
"files": [{"token": file_token} for file_token in files_and_directories.files],
"directories": serialized_data["directories"],
}
if license_name is not None:
data["licenseName"] = license_name
Expand All @@ -30,8 +35,15 @@ def _create_model_instance(model_handle: ModelHandle, files: List[str], license_
logger.info(f"Your model instance has been created.\nFiles are being processed...\nSee at: {model_handle.to_url()}")


def _create_model_instance_version(model_handle: ModelHandle, files: List[str], version_notes: str = "") -> None:
data = {"versionNotes": version_notes, "files": [{"token": file_token} for file_token in files]}
def _create_model_instance_version(
model_handle: ModelHandle, files_and_directories: UploadDirectoryInfo, version_notes: str = ""
) -> None:
serialized_data = files_and_directories.serialize()
data = {
"versionNotes": version_notes,
"files": [{"token": file_token} for file_token in files_and_directories.files],
"directories": serialized_data["directories"],
}
api_client = KaggleApiV1Client()
api_client.post(
f"/models/{model_handle.owner}/{model_handle.model}/{model_handle.framework}/{model_handle.variation}/create/version",
Expand All @@ -43,7 +55,7 @@ def _create_model_instance_version(model_handle: ModelHandle, files: List[str],


def create_model_instance_or_version(
model_handle: ModelHandle, files: List[str], license_name: Optional[str], version_notes: str = ""
model_handle: ModelHandle, files: UploadDirectoryInfo, license_name: Optional[str], version_notes: str = ""
) -> None:
try:
_create_model_instance(model_handle, files, license_name)
Expand Down
38 changes: 32 additions & 6 deletions tests/test_model_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_model_upload_instance_with_valid_handle(self) -> None:
test_filepath.touch() # Create a temporary file in the temporary directory
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_instance_with_nested_directories(self) -> None:
# execution path: get_model -> create_model -> get_instance -> create_version
Expand All @@ -156,7 +156,7 @@ def test_model_upload_instance_with_nested_directories(self) -> None:
test_filepath.touch()
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_version_with_valid_handle(self) -> None:
# execution path: get_model -> get_instance -> create_instance
Expand All @@ -168,7 +168,7 @@ def test_model_upload_version_with_valid_handle(self) -> None:
test_filepath.touch() # Create a temporary file in the temporary directory
model_upload("metaresearch/llama-2/pyTorch/7b", temp_dir, APACHE_LICENSE, "model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_with_too_many_files(self) -> None:
with create_test_http_server(KaggleAPIHandler):
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_model_upload_resumable(self) -> None:
# Check that GcsAPIHandler received two PUT requests
self.assertEqual(GcsAPIHandler.put_requests_count, 2)
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_with_none_license(self) -> None:
with create_test_http_server(KaggleAPIHandler):
Expand All @@ -209,7 +209,7 @@ def test_model_upload_with_none_license(self) -> None:
test_filepath.touch() # Create a temporary file in the temporary directory
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, None, "model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_without_license(self) -> None:
with create_test_http_server(KaggleAPIHandler):
Expand All @@ -219,7 +219,7 @@ def test_model_upload_without_license(self) -> None:
test_filepath.touch() # Create a temporary file in the temporary directory
model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, version_notes="model_type")
self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn(TEMP_ARCHIVE_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)
self.assertIn(TEMP_TEST_FILE, KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_with_invalid_license_fails(self) -> None:
with create_test_http_server(KaggleAPIHandler):
Expand All @@ -244,3 +244,29 @@ def test_single_file_upload(self) -> None:

self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 1)
self.assertIn("single_dummy_file.txt", KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES)

def test_model_upload_with_directory_structure(self) -> None:
with create_test_http_server(KaggleAPIHandler):
with create_test_http_server(GcsAPIHandler, "http://localhost:7778"):
with TemporaryDirectory() as temp_dir:
base_path = Path(temp_dir)
(base_path / "dir1").mkdir()
(base_path / "dir2").mkdir()

(base_path / "file1.txt").touch()

(base_path / "dir1" / "file2.txt").touch()
(base_path / "dir1" / "file3.txt").touch()

(base_path / "dir1" / "subdir1").mkdir()
(base_path / "dir1" / "subdir1" / "file4.txt").touch()

model_upload("metaresearch/new-model/pyTorch/new-variation", temp_dir, APACHE_LICENSE, "model_type")

self.assertEqual(len(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES), 4)
expected_files = {"file1.txt", "file2.txt", "file3.txt", "file4.txt"}
self.assertTrue(set(KaggleAPIHandler.UPLOAD_BLOB_FILE_NAMES).issubset(expected_files))

# TODO: Add assertions on CreateModelInstanceRequest.Directories and
# CreateModelInstanceRequest.Files to verify the expected structure
# is sent.