Skip to content

Commit

Permalink
[GCOLAB] Added type of download
Browse files Browse the repository at this point in the history
  • Loading branch information
YanSte committed Aug 29, 2023
1 parent e079f5c commit 89734a8
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 4 deletions.
9 changes: 9 additions & 0 deletions src/skit/Summarizable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod

class Summarizable(ABC):
def summary(self):
title = f"=== {self.__class__.__name__} Configuration Summary ==="
print(title)
for attr, value in self.__dict__.items():
print(f"{attr}: {value}")
print("=" * len(title))
15 changes: 12 additions & 3 deletions src/skit/gcolab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from skit.config import IN_COLAB
from skit.utils import mkdir
from enum import Enum

if IN_COLAB:
"""
Expand All @@ -18,6 +19,10 @@
except ImportError:
print(f"Missing some imports: {ImportError}")

class DatasetType(Enum):
DATASETS = "datasets"
COMPETITIONS = "competitions"

def install_kaggle():
"""
Installs the Kaggle CLI tool using pip.
Expand Down Expand Up @@ -64,7 +69,7 @@ def is_kaggle_cli_installed():
except subprocess.CalledProcessError:
return False

def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir):
def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir, type):
"""
Downloads and unzips a Kaggle dataset.
Expand All @@ -88,7 +93,7 @@ def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir):
os.chdir(dataset_destination_dir)

try:
subprocess.run(['kaggle', 'datasets', 'download', '-d', kaggle_dataset_url], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
subprocess.run(['kaggle', type.value, 'download', '-c', kaggle_dataset_url], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
zip_files = glob.glob("*.zip")

# Unzip each ZIP file one by one
Expand All @@ -101,6 +106,7 @@ def download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_dir):

def setup_kaggle_dataset(
kaggle_dataset_url,
type = DatasetType.DATASETS,
dataset_destination_path = '/content',
mountpoint_gdrive_path = '/content',
kaggle_config_dir = 'Kaggle'
Expand All @@ -114,6 +120,9 @@ def setup_kaggle_dataset(
kaggle_dataset_url : str
The Kaggle dataset URL.
type : str
Type of Datasets
dataset_destination_path : str, optional
The directory where the dataset will be saved and unzipped.
Default is '/content'.
Expand All @@ -129,7 +138,7 @@ def setup_kaggle_dataset(
try:
install_kaggle()
set_environ_kaggle_config(mountpoint_gdrive_path, kaggle_config_dir)
download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_path)
download_and_unzip_dataset(kaggle_dataset_url, dataset_destination_path, type)
print("Dataset downloaded and unzipped successfully!")

except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion src/skit/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,10 @@ def show_images(

if draw_labels and not draw_predicted_labels:
axs.set_xlabel(y[i], fontsize=font_size)

if draw_labels and draw_predicted_labels:
if y[i] != y_pred[i]:
axs.set_xlabel(f'{y_pred[i]} ({y[i]})', fontsize=font_size)
axs.set_xlabel(f'{y_pred[i]} (✓: {y[i]})', fontsize=font_size)
axs.xaxis.label.set_color('red')
else:
axs.set_xlabel(y[i], fontsize=font_size)
Expand Down

0 comments on commit 89734a8

Please sign in to comment.