In [None]:
# ! pip install wandb
# ! pip install pydot
# ! pip install graphviz
# ! pip install datasets
# ! pip install scikit-learn
# ! pip install webdataset
# ! pip install sagemaker_tensorflow # uses Linux FIFOs so does not work on Mac

# Importing necessary Libraries

In [None]:
import logging as log
import os
import pickle
import shutil
import sys
from datetime import datetime
from pathlib import Path

import boto3
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import webdataset as wds
from IPython.display import Image
from webdataset import multi

import wandb

In [None]:
PROJECT_SUFFIX = "frequency_classifier_multi"
ENTITY = "makersplace"
PROJECT = f"ai-or-not-{PROJECT_SUFFIX}"
SEED = 7
RUNTIME_DATE_SUFFIX = "%m%d_%H%M"

# current time
JOB_TYPE_SUFFIX = f"{PROJECT_SUFFIX}_M"
RUN_NAME_SUFFIX = datetime.now().strftime(RUNTIME_DATE_SUFFIX)


# Datasets Paths
S3_BUCKET = "mp-ml-data-dev"
PREFIX = "finder/ai_or_not/ai_or_not_datasets/test_datasets/"
S3_PREFIX = f"s3://{S3_BUCKET}/{PREFIX}"
training_dataset_path = "../cache/data/training_dataset_tf_record_snapshot"
validation_dataset_path = "../cache/data/validation_dataset_tf_record_snapshot"
dataset_cache_path = "../cache/data/dataset.cache"

# Model Paths
cnn_model_path = Path(f"../cache/models/{JOB_TYPE_SUFFIX}/cnn_{RUN_NAME_SUFFIX}")
effv2_model_dir_path = Path(f"../cache/models/{JOB_TYPE_SUFFIX}/en2s_{RUN_NAME_SUFFIX}")


# Deleted and recreated training and validation dataset folders
CLEAN_RUN = True


np.random.seed(SEED)
tf.random.set_seed(SEED)

# WANDB Login
os.environ["WANDB_API_KEY"] = "d13afab09b400fc9d606e612d806a4b0740790fd"
wandb.login()

# log to stdout
log.basicConfig(
    format="%(asctime)s %(levelname)-8s %(message)s",
    level=log.INFO,
    datefmt="%Y-%m-%d %H:%M:%S",
)

# create boto3 session
boto3_session = boto3.Session(profile_name="dev")
s3_client = boto3_session.client("s3")

# Configuration

In [None]:
CONFIGURATION = {
    "BATCH_SIZE": 64,
    "IM_SIZE": 128,
    "DROPOUT_RATE": 0.1,
    "N_EPOCHS": 15,
    "REGULARIZATION_RATE": 0.01,
    "N_FILTERS": 6,
    "KERNEL_SIZE": 3,
    "N_STRIDES": 1,
    "POOL_SIZE": 2,
    "N_DENSE_1": 2048,
    "N_DENSE_2": 1024,
    "N_DENSE_3": 256,
    "LEARNING_RATE": 0.001,
    "CHANNELS": 3,
    "CLASS_NAMES": ["REAL", "ADM", "SD", "MD"],
}

# DataSet Configuration

In [None]:
TRAIN_DIRECTORIES = [
    # Label 0 - Real
    ("../cache/data/DIRE/train/imagenet/real", 0, "*/*"),
    ("../cache/data/DIRE/train/celebahq/real", 0, "*"),
    ("../cache/data/DIRE/train/lsun_bedroom/real", 0, "*"),
    ("../cache/data/cifake/train/REAL", 0, "*"),
    # Label 1 - GAN
    ("../cache/data/DIRE/train/lsun_bedroom/stylegan", 1, 0),
    # Label 2 - Diffusion
    ("../cache/data/DIRE/train/imagenet/adm", 2, "*/*"),
    ("../cache/data/DIRE/train/lsun_bedroom/adm", 2, "*"),
    # Label 3 SD
    ("../cache/data/cifake/train/FAKE", 3, "*"),  # Generated by SD 1.4
    (
        "../cache/data/FakeImageDataset/ImageData/train/SDv15R-CC1M/SDv15R-dpmsolver-25-1M/SDv15R-CC1M",
        3,
        "*",
    ),  # Generated by SD 1.5
    ("../cache/data/DIRE/train/celebahq/sdv2", 3, "*/*"),
    # Label 4 MD
    ("../cache/data/FakeImageDataset/ImageData/val/Midjourneyv5-5K/Midjourneyv5-5K_train", 4, "*"),
]

