In [None]:
pwd

In [None]:
import os

In [None]:
os.chdir("../")

In [None]:
pwd

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from mlflow.exceptions import MlflowException
from mlflow.tracking import MlflowClient

import glob
from PIL import Image
from pathlib import Path
import time
import os
import mlflow
from typing import Dict, List, Optional
from src.cancer_detection import logger

In [None]:
from dataclasses import dataclass
from pathlib import Path


@dataclass
class TrainingConfig:
    root_dir: Path
    training_data: Path
    model_checkpoints: Path 
    best_model_checkpoints: Path
    params_is_augmentation: bool
    params_image_size: list
    params_batch_size: int
    params_epochs: int
    params_num_classes: int
    params_learning_rate: float


In [None]:
from src.cancer_detection.constants import *
from src.cancer_detection.utils.common import read_yaml, create_directories


class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH):

        self.config = read_yaml(config_filepath)
        self.params = read_yaml(params_filepath)

        create_directories([self.config.artifacts_root])

        

    def get_training_config(self) -> TrainingConfig:
        training_config = self.config.training
        params = self.params
        training_data = os.path.join(self.config.data_ingestion.unzip_dir, "Chest-CT-Scan-data")
        
        create_directories([Path(training_config.root_dir)])
        create_directories([Path(training_config.model_checkpoints)])
        create_directories([Path(training_config.best_model_checkpoints)])

        training_config_ = TrainingConfig(
            root_dir=Path(training_config.root_dir),
            training_data=Path(training_data),
            model_checkpoints=training_config.model_checkpoints,
            best_model_checkpoints=training_config.best_model_checkpoints,
            params_is_augmentation=params.AUGMENTATION,
            params_image_size=params.IMAGE_SIZE,
            params_batch_size=params.BATCH_SIZE,
            params_epochs=params.EPOCHS,
            params_num_classes=params.CLASSES,
            params_learning_rate=params.LEARNING_RATE
        )

        return training_config_

In [None]:
# Data Augmentation
class ImageTransform():
    def __init__(
            self,
            img_size: int  = 224, 
            mean: Optional[list] = None,
            std: Optional[list] = None
        ):

        if mean is None:
            self.mean = [0.485, 0.456, 0.406]
        if std is None:
            self.std = [0.229, 0.224, 0.225]
        
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(img_size, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ]),
            'val': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(img_size),
                transforms.ToTensor()
            ]),
            'test': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(img_size),
                transforms.ToTensor()
            ])
        }
        
    def __call__(self, img : Image.Image, phase : str) -> torch.tensor:
        img = self.data_transform[phase](img)

        if img.shape[0] == 1:
            logger.info(f'logging shape : \n {img.shape}, dtype : {img.dtype}')
            img = torch.repeat_interleave(img, 3, dim=0)
        img = img[:3, :, :]
        
        # Normalize
        img = transforms.functional.normalize(img, mean=self.mean, std=self.std)
        return img
    

class cancerDataset(Dataset):
    """A PyTorch Dataset for the cancer image data."""
    def __init__(self, file_list, transform_fun=None, phase='train') -> None:
        # self.path = path
        # self.files = glob.glob(os.path.join(self.path, '**/*.png'), recursive=True)
        self.file_list = file_list
        self.transform = transform_fun
        self.phase = phase

    def __len__(self) -> int:
        """Returns the number of examples in the dataset."""
        return len(self.file_list)

    def __getitem__(self, idx: int):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        # Transformimg Image
        img_transformed = self.transform(img, self.phase)
        
        # Get Label
        label = img_path.split("/")[-2]
        if label == 'adenocarcinoma':
            label = 0
        elif label == 'normal':
            label = 1

        return img_transformed, label
    


class cancerDataModule(pl.LightningDataModule):
    def __init__(self, path, img_size, batch_size, num_workers, seed: Optional[int] = None) -> None:
        super().__init__()
        self.path = path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.seed = seed
        self.img_size = img_size
        self.rng = np.random.default_rng(seed)
    

    def setup(self,  stage: str) -> None:
        """Splits the dataset into training, validation, and test sets."""

        self.files = np.array(glob.glob(os.path.join(self.path, '**/*.png'), recursive=True))
        train_size = int(0.8 * len(self.files))
        val_size = int(0.1 * len(self.files))
        indices = np.arange(len(self.files))
        self.rng.shuffle(indices)
        train_indices = indices[:train_size]
        val_indices = indices[train_size : train_size + val_size]
        test_indices = indices[train_size + val_size :]
        self.train_files = self.files[train_indices]
        self.val_files   = self.files[val_indices]
        self.test_files  = self.files[test_indices]
            

    def train_dataloader(self) -> DataLoader:
        """Returns a DataLoader for the training set."""
        dataset = cancerDataset(self.train_files, ImageTransform(self.img_size), 'train')
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self) -> DataLoader:
        """Returns a DataLoader for the validation set."""
        dataset = cancerDataset(self.val_files, ImageTransform(self.img_size), 'val')
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def test_dataloader(self) -> DataLoader:
        """Returns a DataLoader for the test set."""
        dataset = cancerDataset(self.test_files,  ImageTransform(self.img_size), 'test')
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )

    def predict_dataloader(self) -> DataLoader:
        dataset = cancerDataset(self.test_files)
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
        )



