In [31]:
import json
import os
import requests
import io
import pathlib
import math
import numpy as np
import glob
import shutil
from PIL import Image, ImageOps, ImageEnhance
from pprint import pprint
from collections import Counter
from datetime import datetime
from sklearn.model_selection import train_test_split

In [2]:
API_BASE_URL = 'http://fireeye-test-backend-container:9090/api/'
TF_SERVING_BASE_URL = 'http://fireeye-test-model-container:8501/'
task_id = '1ac1e8a095df4611af387d9934799251'
id_code_mapping = {
    'dbee3deebc5444f5b011da4e5518752c': '0',
    'edb4cb51d54644c08aa122d3f041bb0a': '1'}

In [3]:
def get_image_by_id(image_id):
    """Retrieve image by its ID."""
    r = requests.get(url=API_BASE_URL + 'image/' + image_id)
    if r.status_code == 200:
        return Image.open(io.BytesIO(r.content))
    else:
        raise RuntimeError(r.text)

In [4]:
import pprint
def get_image_records(task_id):
    """Fetch image records given a task ID."""
    resp = requests.get(
        url=API_BASE_URL + 'image',
        params={'task_id': task_id, 'has_truth': True}
    )
    if resp.status_code == 200:
        return resp.json()
    else:
        raise RuntimeError(resp.text)
image_records = get_image_records(task_id)
print(f'该类别下图片数量是：{len(image_records)}')

该类别下图片数量是：320


In [5]:
def crop_by_percentile(img, lower_percentile=5, upper_percentile=95):
    img_array = np.array(img.convert('L'))
    
    low_val, high_val = np.percentile(img_array, [lower_percentile,upper_percentile])
                         
    mask = np.logical_and(img_array > low_val, img_array < high_val)
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
                         
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    cropped_img = img.crop((cmin, rmin, cmax, rmax))
    return cropped_img

In [6]:
def normalize_image(img: Image.Image) -> np.ndarray:
    img_array = np.array(img)
    return img_array / 255.0

In [7]:
image_dir = "./images"
Category0_dir = os.path.join(image_dir, 'Category0')
Category1_dir = os.path.join(image_dir, 'Category1')
if os.path.exists(Category0_dir):
    shutil.rmtree(Category0_dir)
if os.path.exists(Category1_dir):
    shutil.rmtree(Category1_dir)


os.makedirs(Category0_dir)
os.makedirs(Category1_dir)

In [8]:
def clear_and_create_directory(directory):
    if os.path.exists(directory):
        shutil.rmtree(directory)
    os.makedirs(directory)

base_dir = './images'    

for set_name in ['train', 'test', 'val']:
    for category in ['Category0', 'Category1']:
        directory = os.path.join(base_dir, set_name, category)
        clear_and_create_directory(directory)

In [9]:
labels = [id_code_mapping[record['truth_id']] for record in image_records]


train_records, test_records, train_labels, test_labels = train_test_split(
    image_records, labels, test_size=0.3, stratify=labels, random_state=42)

train_records, val_records, train_labels, val_labels = train_test_split(
    train_records, train_labels, test_size=0.1, stratify=train_labels, random_state=42)


for set_name, records in [('train', train_records), ('test', test_records), ('val', val_records)]:
    for record in records:
        try:
            img = get_image_by_id(record['id'])
            cropped_img = crop_by_percentile(img)
            normalized_img_array = np.array(cropped_img) / 255.0
            normalized_img = Image.fromarray((normalized_img_array * 255).astype(np.uint8))

            truth_id = record['truth_id']
            category = id_code_mapping[truth_id]
            
            directory = os.path.join(base_dir, set_name, f'Category{category}')
            file_path = os.path.join(directory, f'{record["id"]}.png')
            normalized_img.save(file_path, 'PNG')
        except Exception as e:
            print(f'Error processing image {record["id"]}. Error: {e}')

In [10]:
def download_image(image_id):
    response = requests.get(f"{API_BASE_URL}image/download/{image_id}")
    return response.content

In [11]:
def color_jitter(img: Image.Image, brightness=0.2, contrast=0.2, saturation=0.2) -> Image.Image:
    img = ImageEnhance.Brightness(img).enhance(1 + brightness * (2 * np.random.random() - 1))
    img = ImageEnhance.Contrast(img).enhance(1 + contrast * (2 * np.random.random() - 1))
    img = ImageEnhance.Color(img).enhance(1 + saturation * (2 * np.random.random() - 1))
    return img

In [12]:
def vertical_flip(img: Image.Image) -> Image.Image:
    return ImageOps.flip(img)

In [13]:
def horizontal_flip(img: Image.Image) -> Image.Image:
    return ImageOps.mirror(img)

In [14]:
train_directory = './images/train/'


def preprocess_and_save(img, image_id, category):
    color_jittered = color_jitter(img)
    color_jittered_path = os.path.join(train_directory, category, f'{image_id}_colorjittered.png')
    color_jittered.save(color_jittered_path, 'PNG')

    vflipped = vertical_flip(img)
    vflipped_path = os.path.join(train_directory, category, f'{image_id}_vflipped.png')
    vflipped.save(vflipped_path, 'PNG')

    hflipped = horizontal_flip(img)
    hflipped_path = os.path.join(train_directory, category, f'{image_id}_hflipped.png')
    hflipped.save(hflipped_path, 'PNG')

for record in train_records:
    image_id = record['id']
    img = get_image_by_id(image_id)
    truth_id = record['truth_id']
    category = f'Category{id_code_mapping[truth_id]}'
    preprocess_and_save(img, image_id, category)
    
