In [1]:
import os
from dataclasses import dataclass
from pathlib import Path

In [3]:
# %cd ..
# %pwd

'd:\\Machine Learning Projects\\Unet-R Full-stack\\ML_model'

In [77]:
from dataclasses import dataclass
from pathlib import Path


@dataclass(frozen=True)
class PrepareModelConfig:
    root_dir: Path
    model_path: Path
    
    params_image_size: int
    params_num_classes: int
    params_num_layers: int
    params_hidden_dim: int
    params_mlp_dim: int
    params_num_heads: int
    params_dropout_rate: float
    params_num_patches: int
    params_patch_size: int
    params_num_channels: int
    params_learning_rate: float
    

In [78]:
from UNetRMultiClass.constants import *
from UNetRMultiClass.utils.common import read_yaml, create_directories


class ConfigurationManager:
    def __init__(
        self, config_filepath=CONFIG_FILE_PATH, params_filepath=PARAMS_FILE_PATH
    ):
        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)
        create_directories([self.config.artifacts_root])
        self.params.LITE_NUM_PATCHES = (self.params.LITE_IMAGE_SIZE**2)//(self.params.LITE_PATCH_SIZE**2)
        self.params.FULL_NUM_PATCHES = (self.params.FULL_IMAGE_SIZE**2)//(self.params.FULL_PATCH_SIZE**2)
        ## altering the num_patches because the YAML outputs the string instead of the expression
        
        
    def get_prepare_full_model_config(self) -> PrepareModelConfig:
        config = self.config.prepare_models

        create_directories([config.root_dir])

        prepare_full_model_config = PrepareModelConfig(
            root_dir=Path(config.root_dir),
            model_path=Path(config.full_model_path),
            params_image_size=self.params.FULL_IMAGE_SIZE,
            params_num_classes=self.params.NUM_CLASSES,
            params_num_layers=self.params.FULL_NUM_LAYERS,
            params_hidden_dim=self.params.FULL_HIDDEN_DIM,
            params_mlp_dim=self.params.FULL_MLP_DIM,
            params_num_heads=self.params.FULL_NUM_HEADS,
            params_dropout_rate=self.params.DROPOUT_RATE,
            params_num_patches=self.params.FULL_NUM_PATCHES,
            params_patch_size=self.params.FULL_PATCH_SIZE,
            params_num_channels=self.params.NUM_CHANNELS,
            params_learning_rate=self.params.LEARNING_RATE,
        )

        return prepare_full_model_config

    def get_prepare_lite_model_config(self) -> PrepareModelConfig:
        config = self.config.prepare_models

        create_directories([config.root_dir])

        prepare_lite_model_config = PrepareModelConfig(
            root_dir=Path(config.root_dir),
            model_path=Path(config.lite_model_path),
            params_image_size=self.params.LITE_IMAGE_SIZE,
            params_num_classes=self.params.NUM_CLASSES,
            params_num_layers=self.params.LITE_NUM_LAYERS,
            params_hidden_dim=self.params.LITE_HIDDEN_DIM,
            params_mlp_dim=self.params.LITE_MLP_DIM,
            params_num_heads=self.params.LITE_NUM_HEADS,
            params_dropout_rate=self.params.DROPOUT_RATE,
            params_num_patches=self.params.LITE_NUM_PATCHES,
            params_patch_size=self.params.LITE_PATCH_SIZE,
            params_num_channels=self.params.NUM_CHANNELS,
            params_learning_rate=self.params.LEARNING_RATE,
        )

        return prepare_lite_model_config

In [19]:
params = read_yaml(PARAMS_FILE_PATH)

[2024-05-06 18:25:25,743: INFO: common: yaml file: params.yaml loaded successfully]


In [23]:
params.LITE_NUM_PATCHES = (params.LITE_IMAGE_SIZE**2)//(params.LITE_PATCH_SIZE**2)
params.FULL_NUM_PATCHES = (params.FULL_IMAGE_SIZE**2)//(params.FULL_PATCH_SIZE**2)

