From af272d0efc31ab329d2d62dd34ef4e1bab9987f3 Mon Sep 17 00:00:00 2001 From: Yannick Stephan Date: Tue, 29 Aug 2023 08:37:48 +0200 Subject: [PATCH] [GCOLAB] Added type of download --- src/skit/Summarizable.py | 9 +++++++ src/skit/gcolab.py | 53 ++++++++++++++++++++++++++++++---------- src/skit/show.py | 3 ++- 3 files changed, 51 insertions(+), 14 deletions(-) create mode 100644 src/skit/Summarizable.py diff --git a/src/skit/Summarizable.py b/src/skit/Summarizable.py new file mode 100644 index 0000000..8b8c37c --- /dev/null +++ b/src/skit/Summarizable.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + +class Summarizable(ABC): + def summary(self): + title = f"=== {self.__class__.__name__} Configuration Summary ===" + print(title) + for attr, value in self.__dict__.items(): + print(f"{attr}: {value}") + print("=" * len(title)) diff --git a/src/skit/gcolab.py b/src/skit/gcolab.py index cd98de1..f3f24f5 100644 --- a/src/skit/gcolab.py +++ b/src/skit/gcolab.py @@ -1,5 +1,6 @@ from skit.config import IN_COLAB from skit.utils import mkdir +from enum import Enum if IN_COLAB: """ @@ -18,6 +19,18 @@ except ImportError: print(f"Missing some imports: {ImportError}") + class DatasetType(Enum): + DATASETS = "datasets" + COMPETITIONS = "competitions" + + def get_flag(self): + if self == DatasetType.DATASETS: + return "-d" + elif self == DatasetType.COMPETITIONS: + return "-c" + else: + return None + def install_kaggle(): """ Installs the Kaggle CLI tool using pip. @@ -31,23 +44,27 @@ def install_kaggle(): if result.returncode != 0: raise Exception("Error on install Kaggle.") - def set_environ_kaggle_config( - mountpoint_gdrive_path, - kaggle_config_dir - ): + def gdrive_mount(mountpoint_gdrive_path): """ - Mounts the Google Drive to Colab and sets the Kaggle configuration directory. + Mounts the Google Drive to Colab Parameters: ----------- mountpoint_gdrive_path : str Path to mount the Google Drive. + """ + drive.mount(mountpoint_gdrive_path, force_remount=True) + + def set_environ_kaggle_config(kaggle_config_dir): + """ + Sets the Kaggle configuration directory. + Parameters: + ----------- kaggle_config_dir : str The Kaggle configuration directory located in the Google Drive. """ - drive.mount(f'{mountpoint_gdrive_path}/gdrive', force_remount=True) - os.environ['KAGGLE_CONFIG_DIR'] = f"{mountpoint_gdrive_path}/gdrive/My Drive/{kaggle_config_dir}" + os.environ['KAGGLE_CONFIG_DIR'] = kaggle_config_dir def is_kaggle_cli_installed(): """ @@ -64,7 +81,7 @@ def is_kaggle_cli_installed(): except subprocess.CalledProcessError: return False - def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir): + def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir, type): """ Downloads and unzips a Kaggle dataset. @@ -84,11 +101,16 @@ def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir): if not is_kaggle_cli_installed(): raise Exception("Kaggle CLI is not installed. Please install it using `pip install kaggle`.") + # Check if the dataset already exists + if os.path.exists(dataset_destination_dir): + print(f"Dataset already exists in {dataset_destination_dir}. Skipping download.") + return + mkdir(dataset_destination_dir) os.chdir(dataset_destination_dir) try: - subprocess.run(['kaggle', 'datasets', 'download', '-d', kaggle_dataset_url], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + subprocess.run(['kaggle', type.value, 'download', type.get_flag(), kaggle_dataset_url], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) zip_files = glob.glob("*.zip") # Unzip each ZIP file one by one @@ -101,9 +123,10 @@ def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir): def setup_kaggle_dataset( kaggle_dataset_url, + type = DatasetType.DATASETS, dataset_destination_path = '/content', - mountpoint_gdrive_path = '/content', - kaggle_config_dir = 'Kaggle' + mountpoint_gdrive_path = '/content/gdrive', + kaggle_config_dir = '/content/gdrive/My Drive/Kaggle' ): """ Sets up a Kaggle dataset in Google Colab by installing required tools, @@ -114,6 +137,9 @@ def setup_kaggle_dataset( kaggle_dataset_url : str The Kaggle dataset URL. + type : str + Type of Datasets + dataset_destination_path : str, optional The directory where the dataset will be saved and unzipped. Default is '/content'. @@ -128,8 +154,9 @@ def setup_kaggle_dataset( """ try: install_kaggle() - set_environ_kaggle_config(mountpoint_gdrive_path, kaggle_config_dir) - download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_path) + gdrive_mount(mountpoint_gdrive_path) + set_environ_kaggle_config(kaggle_config_dir) + download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_path, type) print("Dataset downloaded and unzipped successfully!") except Exception as e: diff --git a/src/skit/show.py b/src/skit/show.py index 20534a8..2f3de30 100644 --- a/src/skit/show.py +++ b/src/skit/show.py @@ -255,9 +255,10 @@ def show_images( if draw_labels and not draw_predicted_labels: axs.set_xlabel(y[i], fontsize=font_size) + if draw_labels and draw_predicted_labels: if y[i] != y_pred[i]: - axs.set_xlabel(f'{y_pred[i]} ({y[i]})', fontsize=font_size) + axs.set_xlabel(f'{y_pred[i]} (✓: {y[i]})', fontsize=font_size) axs.xaxis.label.set_color('red') else: axs.set_xlabel(y[i], fontsize=font_size)