In [1]:
import os

In [2]:
os.chdir("../")

In [57]:
from dataclasses import dataclass, field
from pathlib import Path


@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    params_image_size: list
    params_learning_rate: float
    params_include_top: bool
    params_weights: str
    params_classes: int
    params_freeze_all: bool = field(default=True)
    params_freeze_till: int = field(default=2)

In [58]:
from src.constants import *
from src.utils.auxiliary_functions import read_yaml, create_directories
#import tensorflow as tf
from pathlib import Path
from typing import Any

In [59]:
CONFIG_FILE_PATH = Path("config/config.yaml")
PARAMS_FILE_PATH = Path("params.yaml")

In [60]:
class ConfigurationManager:
    def __init__(
        self, 
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])


    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        
        create_directories([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_image_size=self.params.IMAGE_SIZE,
            params_learning_rate=self.params.LEARNING_RATE,
            params_include_top=self.params.INCLUDE_TOP,
            params_weights=self.params.WEIGHTS,
            params_classes=self.params.CLASSES,
            params_freeze_all=self.params.FREEZE_ALL,
            params_freeze_till=self.params.get('FREEZE_TILL', None)  # Use get with a default value
        )

        return prepare_base_model_config


In [61]:
import torch
import torchvision.models as models
from torch import nn, optim

class PrepareBaseModel(nn.Module):
    def __init__(self, config):
        super(PrepareBaseModel, self).__init__()
        self.config = config
        self.params_freeze_all = self.config.params_freeze_all
        self.params_freeze_till = self.config.params_freeze_till  # Corrected this line
        self.model = self.get_base_model()
        self.full_model, self.optimizer, self.loss_fn = self.update_base_model()

    def get_base_model(self):
        # Load the pre-trained ResNet101 model
        model = models.resnet101(pretrained=self.config.params_weights)

        # If not including the fully connected top layer
        if not self.config.params_include_top:
            model = nn.Sequential(*list(model.children())[:-2])
        return model

    def _prepare_full_model(self, model, classes, freeze_all, freeze_till, learning_rate):
        if freeze_all:
            for param in model.parameters():
                param.requires_grad = False
        elif freeze_till is not None and freeze_till > 0:
            ct = 0
            for child in model.children():
                ct += 1
                if ct < freeze_till:
                    for param in child.parameters():
                        param.requires_grad = False

        # Modify the model based on whether it's Sequential or not
        if isinstance(model, nn.Sequential):
            # Adjust the in_features according to your architecture
            in_features = 2048  # This is specific to ResNet101
            model = nn.Sequential(
                model,
                nn.Flatten(),
                nn.Linear(in_features, classes),
                nn.Softmax(dim=1)
            )
        else:
            # Replace the last fully connected layer
            num_features = model.fc.in_features
            model.fc = nn.Linear(num_features, classes)
            model.add_module('softmax', nn.Softmax(dim=1))

        optimizer = optim.SGD(model.parameters(), lr=learning_rate)
        loss_fn = nn.CrossEntropyLoss()
        return model, optimizer, loss_fn

    def update_base_model(self):
        full_model, optimizer, loss_fn = self._prepare_full_model(
            model=self.model,
            classes=self.config.params_classes,
            freeze_all=self.config.params_freeze_all,
            freeze_till=self.config.params_freeze_till,
            learning_rate=self.config.params_learning_rate
        )
        return full_model, optimizer, loss_fn

    @staticmethod
    def save_model(path, model):
        torch.save(model.state_dict(), path)

# Example usage
# You need to define your own configuration class or object
# config = YourConfigurationClass()
# prepare_base_model = PrepareBaseModel(config=config)

In [62]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
except Exception as e:
    raise e

[2023-12-04 20:44:32,879: INFO: auxiliary_functions: yaml file: config/config.yaml loaded successfully]
[2023-12-04 20:44:32,880: INFO: auxiliary_functions: The path is: config/config.yaml]
[2023-12-04 20:44:32,882: INFO: auxiliary_functions: The content is: {'artifacts_root': 'artifacts', 'data_ingestion': {'root_dir': 'artifacts', 'local_data_file': 'artifacts/train.zip', 'local_source_file': 'artifacts/train.zip', 'unzip_dir': 'artifacts'}, 'prepare_base_model': {'root_dir': 'artifacts/prepare_base_model', 'base_model_path': 'artifacts/prepare_base_model/base_model.h5', 'updated_base_model_path': 'artifacts/prepare_base_model/base_model_updated.h5'}, 'prepare_callbacks': {'root_dir': 'artifacts/prepare_callbacks', 'tensorboard_root_log_dir': 'artifacts/prepare_callbacks/tensorboard_log_dir', 'checkpoint_model_filepath': 'artifacts/prepare_callbacks/checkpoint_dir/model.h5'}, 'training': {'root_dir': 'artifacts/training', 'trained_model_path': 'artifacts/training/model.h5'}}]
[2023-1