In [1]:
import os
%pwd

'/mnt/cb03386d-9344-47b1-82f9-868fbb64b4ae/python_projects/Cell_Segmentation_YOLO-v8/research'

In [2]:
os.chdir('../')
%pwd

'/mnt/cb03386d-9344-47b1-82f9-868fbb64b4ae/python_projects/Cell_Segmentation_YOLO-v8'

In [3]:
from dataclasses import dataclass
from pathlib import Path

@dataclass(frozen=True)
class DataTransformationConfig:
    root_dir: Path
    data_path: Path
    train_path: Path
    validation_path: Path
    test_path: Path
    YAML_path: Path
    val_size: float
    aug_size: int
    aug_params: dict
    dataset_val_status: bool

In [4]:
from src.cellseg.constant import *
from src.cellseg.utils.main_utils import create_directories, read_yaml
class ConfigurationManager:
    def __init__(
        self,
        config_file_path = CONFIG_FILE_PATH,
        params_file_path = PARAMS_FILE_PATH,
        schema_file_path = SCHEMA_FILE_PATH
    ):
        self.config = read_yaml(config_file_path)
        self.params = read_yaml(params_file_path)
        self.schema = read_yaml(schema_file_path)
        
        create_directories([self.config.artifacts_root])

    def get_data_transformation_config(self) -> DataTransformationConfig:
        config = self.config.data_transformation
        params = self.params.augmentation
        
        dataset_val_status_file = self.config.data_validation.STATUS_FILE
        
        with open(dataset_val_status_file, 'r') as f:
            status = f.read()
        
        status = bool(str.split(status)[-1])
        
        create_directories([config.root_dir, config.train_path, config.validation_path])
        
        data_transformation_config = DataTransformationConfig(
            root_dir=config.root_dir,
            data_path=config.data_path,
            train_path=config.train_path,
            validation_path=config.validation_path,
            test_path=config.test_path,
            YAML_path=config.YAML_path,
            val_size=config.val_size,
            aug_size=config.aug_size,
            aug_params=params,
            dataset_val_status=status
        )
        
        return data_transformation_config

In [5]:
from src.cellseg import logger
from src.cellseg.utils.main_utils import dir_sample_creation
import shutil
import numpy as np
import cv2
from sklearn.model_selection import train_test_split
import albumentations as A
from tqdm import tqdm
import yaml
import random


