# Imports


In [None]:
import os
import sys
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset, DataLoader
from torchvision import transforms
from torchvision.datasets import DatasetFolder
from torchinfo import summary

# Download and Extract The Dataset

In [None]:
# Download the zip file from the provided URL and extract it to the specified directory
def download_and_extract_dataset(url: str, compressed_file_name: str, extract_to: str, root_data_folder: str, overwrite=False, extract_only=True):
    """Downloads a zip file from the given URL and extracts it to the specified directory.

    Args:
        url (str): The URL of the zip file to download.
        compressed_file_name (str): The name of the zip file to be saved locally.
        extract_to (str): The directory where the contents should be extracted.
        root_data_folder (str): The root folder inside the zip file to extract.
        overwrite (bool): If True, will overwrite the existing files in the directory.
        extract_only (bool): If True, will only extract the contents if the directory is empty
        or does not exist.
    """
    import requests
    from zipfile import ZipFile
    from io import BytesIO

    # Check that there are a total of 16766 files in 7 folders (classes)
    def verify_dataset_files(directory):
        exp_file_count = 16766
        exp_folder_count = 7
        total_files = 0
        total_dirs = 0
        for _, dirs, files in os.walk(directory):
            total_files += len(files)
            total_dirs += len(dirs)
        intact = total_files == exp_file_count and total_dirs == exp_folder_count
        print(f"Expected {exp_file_count} files and {exp_folder_count} folders.")
        print(f"Found {total_files} files and {total_dirs} folders.")
        return intact

    zip_file_name = compressed_file_name
    zip_file_path = os.path.join(os.getcwd(), zip_file_name)
    # Ensure the extract_to directory exists
    if not os.path.exists(extract_to):
        os.makedirs(extract_to)

    # If extract_only is False, we will always download the dataset
    if not extract_only:
        # Show download progress and save to zip_file_name
        with requests.get(url, stream=True) as response:
            if response.status_code == 200:
                total_length = int(response.headers.get('content-length', 0))
                chunk_size = 8192
                downloaded = 0
                with open(zip_file_name, 'wb') as f:
                    print("Downloading dataset...")
                    for chunk in response.iter_content(chunk_size=chunk_size):
                        if chunk:
                            f.write(chunk)
                            downloaded += len(chunk)
                            done = int(50 * downloaded / total_length) if total_length else 0
                            sys.stdout.write('\r[{}{}] {:.2f}%'.format(
                                '=' * done, ' ' * (50 - done),
                                100 * downloaded / total_length if total_length else 0))
                            sys.stdout.flush()
                    print()  # Newline after progress bar
                with open(zip_file_name, 'rb') as f:
                    zip_file = ZipFile(f)
            else:
                print(f"Failed to download dataset. Status code: {response.status_code}")
        
    # If overwrite is True, we extract the existing zip file
    if overwrite:
        # Check if the zip file already exists
        if os.path.exists(zip_file_path):
            try:
                print(f"Using existing zip file: {zip_file_path}")
                with ZipFile(zip_file_path, 'r') as zip_file:
                    print(f"Extracting zip file to {extract_to}...")
                    zip_file.extractall(path=extract_to)
                print("Extraction complete.")
            except Exception as e:
                print(f"Failed to extract zip file: {e}")
                print("The zip file may be corrupted. Please delete it and re-download.")
                return
        else:
            print(f"The zip file does not exist at {zip_file_path}. Set extract_only=False to download it.")
            return

    # Check if the dataset is already extracted and intact
    print(f"Checking if existing dataset is intact...")
    if verify_dataset_files(os.path.join(extract_to, root_data_folder)):
        print(f"Dataset is intact.")
    else:
        print(f"Dataset is incomplete or corrupted. Download again.")

download_and_extract_dataset(
    url="https://some_link_to_dataset.zip",
    compressed_file_name="stft_spectrograms.zip",
    extract_to="data_directory",
    root_data_folder="stft_spectrograms",
    overwrite=False,
    extract_only=True,
)