In [None]:
# Test Directories
TEST_SHARDS = 1
TEST_DIRECTORIES = [
    # Label 0
    ("../cache/data/cifake/test/REAL", 0, "*"),
    ("../cache/data/DIRE/test/imagenet/real", 0, "*/*"),
    ("../cache/data/DIRE/test/celebahq/real", 0, "*"),
    # Label 1
    ("../cache/data/DIRE/test/imagenet/adm", 1, "*/*"),
    # Label 2
    ("../cache/data/cifake/test/FAKE", 2, "*"),  # Generated by SD 1.4
    ("../cache/data/DIRE/test/imagenet/sdv1", 2, "*/*"),  # Bad 73
    ("../cache/data/DIRE/test/lsun_bedroom/sdv1_new", 2, "*"),
    ("../cache/data/FakeImageDataset/ImageData/val/SDv15-CC30K/SDv15-CC30K", 2, "*/*"),
    # Label 3
    ("../cache/data/DIRE/test/lsun_bedroom/sdv2", 2, "*"),
    ("../cache/data/DIRE/test/celebahq/sdv2", 2, "*"),
    ("../cache/data/FakeImageDataset/ImageData/val/SDv21-CC15K/SDv21-CC15K/SDv2-dpmsolver-25-10K", 2, "*"),  # Bad 79
    # Label 4
    ("../cache/data/FakeImageDataset/ImageData/val/Midjourneyv5-5K/Midjourneyv5-5K_test", 3, "*"),  # Bad  65
    ("../cache/data/DIRE/test/lsun_bedroom/midjourney", 3, "*"),  # Bad < 13
    # # AI Artbench Dataset
    # ("../cache/data/ai-artbench/test/AI*", 1.0, "*", TEST_SHARDS, 0),  # 675 Batches
    # ("../cache/data/ai-artbench/test/real", 0.0, "*/*", TEST_SHARDS, 0),
    # # CIFAKE Dataset
    # ("../cache/data/FakeImageDataset/ImageData/val/cogview2-22K/cogview2-22K", 1.0, "*", TEST_SHARDS, 0),
    # DIRE Imagenet Dataset
    # ("../cache/data/DIRE/test/celebahq/if", 1.0, "*", TEST_SHARDS, 0),
    # ("../cache/data/DIRE/test/celebahq/dalle2", 1.0, "*", TEST_SHARDS, 0),
    # # DIRE Lsun Bedroom Dataset
    # ("../cache/data/DIRE/test/lsun_bedroom/dalle2", 1.0, "*", TEST_SHARDS, 0),
    # ("../cache/data/DIRE/test/lsun_bedroom/vqdiffusion", 1.0, "*", TEST_SHARDS, 0),
    # FakeImageDataset
]

# Dataset Loading and Tranformations

In [None]:
def resize_image(image):
    image = tf.image.resize_with_pad(
        image=image,
        target_height=CONFIGURATION["IM_SIZE"],
        target_width=CONFIGURATION["IM_SIZE"],
    )
    # divide by 255 to normalize
    image = image / 255.0

    return image


def decode_img(img):
    img = tf.io.decode_image(img, channels=3)
    return resize_image(img)


def process_path(file_path):
    # Load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img


