In [None]:
"""
@title: download the HCP resting-state dataset
@version:0.1
@author: Hui Zheng zh.dmtr@gmail.com
@time: 2022-8-8 10:58:00
"""

In [None]:
# import module
import boto3
import json
from loguru import logger
from pathlib import Path
import time
import hashlib

In [None]:
s3_bucket_name = 'hcp-openaccess'
s3_prefix = 'HCP_1200'
access_key = ''
secret_key = ''
resource = boto3.resource('s3',aws_access_key_id=access_key,aws_secret_access_key=secret_key)
bucket = resource.Bucket('hcp-openaccess')

In [None]:
# the fold your want to download
SERIES_MAP = {
#'MEG_unprocessed':'unprocessed/MEG/',
'3T_unprocessed_rfMRI1_LR':'unprocessed/3T/rfMRI_REST1_LR/',
'3T_unprocessed_rfMRI1_RL':'unprocessed/3T/rfMRI_REST1_RL/',
'3T_unprocessed_rfMRI2_LR':'unprocessed/3T/rfMRI_REST2_LR/',
'3T_unprocessed_rfMRI2_RL':'unprocessed/3T/rfMRI_REST2_RL/',
'3T_unprocessed_T1':'unprocessed/3T/T1w_MPR1/',
#'7T_unprocessed':'7T',
#'Diffusion':'Diffusion',
#'T1w':'T1w',
#'MNINonLinear':'MNINonLinear',
'release-notes':'release-notes',
#'MEG':'MEG'
#'.xdlm':'.xdlm',
}

In [None]:
project_path = Path.cwd()
log_path = Path(project_path, "log")
out_dir = './data/HCP/'
t = time.strftime("%Y_%m_%d")

In [None]:
def get_file_md5(file):
    if not os.path.isfile(file):
        # return
        raise Exception("The file:%s is not exist! Can't get md5 code!"%file)
    m = hashlib.md5()
    with open(file, mode='rb') as f:
        while True:
            data = f.read(10240)
            if not data:
                break
            m.update(data)
    return m.hexdigest()

In [None]:
def check_integrity(file,md5_old):
    md5 = get_file_md5(file)
    if md5 == md5_old:
        return True
    return False

In [None]:
def divide_subject(subject_path, divide_num):
    subject_list = []
    fr = open(subject_path, 'r', encoding='utf-8')
    line = fr.readline()
    while line:
        subject_list.append(line.strip())
        line = fr.readline()

    task_subject = []
    size = int(len(subject_list) /divide_num)
    flag = False
    if len(subject_list) % divide_num == 0:
        for i in range(divide_num):
            task_subject.append(subject_list[size * i:size * (i + 1)])
    else:
        for i in range(divide_num + 1):
            task_subject.append(subject_list[size * i:size * (i + 1)])
        flag = True
    if flag:
        task_subject[-2] += task_subject[-1]
        task_subject.pop(-1)

    return task_subject ## ruturn 

