In [323]:
import json
import os
import requests
import io
import math
import numpy as np
import glob
import shutil
from PIL import Image, ImageOps, ImageEnhance
from pprint import pprint
from collections import Counter
from datetime import datetime
from sklearn.model_selection import train_test_split

In [324]:
API_BASE_URL = 'http://fireeye-test-backend-container:9090/api/'
TF_SERVING_BASE_URL = 'http://fireeye-test-model-container:8501/'
task_id = '1ac1e8a095df4611af387d9934799251'
id_code_mapping = {
    'dbee3deebc5444f5b011da4e5518752c': '0',
    'edb4cb51d54644c08aa122d3f041bb0a': '1'}

In [325]:
def get_image_by_id(image_id):
    """Retrieve image by its ID."""
    r = requests.get(url=API_BASE_URL + 'image/' + image_id)
    if r.status_code == 200:
        return Image.open(io.BytesIO(r.content))
    else:
        raise RuntimeError(r.text)

In [326]:
import pprint
def get_image_records(task_id):
    """Fetch image records given a task ID."""
    resp = requests.get(
        url=API_BASE_URL + 'image',
        params={'task_id': task_id, 'has_truth': True}
    )
    if resp.status_code == 200:
        return resp.json()
    else:
        raise RuntimeError(resp.text)
image_records = get_image_records(task_id)
print(f'该类别下图片数量是：{len(image_records)}')

该类别下图片数量是：320


In [327]:
def crop_by_percentile(img, lower_percentile=5, upper_percentile=95):
    img_array = np.array(img.convert('L'))
    
    low_val, high_val = np.percentile(img_array, [lower_percentile,upper_percentile])
                         
    mask = np.logical_and(img_array > low_val, img_array < high_val)
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
                         
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    cropped_img = img.crop((cmin, rmin, cmax, rmax))
    return cropped_img

In [328]:
def normalize_image(img: Image.Image) -> np.ndarray:
    img_array = np.array(img)
    return img_array / 255.0

In [329]:
image_dir = "./images"
Category0_dir = os.path.join(image_dir, 'Category0')
Category1_dir = os.path.join(image_dir, 'Category1')
if os.path.exists(Category0_dir):
    shutil.rmtree(Category0_dir)
if os.path.exists(Category1_dir):
    shutil.rmtree(Category1_dir)


os.makedirs(Category0_dir)
os.makedirs(Category1_dir)

In [330]:
def clear_and_create_directory(directory):
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)

base_dir = './images'    

for set_name in ['train', 'test', 'val']:
    for category in ['Category0', 'Category1']:
        directory = os.path.join(base_dir, set_name, category)
        clear_and_create_directory(directory)

In [331]:
labels = [id_code_mapping[record['truth_id']] for record in image_records]


train_records, test_records, train_labels, test_labels = train_test_split(
    image_records, labels, test_size=0.3, stratify=labels, random_state=42)

train_records, val_records, train_labels, val_labels = train_test_split(
    train_records, train_labels, test_size=0.1, stratify=train_labels, random_state=42)


for set_name, records in [('train', train_records), ('test', test_records), ('val', val_records)]:
    for record in records:
        try:
            img = get_image_by_id(record['id'])
            cropped_img = crop_by_percentile(img)
            normalized_img_array = np.array(cropped_img) / 255.0
            normalized_img = Image.fromarray((normalized_img_array * 255).astype(np.uint8))

            truth_id = record['truth_id']
            category = id_code_mapping[truth_id]
            
            directory = os.path.join(base_dir, set_name, f'Category{category}')
            file_path = os.path.join(directory, f'{record["id"]}.png')
            normalized_img.save(file_path, 'PNG')
        except Exception as e:
            print(f'Error processing image {record["id"]}. Error: {e}')

In [332]:
def download_image(image_id):
    response = requests.get(f"{API_BASE_URL}image/download/{image_id}")
    return response.content

In [341]:
def color_jitter(img: Image.Image, brightness=0.2, contrast=0.2, saturation=0.2) -> Image.Image:
    img = ImageEnhance.Brightness(img).enhance(1 + brightness * (2 * np.random.random() - 1))
    img = ImageEnhance.Contrast(img).enhance(1 + contrast * (2 * np.random.random() - 1))
    img = ImageEnhance.Color(img).enhance(1 + saturation * (2 * np.random.random() - 1))
    return img

In [342]:
def vertical_flip(img: Image.Image) -> Image.Image:
    return ImageOps.flip(img)

In [343]:
def horizontal_flip(img: Image.Image) -> Image.Image:
    return ImageOps.mirror(img)

In [345]:
train_directory = './images/train/'


def preprocess_and_save(img, image_id, category):
    color_jittered = color_jitter(img)
    color_jittered_path = os.path.join(train_directory, category, f'{image_id}_colorjittered.png')
    color_jittered.save(color_jittered_path, 'PNG')

    vflipped = vertical_flip(img)
    vflipped_path = os.path.join(train_directory, category, f'{image_id}_vflipped.png')
    vflipped.save(vflipped_path, 'PNG')

    hflipped = horizontal_flip(img)
    hflipped_path = os.path.join(train_directory, category, f'{image_id}_hflipped.png')
    hflipped.save(hflipped_path, 'PNG')

for record in train_records:
    image_id = record['id']
    img = get_image_by_id(image_id)
    truth_id = record['truth_id']
    category = f'Category{id_code_mapping[truth_id]}'
    preprocess_and_save(img, image_id, category)
    
print('Data augmentation for teh training set is complete.')

Data augmentation for teh training set is complete.
