In [5]:

import torch
from torch.utils.data import Dataset, DataLoader
from models.stegastamp_wm import StegaStampDecoder, StegaStampEncoder
from score import f1
from torchvision import transforms
import os
from dataset import get_image_dataloader



In [6]:
image_size = 256
num_bits = 64
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'    

transform = transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor()
                            ]) 

In [7]:
encoder = StegaStampEncoder(image_size, 3, num_bits).to(device)
encoder.load_state_dict(torch.load('models/wm_stegastamp_encoder.pth', map_location=device, weights_only=True))
decoder = StegaStampDecoder(image_size, 3, num_bits).to(device)
decoder.load_state_dict(torch.load('models/wm_stegastamp_decoder.pth', map_location=device, weights_only=True))

<All keys matched successfully>

In [9]:
val_loader = get_image_dataloader("./../data/images/val", transform, batch_size=64, shuffle=True)

In [12]:
tp = 1e-10
fp = 1e-10
fn = 1e-10
tn = 1e-10
avg_bit_acc = 0
with torch.no_grad():
    for images in val_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encode_split = images.shape[0] // 2
        watermarks[:encode_split] = 2
        true_labels = (watermarks[:, 0] != 2).int()
        pred_labels = torch.randint(0, 2, (watermarks.shape[0], 1)).float().to(device)
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
        if torch.sum(true_labels == 1) > 0:
            true_watermarks = watermarks[true_labels == 1]
            pred_watermarks = torch.randint(0, 2, (true_watermarks.shape[0], num_bits)).float().to(device)
            avg_bit_acc += torch.mean((pred_watermarks == true_watermarks).float())
    print(f1(tp, tn, fp, fn))
    print(avg_bit_acc / len(val_loader))

{'Precision': 0.5, 'Recall': 0.482151963284039, 'F1-score': 0.49091381100726905, 'Accuracy': 0.5}
tensor(0.4985, device='mps:0')
