In [None]:
from google.colab import drive

drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


In [None]:
ROOT_PATH = '/content/drive/My Drive/synthetic_image_detection/rule_based/'.replace(" ", "\\")

In [None]:
from tqdm import tqdm
import numpy as np
import torch
from pathlib import Path
import pickle

# Import local modules
%run {ROOT_PATH + 'fingerprint_generator_preprocessing.ipynb'}
%run {ROOT_PATH + 'fingerprint_generator.ipynb'}

def train_fingerprint_generator_model(image_dir, checkpoint_dir, model_name=None, epochs=100,
                                      checkpoint_interval=5, learning_rate=5e-4, train_size=100000,
                                      crop_size=256, alpha=1.0, boost=False, batch_size=64):
    """
    Train a model to generate fingerprint for each generative model.

    Parameters:
        image_dir (str): Directory containing real and synthetic images.
        checkpoint_dir (str): Directory to save checkpoints and training hyperparameters.
        epochs (int): Number of training epochs (default: 500).
        checkpoint_interval (int): Interval for saving checkpoints (default: 5).
        learning_rate (float): Learning rate for training (default: 5e-4).
        train_size (int): Number of samples to use for training (default: 100000).
        crop_size (int): Size for cropping input images (default: 256).
        alpha (float): Alpha parameter for the model (default: 1.0).
        boost (bool): Whether to boost training (default: False).
        batch_size (int): Batch size for training (default: 64).
    """
    data_root = Path(image_dir)
    check_dir = Path(checkpoint_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    hyper_pars = {'Epochs': epochs, 'Factor': 5, 'Noise Type': 'uniform', "Train Size": 100000,
              'Noise STD': 0.03, 'Inp. Channel': 16, 'Batch Size': batch_size,
              'LR': learning_rate, 'Device': device, 'Crop Size': (crop_size, crop_size), 'Margin':1,
              'Out. Channel': 3, 'Arch.': 32, 'Depth': 4, 'Alpha': alpha, 'Boost': boost,
              'Concat': [1, 1, 1, 1]}

    check_existence(check_dir, True)
    check_existence(data_root, False)

    print('Preparing Data Sets...')

    real_data_root = data_root / "real_images"
    fake_data_root = data_root / "synthetic_images"
    if model_name:
      fake_data_root = Path(fake_data_root, model_name)
      print('model images path', fake_data_root)

    real_path_list = [list(real_data_root.rglob('*.' + x)) for x in ['jpg', 'jpeg', 'png']]
    real_path_list = [ele for ele in real_path_list if ele != []][0]

    fake_path_list = [list(fake_data_root.rglob('*.' + x)) for x in ['jpg', 'jpeg', 'png']]
    fake_path_list = [ele for ele in fake_path_list if ele != []][0]
    number_samples = min(len(fake_path_list), len(real_path_list))
    print('namber of training samples:', number_samples)
    real_path_list = real_path_list[:number_samples]
    fake_path_list = fake_path_list[:number_samples]

    train_set = PreProcessData(real_path_list, fake_path_list, hyper_pars, demand_equal=False, train_mode=False)
    train_loader = train_set.get_loader()
    pickle.dump(hyper_pars, open((check_dir / 'train_hypers.pt'), 'wb'))

    print('Preparing Trainer...')
    trainer = FingerprintGenerator(hyper_pars).to(hyper_pars['Device'])

    epochs_list = list(range(1, hyper_pars['Epochs'] + 1))
    pbar = tqdm(total=len(epochs_list), desc='')

    for ep in epochs_list:
        pbar.update()

        for residual, labels in train_loader:
            trainer.train_step(residual, labels)

        if (ep % hyper_pars['Factor']) == 0:
            if ep > 0:
                trainer.save_stats(check_dir / ('chk_' + str(ep) + '.pt'))
:.
        pbar.postfix = f'Loss {np.mean(trainer.train_loss[-10:])3f} ' + \
                   f'| Fake C {np.mean(trainer.train_corr_f[-10:]):.3f} | Real C {np.mean(trainer.train_corr_r[-10:]):.3f}'

    trainer.save_stats(check_dir / ('chk_' + str(hyper_pars['Epochs']) + '.pt'))
    torch.cuda.empty_cache()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
images_dir = r'/content/drive/My Drive/synthetic_image_detection/data/small_data/training_data'
checkpoints_dir = r'/content/drive/My Drive/synthetic_image_detection/rule_based/checkpoints/stable_diffusion'
model_name = 'stable_diffusion'
train_fingerprint_generator_model(images_dir, checkpoints_dir, model_name)

Preparing Data Sets...
model images path /content/drive/My Drive/synthetic_image_detection/data/training_data/synthetic_images/stable_diffusion
namber of training samples: 332
Preparing Trainer...
device cuda


 27%|██▋       | 269/1000 [1:43:32<4:25:48, 21.82s/it, Loss 0.465 | Fake C 0.109 | Real C 0.047]