In [1]:
image_dir = "/home/yinjie/JPEG-project/dataset/step2dir/comp_image.jpg"
comb_mask_dir = "/home/yinjie/JPEG-project/dataset/step2dir/combined_mask.png"

In [2]:
import io
import os
import sys
import time
import glob
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, to_pil_image
import torchjpeg.codec
from fast_histogram import histogram1d

sys.path.append('/home/yinjie/JPEG-project/utils')
from utils import *

In [3]:
def plot_histogram(hist, bin_range):
    bins = np.arange(-bin_range, bin_range+2)
    width = 0.7 * (bins[1] - bins[0])
    center = (bins[:-1] + bins[1:]) / 2
    plt.bar(center, hist, align='center', width=width)
    plt.show()
    
def crop_leave4(im):
    return im.crop((4, 4, im.size[0]-4, im.size[1]-4))

def chi2_hist_distance(h1, h2):
    distance = 0
    for b in range(len(h1)):
        if h1[b] == 0 and h2[b] == 0:
            continue
        distance += (h1[b] - h2[b])**2 / (h1[b] + h2[b])
 
    return distance

def get_closest_histogram(reference, histogram_list):
    smallest_dist = np.inf
    smallest_idx = -1
    closest_hist = None
    
    for idx, h in enumerate(histogram_list):
        dist = chi2_hist_distance(h, reference)
        
        if dist < smallest_dist:
            smallest_dist = dist
            smallest_idx = idx
            closest_hist = h
    
    return closest_hist, smallest_idx

def compute_dct_coefficient_histogram(dct_blocks, bin_range=50):
    k_factor_list = dct_blocks.reshape(-1, 64).transpose()
    
    histograms = []
    for k in k_factor_list:
        h = histogram1d(k, bins=bin_range*2+1, range=(-bin_range,bin_range+1))
        histograms.append(h)

    return np.array(histograms)

def get_first_n_percentage_diff(q1, q2, n):
    q1_first_n = q1[ZIGZAG_ROW_IDX[:n], ZIGZAG_COL_IDX[:n]]
    q2_first_n = q2[ZIGZAG_ROW_IDX[:n], ZIGZAG_COL_IDX[:n]]
    diff = q1_first_n - q2_first_n
    percentage_diff = np.count_nonzero(diff) / n * 100
    return percentage_diff

def compute_dct_blocks(tampered_path, gt_path):
    # Read tampered image as dct blocks
    dims, mth_q_tables, quantized_dct_blocks, _ = torchjpeg.codec.read_coefficients(tampered_path)
    mth_q_table_lumi = mth_q_tables[0]
    dct_blocks = (quantized_dct_blocks * mth_q_table_lumi).squeeze()
    
    # Read ground truth mask
    gt_mask = to_tensor(Image.open(gt_path).convert('L')).squeeze()  # PIL convert("L") remove unnecessary colour/alpha channels
    
    h = dims[0,0].item()
    w = dims[0,1].item()
    
    # Remove last remainder rows if image height is not multiples of 8x8
    if h % 8 != 0:
        new_h = h // 8 * 8
        dct_blocks = dct_blocks[:-1]  # already in blocks of 8, drop last
        gt_mask = gt_mask[:new_h]
        
    # Do the same for columns
    if w % 8 != 0:
        new_w = w // 8 * 8
        dct_blocks = dct_blocks[:,:-1]  # already is blocks of 8, drop last
        gt_mask = gt_mask[:,:new_w]

    # Blockify ground truth image
    gt_include_edge = torch.nn.functional.max_pool2d(gt_mask[(None,)*2], kernel_size=(8,8)).squeeze()
    gt_exclude_edge = (-torch.nn.functional.max_pool2d(-gt_mask[(None,)*2], kernel_size=(8,8))).squeeze()
    
    
    # Retrieve clean and tampered blocks
    tampered_dct_blocks = dct_blocks[gt_exclude_edge.bool()]
    clean_dct_blocks = dct_blocks[(1 - gt_include_edge).bool()]

    return tampered_dct_blocks, clean_dct_blocks