params.LITE_FLAT_PATCHES_SHAPE = (params.LITE_NUM_PATCHES, params.LITE_PATCH_SIZE*params.LITE_PATCH_SIZE*params.NUM_CHANNELS)
params.FULL_FLAT_PATCHES_SHAPE = (params.FULL_NUM_PATCHES, params.FULL_PATCH_SIZE*params.FULL_PATCH_SIZE*params.NUM_CHANNELS)

In [24]:
params

ConfigBox({'LITE_IMAGE_SIZE': 256, 'LITE_NUM_LAYERS': 12, 'LITE_HIDDEN_DIM': 128, 'LITE_MLP_DIM': 32, 'LITE_NUM_HEADS': 6, 'LITE_PATCH_SIZE': 16, 'FULL_IMAGE_SIZE': 256, 'FULL_NUM_LAYERS': 12, 'FULL_HIDDEN_DIM': 768, 'FULL_MLP_DIM': 3072, 'FULL_NUM_HEADS': 12, 'FULL_PATCH_SIZE': 16, 'NUM_CLASSES': 11, 'DROPOUT_RATE': 0.1, 'NUM_PATCHES': '(image_size**2)//(patch_size**2)', 'NUM_CHANNELS': 3, 'LEARNING_RATE': 0.1, 'BATCH_SIZE': 16, 'NUM_EPOCHS': 1, 'FLAT_PATCHES_SHAPE': '(num_patches,patch_size*patch_size*num_channels)', 'LITE_NUM_PATCHES': 256, 'FULL_NUM_PATCHES': 256, 'LITE_FLAT_PATCHES_SHAPE': (256, 768), 'FULL_FLAT_PATCHES_SHAPE': (256, 768)})

In [79]:

import os
import urllib.request as request
from zipfile import ZipFile
import tensorflow as tf
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam, SGD
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

from math import log2