class DataTransformation:
    def __init__(self, config: DataTransformationConfig):
        self.config = config
    
    def transform_preparation(self, crop_dim):
        transform = A.Compose([
            A.Crop(
                x_min=self.config.aug_params.Crop.x_min,
                y_min=self.config.aug_params.Crop.y_min,
                x_max=crop_dim,
                y_max=crop_dim,
                p=self.config.aug_params.Crop.p
            ),
            A.Resize(
                height=self.config.aug_params.Resize.height,
                width=self.config.aug_params.Resize.width,
                p=self.config.aug_params.Resize.p
            ),
            A.RandomBrightnessContrast(
                brightness_limit=self.config.aug_params.RandomBrightnessContrast.brightness_limit,
                contrast_limit=self.config.aug_params.RandomBrightnessContrast.contrast_limit,
                p=self.config.aug_params.RandomBrightnessContrast.p
            ),
            A.RandomGamma(
                gamma_limit=self.config.aug_params.RandomGamma.gamma_limit,
                p=self.config.aug_params.RandomGamma.p
            ),
            A.Rotate(
                limit=self.config.aug_params.Rotate.limit,
                border_mode=self.config.aug_params.Rotate.border_mode,
                p=self.config.aug_params.Rotate.p
            ),
            A.HorizontalFlip(
                p=self.config.aug_params.HorizontalFlip.p
            ),
            A.VerticalFlip(
                p=self.config.aug_params.VerticalFlip.p
            ),
            A.RandomResizedCrop(
                scale=(0.5,1.0),
                size=(self.config.aug_params.Resize.height, self.config.aug_params.Resize.width)
            )
        ])
        
        return transform
    
    def balance_augment_data_lists(self):
        color_list = []
        grayscale_list = []
        
        for folder in tqdm(os.listdir(self.config.data_path)):
            img = cv2.imread(os.path.join(
                self.config.data_path,
                folder,
                'images',
                folder + '.png'
            ))
            
            if np.array_equal(img[:,:,0], img[:,:,1]) and np.array_equal(img[:,:,1], img[:,:,2]):
                grayscale_list.append(folder + '.png')
            else:
                color_list.append(folder + '.png')
        
        if len(grayscale_list) >= len(color_list):
            gray_aug_count = self.config.aug_size * len(grayscale_list) - len(grayscale_list)
            color_aug_count = self.config.aug_size * len(grayscale_list) - len(color_list)
        else:
            gray_aug_count = self.config.aug_size * len(color_list) - len(grayscale_list)
            color_aug_count = self.config.aug_size * len(color_list) - len(color_list)
            
        grayscale_list.extend(random.choices(grayscale_list, k=gray_aug_count))
        color_list.extend(random.choices(color_list, k=color_aug_count))
        
        return grayscale_list, color_list
    
    def chunk_transform(self, chunk_list):
        for img_name in tqdm(chunk_list):
            img_path = os.path.join(
                self.config.data_path,
                str.split(img_name, '.')[0],
                'images',
                img_name
            )
            
            image = cv2.cvtColor(cv2.imread(img_path, 1), cv2.COLOR_BGR2RGB)
            
            crop_dim = min(image.shape[0], image.shape[1])
            transform = self.transform_preparation(crop_dim)
            
            masks = []
            mask_dir = os.path.join(self.config.data_path, str.split(img_name, '.')[0], 'masks')
            
            for mask_name in os.listdir(mask_dir):
                mask_img = cv2.imread(os.path.join(mask_dir, mask_name), 0)
                masks.append(mask_img)
            
            composite_mask = np.stack(masks, axis=-1)
            
            augmentations = transform(image=image, mask=composite_mask)
            
            dir_sample_creation(augmentations, str.split(img_name, '.')[0], self.config.train_path)

    def data_to_YOLO_formating(self):
        logger.info("YOLO formating started!")
            
        for dir in tqdm(os.listdir(self.config.train_path)):
            shutil.move(
                os.path.join(
                    self.config.train_path,
                    dir,
                    'images',
                    dir + '.png',
                ),
                self.config.train_path
            )
            
            masks = ''
            
            for cell_mask in os.listdir(os.path.join(self.config.train_path, dir, 'masks')):
                
                cell_mask_str = '0'

                cell_mask_img = cv2.imread(os.path.join(
                    self.config.train_path,
                    dir,
                    'masks',
                    cell_mask
                ), 0)

                contours, _ = cv2.findContours(
                    cell_mask_img,
                    cv2.RETR_LIST,
                    cv2.CHAIN_APPROX_NONE
                )
                
                if contours:
                    for dot in contours[0]:
                        cell_mask_str += ' ' + str(dot[0][0] / cell_mask_img.shape[1]) + ' ' + str(dot[0][1] / cell_mask_img.shape[0])

                    masks += cell_mask_str + '\n'

            with open(os.path.join(self.config.train_path, dir + '.txt'), 'w') as file:
                file.write(masks)
            
            shutil.rmtree(os.path.join(self.config.train_path, dir))
                
        logger.info("YOLO formating finished!")

    def train_validation_separation(self):
        logger.info("Train/validation split started!")
        
        img_list = os.listdir(self.config.train_path)
        img_list = [s for s in img_list if '.png' in s]
        
        _, val_list = train_test_split(
            img_list,
            test_size=self.config.val_size,
            random_state=42,
            shuffle=True
        )
        
        for img in val_list:
            img_path = os.path.join(self.config.train_path, img)
            ann_path = os.path.join(self.config.train_path, str.split(img, '.')[0] + '.txt')
            
            shutil.move(img_path, self.config.validation_path, )
            shutil.move(ann_path, self.config.validation_path)
        
        logger.info("Train/validation split finished!")
    
    def dataset_yaml_creation(self):
        yaml_content = {
            'train': os.path.join(os.getcwd(), self.config.train_path),
            'val': os.path.join(os.getcwd(), self.config.validation_path),
            'test': '',
            'nc': 1,
            'names': ['Cell']
        }
        
        yaml_file = yaml.safe_dump(yaml_content, default_flow_style=None, sort_keys=False)
        
        with open(self.config.YAML_path, 'w') as file:
            file.write(yaml_file)
        logger.info("File dataset.yaml created!")

    def transformation_compose(self):
        if self.config.dataset_val_status:
            if not os.listdir(self.config.train_path) and not os.listdir(self.config.validation_path):
                logger.info("Data augmentation started!")
                grayscale_list, color_list = self.balance_augment_data_lists()
                self.chunk_transform(color_list)
                self.chunk_transform(grayscale_list)
                logger.info("Data augmentation finished!")
                self.data_to_YOLO_formating()
                self.train_validation_separation()
                self.dataset_yaml_creation()
            elif not os.path.exists(self.config.YAML_path):
                logger.info("Transformation already performed!")
                self.dataset_yaml_creation()
            else:
                logger.info("Transformation already performed!")
        else:
            logger.info("Transformation stoped, dataset isn't valid!")

In [6]:
try:
    config = ConfigurationManager()
    data_transformation_config = config.get_data_transformation_config()
    data_transformation = DataTransformation(config=data_transformation_config)
    data_transformation.transformation_compose()

except Exception as e:
    raise e                                                                                                                                                                                                                                                                                                                                                                         

[2024-11-21 13:53:00,088: INFO: main_utils: created directory at: artifacts]
[2024-11-21 13:53:00,090: INFO: main_utils: created directory at: artifacts/data_transformation]
[2024-11-21 13:53:00,090: INFO: main_utils: created directory at: artifacts/data_transformation/train]
[2024-11-21 13:53:00,091: INFO: main_utils: created directory at: artifacts/data_transformation/validation]
[2024-11-21 13:53:00,093: INFO: 3174540254: Data augmentation started!]


100%|██████████| 670/670 [00:02<00:00, 251.83it/s]
100%|██████████| 1686/1686 [02:00<00:00, 13.94it/s]
100%|██████████| 1686/1686 [05:47<00:00,  4.85it/s]

[2024-11-21 14:00:51,640: INFO: 3174540254: Data augmentation finished!]
[2024-11-21 14:00:51,640: INFO: 3174540254: YOLO formating started!]



100%|██████████| 3372/3372 [02:09<00:00, 26.11it/s]

[2024-11-21 14:03:00,787: INFO: 3174540254: YOLO formating finished!]
[2024-11-21 14:03:00,790: INFO: 3174540254: Train/validation split started!]
[2024-11-21 14:03:00,865: INFO: 3174540254: Train/validation split finished!]
[2024-11-21 14:03:00,868: INFO: 3174540254: File dataset.yaml created!]



