In [None]:
import hashlib
import os
import tarfile
import zipfile
import requests
#@save
DATA_HUB = dict()
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com'

In [None]:
def download(name, cache_dir=os.path.join('..', 'data')): #@save
    """下载一个DATA_HUB中的文件，返回本地文件名

    Args:
        name (_type_): _description_
        cache_dir (_type_, optional): _description_. Defaults to os.path.join('..', 'data').
    """
    
    assert name in DATA_HUB[name], f"{name} 不存在于 {DATA_HUB}"
    url, shal_hash = DATA_HUB[name]
    os.makedirs(cache_dir, exist_ok=True)
    fname = os.path.join(cache_dir, url.split('/')[-1])
    
    if os.path.exists(fname):
        sha1 = hashlib.sha1()
        with open(fname, 'rb') as f:
            while True:
                data = f.read(1048576)
                if not data:
                    break
                sha1.update(data)
        if sha1.hexdigest() == shal_hash:
            return fname # 命中缓存
    print(f"正在从{url}下载{fname}...")
    r = requests.get(url, stream=True, verify=True)
    with open(fname, 'wb') as f:
        f.write(r.content)
    return fname

In [None]:
def download_extract(name, folder=None): #@save
    """下载并解压zip/tar文件

    Args:
        name (_type_): _description_
        folder (_type_, optional): _description_. Defaults to None.
    """
    fname = download(name)
    base_dir = os.path.dirname(fname)
    data_dir, ext = os.path.splitext(fname)
    if ext == '.zip':
        fp = zipfile.ZipFile(fname, 'r')
    elif ext in ('.tar', '.gz'):
        fp = tarfile.open(fname, 'r')
    else:
        assert False, 'Only zip/tar can be unziped'
    fp.extractall(base_dir)
    return os.path.join(base_dir, folder) if folder else data_dir

In [None]:
def download_all(): #@save
    """下载DATA_HUB中的所有文件
    """
    for name in DATA_HUB:
        download(name)