# Decoding

The key is a string of 48 bits, which can be converted to a boolean array of 48 elements. 
The `msg_extractor` is a TorchScript model that extracts the message from the image.

Derived from https://github.com/facebookresearch/stable_signature/blob/main/decoding.ipynb as of 16th August 2023

### Imports and setup

In [17]:
from PIL import Image
import torch
import torchvision.transforms as transforms
import os
from scipy.stats import binomtest
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np

def msg2str(msg):
    return "".join([('1' if el else '0') for el in msg])

def str2msg(str):
    return [True if el=='1' else False for el in str]

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"{device} will be used.")

msg_extractor = torch.jit.load("dec_48b_whit.torchscript.pt").to(device)

transform_imnet = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    ])

cuda will be used.


### Functions

In [3]:
def average(total, num_images):
    return total / num_images

def eval_img(img):
    # feed image into model and extract the bit string (message)
    img = transform_imnet(img).unsqueeze(0).to(device)
    msg = msg_extractor(img) # b c h w -> b k
    
    # convert message into boolean message
    bool_msg = (msg>0).squeeze().cpu().numpy().tolist()

    # compute difference between model key and message extracted from image
    diff = [bool_msg[i] != bool_key[i] for i in range(len(bool_msg))]
    
    # calculate bit accuracy
    bit_acc = 1 - sum(diff)/len(diff)
    
    # compute p-value
    pval = binomtest(sum(diff), len(diff), 0.5).pvalue
    
    return bit_acc, pval

### Decode images and compute metrics

Metrics are:
- **Bit accuracy**: number of matching bits between the key and the message, divided by the total number of bits.
- **$p$-value**: probability of observing a bit accuracy as high as the one observed, assuming the null hypothesis that the image is genuine.

In [None]:
w_img_path = './data/test/watermarked/0'
wr_img_path = './data/test/watermark_removed/0'
key = '111010110101000001010111010011010100010000100111' # model key
bool_key = str2msg(key)

total_w_acc, total_wr_acc = 0, 0
total_w_pval, total_wr_pval = 0, 0
total_psnr = 0
num_images = 0


# go through each image of the directory
for filename in os.listdir(w_img_path):
    
    if filename.endswith(".jpg") or filename.endswith(".png"):
        
        # open the images
        watermarked_img = Image.open(os.path.join(w_img_path, filename))
        watermark_removed_img = Image.open(os.path.join(wr_img_path, filename))
        
        # get bit accuracy and pvalue of the watermarked and watermark-removed images  
        w_bit_acc, w_pval = eval_img(watermarked_img)
        wr_bit_acc, wr_pval = eval_img(watermark_removed_img)
        
        # compute PSNR
        psnr_value = psnr(np.array(watermarked_img), np.array(watermark_removed_img))
        
        # used for metrics
        total_w_acc += w_bit_acc
        total_wr_acc += wr_bit_acc
        total_w_pval += w_pval
        total_wr_pval += wr_pval
        total_psnr += psnr_value
        num_images += 1

# compute averages of metrics and print them
print("{:<40} {:<40}".format('Watermarked Images', 'Watermark Removed Images'))
print("{:<40} {:<40}".format(f'Average Bit Accuracy: {total_w_acc / num_images:.6f}', f'Average Bit Accuracy: {total_wr_acc / num_images:.6f}'))
print("{:<40} {:<40}".format(f'Average p-value: {total_w_pval / num_images:.6f}', f'Average p-value: {total_wr_pval / num_images:.6f}'))
print(f'Average PSNR: {total_psnr / num_images:.6f}')

In [16]:
from prettytable import PrettyTable

# define table
table = PrettyTable()

# add columns
table.field_names = ["", "Watermarked Images", "Watermark Removed Images"]
table.add_row(["Average Bit Accuracy", f"{total_w_acc / num_images:.6f}", f"{total_wr_acc / num_images:.6f}"])
table.add_row(["Average p-value", f"{total_w_pval / num_images:.6f}", f"{total_wr_pval / num_images:.6f}"])
table.add_row(["Average PSNR", f"{total_psnr / num_images:.6f}", f"{total_psnr / num_images:.6f}"])

# print table
print(table)

+----------------------+--------------------+--------------------------+
|                      | Watermarked Images | Watermark Removed Images |
+----------------------+--------------------+--------------------------+
| Average Bit Accuracy |      0.988542      |         0.535354         |
|   Average p-value    |      0.000035      |         0.549755         |
|     Average PSNR     |     23.131498      |        23.131498         |
+----------------------+--------------------+--------------------------+
