Skip to content

Commit

Permalink
[GCOLAB] Added type of download
Browse files Browse the repository at this point in the history
  • Loading branch information
YanSte committed Aug 29, 2023
1 parent e079f5c commit af272d0
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
9 changes: 9 additions & 0 deletions src/skit/Summarizable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod

class Summarizable(ABC):
def summary(self):
title = f"=== {self.__class__.__name__} Configuration Summary ==="
print(title)
for attr, value in self.__dict__.items():
print(f"{attr}: {value}")
print("=" * len(title))
53 changes: 40 additions & 13 deletions src/skit/gcolab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from skit.config import IN_COLAB
from skit.utils import mkdir
from enum import Enum

if IN_COLAB:
"""
Expand All @@ -18,6 +19,18 @@
except ImportError:
print(f"Missing some imports: {ImportError}")

class DatasetType(Enum):
DATASETS = "datasets"
COMPETITIONS = "competitions"

def get_flag(self):
if self == DatasetType.DATASETS:
return "-d"
elif self == DatasetType.COMPETITIONS:
return "-c"
else:
return None

def install_kaggle():
"""
Installs the Kaggle CLI tool using pip.
Expand All @@ -31,23 +44,27 @@ def install_kaggle():
if result.returncode != 0:
raise Exception("Error on install Kaggle.")

def set_environ_kaggle_config(
mountpoint_gdrive_path,
kaggle_config_dir
):
def gdrive_mount(mountpoint_gdrive_path):
"""
Mounts the Google Drive to Colab and sets the Kaggle configuration directory.
Mounts the Google Drive to Colab
Parameters:
-----------
mountpoint_gdrive_path : str
Path to mount the Google Drive.
"""
drive.mount(mountpoint_gdrive_path, force_remount=True)

def set_environ_kaggle_config(kaggle_config_dir):
"""
Sets the Kaggle configuration directory.
Parameters:
-----------
kaggle_config_dir : str
The Kaggle configuration directory located in the Google Drive.
"""
drive.mount(f'{mountpoint_gdrive_path}/gdrive', force_remount=True)
os.environ['KAGGLE_CONFIG_DIR'] = f"{mountpoint_gdrive_path}/gdrive/My Drive/{kaggle_config_dir}"
os.environ['KAGGLE_CONFIG_DIR'] = kaggle_config_dir

def is_kaggle_cli_installed():
"""
Expand All @@ -64,7 +81,7 @@ def is_kaggle_cli_installed():
except subprocess.CalledProcessError:
return False

def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir):
def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir, type):
"""
Downloads and unzips a Kaggle dataset.
Expand All @@ -84,11 +101,16 @@ def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir):
if not is_kaggle_cli_installed():
raise Exception("Kaggle CLI is not installed. Please install it using `pip install kaggle`.")

# Check if the dataset already exists
if os.path.exists(dataset_destination_dir):
print(f"Dataset already exists in {dataset_destination_dir}. Skipping download.")
return

mkdir(dataset_destination_dir)
os.chdir(dataset_destination_dir)

try:
subprocess.run(['kaggle', 'datasets', 'download', '-d', kaggle_dataset_url], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
subprocess.run(['kaggle', type.value, 'download', type.get_flag(), kaggle_dataset_url], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
zip_files = glob.glob("*.zip")

# Unzip each ZIP file one by one
Expand All @@ -101,9 +123,10 @@ def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir):

def setup_kaggle_dataset(
kaggle_dataset_url,
type = DatasetType.DATASETS,
dataset_destination_path = '/content',
mountpoint_gdrive_path = '/content',
kaggle_config_dir = 'Kaggle'
mountpoint_gdrive_path = '/content/gdrive',
kaggle_config_dir = '/content/gdrive/My Drive/Kaggle'
):
"""
Sets up a Kaggle dataset in Google Colab by installing required tools,
Expand All @@ -114,6 +137,9 @@ def setup_kaggle_dataset(
kaggle_dataset_url : str
The Kaggle dataset URL.
type : str
Type of Datasets
dataset_destination_path : str, optional
The directory where the dataset will be saved and unzipped.
Default is '/content'.
Expand All @@ -128,8 +154,9 @@ def setup_kaggle_dataset(
"""
try:
install_kaggle()
set_environ_kaggle_config(mountpoint_gdrive_path, kaggle_config_dir)
download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_path)
gdrive_mount(mountpoint_gdrive_path)
set_environ_kaggle_config(kaggle_config_dir)
download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_path, type)
print("Dataset downloaded and unzipped successfully!")

except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion src/skit/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,10 @@ def show_images(

if draw_labels and not draw_predicted_labels:
axs.set_xlabel(y[i], fontsize=font_size)

if draw_labels and draw_predicted_labels:
if y[i] != y_pred[i]:
axs.set_xlabel(f'{y_pred[i]} ({y[i]})', fontsize=font_size)
axs.set_xlabel(f'{y_pred[i]} (✓: {y[i]})', fontsize=font_size)
axs.xaxis.label.set_color('red')
else:
axs.set_xlabel(y[i], fontsize=font_size)
Expand Down

0 comments on commit af272d0

Please sign in to comment.