def compute_cropped_dct_blocks(tampered_path, gt_path):
    # Read tampered image into spatial dimension and crop
    tampered_cropped = to_tensor(crop_leave4(Image.open(tampered_path).convert('L'))).squeeze()
        
    # Read ground truth mask and crop
    gt_cropped = to_tensor(crop_leave4(Image.open(gt_path).convert('L'))).squeeze()
    
    # Remove last remainder rows if image height is not multiples of 8x8
    if tampered_cropped.size(0) % 8 != 0:
        h = tampered_cropped.size(0) // 8 * 8
        tampered_cropped = tampered_cropped[:h]
        gt_cropped = gt_cropped[:h]
        
    # Do the same for columns
    if tampered_cropped.size(1) % 8 != 0:
        w = tampered_cropped.size(1) // 8 * 8
        tampered_cropped = tampered_cropped[:,:w]
        gt_cropped = gt_cropped[:,:w]

    # Process tampered image pixels into dct blocks
    pixels_blocks = ((tampered_cropped * 255) - 128).squeeze().unfold(0, 8, 8).unfold(1, 8, 8).reshape(-1, 8, 8)
    dct_blocks = torchjpeg.dct.block_dct(pixels_blocks[(None,)*2]).squeeze()

    # Blockify ground truth image
    gt_include_edge = torch.nn.functional.max_pool2d(gt_cropped[(None,)], kernel_size=(8,8)).squeeze().reshape(-1)
    gt_exclude_edge = (-torch.nn.functional.max_pool2d(-gt_cropped[(None,)], kernel_size=(8,8))).squeeze().reshape(-1)

    # Retrieve clean and tampered blocks
    tampered_dct_blocks = dct_blocks[gt_exclude_edge.bool()]
    clean_dct_blocks = dct_blocks[(1 - gt_include_edge).bool()]
    
    return tampered_dct_blocks, clean_dct_blocks

def compression_simulation_for_dct_blocks(dct_blocks, n, mth_q_table, bin_range=100):
    """
    3. Simulate compressions with n constant matrices using DCT coefficients
        1. Perform quantization with chosen constant quantization table. This procedure is lossy. 
        2. Dequantize coefficients by multiplying with the same constant quantization table.
        3. Repeat steps 2.C.a and 2.C.b with m-th quantization table
        4. After the "simulated" compressions (steps 2.C.a-b), compute the DCT coefficient histogram.
    
    Args:
        im: PIL image.
        n: Number of constant matrices to try.
        mth_q_table: Quantization table used in the m-th compression.
    Returns:
        List of list of histograms, of shape (64, n, histogram_size). 
    """    
    ############################################
    # Step 3: Create n constant matrices and do compression (quantize and dequantize)
    ############################################
    k_hists_compare = []
    
    for i in range(1, n+1):
        # Create constant matrix with element i
        # M_i is just a length 64 constant array since jpeg compress takes in a 1d array
        M_i = torch.ones((8,8)) * i
        
        ############################################
        # Step 3.1 & 3.2: "Compress" using M_i
        ############################################
        quantized_dct_blocks = torch.round(dct_blocks / M_i)  # lossy step
        dequantized_dct_blocks = quantized_dct_blocks * M_i
        
        ############################################
        # Step 3.3: "Compress" again using mth_q_table
        ############################################
        quantized_dct_blocks = torch.round(dequantized_dct_blocks / mth_q_table)  # lossy step
        dequantized_dct_blocks = quantized_dct_blocks * mth_q_table
        
        ############################################
        # Step 3.4: Compute 64 histograms from dequantized dct coefficients
        ############################################
        k_hists = compute_dct_coefficient_histogram(np.array(dequantized_dct_blocks.squeeze()), bin_range=bin_range)
        k_hists_compare.append(k_hists)
        
    k_hists_compare = np.array(k_hists_compare).transpose(1,0,2)
    return k_hists_compare

