# <font color = 'orange'> Pull artifacts from s3 bucket

In [1]:
import os

os.chdir('../')

In [10]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class PullArtifactsConfig:
    trained_model_path: Path
    access_key_id: str
    secret_access_key: str
    region: str
    bucket_name: str
    object_key_name: str

In [11]:
from lung_cancer_classifier.constants import *
from lung_cancer_classifier.utils.common import read_yaml, create_directories
from pathlib import Path


class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        secrets_file_path = SECRETS_FILE_PATH  
    ):
        self.config = read_yaml(config_filepath)
        self.secrets = read_yaml(secrets_file_path)

    def get_pull_artifacts_config(self) -> PullArtifactsConfig:
        aws_secrets = self.secrets.aws
        training = self.config.training
        
        create_directories([Path(training.root_dir)])
        
        pull_artifacts_config = PullArtifactsConfig(
            trained_model_path=Path(training.trained_model_path),
            access_key_id=aws_secrets.ACCESS_KEY_ID,
            secret_access_key=aws_secrets.SECRET_ACCESS_KEY,
            region=aws_secrets.REGION,            
            bucket_name=aws_secrets.BUCKET_NAME,
            object_key_name=aws_secrets.OBJECT_KEY_NAME
        )
        
        return pull_artifacts_config

In [12]:
import boto3


class PullArtifacts:
    def __init__(self, config: PullArtifactsConfig):
        self.config = config
        
    def download_model(self):

        s3_resource = boto3.resource(
            service_name='s3',
            region_name=self.config.region,
            aws_access_key_id=self.config.access_key_id,
            aws_secret_access_key=self.config.secret_access_key
        )
        
        
        s3_resource.Bucket(self.config.bucket_name).download_file(Key=self.config.object_key_name, Filename=self.config.trained_model_path)

In [None]:
try:
    config = ConfigurationManager()
    pull_artifacts_config = config.get_pull_artifacts_config()
    pull_artifacts = PullArtifacts(config=pull_artifacts_config)
    # pull_artifacts.download_model()
except Exception as e:
    raise e