In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from bert_score import score
from matplotlib import pyplot as plt
import collections
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor,AutoTokenizer
from tqdm import trange
import numpy as np
from models.stegastamp_wm import StegaStampDecoder, StegaStampEncoder
from evaluate import load
from score import f1, compute_evaluation_metrics


[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/varunchitturi/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /Users/varunchitturi/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /Users/varunchitturi/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [2]:
DATASET_SIZE = 1000
IMAGE_SIZE = 256
NUM_BITS = 48
IMAGE_CHANNELS = 3

if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'


coco_dataset = datasets.CocoCaptions(root = './data/images/train',
                        annFile = './data/annotations/train_captions.json',
                        transform=transforms.Compose([
                            transforms.Resize(IMAGE_SIZE),
                            transforms.CenterCrop(IMAGE_SIZE),
                            transforms.ToTensor()
                        ]))


loading annotations into memory...
Done (t=0.27s)
creating index...
index created!


In [3]:
signature = torch.randint(0, 2, (1, NUM_BITS), device=device).float()
wm_encoder = StegaStampEncoder(
    IMAGE_SIZE,
    IMAGE_CHANNELS,
    NUM_BITS,
)
wm_encoder_load = torch.load('models/wm_stegastamp_encoder.pth', map_location=device, weights_only=True)
if type(wm_encoder_load) is collections.OrderedDict:
    wm_encoder.load_state_dict(wm_encoder_load)
else:
    wm_encoder = wm_encoder_load

wm_decoder = StegaStampDecoder(
    IMAGE_SIZE,
    IMAGE_CHANNELS,
    NUM_BITS,
)
wm_decoder_load = torch.load('models/wm_stegastamp_decoder.pth', map_location=device, weights_only=True)
if type(wm_decoder_load) is collections.OrderedDict:
    wm_decoder.load_state_dict(wm_decoder_load)
else:
    wm_encoder = wm_encoder_load
    
wm_encoder.to(device)
wm_decoder.to(device)


StegaStampDecoder(
  (decoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (13): ReLU()
  )
  (dense): Sequential(
    (0): Linear(in_features=8192, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=48, bias=True)
  )
)

In [4]:
class CocoCaptionMixedWMDataset(Dataset):
    def __init__(self, signature, coco_dataset, num_images):
        super(CocoCaptionMixedWMDataset, self).__init__()
        self.coco_dataset = coco_dataset
        self.dataset = []
        for i in trange(num_images):
            try:
                image, caption = self.coco_dataset[i]
                image = image.to(device).float()
                wm_image = wm_encoder(signature.unsqueeze(0).to(device), image.unsqueeze(0).to(device))
                self.dataset.append((wm_image, signature))
                self.dataset.append((image.unsqueeze(0).to(device), caption))
            except Exception as e:
                print(e)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]

In [5]:
dataset = CocoCaptionMixedWMDataset(signature, coco_dataset, DATASET_SIZE)

100%|██████████| 1000/1000 [00:24<00:00, 40.48it/s]


In [6]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

model.to(device)



VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_featur

In [15]:
tp = 1e-10
fp = 1e-10
fn = 1e-10
tn = 1e-10
bit_threshold = 5
bit_decoding_err = []
pred_captions = []
true_captions = []
for i in trange(len(dataset)):
    image, caption = dataset[i]
    decoded_signature = (wm_decoder(image) > 0).float()
    with torch.no_grad():
        outputs = model.generate(image[:, :, 16:240, 16:240], max_length=50)
    pred_caption = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    bit_match = torch.sum(torch.eq(decoded_signature, signature))
    if signature.shape[1] - bit_match <= bit_threshold:
        pred_caption = "".join(map(str, decoded_signature.squeeze().int().tolist()))
        if type(caption) is torch.Tensor:
            caption = ["".join(map(str, caption.squeeze().int().tolist()))]

            bit_decoding_err.append(signature.shape[1] - bit_match)
            tp += 1
        else:
            fp += 1
    else:
        if type(caption) is torch.Tensor:
            caption = ["".join(map(str, caption.squeeze().int().tolist()))]
            bit_decoding_err.append(signature.shape[1] - bit_match)
            fn += 1
        else:
            tn += 1
    pred_captions.append(pred_caption)
    true_captions.append(caption)
    

100%|██████████| 2000/2000 [09:05<00:00,  3.66it/s]


In [16]:
print(compute_evaluation_metrics(pred_captions, true_captions))
print(f1(tp, tn, fp, fn))
print("Avg bit decoding error:", (sum(bit_decoding_err)/len(bit_decoding_err)).item())

{'meteor': 0.34689708099758887, 'bleu': 0.2623011782458244, 'rouge1': 0.5263947638545295, 'rouge2': 0.16398056382581075, 'rougeL': 0.5050689780629117}
{'Precision': 0.9999999999998951, 'Recall': 0.9539999999999091, 'F1-score': 0.9764585465710386}
Avg bit decoding error: 1.3830000162124634