# write a function apply Fourier Transform to the image and return the image
def apply_fourier_transform(image):
    # print(f"Image shape: {image.shape}")
    # extract r channel from the image
    r = image[:, :, 0]
    # extract g channel from the image
    g = image[:, :, 1]
    # extract b channel from the image
    b = image[:, :, 2]

    # apply fourier transform to the image
    r = tf.signal.fft2d(tf.cast(r, dtype=tf.complex64))
    g = tf.signal.fft2d(tf.cast(g, dtype=tf.complex64))
    b = tf.signal.fft2d(tf.cast(b, dtype=tf.complex64))
    # # shift the zero-frequency component to the center of the spectrum
    r = tf.signal.fftshift(r)
    g = tf.signal.fftshift(g)
    b = tf.signal.fftshift(b)
    # apply log to the image enhance the magnitude of the image and to reduce the dynamic range of the data for visualization
    r = 20 * tf.math.log(tf.abs(r) + 1)
    g = 20 * tf.math.log(tf.abs(g) + 1)
    b = 20 * tf.math.log(tf.abs(b) + 1)
    # normalize the value using min-max normalization
    r = (r - tf.reduce_min(r)) / (tf.reduce_max(r) - tf.reduce_min(r))
    g = (g - tf.reduce_min(g)) / (tf.reduce_max(g) - tf.reduce_min(g))
    b = (b - tf.reduce_min(b)) / (tf.reduce_max(b) - tf.reduce_min(b))
    # merge channels
    if CONFIGURATION["CHANNELS"] == 6:
        o_r = image[:, :, 0]
        o_g = image[:, :, 1]
        o_b = image[:, :, 2]
        image = tf.stack([o_r, o_g, o_b, r, g, b], axis=-1)
    else:
        image = tf.stack([r, g, b], axis=-1)

    return image


def visualize_dataset(samples):
    plt.figure(figsize=(12, 12))
    index = 1
    for image, label in samples:
        plt.subplot(4, 4, index)
        plt.imshow(image)
        title = CONFIGURATION["CLASS_NAMES"][int(label)]
        plt.title(title)
        plt.axis("off")
        index += 1

    plt.show()


def get_custom_dataset2(directory, label, pattern):
    # if directory path contains 'aiornot' load it as tf dataset else load it as a custom dataset
    directory = directory.decode("utf-8")
    pattern = pattern.decode("utf-8")

    if "aiornot" in directory:
        read_aiornot = load_from_disk(dataset_path=directory)
        dataset = read_aiornot.to_tf_dataset(
            columns="image",
            label_cols="label",
        )
        dataset = dataset.map(lambda x, y: (tf.cast(x, tf.float32), y))
        dataset = dataset.map(lambda x, y: (resize_image(x), y), num_parallel_calls=tf.data.AUTOTUNE)

    else:
        list_ds = tf.data.Dataset.list_files(str(Path(directory) / pattern), shuffle=True)
        dataset = list_ds.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.map(lambda x: (x, label))

    return dataset


def get_dataset2(directory, pattern):
    pattern = tf.strings.join([directory, pattern], separator="/")
    dataset = tf.data.TFRecordDataset.list_files(pattern, shuffle=True)
    return dataset

# Individual Dataset Creation and Visualization

In [None]:
wds_prefix = Path("../cache/wds/")
# create the directory if it does not exist
wds_prefix.mkdir(parents=True, exist_ok=True)

# test_directory_index = 1
# dataset = get_dataset2(TEST_DIRECTORIES[test_directory_index][0], TEST_DIRECTORIES[test_directory_index][2])
# labelled_dataset = dataset.map(lambda file_path: (file_path, TEST_DIRECTORIES[test_directory_index][1]))
# images_dataset = labelled_dataset.map(
#     lambda image, label: (process_path(image), label),
#     num_parallel_calls=tf.data.AUTOTUNE
# )


# tar_file_path = str(wds_prefix / f"test_{test_directory_index}_dataset.tar")
# sink = wds.TarWriter(tar_file_path)
# for index, (input, output) in enumerate(images_dataset):
#     if index%1000==0:
#         print(f"{index:6d}", end="\r", flush=True, file=sys.stderr)
#     # conver input to numpy array
#     input = input.numpy()
#     output = output

#     sink.write({
#         "__key__": "sample%06d" % index,
#         "input.pyd": input,
#         "output.pyd": output,
#     })
# sink.close()