def estimate_q_table_from_dct_blocks(dct_blocks, dct_blocks_cropped, mth_q_table, n, bin_range):
    """
    Args:
        im_path: Path to m-compressed image. Assume image is grayscale only.
        n: Greatest value assumed by quantization factors
    """
    # Step 1: Compute reference k DCT histograms for m-compressed image    
    k_hists_ref = compute_dct_coefficient_histogram(dct_blocks.numpy(), bin_range=bin_range)
    
    # Step 2: Simulate compressions to get k factor histograms
    k_hists_compare = compression_simulation_for_dct_blocks(dct_blocks_cropped.numpy(), n, mth_q_table, bin_range=bin_range)
    
    # Step 3: Compute closest histogram using chi-square histogram distance
    estimation = np.zeros(64)
    for i in range(64):
        hist, idx = get_closest_histogram(k_hists_ref[i], k_hists_compare[i])
        best_n = idx + 1  # since n starts from 1
        estimation[i] = best_n
    
    return estimation.reshape((8,8))

In [4]:
def predict_q_tables(tampered_path, ground_truth_path, n, bin_range):
    # Get m-th quantization table
    _, mth_q_tables, _, _ = torchjpeg.codec.read_coefficients(tampered_path)
    mth_q_table_lumi = mth_q_tables[0]
    
    # Compute DCT blocks
    tampered_dct_blocks, clean_dct_blocks = compute_dct_blocks(tampered_path, ground_truth_path)
    cropped_tampered_dct_blocks, cropped_clean_dct_blocks = compute_cropped_dct_blocks(tampered_path, ground_truth_path)
    
    # Estimate q-tables
    est_clean = estimate_q_table_from_dct_blocks(clean_dct_blocks, cropped_clean_dct_blocks, mth_q_table_lumi, n, bin_range)    
    est_tampered = estimate_q_table_from_dct_blocks(tampered_dct_blocks, cropped_tampered_dct_blocks, mth_q_table_lumi, n, bin_range)
    diff = get_first_n_percentage_diff(est_clean, est_tampered, 15)
    # 计算余弦相似度


    
    #cos_sim = torch.nn.functional.cosine_similarity(torch.from_numpy(est_clean).view(1, -1), torch.from_numpy(est_tampered).view(1, -1), dim=1)

    # 计算欧氏距离
    #euclidean_dist = torch.norm(a-b)

    #print("Cosine Similarity: ", cos_sim.item())
    #print("Euclidean Distance: ", euclidean_dist.item())

    
    return est_clean, est_tampered, diff

In [24]:
whole_img = image_dir
tmp_img = comb_mask_dir

est_clean, est_tampered, diff = predict_q_tables(whole_img,tmp_img, 100, 100)
print("diff:{}".format(diff))

diff:46.666666666666664


comp30.jpg 100

splicing_PQ10.jpg 13.333333333333334(out) 13.333333333333334(in) 

copy_move_PQ10  0.0(out) 

im40_edit1.jpg 26.666666666666668

comp26912.jpg  diff:60.0

im41_edit1  53.333333333333336

im41_edit2 40.0

im41_edit3  46.666666666666664

im42_edit1 80.0

im42_edit2 20.0

im42_edit3 46.666666666666664



In [13]:
est_clean

array([[ 2.,  2.,  1.,  4.,  1.,  2.,  9.,  3.],
       [ 2.,  2.,  2.,  4.,  2.,  3.,  1.,  2.],
       [ 1.,  2.,  3.,  2.,  3.,  4.,  3., 13.],
       [ 4.,  4.,  2.,  1.,  4.,  5., 11., 13.],
       [ 2.,  2.,  3.,  4.,  5.,  5., 13.,  4.],
       [ 2.,  5.,  5.,  7., 10.,  4.,  4., 14.],
       [ 1.,  5.,  3.,  3., 11., 11., 11., 14.],
       [ 5.,  5.,  5.,  6.,  6.,  6., 14., 14.]])

In [14]:
est_tampered

array([[ 2.,  2.,  1.,  4.,  1.,  2.,  8.,  2.],
       [ 2.,  2.,  2.,  4.,  1.,  8.,  1.,  1.],
       [ 1.,  2.,  1.,  1.,  1., 11.,  3., 15.],
       [ 4.,  4.,  1.,  1.,  2.,  5., 11., 14.],
       [ 1.,  1.,  1.,  2., 13., 17., 11.,  1.],
       [ 2.,  1.,  8.,  8., 10.,  4.,  4.,  2.],
       [ 8.,  8.,  3.,  6., 10., 11., 10.,  3.],
       [ 5.,  5., 17., 17., 39., 17.,  1.,  3.]])