In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from bert_score import score
from matplotlib import pyplot as plt
import collections
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from tqdm import trange
import numpy as np
from models.stegastamp_wm import StegaStampDecoder, StegaStampEncoder




In [3]:
DATASET_SIZE = 1000
IMAGE_SIZE = 256
NUM_BITS = 48
IMAGE_CHANNELS = 3

if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'


coco_dataset = datasets.CocoCaptions(root = './data/images/train',
                        annFile = './data/annotations/train_captions.json',
                        transform=transforms.Compose([
                            transforms.Resize(IMAGE_SIZE),
                            transforms.CenterCrop(IMAGE_SIZE),
                            transforms.ToTensor()
                        ]))


loading annotations into memory...
Done (t=0.27s)
creating index...
index created!


In [4]:
signature = torch.randint(0, 2, (1, NUM_BITS), device=device).float()
wm_encoder = StegaStampEncoder(
    IMAGE_SIZE,
    IMAGE_CHANNELS,
    NUM_BITS,
)
wm_encoder_load = torch.load('models/wm_stegastamp_encoder.pth', map_location=device, weights_only=True)
if type(wm_encoder_load) is collections.OrderedDict:
    wm_encoder.load_state_dict(wm_encoder_load)
else:
    wm_encoder = wm_encoder_load

wm_decoder = StegaStampDecoder(
    IMAGE_SIZE,
    IMAGE_CHANNELS,
    NUM_BITS,
)
wm_decoder_load = torch.load('models/wm_stegastamp_decoder.pth', map_location=device, weights_only=True)
if type(wm_decoder_load) is collections.OrderedDict:
    wm_decoder.load_state_dict(wm_decoder_load)
else:
    wm_encoder = wm_encoder_load
    
wm_encoder.to(device)
wm_decoder.to(device)


StegaStampDecoder(
  (decoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (13): ReLU()
  )
  (dense): Sequential(
    (0): Linear(in_features=8192, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=48, bias=True)
  )
)

In [5]:
class CocoCaptionMixedWMDataset(Dataset):
    def __init__(self, signature, coco_dataset, num_images):
        super(CocoCaptionMixedWMDataset, self).__init__()
        self.coco_dataset = coco_dataset
        self.dataset = []
        for i in trange(num_images):
            try:
                image, caption = self.coco_dataset[i]
                image = image.to(device).float()
                wm_image = wm_encoder(signature.unsqueeze(0).to(device), image.unsqueeze(0).to(device))
                self.dataset.append((wm_image, signature))
                self.dataset.append((image.unsqueeze(0).to(device), caption))
            except Exception as e:
                print(e)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]

In [6]:
dataset = CocoCaptionMixedWMDataset(signature, coco_dataset, DATASET_SIZE)

100%|██████████| 1000/1000 [00:22<00:00, 44.50it/s]


In [12]:
tp = 1e-10
fp = 1e-10
fn = 1e-10
tn = 1e-10
bit_threshold = 5
bit_decoding_err = []
for i in range(len(dataset)):
    image, caption = dataset[i]
    decoded_signature = (wm_decoder(image) > 0).float()
    bit_match = torch.sum(torch.eq(decoded_signature, signature))
    if signature.shape[1] - bit_match <= bit_threshold:
        if type(caption) is torch.Tensor:
            bit_decoding_err.append(signature.shape[1] - bit_match)
            tp +=1
        else:
            fp +=1
    else:
        if type(caption) is torch.Tensor:
            bit_decoding_err.append(signature.shape[1] - bit_match)
            fn += 1
        else:
            tn += 1

print("Precision: ", tp/(tp+fp))
print("Recall: ", tp/(tp+fn))
print("F-1: ", 2*tp/(2*tp+fp+fn))
print("Avg bit decoding error:", (sum(bit_decoding_err)/len(bit_decoding_err)).item())
    

Precision:  0.9999999999998964
Recall:  0.9659999999999067
F-1:  0.9827060020344898
Avg bit decoding error: 1.2259999513626099


In [14]:
# Pass model through watermarking model, if matches to secret, classify as embedded and output secret
# else output vqa output.
# Output both classification accuracy and bert score

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto")


inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:  31%|###       | 3.05G/10.0G [00:00<?, ?B/s]