In [None]:

# ---------------------------- transfer learning model----------------------- #
class vgg16_modified(nn.Module):
    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config
        self.model =  torchvision.models.vgg16(pretrained=True)
        self.model.classifier[6] = nn.Linear(
            in_features=self.model.classifier[6].in_features,
            out_features=self.config.params_num_classes
        )
        
        # Specify the layers for updating
        params_to_update = []
        update_params_name = ['classifier.6.weight', 'classifier.6.bias']
        for name, param in self.model.named_parameters():
            if name in update_params_name:
                param.requires_grad = True
                params_to_update.append(param)
            else:
                param.requires_grad = False
    
    def forward(self, batch):
        return self.model(batch)


# ---------------------------- Lightning Module ----------------------------- #
class cancerClassifier(pl.LightningModule):
    def __init__(self, model: nn.Module, config: TrainingConfig) -> None:
        """Load the CNN classifier."""
        super().__init__()
        self.config = config
        self.model = model
        # define loss function
        self.criterion = nn.CrossEntropyLoss()
        

    def forward(self, batch: torch.tensor) -> torch.tensor:
        """Forward pass of the model to return output predictions."""
        return self.model(batch)


    def training_step(self, batch: torch.tensor, batch_idx: int) -> torch.tensor:
        """Perform a single traing step, returning the loss on a training batch."""
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        self.log("train_loss", loss, prog_bar=True)

        preds = self.predict_step(logits)
        acc = torch.sum(preds == y).float()/len(y)
        self.log("train_acc", acc, prog_bar=True)
        return loss
    

    def validation_step(self, batch: torch.tensor, batch_idx: int) -> None:
        """Perform a single validation step."""
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        self.log("valid_loss", loss, prog_bar=True)

        preds = self.predict_step(logits)
        acc = torch.sum(preds == y).float()/len(y)
        self.log("val_acc", acc, prog_bar=True, logger=True)


    def test_step(self, batch: torch.tensor, batch_idx: int) -> None:
        """Perform a single test step."""
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        self.log("test_loss", loss, prog_bar=True)

        preds = self.predict_step(logits)
        acc = torch.sum(preds == y).float()/len(y)
        self.log("test_acc", acc, prog_bar=True)


    def configure_optimizers(self) -> torch.optim.Optimizer:
        """Configure the optimizer to use for training."""
        return torch.optim.Adam(
            self.parameters(),
            lr=self.config.params_learning_rate, # add learning rate here
        )


    def predict_step(self, logits: torch.tensor) -> torch.tensor: 
        preds = torch.argmax(logits, dim = 1)
        return preds



def train(config : TrainingConfig, fast_dev_run: bool = False):
    """Run the full data-loading and model-training loop."""
    
    # Set seed to control randomness
    seed = 123
    torch.manual_seed(seed)    

    # Prepare data module
    num_workers = max(0, (os.cpu_count() or 1) - 1)
    datamodule = cancerDataModule(
        path=config.training_data,
        batch_size=config.params_batch_size,
        num_workers=num_workers,
        img_size=config.params_image_size[0]
    )

    callbacks = [
        ModelCheckpoint(
            dirpath=config.model_checkpoints,
            filename="validation-{epoch}-{step}-{val_loss:.1f}",
            monitor="val_acc",
            save_top_k=1,  # save all checkpoints
            mode="max",
            every_n_epochs=5,
        ),
        EarlyStopping(
            monitor="val_acc",
            mode="max",
            patience=20,
            verbose=True,
        ),
    ]

    # model
    model = vgg16_modified(config)
    # Train model
    learner = cancerClassifier(model, config)
    trainer = pl.Trainer(
        max_epochs=config.params_epochs,
        fast_dev_run=fast_dev_run,
        enable_checkpointing=True,
        callbacks = callbacks
    )


    client = MlflowClient()
    experiment_name = "vgg16classifier"
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    run_name = f"{timestamp}"

    try:
        experiment_id = client.create_experiment(experiment_name)
        experiment = client.get_experiment(experiment_id)
    except MlflowException:
        experiment = client.get_experiment_by_name(experiment_name)
        experiment_id = experiment.experiment_id

    
    # Fetch experiment metadata information
    logger.info(f"Name: {experiment.name}")
    logger.info(f"Experiment_id: {experiment.experiment_id}")
    logger.info(f"Artifact Location: {experiment.artifact_location}")
    logger.info(f"Tags: {experiment.tags}")
    logger.info(f"Lifecycle_stage: {experiment.lifecycle_stage}")

    mlflow.set_tracking_uri("ADD URI HERE")
    
    # Activate auto logging for pytorch lightning module
    mlflow.pytorch.autolog(log_models=False)

     # Start MLflow run
    with mlflow.start_run(experiment_id=experiment_id, run_name=run_name) as run:
        mlflow.log_params(config.__dict__)
        logger.info("Training model...")
        trainer.fit(learner, datamodule=datamodule)

In [None]:
try:
    config = ConfigurationManager()
    training_config = config.get_training_config()
    training = train(training_config)
except Exception as e:
    raise e