# Training an image classifier to detect AI generated images

Based on the research here: https://arxiv.org/pdf/2303.14126v1.pdf

Dataset: https://www.kaggle.com/datasets/birdy654/cifake-real-and-ai-generated-synthetic-images

Dataset attribution:

Krizhevsky, A., & Hinton, G. (2009). Learning multiple layers of features from tiny images.

Bird, J.J. and Lotfi, A., 2024. CIFAKE: Image Classification and Explainable Identification of AI-Generated Synthetic Images. IEEE Access.

Real images are from Krizhevsky & Hinton (2009), fake images are from Bird & Lotfi (2024). The Bird & Lotfi study is available here https://ieeexplore.ieee.org/abstract/document/10409290 

# Pre-Requisite Activities

NOTE: If you are running this Notebook on an ARM device, it must be built from source as per this Github issue: https://github.com/apache/mxnet/issues/19234#issuecomment-699571539

Step 1: Download the dataaset from https://www.kaggle.com/datasets/birdy654/cifake-real-and-ai-generated-synthetic-images

Step 2: Unzip dataset into /data folder


## Install pre-requisites and convert dataset to RecordIO format

In [None]:
!pip install -r requirements.txt

!python im2rec.py data/train.lst data/train --recursive --list --num-thread 8
!python im2rec.py data/test.lst data/test --recursive --list --num-thread 8


In [None]:
!python im2rec.py data/train.lst data/train --recursive --pass-through --pack-label --num-thread 8
!python im2rec.py data/test.lst data/test --recursive --pass-through --pack-label --num-thread 8

In [None]:
import boto3
import sagemaker,os
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri

s3_client = boto3.client('s3')

sagemaker_session = sagemaker.Session()

role = get_execution_role()
training_image = get_image_uri(sagemaker_session.boto_region_name, 'image-classification', repo_version="latest")

bucket_name = sagemaker_session.default_bucket() 
prefix = "dataset"

s3_client.upload_file("data/train.lst.rec", bucket_name, "dataset/train.rec")
s3_client.upload_file("data/test.lst.rec", bucket_name, "dataset/test.rec")

print('Uploaded dataset files to S3')

## Train Image Classifier

NOTE: This could take a few hours to complete.

In [None]:


def count_files_in_directory(directory):
    count = 0
    for dirpath, dirnames, filenames in os.walk(directory):
        count += len(filenames)
    return count



s3_train_data = 's3://{}/{}/train.rec'.format(bucket_name,prefix)
s3_validation_data = 's3://{}/{}/test.rec'.format(bucket_name,prefix)
s3_output_location = 's3://{}/output'.format(bucket_name)

image_classifier = sagemaker.estimator.Estimator(
    training_image,
    role,
    instance_count=1,
    instance_type='ml.p2.xlarge',
    volume_size=50,
    max_run=360000,
    input_mode='File',
    output_path=s3_output_location,
    sagemaker_session=sagemaker_session
)

# Set hyperparameters
image_classifier.set_hyperparameters(
    num_layers=50,
    image_shape="3,32,32",
    num_classes=2,
    num_training_samples=count_files_in_directory('data/train'),
    epochs=10,
    learning_rate=0.01
)

# Create data channels
train_data = sagemaker.inputs.TrainingInput(s3_train_data, content_type='application/x-recordio', s3_data_type='S3Prefix')
validation_data = sagemaker.inputs.TrainingInput(s3_validation_data, content_type='application/x-recordio', s3_data_type='S3Prefix')
data_channels = {'train': train_data, 'validation': validation_data}

# Train
image_classifier.fit(inputs=data_channels, logs=True)



## Create Sagemaker Endpoint

In [None]:
predictor = image_classifier.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

print('Sagemaker Endpoint deployed', predictor.endpoint_name)

print('Placing the Sagemaker Endpoint name in the SSM Parameter /fraud-detection/sagemaker/endpoint/name')

ssm_client = boto3.client('ssm')

parameter_name = '/fraud-detection/sagemaker/endpoint/name'


# Write the parameter value to SSM
ssm_client.put_parameter(
    Name=parameter_name,
    Value=predictor.endpoint_name,
    Type='String',
    Overwrite=True
)

## Prediction

In [None]:
import numpy as np
from PIL import Image
import io

def load_and_preprocess_image(image_path):
    # Load and resize the image using PIL
    with Image.open(image_path) as img:
        img_resized = img.resize((32, 32))
        
    # Convert to numpy array and normalize
    img_array = np.array(img_resized)
    img_array = img_array.astype(np.float32) / 255

    # Change the shape of the array to CHW from HWC
    img_array = np.transpose(img_array, (2, 0, 1))
    img_array = np.expand_dims(img_array, axis=0)

    # Convert to byte stream
    img_byte_stream = io.BytesIO()
    np.save(img_byte_stream, img_array)
    
    return img_byte_stream.getvalue()

image_path = 'laptop-generated.png'
image_payload = load_and_preprocess_image(image_path)

response = predictor.predict(image_payload)

probabilities = response['predictions'][0]['probabilities']
predicted_class = np.argmax(probabilities)

labels = ['REAL', 'FAKE']
print(f"Predicted class: {labels[predicted_class]} with probability: {probabilities[predicted_class]}")