In [80]:
class PrepareModel:
    def __init__(self, config: PrepareModelConfig):
        self.config = config

    def mlp(self, x):
        x = L.Dense(self.config.params_mlp_dim, activation="gelu")(x)
        x = L.Dropout(self.config.params_dropout_rate)(x)
        x = L.Dense(self.config.params_hidden_dim)(x)
        x = L.Dropout(self.config.params_dropout_rate)(x)
        
        return x
    
    def transformer_encoder(self, x):
    
        skip_1 = x
        x = L.LayerNormalization()(x)
        x = L.MultiHeadAttention(num_heads=self.config.params_num_heads, key_dim=self.config.params_hidden_dim)(x,x)
        x = L.Add()([x, skip_1])
        
        skip_2 = x
        x = L.LayerNormalization()(x)
        x = self.mlp(x)
        x = L.Add()([x, skip_2])
        
        return x
    
    def conv_block(self, x, num_filters, kernel_size=3):
        x = L.Conv2D(num_filters, kernel_size=kernel_size, padding="same")(x)
        x = L.BatchNormalization()(x)
        x = L.ReLU()(x)
        
        return x
        
    def deconv_block(self, x, num_filters):
        x = L.Conv2DTranspose(num_filters, kernel_size=2, padding="same", strides=2)(x)
        return x
    
    def get_full_model(self):
        """ inputs """
    
        input_shape = (self.config.params_num_patches, self.config.params_patch_size*self.config.params_patch_size*self.config.params_num_channels)
        inputs = L.Input(input_shape)  ## (None, 256, 768)
        
        """ Patch + Positional Embeddings """
        patch_embed = L.Dense(self.config.params_hidden_dim)(inputs)  ## (None, 256, 768)
        
        positions = tf.range(start=0, limit=self.config.params_num_patches, delta=1) ## (256, )
        
        pos_embed = L.Embedding(input_dim=self.config.params_num_patches, output_dim=self.config.params_hidden_dim)(positions)  ## (256, 768)
        
        x = patch_embed + pos_embed
        
        skip_connection_indexes = [3, 6, 9, 12]
        skip_connections = []
        for i in range(1, self.config.params_num_layers +1, 1):
            x = self.transformer_encoder(x)  ## (None, 256, 768)
            
            if i in skip_connection_indexes:
                skip_connections.append(x)
                
        """ CNN Decoder  """
        
        
        z3, z6, z9, z12 = skip_connections
        
        size = self.config.params_image_size // self.config.params_patch_size
        
        """ Reshaping """
        z0 = L.Reshape((self.config.params_image_size, self.config.params_image_size, self.config.params_num_channels))(inputs)  ## (None, 256, 256, 3)
        
        z3 = L.Reshape((size, size, z3.shape[-1]))(z3)  ## (None, 16, 16, 768)
        z6 = L.Reshape((size, size, z6.shape[-1]))(z6)  ## (None, 16, 16, 768)
        z9 = L.Reshape((size, size, z9.shape[-1]))(z9)  ## (None, 16, 16, 768)
        z12 = L.Reshape((size, size, z12.shape[-1]))(z12)  ## (None, 16, 16, 768)
        
        ## Decoder 1
        x = self.deconv_block(z12, 512)
        
        s = self.deconv_block(z9, 512)
        s = self.conv_block(s, 512)
        
        x = L.Concatenate()([x,s])
        x = self.conv_block(x, 512)
        x = self.conv_block(x, 512)
        
        ## Decoder 2
        x = self.deconv_block(x, 256)
        
        s = self.deconv_block(z6, 256)
        s = self.conv_block(s, 256)
        s = self.deconv_block(s, 256)
        s = self.conv_block(s, 256)
        
        x = L.Concatenate()([x, s])
        x = self.conv_block(x, 256)
        x = self.conv_block(x, 256)
        
        ## Decoder 3
        x = self.deconv_block(x, 128)
        
        s = self.deconv_block(z3, 128)
        s = self.conv_block(s, 128)
        s = self.deconv_block(s, 128)
        s = self.conv_block(s, 128)
        s = self.deconv_block(s, 128)
        s = self.conv_block(s, 128)
        
        x = L.Concatenate()([x, s])
        x = self.conv_block(x, 128)
        x = self.conv_block(x, 128)
        
        ## Decoder 4
        x = self.deconv_block(x, 64)
        
        s = self.conv_block(z0, 64)
        s = self.conv_block(s, 64)
        
        x = L.Concatenate()([x, s])
        x = self.conv_block(x, 64)
        x = self.conv_block(x, 64)
        
        """ Output """
        outputs = L.Conv2D(self.config.params_num_classes, kernel_size=1, padding="same", activation="sigmoid")(x) ## 1 -> binary segmentation and hence the sigmoid fxn, can change for multi-class
        full_model = Model(inputs, outputs, name="UNETR_2D")
        full_model.compile(loss="categorical_crossentropy", optimizer=SGD(self.config.params_learning_rate))
        
        
        full_model.summary()
        self.save_model(path=self.config.model_path, model=full_model)
        
        return full_model

    def get_lite_model(self):
        """ Inputs """
        input_shape = (self.config.params_num_patches, self.config.params_patch_size * self.config.params_patch_size * self.config.params_num_channels)
        inputs = L.Input(input_shape) ## (None, 256, 3072)
        # print(inputs.shape)

        """ Patch + Position Embeddings """
        patch_embed = L.Dense(self.config.params_hidden_dim)(inputs) ## (None, 256, 768)

        positions = tf.range(start=0, limit=self.config.params_num_patches, delta=1) ## (256,)
        pos_embed = L.Embedding(input_dim=self.config.params_num_patches, output_dim=self.config.params_hidden_dim)(positions) ## (256, 768)
        x = patch_embed + pos_embed ## (None, 256, 768)

        """ Transformer Encoder """
        skip_connection_index = [3, 6, 9, 12]
        skip_connections = []

        for i in range(1, self.config.params_num_layers+1, 1):
            x = self.transformer_encoder(x)

            if i in skip_connection_index:
                skip_connections.append(x)

        """ CNN Decoder """
        z3, z6, z9, z12 = skip_connections

        ## Reshaping
        z0 = L.Reshape((self.config.params_image_size, self.config.params_image_size, self.config.params_num_channels))(inputs)

        shape = (
            self.config.params_image_size//self.config.params_patch_size,
            self.config.params_image_size//self.config.params_patch_size,
            self.config.params_hidden_dim
        )
        z3 = L.Reshape(shape)(z3)
        z6 = L.Reshape(shape)(z6)
        z9 = L.Reshape(shape)(z9)
        z12 = L.Reshape(shape)(z12)

        ## Additional layers for managing different patch sizes
        total_upscale_factor = int(log2(self.config.params_patch_size))
        upscale = total_upscale_factor - 4

        if upscale >= 2: ## Patch size 16 or greater
            z3 = self.deconv_block(z3, z3.shape[-1], strides=2**upscale)
            z6 = self.deconv_block(z6, z6.shape[-1], strides=2**upscale)
            z9 = self.deconv_block(z9, z9.shape[-1], strides=2**upscale)
            z12 = self.deconv_block(z12, z12.shape[-1], strides=2**upscale)
            # print(z3.shape, z6.shape, z9.shape, z12.shape)

        if upscale < 0: ## Patch size less than 16
            p = 2**abs(upscale)
            z3 = L.MaxPool2D((p, p))(z3)
            z6 = L.MaxPool2D((p, p))(z6)
            z9 = L.MaxPool2D((p, p))(z9)
            z12 = L.MaxPool2D((p, p))(z12)

        ## Decoder 1
        x = self.deconv_block(z12, 128)

        s = self.deconv_block(z9, 128)
        s = self.conv_block(s, 128)

        x = L.Concatenate()([x, s])

        x = self.conv_block(x, 128)
        x = self.conv_block(x, 128)

        ## Decoder 2
        x = self.deconv_block(x, 64)

        s = self.deconv_block(z6, 64)
        s = self.conv_block(s, 64)
        s = self.deconv_block(s, 64)
        s = self.conv_block(s, 64)

        x = L.Concatenate()([x, s])
        x = self.conv_block(x, 64)
        x = self.conv_block(x, 64)

        ## Decoder 3
        x = self.deconv_block(x, 32)

        s = self.deconv_block(z3, 32)
        s = self.conv_block(s, 32)
        s = self.deconv_block(s, 32)
        s = self.conv_block(s, 32)
        s = self.deconv_block(s, 32)
        s = self.conv_block(s, 32)

        x = L.Concatenate()([x, s])
        x = self.conv_block(x, 32)
        x = self.conv_block(x, 32)

        ## Decoder 4
        x = self.deconv_block(x, 16)

        s = self.conv_block(z0, 16)
        s = self.conv_block(s, 16)

        x = L.Concatenate()([x, s])
        x = self.conv_block(x, 16)
        x = self.conv_block(x, 16)

        """ Output """
        outputs = L.Conv2D(self.config.params_num_classes, kernel_size=1, padding="same", activation="sigmoid")(x)

        lite_model = Model(inputs, outputs, name="UNETR_2D_lite")
        lite_model.compile(loss="categorical_crossentropy", optimizer=SGD(self.config.params_learning_rate))
        
        
        lite_model.summary()
        self.save_model(path=self.config.model_path, model=lite_model)
        
        return lite_model
        
        
    @staticmethod
    def save_model(path: Path, model: tf.keras.Model):
        model.save(path)


In [81]:
try:
    config = ConfigurationManager()
    prepare_full_model_config = config.get_prepare_full_model_config()
    prepare_full_model = PrepareModel(config=prepare_full_model_config)
    prepare_full_model.get_full_model()
    
    prepare_lite_model_config = config.get_prepare_lite_model_config()
    prepare_lite_model = PrepareModel(config=prepare_lite_model_config)
    prepare_lite_model.get_lite_model()
except Exception as e:
    raise e

[2024-05-06 19:43:23,349: INFO: common: yaml file: config\config.yaml loaded successfully]
[2024-05-06 19:43:23,370: INFO: common: yaml file: params.yaml loaded successfully]
[2024-05-06 19:43:23,372: INFO: common: created directory at: artifacts]
[2024-05-06 19:43:23,375: INFO: common: created directory at: artifacts/prepare_model]


[2024-05-06 19:43:38,773: INFO: common: created directory at: artifacts/prepare_model]




In [42]:
prepare_full_model_config.params_mlp_dim

3072