In [1]:
import os

In [2]:
%pwd

'/Users/arash/ML_End_to_End_Pj/end-to-end-solar-dust-detection/research'

In [3]:
os.chdir("../")

In [4]:
%pwd

'/Users/arash/ML_End_to_End_Pj/end-to-end-solar-dust-detection'

# Update Entity

In [5]:
from dataclasses import dataclass
from pathlib import Path
from typing import List

@dataclass(frozen=True)
class BaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    
    # Model parameters
    params_image_size: List[int]
    params_learning_rate: float
    params_weights: str
    params_classes: int

## Update Configuration Manager

In [6]:
from solar_dust_detection.constants import *
from solar_dust_detection.utils.common import read_yaml, create_directories

In [7]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath: Path = CONFIG_FILE_PATH,
        params_filepath: Path = PARAMS_FILE_PATH,
    ):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([Path(self.config.artifacts_root)])
        
    def get_base_model_config(self) -> BaseModelConfig:
        config = self.config.base_model
        
        create_directories([Path(config.root_dir)])
        
        base_model_config = BaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_image_size=self.params.IMAGE_SIZE,
            params_learning_rate=self.params.LEARNING_RATE,
            params_weights=self.params.WEIGHTS,
            params_classes=self.params.CLASSES,
        )
        
        return base_model_config

## Update components

In [10]:
import os
import urllib.request as request
from zipfile import ZipFile
import torch
import torchvision.models as models
import torch.nn as nn
from solar_dust_detection import logger

In [None]:
class BaseModel:
    def __init__(self, config: BaseModelConfig):
        self.config = config
    
    def get_base_model(self):
        if self.config.params_weights == "imagenet":
            weights = models.ResNet18_Weights.IMAGENET1K_V1
        else:
            weights = None
        
        self.model = models.resnet18(weights=weights)
        self.save_model(path=self.config.base_model_path, model=self.model)
        logger.info(f"Base model saved at {self.config.root_dir}")

        for param in self.model.parameters():
            param.requires_grad = False
        
        num_features = self.model.fc.in_features
        
        self.model.fc = nn.Linear(num_features, self.config.params_classes)
        
        self.save_model(path=self.config.updated_base_model_path, model=self.model)
        
        logger.info(f"Updated model (classes={self.config.params_classes}) saved to {self.config.updated_base_model_path}")
    
    @staticmethod
    def save_model(path: Path, model: nn.Module):
        torch.save(model.state_dict(), path)

In [12]:
try:
    config_manager = ConfigurationManager()
    base_model_config = config_manager.get_base_model_config()
    
    logger.info(f"Base model config: {base_model_config}")
    
    base_model = BaseModel(config=base_model_config)
    base_model.get_base_model()
except Exception as e:
    logger.exception(e)
    raise e

[2026-01-24 11:50:45,885: INFO: common]: YAML file: config/config.yaml loaded successfully
[2026-01-24 11:50:45,895: INFO: common]: YAML file: params.yaml loaded successfully
[2026-01-24 11:50:45,899: INFO: common]: Created directory at: artifacts
[2026-01-24 11:50:45,900: INFO: common]: Created directory at: artifacts/base_model
[2026-01-24 11:50:45,902: INFO: 3036805536]: Base model config: BaseModelConfig(root_dir=PosixPath('artifacts/base_model'), base_model_path=PosixPath('artifacts/base_model/base_model.pt'), updated_base_model_path=PosixPath('artifacts/base_model/updated_base_model.pt'), params_image_size=BoxList([224, 224, 3]), params_learning_rate=0.01, params_weights='imagenet', params_classes=2)
[2026-01-24 11:50:46,535: INFO: 3156017278]: Base model and updated model are saved at artifacts/base_model
[2026-01-24 11:50:46,709: INFO: 3156017278]: Base model and updated model are saved at artifacts/base_model
[2026-01-24 11:50:46,709: INFO: 3156017278]: Updated model (classes=