In [2]:
import os
%pwd

'/media/kirti/Dev/DeepLearning/Project/E2E/ChestCancerDetection/research'

In [3]:
os.chdir('../')
%pwd    

'/media/kirti/Dev/DeepLearning/Project/E2E/ChestCancerDetection'

In [5]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class ModelInferenceConfig:
    root_dir: Path
    source_url: str
    local_model_file: Path
    model_dir: Path

In [8]:
from cnnClassifier.constants import *
from cnnClassifier.utils.common import read_yaml, create_directories

class ConfigManager:
    def __init__(self, config_file_path=CONFIG_FILE_PATH, params_file_path=PARAMS_FILE_PATH):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)
        print(self.config.artifacts_root)
        create_directories([self.config.artifacts_root]) 
    
    def get_model_inference_config(self) -> ModelInferenceConfig:
        config = self.config.model_inference
        create_directories([config.root_dir])
        return ModelInferenceConfig(
            root_dir=config.root_dir,
            source_url=config.source_url,
            local_model_file=config.local_model_file,
            model_dir=config.model_dir
        )

In [9]:
import os
import zipfile
import gdown
from cnnClassifier import logger
from cnnClassifier.utils.common import get_file_size

In [10]:
class ModelInference:
    def __init__(self,config: ModelInferenceConfig) -> None:
        self.config = config

    def download_model(self):
        
        '''
        Downloads model from the source URL to the local data file path.
        If the file already exists, it skips the download.
        '''

        try:
            model_url = self.config.source_url
            model_path = self.config.local_model_file
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            logger.info(f"Downloading model from {model_url}")

            file_id = model_url.split('/')[-2]
            prefix = f"https://drive.google.com/uc?export=download&id={file_id}"
            if os.path.exists(model_path):
                file_size = get_file_size(Path(model_path))
                if file_size > 0:
                    logger.info(f"File already exists at {model_path} with size {file_size} bytes. Skipping download.")
                    return
            
            gdown.download(prefix, model_path, quiet=False)

            logger.info(f"Downloaded data to {model_path}")

        except Exception as e:
            logger.info(f"Error downloading data: {e}")
            raise e
    

In [14]:
try:
    config_manager = ConfigManager()
    data_ingestion_config = config_manager.get_model_inference_config()
    data_ingestion = ModelInference(config=data_ingestion_config)

    data_ingestion.download_model()
    logger.info("Model downloaded successfully.")
except Exception as e:
    logger.error(f"Error while downloading model : {e}")
    raise e

artifacts
[2025-07-28 22:46:50,703|(INFO)| File: common | Message: Created directory: artifacts]
[2025-07-28 22:46:50,729|(INFO)| File: common | Message: Created directory: artifacts/model_inference]
[2025-07-28 22:46:50,745|(INFO)| File: 1374819729 | Message: Downloading model from https://drive.google.com/file/d/1wDxN-Fmsd3oGFJT2a__38iK_gtJfhFSo/view?usp=sharing]
[2025-07-28 22:46:50,756|(INFO)| File: 1374819729 | Message: File already exists at artifacts/model_inference/model.pth with size 162722852 bytes. Skipping download.]
[2025-07-28 22:46:50,777|(INFO)| File: 145245731 | Message: Model downloaded successfully.]


In [13]:
import torch
model = torch.load(data_ingestion_config.local_model_file,
                   map_location=torch.device('cpu'),
                   weights_only=False)
print(type(model))

<class 'torchvision.models.vgg.VGG'>
