In [None]:
# Importing Librabries
import cv2
import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from tensorflow import keras
import tensorflow.compat.v2 as tf
from sklearn.model_selection import train_test_split
from tensorflow.keras.applications.xception import Xception
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from keras.applications.imagenet_utils import decode_predictions

import lime
from lime import lime_image
from lime import submodular_pick
import skimage.io 
import skimage.segmentation

from skimage.segmentation import mark_boundaries

### TRAINING AND TESTING MODEL

In [None]:
train_path = 'Dataset\\Train'
val_path = 'Dataset\\Val'

# create the data generators
train_datagen = ImageDataGenerator( rescale=1./255, 
                             rotation_range=20,  
                             zoom_range=0.2,   
                             width_shift_range=0.1,  
                             height_shift_range=0.1, 
                             brightness_range=[0.2, 0.8],
                             shear_range=0.2,
                             horizontal_flip=True, 
                             vertical_flip=False)    

val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_path,
    target_size=(256, 256),
    batch_size=32,
    class_mode='categorical',
    classes=['Real', 'Fake_Photoshop', 'Fake_Deepfake', 'Fake_Gan'],
    shuffle=True
)

val_generator = val_datagen.flow_from_directory(
    val_path ,
    target_size=(256, 256),
    batch_size=32,
    class_mode='categorical',
    classes=['Real', 'Fake_Photoshop', 'Fake_Deepfake', 'Fake_Gan'],
    shuffle=True
)


In [None]:
# Load the XceptionNet model
base_model = Xception(include_top=False, weights='imagenet', input_shape=(256, 256, 3))

# Add custom layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(512, activation='relu')(x)
x = Dropout(0.5)(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)

# Add output layers
image_type = Dense(4, activation='softmax', name='type')(x)

# Create the custom model
model = Model(inputs=base_model.input, outputs=image_type)

# Freeze the pre-trained layers
for layer in base_model.layers:
    layer.trainable = False

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

In [None]:
checkpoint_path = 'D:\\Nithies_FYP\\Checkpoints'
custom_callbacks = [EarlyStopping(monitor='loss', mode='min', patience=3, verbose=1),
                    ReduceLROnPlateau(monitor='loss', factor=0.2, patience=3, min_lr=0.00001),
                    ModelCheckpoint(filepath=os.path.join(checkpoint_path, 'final_model_checkpoint.h5'),
                                                          monitor='loss', 
                                                          mode='min', 
                                                          verbose=1,
                                                          save_best_only=True)
                    ]

# train the model using the fit_generator method
history = model.fit(
        train_generator,
        steps_per_epoch=len(train_generator),
        epochs=10,
        validation_data=val_generator,
        callbacks=custom_callbacks,
        validation_steps=len(val_generator))

#Saving the History
with open('D:\\Nithies_FYP\\Checkpoints\\trainHistory', 'wb') as file_pi:
    pickle.dump(history.history, file_pi)
    
# #Loading the history    
# with open('/trainHistoryDict', "rb") as file_pi:
#     history = pickle.load(file_pi)
    
len(history.history['loss'])

In [None]:
model.save('D:\\Nithies_FYP\\Models\\final_model.h5')

In [None]:
# PLotting Graphs

#Accuracy
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

#Loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

In [None]:
# Evaluating the Model

test_path = 'D:\\FYP PROJECT\\Preprocess3\\Test'

test_datagen = ImageDataGenerator(rescale=1./255)

test_generator = test_datagen.flow_from_directory(
    test_path ,
    target_size=(256, 256),
    batch_size=32,
    class_mode='categorical',
    classes=['Real', 'Fake_Photoshop', 'Fake_Deepfake', 'Fake_Gan'],
    shuffle=True
)

test_loss, test_acc = model.evaluate(test_generator, steps=len(test_generator))
print('Test loss:', test_loss)
print('Test accuracy:', test_acc)

### Loading the Prototype and Implementing XAI

In [None]:
# Loading the Model
new_model = tf.keras.models.load_model('Models\\final_model.h5')

# Check its architecture
new_model.summary()

In [None]:
#Creating XAI instance
explainer = lime_image.LimeImageExplainer()

In [None]:
#Readind the Image
img = plt.imread("D:\\root\\pics\\Friends\\Snapchat-2053075564.jpg")

print("image shape : ", img.shape)
plt.imshow(img)
plt.show()

In [None]:
def plot_image(image):
    x = plt.axis('off')
    x = plt.imshow(image)
    x = plt.show()
    return x

def preprocess(img):
    if img.shape[-1] == 4:
        # Convert the image from RGBA to RGB
        img = np.array(img)
        img = img[:, :, :3]
    img = cv2.resize(img, (256,256)) # Resize the image to match the input size of the model
    img = img.astype('float32') / 255.0 # Normalize the image
    # img = np.expand_dims(img, axis=0) # Add batch dimension
    img = img.reshape((-1,) + img.shape)
    print("Done...")
    return img
    
test_image = preprocess(img)

print("image shape : ", test_image.shape)
image = plot_image(test_image[0])

In [None]:
prediction = new_model.predict(test_image)

predicted_class = np.argmax(prediction)

print("Prediction Class : ",predicted_class)
# Print the prediction result
if predicted_class == 0:
    print('Real Image')
elif predicted_class == 1:
    print('Photoshopped Image')
elif predicted_class == 2:
    print('Deepfake Image')
elif predicted_class == 3:
    print('GAN-generated Image')
else:
    print('Unknown')

print(prediction)
print(prediction.size)
print("Prediction Scores")
type = ['Real', 'Photoshop', 'Deepfake', 'Gan']
count=0
for value in prediction:
    for val in value:
        percentage = float(val) * 100
        formatted_percentage = "{:.4f}%".format(percentage)
        print(type[count] , " - " ,formatted_percentage )
        count = count +1
    print('\n')



In [None]:
exp = explainer.explain_instance(test_image[0],  new_model.predict, top_labels=4, hide_color=0, num_samples=4)

In [None]:
import matplotlib

import matplotlib.pyplot as plt

def explanation_heatmap(exp, exp_class):
    # '''
    # Using heat-map to highlight the importance of each super-pixel for the model prediction
    # '''
    dict_heatmap = dict(exp.local_exp[exp_class])
    heatmap = np.vectorize(dict_heatmap.get)(exp.segments) 
    
    plt.imshow(heatmap, cmap = 'RdBu', vmin  = -heatmap.max(), vmax = heatmap.max())
    plt.axis('off')
    #Saving the figure
    h_img = plt.savefig("Models\\prediction_heatmap.jpg", format='png')
    plt.show()
    return h_img



def generate_prediction_sample(exp, exp_class, show_positive, hide_background, filename):
    image, mask = exp.get_image_and_mask(exp_class, 
                                         positive_only=show_positive, 
                                         num_features=10, 
                                         hide_rest=hide_background,
                                        )

    # x = plot_image(mark_boundaries(image, mask))
    plt.imshow(mark_boundaries(image, mask, outline_color=(0, 1, 0)))
    plt.axis('off')
    y = plt.savefig("Models\\prediction_highlight.jpg", format='png')
    plt.show()
    
    
    return y

pred_high = generate_prediction_sample(exp, exp.top_labels[0], True, False, "prediction_highlight")

heatmap = explanation_heatmap(exp, exp.top_labels[0])