In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from tqdm import trange
import numpy as np
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import precision_score, recall_score, f1_score
from models.stegastamp_wm import StegaStampDecoder, StegaStampEncoder
import collections




In [3]:
DATASET_SIZE = 1000
IMAGE_SIZE = 256
NUM_BITS = 48
IMAGE_CHANNELS = 3

if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'


coco_dataset = datasets.CocoCaptions(root = './data/images/train',
                        annFile = './data/annotations/train_captions.json',
                        transform=transforms.Compose([
                            transforms.Resize(IMAGE_SIZE),
                            transforms.CenterCrop(IMAGE_SIZE),
                            transforms.ToTensor()
                        ]))


loading annotations into memory...
Done (t=0.27s)
creating index...
index created!


In [4]:
signature = torch.randint(0, 2, (1, NUM_BITS), device=device).float()
wm_encoder = StegaStampEncoder(
    IMAGE_SIZE,
    IMAGE_CHANNELS,
    NUM_BITS,
)
wm_encoder_load = torch.load('models/wm_stegastamp_encoder.pth', map_location=device, weights_only=True)
if type(wm_encoder_load) is collections.OrderedDict:
    wm_encoder.load_state_dict(wm_encoder_load)
else:
    wm_encoder = wm_encoder_load

wm_decoder = StegaStampDecoder(
    IMAGE_SIZE,
    IMAGE_CHANNELS,
    NUM_BITS,
)
wm_decoder_load = torch.load('models/wm_stegastamp_decoder.pth', map_location=device, weights_only=True)
if type(wm_decoder_load) is collections.OrderedDict:
    wm_decoder.load_state_dict(wm_decoder_load)
else:
    wm_encoder = wm_encoder_load
    
wm_encoder.to(device)
wm_decoder.to(device)


StegaStampDecoder(
  (decoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (13): ReLU()
  )
  (dense): Sequential(
    (0): Linear(in_features=8192, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=48, bias=True)
  )
)

In [None]:
class CocoCaptionMixedWMDataset(Dataset):
    def __init__(self, signature, coco_dataset, num_images):
        super(CocoCaptionMixedWMDataset, self).__init__()
        self.coco_dataset = coco_dataset
        self.dataset = []
        for i in trange(num_images):
            try:
                image, caption = coco_dataset[i]
                image = image.to(device).float()
                wm_image = wm_encoder(signature.unsqueeze(0), image.unsqueeze(0))
                self.dataset.append((wm_image.squeeze(0), 1)) 
                self.dataset.append((image, 0))               
            except Exception as e:
                print(e)
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]

In [6]:
dataset = CocoCaptionMixedWMDataset(signature, coco_dataset, DATASET_SIZE)

100%|██████████| 1000/1000 [00:22<00:00, 44.50it/s]


In [None]:
features = []
labels = []

for image, label in dataset:
    decoded_signature = wm_decoder(image.unsqueeze(0))
    decoded_bits = (decoded_signature > 0).cpu().numpy().flatten()
    features.append(decoded_bits)
    labels.append(label)

features = np.array(features)
labels = np.array(labels)

split_idx = int(len(features) * 0.8)
X_train, X_test = features[:split_idx], features[split_idx:]
y_train, y_test = labels[:split_idx], labels[split_idx:]

clf = GaussianNB()
clf.fit(X_train, y_train)

y_pred = clf.predict(X_test)

precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)