In [1]:
import os

In [2]:
%pwd

'd:\\Satellite-Image-Classification\\research'

In [3]:
os.chdir("../")

In [4]:
%pwd

'd:\\Satellite-Image-Classification'

In [5]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class PrepareBaseModelConfig:
    root_dir: Path
    base_model_path: Path
    updated_base_model_path: Path
    params_image_size: list
    params_learning_rate: float
    params_include_top: bool
    params_weights: str
    params_classes: int

In [6]:
from cnnClassifier.constants import *
from cnnClassifier.utils.common import read_yaml, create_directories

In [7]:
class ConfigurationManager:
    def __init__(
        self, 
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])


    def get_prepare_base_model_config(self) -> PrepareBaseModelConfig:
        config = self.config.prepare_base_model
        
        create_directories([config.root_dir])

        prepare_base_model_config = PrepareBaseModelConfig(
            root_dir=Path(config.root_dir),
            base_model_path=Path(config.base_model_path),
            updated_base_model_path=Path(config.updated_base_model_path),
            params_image_size=self.params.IMAGE_SIZE,
            params_learning_rate=self.params.LEARNING_RATE,
            params_include_top=self.params.INCLUDE_TOP,
            params_weights=self.params.WEIGHTS,
            params_classes=self.params.CLASSES
        )

        return prepare_base_model_config

In [8]:
import os
import urllib.request as request
from zipfile import ZipFile

In [10]:
pip install torchvision


Collecting torchvision
  Downloading torchvision-0.15.2-cp38-cp38-win_amd64.whl (1.2 MB)
     ---------------------------------------- 0.0/1.2 MB ? eta -:--:--
     ---------------------------------------- 0.0/1.2 MB ? eta -:--:--
      --------------------------------------- 0.0/1.2 MB 131.3 kB/s eta 0:00:09
     - -------------------------------------- 0.0/1.2 MB 281.8 kB/s eta 0:00:05
     --- ------------------------------------ 0.1/1.2 MB 595.3 kB/s eta 0:00:02
     ----------- ---------------------------- 0.4/1.2 MB 1.6 MB/s eta 0:00:01
     ----------------------------- ---------- 0.9/1.2 MB 3.2 MB/s eta 0:00:01
     ---------------------------------------- 1.2/1.2 MB 3.8 MB/s eta 0:00:00
Installing collected packages: torchvision
Successfully installed torchvision-0.15.2
Note: you may need to restart the kernel to use updated packages.


In [11]:
import torch
import torch.nn as nn
import torchvision.models as models

class PrepareBaseModel:
    def __init__(self, config):
        self.config = config

    def get_base_model(self):
        self.model = models.vgg16(
            pretrained=self.config.params_weights == 'imagenet',
            num_classes=self.config.params_classes if self.config.params_include_top else 1000
        )

        self.save_model(path=self.config.base_model_path, model=self.model)

    @staticmethod
    def _prepare_full_model(model, classes, freeze_all, freeze_till, learning_rate):
        if freeze_all:
            for param in model.parameters():
                param.requires_grad = False
        elif (freeze_till is not None) and (freeze_till > 0):
            for idx, layer in enumerate(model.children()):
                if idx < len(model.children()) - freeze_till:
                    for param in layer.parameters():
                        param.requires_grad = False

        in_features = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_features, classes)

        full_model = model
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(full_model.parameters(), lr=learning_rate)

        return full_model, criterion, optimizer

    def update_base_model(self):
        self.full_model, criterion, optimizer = self._prepare_full_model(
            model=self.model,
            classes=self.config.params_classes,
            freeze_all=True,
            freeze_till=None,
            learning_rate=self.config.params_learning_rate
        )

        self.save_model(path=self.config.updated_base_model_path, model=self.full_model)

    @staticmethod
    def save_model(path, model):
        torch.save(model.state_dict(), path)

In [13]:
try:
    config = ConfigurationManager()
    prepare_base_model_config = config.get_prepare_base_model_config()
    prepare_base_model = PrepareBaseModel(config = prepare_base_model_config)
    prepare_base_model.get_base_model()
    prepare_base_model.update_base_model()
except Exception as e:
    raise e

[2023-08-30 18:11:56,539: INFO: common: yaml file: config\config.yaml loaded successfully]
[2023-08-30 18:11:56,550: INFO: common: yaml file: params.yaml loaded successfully]
[2023-08-30 18:11:56,553: INFO: common: created directory at: artifacts]
[2023-08-30 18:11:56,555: INFO: common: created directory at: artifacts/prepare_base_model]


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\Aditya Rao/.cache\torch\hub\checkpoints\vgg16-397923af.pth
100%|██████████| 528M/528M [00:16<00:00, 33.8MB/s] 
