# Экпериментальный метод аугментации данных при помощи модели Stable Diffusion


Предварительно в label-studio разметили области в которых будем генерировать дефекты

In [1]:
import torch
from torch import autocast
from diffusers import StableDiffusionInpaintPipelineLegacy
from diffusers.pipelines.stable_diffusion import safety_checker
import numpy as np
from PIL import Image, ImageDraw
from pycocotools.coco import COCO
import matplotlib.pyplot as plt


LABELID2TEXT = {0:'crack', 1:'fistula', 2:'rupture'}
PROMT_TEMAPLTE = '<OBJ> on a metal pipe'


def sc(self, clip_input, images) :
    return images, [False for i in images]
# edit StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
safety_checker.StableDiffusionSafetyChecker.forward = sc


def get_512x512_bbox_from_orig_bbox(defect_bbox, orig_img_w, orig_img_h):
    bbox_x, bbox_y, bbox_w, bbox_h =  defect_bbox
    if int(bbox_w) < 512:
        bbox_x -= (512 - bbox_w) / 2
    if int(bbox_h) < 512:
        bbox_y -= (512 - bbox_h) / 2

    #Проверяем увеличенный ббокс на выход за границы изображения
    x1, y1 = (bbox_x, bbox_y)
    x2, y2 = (bbox_x + 512, bbox_y + 512)

    new_offset_x = 0
    new_offset_y = 0
    if x2 > orig_img_w:
        new_offset_x = orig_img_w - x2
    if y2 > orig_img_h:
        new_offset_y = orig_img_h - y2
    if x1 < 0:
        new_offset_x = -x1
    if y1 < 0:
        new_offset_y = -y1

    x1 += new_offset_x
    y1 += new_offset_y
    
    return [x1, y1, 512, 512]


def get_mask_of_defect(orig_img, defect_bbox):
    img_rgba = orig_img.convert('RGBA')
    draw = ImageDraw.Draw(img_rgba)
    w,h = img_rgba.size
    leftUpPoint = (defect_bbox[0], defect_bbox[1])
    rightDownPoint = (defect_bbox[0] + defect_bbox[2], defect_bbox[1] + defect_bbox[3])
    twoPointList = [leftUpPoint, rightDownPoint]
    draw.rectangle(twoPointList, fill=(255, 255, 255, 0))

    img_np = np.array(img_rgba)
    mask = img_np[:, :, 3] == 0
    mask = Image.fromarray(mask)
    
    return mask


def get_img_and_mask_for_sd_inference(orig_img, defect_bbox):
    mask_of_defect = get_mask_of_defect(orig_img, defect_bbox)
    w, h  = orig_img.size
    bbox_512x512 = get_512x512_bbox_from_orig_bbox(defect_bbox, w, h)
    
    cropped_img512x512 = orig_img.crop([
        bbox_512x512[0],
        bbox_512x512[1],
        bbox_512x512[0] + 512,
        bbox_512x512[1] + 512
    ])
    
    cropped_mask512x512 = mask_of_defect.crop([
        bbox_512x512[0],
        bbox_512x512[1],
        bbox_512x512[0] + 512,
        bbox_512x512[1] + 512
    ])
    
    return cropped_img512x512, cropped_mask512x512, bbox_512x512


def generate_prompt_by_cat_id(category_id):
    return PROMT_TEMAPLTE.replace('<OBJ>', LABELID2TEXT[category_id])


def pasting_img_to_img_by_bbox(orig_img, generated_img, bbox):
    pass


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    StableDiffusionInpaintPipelineLegacy
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [2]:
ann_file = '/app/img/origs/result.json'
img_path = '/app/img/origs'
path_to_augmentated_img = '/app/dataset/augmentated'
path_to_stable_diffusion_weights = "/app/checkpoints/sd-weights-one-prompt-15000iters"
hf_token = "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"

number_of_genereation_for_one_defects = 2

In [3]:
coco = COCO(ann_file)
coco_dict = coco.__dict__['dataset']

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


In [4]:
cats = coco.loadCats(coco.getCatIds())
cats

[{'id': 0, 'name': 'crack'},
 {'id': 1, 'name': 'fistula'},
 {'id': 2, 'name': 'rupture'}]

