In [1]:
import os


In [2]:
os.chdir("../")

In [3]:
%pwd

'c:\\Ankan\\M.Tech\\2024\\Term3\\Industrial_AI_Scale\\MLOPS_Project\\CH24M512-MLOPS-Pipeline'

In [5]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    params_image_size: list
    params_learning_rate: float
    params_include_top: bool
    params_weights: str
    params_classes: int

In [6]:
from cnnClassifier.constants import *
from cnnClassifier.utils.common import read_yaml, create_directories

In [9]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

    

    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        
        create_directories([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_image_size=self.params.IMAGE_SIZE,
            params_learning_rate=self.params.LEARNING_RATE,
            params_include_top=self.params.INCLUDE_TOP,
            params_weights=self.params.WEIGHTS,
            params_classes=self.params.CLASSES
        )

        return prepare_base_model_config

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from pathlib import Path


class PrepareBaseModel:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def get_base_model(self):
        # Load pretrained VGG16
        self.model = models.vgg16(
            weights="IMAGENET1K_V1" if self.config.params_weights == "imagenet" else None
        )

        # Adjust input shape (torch models already expect (3, 224, 224))
        # No need to pass input shape explicitly like TF

        if not self.config.params_include_top:
            # Remove the classifier head (keep only convolutional base)
            self.model.classifier = nn.Identity()

        # Save initial base model
        self.save_model(path=self.config.base_model_path, model=self.model)

    @staticmethod
    def _prepare_full_model(model, classes, freeze_all, freeze_till, learning_rate):
        # Freeze layers
        if freeze_all:
            for param in model.parameters():
                param.requires_grad = False
        elif (freeze_till is not None) and (freeze_till > 0):
            # Freeze all layers except last `freeze_till`
            child_counter = 0
            for child in model.features.children():
                if child_counter < len(model.features) - freeze_till:
                    for param in child.parameters():
                        param.requires_grad = False
                child_counter += 1

        # Replace classifier
        in_features = model.classifier[0].in_features if isinstance(model.classifier, nn.Sequential) else 25088
        model.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, classes),
            nn.Softmax(dim=1)
        )

        # Define loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=learning_rate)

        return model, criterion, optimizer

    def update_base_model(self):
        self.full_model, self.criterion, self.optimizer = self._prepare_full_model(
            model=self.model,
            classes=self.config.params_classes,
            freeze_all=True,
            freeze_till=None,
            learning_rate=self.config.params_learning_rate
        )

        # Save updated model
        self.save_model(path=self.config.updated_base_model_path, model=self.full_model)

    @staticmethod
    def save_model(path: Path, model: torch.nn.Module):
        torch.save(model.state_dict(), path)
