In [1]:
import os

In [2]:
%pwd

'/Users/mark42/Documents/ML-Pipeline/research'

In [3]:
os.chdir("../")

In [4]:
%pwd

'/Users/mark42/Documents/ML-Pipeline'

In [5]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class CreateModelConfig:
    root_dir: Path
    model_path: Path

In [6]:
from cnnClassifier.constants import *
from cnnClassifier.utils.common import read_yaml, create_directories

In [7]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])


    def get_create_model_config(self) -> CreateModelConfig:
        config = self.config.create_model
        
        create_directories([config.root_dir])

        create_model_config = CreateModelConfig(
            root_dir=Path(config.root_dir),
            model_path=Path(config.model_path),
        )

        return create_model_config

In [8]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [9]:
class CreateModel(nn.Module):
    def __init__(self, config: CreateModelConfig):
        self.config = config
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
    
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

    @staticmethod
    def save_model(path: Path, model):
        torch.save(model, path)
        print(f"Saved PyTorch Model State {path}")


In [10]:
try:
    # Get cpu, gpu or mps device for training.
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using {device} device")
    config = ConfigurationManager()
    create_model_config = config.get_create_model_config()
    create_model = CreateModel(config=create_model_config)
    model = create_model.to(device)
    print(model)
    create_model.save_model(path=create_model_config.model_path, model=model)
    
except Exception as e:
    raise e

Using mps device
[2024-08-12 18:59:48,923: INFO: common: yaml file: config/config.yaml loaded successfully]
[2024-08-12 18:59:48,925: INFO: common: yaml file: params.yaml loaded successfully]
[2024-08-12 18:59:48,926: INFO: common: created directory at: artifacts]
[2024-08-12 18:59:48,926: INFO: common: created directory at: artifacts/create_model]
CreateModel(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
Saved PyTorch Model State artifacts/create_model/model