In [5]:
aug_ann = dict()
aug_ann['categories'] = cats
aug_ann['info'] = coco_dict['info']
aug_ann['images'] = []
aug_ann['annotations'] = []

In [6]:
device = "cuda"
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
    path_to_stable_diffusion_weights, torch_dtype=torch.float16, use_auth_token=hf_token
).to(device)

In [7]:
%%time
ann_id = 0
id_for_augmentated_img = 0

for img_from_ann in coco_dict['images']:
    img_orig = Image.open(f'{img_path}/{img_from_ann["file_name"]}')
    anns_ids = coco.getAnnIds(imgIds=img_from_ann["id"])
    anns = coco.loadAnns(anns_ids)
    for ann in anns:
        bbox = ann['bbox'].copy()
        category_id = ann['category_id']
        # Создаем изображение и маску для инференса стебля
        img_for_sd_inference, mask_for_sd_inference, bbox_512x512 = get_img_and_mask_for_sd_inference(img_orig, bbox)
        # Генерируем промпт для инференса стебля
        prompt = generate_prompt_by_cat_id(category_id)
        # Генерируем дефект на трубе
        generated_images_with_defects = [pipe(prompt=prompt, init_image=img_for_sd_inference, mask_image=mask_for_sd_inference).images[0] for i in range(number_of_genereation_for_one_defects)]
        generation_id = 0
        # Накладываем сгенерированное изображение на исходное
        for generated_image_with_defect in generated_images_with_defects:
            bbox_512x512_int = [round(el) for el in bbox_512x512]
            img_orig_copy = img_orig.copy()
            img_orig_copy.paste(generated_image_with_defect, (bbox_512x512_int[0], bbox_512x512_int[1]))
            orig_img_with_defect = img_orig_copy.copy()
            size_of_orig_img_with_defect = orig_img_with_defect.size
            new_img_name = f'aug_img_id{id_for_augmentated_img}_cat{category_id}_iter{generation_id}.jpg'
            orig_img_with_defect.save(f'{path_to_augmentated_img}/{new_img_name}')
            generation_id += 1


            aug_ann['images'].append(
                {
                    "width": size_of_orig_img_with_defect[0],
                    "height": size_of_orig_img_with_defect[1],
                    "id": id_for_augmentated_img,
                    "file_name": new_img_name
                }
            )

            aug_ann['annotations'].append(
                {
                    "id": ann_id,
                    "image_id": id_for_augmentated_img,
                    "category_id": int(category_id),
                    "segmentation": [],
                    "bbox": ann['bbox'],
                    "ignore": 0,
                    "iscrowd": 0,
                    "area": ann['area']
                }
            )



            ann_id += 1
            id_for_augmentated_img += 1

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

CPU times: user 48.4 s, sys: 387 ms, total: 48.7 s
Wall time: 43.3 s


Сохраняем аннотации в формате coco:

In [8]:
import json

with open(f"{path_to_augmentated_img}/aug_result.json", 'w') as f:
    json.dump(aug_ann, f)

### COCO to label-studio format

In [9]:
!git clone https://github.com/heartexlabs/label-studio-converter.git

Cloning into 'label-studio-converter'...
remote: Enumerating objects: 1645, done.[K
remote: Counting objects: 100% (657/657), done.[K
remote: Compressing objects: 100% (232/232), done.[K
remote: Total 1645 (delta 538), reused 434 (delta 425), pack-reused 988[K
Receiving objects: 100% (1645/1645), 2.67 MiB | 1.80 MiB/s, done.
Resolving deltas: 100% (964/964), done.


In [10]:
pip install -e ./label-studio-converter

Obtaining file:///app/notebooks/label-studio-converter
  Preparing metadata (setup.py) ... [?25ldone
Collecting Pillow==9.3.0
  Downloading Pillow-9.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting nltk==3.6.7
  Downloading nltk-3.6.7-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting label-studio-tools==0.0.1
  Downloading label_studio_tools-0.0.1-py3-none-any.whl (10 kB)
Collecting lxml>=4.2.5
  Downloading lxml-4.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (6.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting appdirs>=1.4.3
  Downloadin

In [11]:
!label-studio-converter import coco -i /app/img/test_stable_aug_ann_with_three_imgs.json -o img/stable_aug_for_label_studio.json

/bin/sh: 1: label-studio-converter: not found
