From 1c55e9fb4c6c687d684448e5861c7c04aaf5223d Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Tue, 29 Aug 2023 08:37:48 +0200 Subject: [PATCH] [GCOLAB] Added type of download --- src/skit/Summarizable.py | 9 ++++++ src/skit/gcolab.py | 66 +++++++++++++++++++++++++++++++--------- src/skit/show.py | 3 +- 3 files changed, 62 insertions(+), 16 deletions(-) create mode 100644 src/skit/Summarizable.py diff --git a/src/skit/Summarizable.py b/src/skit/Summarizable.py new file mode 100644 index 0000000..8b8c37c --- /dev/null +++ b/src/skit/Summarizable.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + +class Summarizable(ABC): + def summary(self): + title = f"=== {self.__class__.__name__} Configuration Summary ===" + print(title) + for attr, value in self.__dict__.items(): + print(f"{attr}: {value}") + print("=" * len(title)) diff --git a/src/skit/gcolab.py b/src/skit/gcolab.py index cd98de1..87d1003 100644 --- a/src/skit/gcolab.py +++ b/src/skit/gcolab.py @@ -1,5 +1,7 @@ from skit.config import IN_COLAB from skit.utils import mkdir +from enum import Enum +from zipfile import ZipFile if IN_COLAB: """ @@ -18,6 +20,18 @@ except ImportError: print(f"Missing some imports: {ImportError}") + class DatasetType(Enum): + DATASETS = "datasets" + COMPETITIONS = "competitions" + + def get_flag(self): + if self == DatasetType.DATASETS: + return "-d" + elif self == DatasetType.COMPETITIONS: + return "-c" + else: + return None + def install_kaggle(): """ Installs the Kaggle CLI tool using pip. @@ -31,23 +45,27 @@ def install_kaggle(): if result.returncode != 0: raise Exception("Error on install Kaggle.") - def set_environ_kaggle_config( - mountpoint_gdrive_path, - kaggle_config_dir - ): + def gdrive_mount(mountpoint_gdrive_path): """ - Mounts the Google Drive to Colab and sets the Kaggle configuration directory. + Mounts the Google Drive to Colab Parameters: ----------- mountpoint_gdrive_path : str Path to mount the Google Drive. + """ + drive.mount(mountpoint_gdrive_path, force_remount=True) + def set_environ_kaggle_config(kaggle_config_dir): + """ + Sets the Kaggle configuration directory. + + Parameters: + ----------- kaggle_config_dir : str The Kaggle configuration directory located in the Google Drive. """ - drive.mount(f'{mountpoint_gdrive_path}/gdrive', force_remount=True) - os.environ['KAGGLE_CONFIG_DIR'] = f"{mountpoint_gdrive_path}/gdrive/My Drive/{kaggle_config_dir}" + os.environ['KAGGLE_CONFIG_DIR'] = kaggle_config_dir def is_kaggle_cli_installed(): """ @@ -64,7 +82,21 @@ def is_kaggle_cli_installed(): except subprocess.CalledProcessError: return False - def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir): + def unzip_and_delete_from_zip(zip_filepath, extract_to): + with ZipFile(zip_filepath, 'r') as zip_ref: + all_files = zip_ref.namelist() + + for file_name in all_files: + zip_ref.extract(file_name, extract_to) + + with ZipFile(zip_filepath, 'a') as zip_write: + zip_write._delete(file_name) + + if not zip_ref.namelist(): + os.remove(zip_filepath) + break + + def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir, type): """ Downloads and unzips a Kaggle dataset. @@ -88,22 +120,22 @@ def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir): os.chdir(dataset_destination_dir) try: - subprocess.run(['kaggle', 'datasets', 'download', '-d', kaggle_dataset_url], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + subprocess.run(['kaggle', type.value, 'download', type.get_flag(), kaggle_dataset_url], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) zip_files = glob.glob("*.zip") # Unzip each ZIP file one by one for zip_file in zip_files: - subprocess.run(['unzip', zip_file]) - os.remove(zip_file) + unzip_and_delete_from_zip(zip_file, dataset_destination_dir) except subprocess.CalledProcessError as e: raise Exception(f"An error occurred while downloading the dataset: {e}") def setup_kaggle_dataset( kaggle_dataset_url, + type = DatasetType.DATASETS, dataset_destination_path = '/content', - mountpoint_gdrive_path = '/content', - kaggle_config_dir = 'Kaggle' + mountpoint_gdrive_path = '/content/gdrive', + kaggle_config_dir = '/content/gdrive/My Drive/Kaggle' ): """ Sets up a Kaggle dataset in Google Colab by installing required tools, @@ -114,6 +146,9 @@ def setup_kaggle_dataset( kaggle_dataset_url : str The Kaggle dataset URL. + type : str + Type of Datasets + dataset_destination_path : str, optional The directory where the dataset will be saved and unzipped. Default is '/content'. @@ -128,8 +163,9 @@ def setup_kaggle_dataset( """ try: install_kaggle() - set_environ_kaggle_config(mountpoint_gdrive_path, kaggle_config_dir) - download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_path) + gdrive_mount(mountpoint_gdrive_path) + set_environ_kaggle_config(kaggle_config_dir) + download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_path, type) print("Dataset downloaded and unzipped successfully!") except Exception as e: diff --git a/src/skit/show.py b/src/skit/show.py index 20534a8..2f3de30 100644 --- a/src/skit/show.py +++ b/src/skit/show.py @@ -255,9 +255,10 @@ def show_images( if draw_labels and not draw_predicted_labels: axs.set_xlabel(y[i], fontsize=font_size) + if draw_labels and draw_predicted_labels: if y[i] != y_pred[i]: - axs.set_xlabel(f'{y_pred[i]} ({y[i]})', fontsize=font_size) + axs.set_xlabel(f'{y_pred[i]} (✓: {y[i]})', fontsize=font_size) axs.xaxis.label.set_color('red') else: axs.set_xlabel(y[i], fontsize=font_size)