for test_directory in TEST_DIRECTORIES:
    directory_path = test_directory[0].replace("../cache/data", "")
    test_dataset_name = (
        directory_path.split("/")[1] + "_" + directory_path.split("/")[-2] + "_" + directory_path.split("/")[-1]
    )

    dataset = get_dataset2(test_directory[0], test_directory[2])
    labelled_dataset = dataset.map(lambda file_path: (file_path, test_directory[1]))
    images_dataset = labelled_dataset.map(
        lambda image, label: (process_path(image), label), num_parallel_calls=tf.data.AUTOTUNE
    )

    tar_file_path = str(wds_prefix / f"test_{test_dataset_name}.tar")
    sink = wds.TarWriter(tar_file_path)
    for index, (input, output) in enumerate(images_dataset):
        if index % 1000 == 0:
            print(f"{index:6d}", end="\r", flush=True, file=sys.stderr)
        # conver input to numpy array
        input = input.numpy()
        output = output

        sink.write(
            {
                "__key__": "sample%06d" % index,
                "input.pyd": input,
                "output.pyd": output,
            }
        )
    sink.close()
    log.info(f"Test dataset {test_dataset_name} created locally {tar_file_path}")

    # # upload the generated tar file to s3 bucket using s3 sdk
    # key = PREFIX + tar_file_path.split("/")[-1]
    # s3_client.upload_file(
    #     Filename = tar_file_path,
    #     Bucket = S3_BUCKET,
    #     Key = key
    # )

    log.info(f"Test dataset {test_dataset_name} uploaded at {key}")

In [None]:
# # dataset = wds.WebDataset('https://mp-ml-data-dev.s3.us-west-2.amazonaws.com/finder/ai_or_not/ai_or_not_datasets/webdatasets/test_1_dataset.tar')

# samples = islice(dataset, 0, 16)
# images = []
# for sample in samples:
#     # print(sample.keys())
#     # print(sample["input.pyd"])
#     # read pickle data
#     image_numpy_array = pickle.loads(sample["input.pyd"])
#     label = pickle.loads(sample["output.pyd"])
#     # print(image_numpy_array.shape)
#     images.append((image_numpy_array, label))


# visualize_dataset(images)

In [None]:
class TFS3Dataset:
    """This class is a convenient placeholder for the dataset-related information.
    You could also just define these iterator etc. as global functions."""

    def __init__(self, prefix, files):
        self.length = 200_000
        self.urls = []
        for f in files:
            self.urls.append(prefix + f)
        self.dataset = self.get_s3_dataset(self.urls)
        self.loader = multi.MultiLoader(self.dataset, workers=12)

    def __iter__(self):
        for sample in self.loader:
            yield pickle.loads(sample["input.pyd"]), pickle.loads(sample["output.pyd"])

    def __len__(self):
        return self.length

    def get_s3_dataset(self, urls):
        # add awscli command to urls
        urls = [f"pipe:aws s3 cp {url} -" for url in urls]
        dataset = wds.WebDataset(urls, shardshuffle=True)
        return dataset

    def output_shapes(self):
        return ((128, 128, 3), ())

    def output_types(self):
        return (tf.float32, tf.int64)


s3_dataset_url = "s3://mp-ml-data-dev/finder/ai_or_not/ai_or_not_datasets/webdatasets/test_7_dataset.tar"

# list all files in S3_PREFIX directory
s3_dataset_files = s3_client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PREFIX)
# get all files from the response
s3_dataset_files = s3_dataset_files["Contents"]
# get the file names
s3_dataset_files = [file["Key"] for file in s3_dataset_files]
# remove the directory name from the file names
s3_dataset_files = [file.replace(PREFIX, "") for file in s3_dataset_files]
# remove the empty string from the list
s3_dataset_files = list(filter(None, s3_dataset_files))
# log the number of files
log.info(f"Number of files in the dataset: {s3_dataset_files}")

tf_s3_dataset = TFS3Dataset(prefix=S3_PREFIX, files=s3_dataset_files)


tdf = tf.data.Dataset.from_generator(
    generator=tf_s3_dataset.__iter__,
    output_types=tf_s3_dataset.output_types(),
    output_shapes=tf_s3_dataset.output_shapes(),
)


visualize_dataset(tdf.take(16))

In [None]:
# load a keras model from directory
model = tf.keras.models.load_model(
    "/Users/skoneru/workspace/discovery/playground/ai_or_not/cache/models/model_ev2s_99_acc_rgb/saved_model"
)

In [None]:
model.evaluate(tdf.batch(64), verbose=1)