In [None]:
def collect_and_download(out_dir,subjects):
    # resource = boto3.resource('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key)
    # bucket = resource.Bucket(s3_bucket_name)
    # print('connect to client successfully!')
    # logger.info("connect to client successfully!")

    for subject in subjects:
        logger.add(f'{log_path}/{subject}_info_{t}.log', rotation="500MB", encoding="utf-8", enqueue=True, compression="zip",
                   retention="10 days", level="INFO")
        logger.add(f'{log_path}/{subject}_error_{t}.log', rotation="500MB", encoding="utf-8", enqueue=True, compression="zip",
                   retention="10 days",level="ERROR")

        time_start = time.time()

        # read md5 file
        fp = open('./data/md5/'+ str(subject) + '_md5.json', 'r', encoding='utf-8')
        subject_md5 = json.loads(fp.readline())['Include']
        fp.close()

        s3_keys = bucket.objects.filter(Prefix='HCP_1200/%s/'%subject)
        s3_keylist = [key.key for key in s3_keys]

        prefixes = ["HCP_1200/%s/%s"%(subject,x) for x in SERIES_MAP.values()]
        prefixes = tuple(prefixes)
        s3_keylist = [x for x in s3_keylist if x.startswith(prefixes)]

        # remove png and html
        # s3_keylist = [x for x in s3_keylist if not x.endswith(('png','html'))]

        # If output path doesn't exist, create it
        if not os.path.exists(out_dir):
            print('Could not find %s, creating now...' % out_dir)
            logger.warning(f'Could not find {out_dir}, creating now...')
            os.makedirs(out_dir)

        total_num_files = len(s3_keylist)
        files_downloaded = len(s3_keylist)

        count = 0

        for path_idx, s3_path in enumerate(s3_keylist):
            count += 1
            rel_path = s3_path.replace(s3_prefix, '')
            rel_path = rel_path.lstrip('/')

            download_file = os.path.join(out_dir, rel_path)
            download_dir = os.path.dirname(download_file)
            # If downloaded file's directory doesn't exist, create it
            if not os.path.exists(download_dir):
                os.makedirs(download_dir)
            try:
                if not os.path.exists(download_file) or os.path.getsize(download_file) == 0:
                    # while file is empty
                    if os.path.exists(download_file):
                        if os.path.getsize(download_file) == 0:
                            print("%s is empty, download again!" % (s3_path))
                            logger.error(f"{s3_path} is empty, download again!")
                    print('Downloading to: %s' % download_file)
                    # with open(download_file, 'wb') as f:
                    #     # download_file:  The path to the file to download to.
                    #     # s3_path: The name of the key to download from.
                    #     bucket.download_file(s3_path,download_file)
                    bucket.download_file(s3_path,download_file)

                    #md5 integrity verify, Waiting for optimization
                    md5_old = None
                    for item in subject_md5:
                       if rel_path == item['URI']:
                           md5_old = item['Checksum']
                    if md5_old:
                       verify = check_integrity(download_file,md5_old)
                       if not verify:
                           print("Download fail about file: %s"%(s3_path))
                           logger.error(f"Download fail about file: {s3_path}")
                           # download fail, write to a file
                           fw2 = open('./data/fail/'+str(subject)+'_fail.txt','a',encoding='utf-8')
                           fw2.write(s3_path + '\n')
                           fw2.close()
                    else:
                       print("There are not have md5 code about the file: %s" % (s3_path))
                       logger.error(f"There are have not md5 code about the file: {s3_path}")

                    print("FACTS: path: %s, file: %s"%(s3_path, download_file))
                    print('%.3f%% percent complete' % \
                          (100*(float(path_idx+1)/total_num_files)))
                    complete_percent = (100*(float(path_idx+1)/total_num_files))
                    if count%500 == 0:
                        logger.info(f"{complete_percent} percent complete")
                else:
                    print('File %s already exists, skipping...' % download_file)
                    files_downloaded -= 1
            except Exception as exc:
                print('There was a problem downloading %s.\n'\
                      'Check and try again.' % s3_path)
                logger.error(f"There was a problem downloading {s3_path}. Check and try again.")
                print(exc)
                logger.error(exc)

        print('%d files downloaded for subject %s.' % (files_downloaded,subject))
        logger.info(f"{files_downloaded} files downloaded for subject {subject}.")

        print('Done!')
        logger.info("DOne!")

        time_cost = (time.time()-time_start)/3600
        print("Time cost of downloading {} is :{} h".format(subject,time_cost))
        logger.info(f"Time cost of downloading {subject} is :{time_cost} h")


In [None]:
sl=divide_subject("./utils/subjects_want.txt", 4)
sl[2]

In [None]:
from threading import Thread
from time import sleep, ctime
t1=Thread(target=collect_and_download,args=(out_dir,sl[0]))
t2=Thread(target=collect_and_download,args=(out_dir,sl[1]))
t3=Thread(target=collect_and_download,args=(out_dir,sl[2]))
t4=Thread(target=collect_and_download,args=(out_dir,sl[3]))

t1.start()
t2.start()
t3.start()
t4.start()
t1.join()
t2.join()
t3.join()
t4.join()