print('Data augmentation for teh training set is complete.')

Data augmentation for teh training set is complete.


In [41]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

print(tf.__version__)
train_dir = './images/train'
val_dir = './images/val'
test_dir = './images/test'

img_height, img_width = 218, 175
input_shape = (img_height, img_width, 3)
BATCH_SIZE = 32

2.8.2


In [42]:
# Data Augmentation and normalization for training
train_image_generator = ImageDataGenerator(rescale=1. / 255)

# Just rescaling for validation and test
val_image_generator = ImageDataGenerator(rescale=1. / 255)
test_image_generator = ImageDataGenerator(rescale=1. / 255)

# Data Loading
train_data_gen = train_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(img_height, img_width),
                                                           class_mode='binary')

val_data_gen = val_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
                                                       directory=val_dir,
                                                       target_size=(img_height, img_width),
                                                       class_mode='binary')

test_data_gen = test_image_generator.flow_from_directory(batch_size=BATCH_SIZE,
                                                         directory=test_dir,
                                                         target_size=(img_height, img_width),
                                                         class_mode='binary')


# 2. Model Creation and Compilation

def create_advanced_cnn(input_shape):
    input_layer = Input(shape=input_shape)

    x = Conv2D(32, (3, 3), activation='relu')(input_layer)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(128, (3, 3), activation='relu')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Flatten()(x)
    x = Dense(512, activation='relu')(x)
    x = Dropout(0.5)(x)
    output = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=input_layer, outputs=output)
    return model


model = create_advanced_cnn(input_shape)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])


Found 804 images belonging to 2 classes.
Found 23 images belonging to 2 classes.
Found 96 images belonging to 2 classes.


In [43]:
history = model.fit(
    train_data_gen,
    steps_per_epoch=train_data_gen.samples // BATCH_SIZE,
    epochs=10,  # Adjust as needed
    validation_data=val_data_gen,
    validation_steps=val_data_gen.samples // BATCH_SIZE
)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [44]:
test_loss, test_accuracy = model.evaluate(test_data_gen)
print(f'Test accuracy: {test_accuracy}')
model_path = "./saved_model/my_model"
model.save(model_path)
print("Model saved to", model_path)

Test accuracy: 1.0
INFO:tensorflow:Assets written to: ./saved_model/my_model/assets
Model saved to ./saved_model/my_model


In [45]:
import pytz
from datetime import datetime

#model_version = datetime.now(pytz.timezone('Asia/Shanghai')).strftime('%Y%m%d%H%M%S')
#tf.keras.models.save_model(
#    model,
#    f'/models/slot1/{model_version}/',
#    overwrite=True,
#)
model_version = datetime.now(pytz.timezone('Asia/Shanghai')).strftime('%Y%m%d%H%M%S')
model_save_path = f'/models/slot1/{model_version}/'

tf.keras.models.save_model(
    model,
    model_save_path,
    overwrite=True
)

INFO:tensorflow:Assets written to: /models/slot1/20230913161840/assets


In [57]:
import base64


TF_SERVING_BASE_URL = 'http://fireeye-test-model-container:8501/' 


def predict_image(images):
    image_list = []

    for image_path in images:
        with open(image_path, 'rb') as fimage:
            content = fimage.read()
        b64_encoded_image = base64.urlsafe_b64encode(content).decode('utf-8')
        image_list.append({"b64": b64_encoded_image})

    # Format the payload
    payload = {
        "signature_name": "serving_default",
        "instances": image_list
    }

    # Make the request
    response = requests.post(TF_SERVING_BASE_URL,
                             data=json.dumps(payload),
                             headers={"content-type": "application/json"})

    if response.status_code != 200:
        raise RuntimeError('Request tf-serving failed: ' + response.text)

    resp_data = json.loads(response.text)
    if 'outputs' not in resp_data or type(resp_data['outputs']) is not list:
        raise RuntimeError('Invalid tf-serving response format: ' + response.text)

    return resp_data['outputs']


def test_image_model(test_dir, code, batch_size=10):
    # Mapping the codes to their directory names
    code_to_category = {
        0: "Category0",
        1: "Category1"
    }

    category_dir = code_to_category.get(code)
    if category_dir is None:
        raise ValueError(f"Invalid code {code}. Expected 0 or 1.")

    code_dir = pathlib.Path(test_dir).joinpath(category_dir)

    if not code_dir.exists():
        raise FileNotFoundError(f"The directory {code_dir} does not exist!")

    images = list(code_dir.glob('*.png'))
    codes = []

    total_images = len(images)
    print(f"Total images found in {category_dir}: {total_images}")  # Debug: check the total number of images found

    for step in range(math.ceil(total_images / batch_size)):
        outputs = predict_image(images[step * batch_size:(step + 1) * batch_size])
        for i, o in zip(images, outputs):
            if o != code:
                print('Error picture:', i)
        codes.extend(outputs)

    accuracy = round(codes.count(code) / len(codes), 4)
    return accuracy, codes

single_image_path = pathlib.Path(test_dir).joinpath('Category0').glob('*.png').__next__()
response = predict_image([single_image_path])
print(response)



accuracy, codes = test_image_model(test_dir, 0)
print('类别0的准确率:', accuracy)
print('类别0的测试结果:', codes)

accuracy, codes = test_image_model(test_dir, 1)
print('类别1的准确率:', accuracy)
print('类别1的测试结果:', codes)

RuntimeError: Request tf-serving failed: <HTML><HEAD>
<TITLE>404 Not Found</TITLE>
</HEAD><BODY>
<H1>Not Found</H1>
</BODY></HTML>
