In [2]:
import os

In [1]:
%pwd

'c:\\Users\\Sasu4\\SHIP_Classification_using_Resnet\\research'

In [3]:
os.chdir("../")

In [4]:
%pwd

'c:\\Users\\Sasu4\\SHIP_Classification_using_Resnet'

In [32]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    params_image_size: list
    params_learning_rate: float
    params_include_top: bool
    params_weights: str
    params_classes: int
    freeze_all: bool
    freeze_till: int = None

In [27]:
from Ship_Classifier.constants import *
from Ship_Classifier.utils.common import read_yaml,create_directories
from pathlib import Path
from Ship_Classifier.entity.config_entity import PrepareBaseModelConfig

In [114]:
class ConfigurationManager:
    def __init__(
        self, 
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])


    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        
        create_directories([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_image_size=self.params.IMAGE_SIZE,
            params_learning_rate=self.params.LEARNING_RATE,
            params_include_top=self.params.INCLUDE_TOP,
            params_weights=self.params.WEIGHTS,
            params_classes=self.params.CLASSES,
            freeze_all=config.get("freeze_all", True),  # Default to True if not specified
            freeze_till=config.get("freeze_till", None)
        )

        return prepare_base_model_config

In [79]:
import os
import urllib.request as request
from zipfile import ZipFile
import torch
import torch.nn as nn
from torchvision import models,transforms
import torch.optim as optim
from PIL import Image

In [140]:
class PrepareBaseModel:
  def __init__(self, config: PrepareBaseModelConfig):
        self.config = config
        self.model=None
        self.optimizer = None  # Initialize as needed
        self.loss_fn = None  # Initialize as needed

  def get_base_model(self):
       
        # Load the pretrained ResNet-18 model
       self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
       self.model.fc = nn.Linear(self.model.fc.in_features, self.config.params_classes)
        # Freeze layers if specified
       print(f"Base model loaded with {self.config.params_classes} classes.")
        
        
  @staticmethod      
  def _prepare_full_model(model, classes, freeze_all, freeze_till, learning_rate,optimizer,loss_fn):
    # Freeze layers as specified
    if freeze_all:
        for param in model.parameters():
            param.requires_grad = False
    elif freeze_till is not None and freeze_till > 0:
        layers = list(model.children())
        num_layers_to_freeze = min(freeze_till, len(layers))  # Ensure we do not exceed the number of layers
        for i, layer in enumerate(layers[:num_layers_to_freeze]):
                for param in layer.parameters():
                    param.requires_grad = False
        
    
    # Modify the final layer to match the number of classes
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, classes)
    
    # Set up the optimizer
    if optimizer is None:
     optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
    
    # Set up the loss function
    if loss_fn is None:
     loss_fn = nn.CrossEntropyLoss()
     #Predictions
    def predict(input_tensor):
            with torch.no_grad():  # No need to track gradients for predictions
                model.eval()  # Set the model to evaluation mode
                output = model(input_tensor)
                predictions = torch.argmax(output, dim=1)
            return predictions
    
    print(f"Model prepared with classes={classes}, freeze_all={freeze_all}, freeze_till={freeze_till}, learning_rate={learning_rate}")
    # Summary
    print(model)
    
    return model, optimizer, loss_fn, predict
         

       
    
  @staticmethod
  def save_model(model,path: Path):
        # Save the PyTorch model to the specified path
      if model is not None:
        torch.save(model.state_dict(), path)
        print(f"Model saved at {path}")
      else:
            raise ValueError("Model is not defined. Please check if the model is properly initialized.")
          
        
  def update_base_model(self, freeze_all=True, freeze_till=None):
         self.full_model,self.optimizer,self.loss_fn ,self.predict_fn =   self._prepare_full_model (
            model=self.model,
            classes=self.config.params_classes,
            freeze_all=freeze_all,
            freeze_till=freeze_till,
            optimizer=self.optimizer,
            loss_fn=self.loss_fn,
            learning_rate=self.config.params_learning_rate)
        
        # print(f"Model prepared with classes={self._prepare_full_model.classes}, freeze_all={self._prepare_full_model.freeze_all}, freeze_till={self._prepare_full_model.freeze_till}, learning_rate={self._prepare_full_model.learning_rate}")
    # Summary
         print("Full Model :",self.full_model) 
         self.save_model(path=self.config.updated_base_model_path, model=self.full_model)

  def load_model(self, path: Path):

      #   Load the model from the specified path.
      
        self.full_model.load_state_dict(torch.load(path))
        print(f"Model loaded from {path}")

In [141]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config=prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
    

    
    #prepare_base_model.save_model(model=prepare_base_model.model,path=prepare_base_model_config.base_model_path)
    prepare_base_model.save_model(model=prepare_base_model.full_model,path=prepare_base_model_config.updated_base_model_path)
except Exception as e:
    raise Exception(f"An error occurred while preparing the base model: {str(e)}") from e

[2024-08-22 15:41:00,241: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-08-22 15:41:00,247: INFO: common: yaml file: params.yaml loaded successfully]
[2024-08-22 15:41:00,251: INFO: common: created directory at: artifacts]
[2024-08-22 15:41:00,254: INFO: common: created directory at: artifacts/prepare_base_model]
Base model loaded with 5 classes.
Model prepared with classes=5, freeze_all=True, freeze_till=None, learning_rate=0.01
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (r