In [1]:
from huggingface_hub import snapshot_download
import boto3

from dataclasses import dataclass
from typing import Optional

import os
import time
from dotenv import load_dotenv

In [2]:
load_dotenv()

AWS_ACCESS_KEY_ID = os.getenv('aws_access_key_id')
AWS_SECRET_ACCESS_KEY = os.getenv('aws_secret_access_key')
ENDPOINT_URL = os.getenv('endpoint_url')

In [3]:
@dataclass
class Config:
    #model_name: str = "facebook/opt-30b"
    model_name: str = "gpt2"
    # Path to HuggingFace cache directory. Default is ~/.cache/huggingface/.
    cache_dir: Optional[str] =  '/data/modelcache/test/hub' #None
    # Path to the directory that actually holds model files.
    # e.g., ~/.cache/huggingface/models--facebook--opt-30b/snapshots/xxx/
    # If this path is not None, we skip download models from HuggingFace.
    repo_root: Optional[str] = None
    # This is how many DeepSpeed-inference replicas to run for
    # this batch inference job.
    num_worker_groups: int = 1
    # Number of DeepSpeed workers per group.
    num_workers_per_group: int = 8

    batch_size: int = 1
    dtype: str = "float16"
    # Maximum number of tokens DeepSpeed inference-engine can work with,
    # including the input and output tokens.
    max_tokens: int = 1024
    # Use meta tensors to initialize model.
    use_meta_tensor: bool = True
    # Use cache for generation.
    use_cache: bool = True
    # The path for which we want to save the loaded model with a checkpoint.
    save_mp_checkpoint_path: Optional[str] = None


config = Config()

In [4]:
#load_dotenv()

#AWS_ACCESS_KEY_ID = os.getenv('aws_access_key_id')
#AWS_SECRET_ACCESS_KEY = os.getenv('aws_secret_access_key')

#ENDPOINT_URL = 'http://10.0.0.179:30387'

In [5]:
model_folder = f"hub/models--{config.model_name.replace('/', '--')}"
model_folder

'hub/models--gpt2'

In [6]:

s3 = boto3.resource('s3',
    endpoint_url = ENDPOINT_URL,
    aws_access_key_id = AWS_ACCESS_KEY_ID,
    aws_secret_access_key = AWS_SECRET_ACCESS_KEY
)

In [7]:
def list_model_folder(bucket_name, s3_folder):
    # Get bucket object
    my_bucket = s3.Bucket('models')
    # Iterate over objects in bucket
    for obj in my_bucket.objects.filter(Prefix=model_folder):
        print(obj)
        
#list_model_folder('models', model_folder)        

In [8]:
import os
def download_s3_folder(bucket_name, s3_folder, local_dir=None):
    """
    Download the contents of a folder directory
    Args:
        bucket_name: the name of the s3 bucket
        s3_folder: the folder path in the s3 bucket
        local_dir: a relative or absolute directory path in the local file system
    """
    bucket = s3.Bucket(bucket_name)
    for obj in bucket.objects.filter(Prefix=s3_folder):
        target = obj.key if local_dir is None \
            else os.path.join(local_dir, os.path.relpath(obj.key, s3_folder))
        if not os.path.exists(os.path.dirname(target)):
            os.makedirs(os.path.dirname(target))
        if obj.key[-1] == '/':
            continue
        if not os.path.exists(target):
            bucket.download_file(obj.key, target)
            print('downloading: '+target)
        else:
            print('exists: '+target)

In [9]:
start = time.time()
download_s3_folder('models', model_folder, config.cache_dir)
end = time.time()
print(f'model downloaded in: {end-start}')

downloading: /data/modelcache/test/hub/blobs/10c66461e4c109db5a2196bff4bb59be30396ed8
downloading: /data/modelcache/test/hub/blobs/1ceafd82e733dd4b21570b2a86cf27556a983041806c033a55d086e0ed782cd3
downloading: /data/modelcache/test/hub/blobs/1f1d9aaca301414e7f6c9396df506798ff4eb9a6
downloading: /data/modelcache/test/hub/blobs/226b0752cac7789c48f0cb3ec53eda48b7be36cc
downloading: /data/modelcache/test/hub/blobs/3dc481ecc3b2c47a06ab4e20dba9d7f4b447bdf3
downloading: /data/modelcache/test/hub/blobs/4b988bccc9dc5adacd403c00b4704976196548f8
downloading: /data/modelcache/test/hub/blobs/602b71f15d40ed68c5f96330e3f3175a76a32126
downloading: /data/modelcache/test/hub/blobs/7c5d3f4b8b76583b422fcb9189ad6c89d5d97a094541ce8932dce3ecabde1421
downloading: /data/modelcache/test/hub/blobs/a16a55fda99d2f2e7b69cce5cf93ff4ad3049930
downloading: /data/modelcache/test/hub/blobs/adf0adedbf4016b249550f866c66a3b3a3d09c8b3b3a1f6e5e9a265d94e0270e
downloading: /data/modelcache/test/hub/blobs/c966da3b74697803352ca7c