In [135]:
import cv2
import numpy as np
import albumentations as A
from albumentations import Compose, HorizontalFlip, RandomRotate90, ColorJitter, RandomCrop, GaussianBlur, Normalize
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os.path
from PIL import Image, ImageDraw
from sahi.utils.file import load_json, save_json
from tqdm import tqdm



In [138]:
# Define paths
DATA_DIR = os.path.join("..", "data")
IMAGE_DIR = os.path.join(DATA_DIR, "cassette1_bboxes_vis")
IMAGE_PATH = os.path.join(IMAGE_DIR, "01BE01.png")
AUGMENTATION_PATH = os.path.join(DATA_DIR, "coco_json_files/cassette1_train_sliced_coco.json")
NEW_AUGMENTATION_PATH = os.path.join(DATA_DIR, "augmentation/cassette1_train_augmented.json")
BBOX_VISUALIZATION_DIR = os.path.join(DATA_DIR, "bbox_visualization")

# Ensure directories exist
os.makedirs(BBOX_VISUALIZATION_DIR, exist_ok=True)

In [139]:
coco_dict = load_json(AUGMENTATION_PATH)
[img.update({"file_name": img["file_name"].split("/")[-1]}) for img in coco_dict["images"]]
save_json(coco_dict, save_path=NEW_AUGMENTATION_PATH)

coco_dict

{'images': [{'width': 4096,
   'height': 2000,
   'id': 1,
   'file_name': '01BE02.png'},
  {'width': 4096, 'height': 2000, 'id': 2, 'file_name': '01BE03.png'},
  {'width': 4096, 'height': 2000, 'id': 5, 'file_name': '01BN03.png'},
  {'width': 4096, 'height': 2000, 'id': 6, 'file_name': '01BS00.png'},
  {'width': 4096, 'height': 2000, 'id': 9, 'file_name': '01BS03.png'},
  {'width': 4096, 'height': 2000, 'id': 10, 'file_name': '01BW00.png'},
  {'width': 4096, 'height': 2000, 'id': 11, 'file_name': '01BW01.png'},
  {'width': 4096, 'height': 2000, 'id': 12, 'file_name': '01BW02.png'},
  {'width': 4096, 'height': 2000, 'id': 13, 'file_name': '01FE00.png'},
  {'width': 4096, 'height': 2000, 'id': 14, 'file_name': '01FE01.png'},
  {'width': 4096, 'height': 2000, 'id': 15, 'file_name': '01FE02.png'},
  {'width': 4096, 'height': 2000, 'id': 16, 'file_name': '01FN00.png'},
  {'width': 4096, 'height': 2000, 'id': 17, 'file_name': '01FN01.png'},
  {'width': 4096, 'height': 2000, 'id': 18, 'file_

## Full image augmentation

In [143]:
class DefectAugmentationPipeline:
    def __init__(self, scale_classification):
        self.scale_classification = scale_classification
        self.augmentations_small, self.augmentations_large = self.define_augmentations()

    def define_augmentations(self):
        # Define augmentations for small and large scale
        augmentations_small = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=90, p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.GaussianBlur(blur_limit=(3, 5), p=0.3),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
            ToTensorV2(),
        ], bbox_params=A.BboxParams(format='coco', label_fields=['class_labels']))

        augmentations_large = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=45, p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.GaussianBlur(blur_limit=(5, 7), p=0.3),
            ToTensorV2(),
        ], bbox_params=A.BboxParams(format='coco', label_fields=['class_labels']))

        return augmentations_small, augmentations_large

    def get_scale_type(self, defect_types):
        # Determine predominant scale based on defect types
        return "small" if any(
            self.scale_classification.get(dt, "Small Scale") == "Small Scale" for dt in defect_types
        ) else "large"

    def augment_image(self, image, bboxes, class_labels, defect_types):
        # Choose augmentation based on scale
        predominant_scale = self.get_scale_type(defect_types)
        augmentation = self.augmentations_small if predominant_scale == "small" else self.augmentations_large
        
        # Apply augmentation
        augmented = augmentation(image=image, bboxes=bboxes, class_labels=class_labels)
        return augmented['image'], augmented['bboxes']

    def prepare_image(self, image_path):
        # Load and convert image to compatible format for albumentations
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            raise ValueError("Failed to load image. Check the file path and format.")
        return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)  # Convert to 3-channel

ValueError: Failed to load image. Check the file path and format.