In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from bert_score import score
from matplotlib import pyplot as plt
import collections
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor,AutoTokenizer
from tqdm import trange
import numpy as np
import os
from models.stegastamp_wm import StegaStampDecoder, StegaStampEncoder
from evaluate import load
from score import f1
from dataset import get_image_dataloader

In [2]:
image_size = 256
num_bits = 64
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'    
    
transform = transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor()
                            ]) 

In [None]:
encoder = StegaStampEncoder(image_size, 3, num_bits).to(device)
encoder.load_state_dict(torch.load('models/wm_stegastamp_encoder.pth', map_location=device, weights_only=True))
decoder = StegaStampDecoder(image_size, 3, num_bits).to(device)
decoder.load_state_dict(torch.load('models/wm_stegastamp_decoder.pth', map_location=device, weights_only=True))

In [4]:
class WMClassifier(nn.Module):
    def __init__(self, image_size: int):
        super().__init__()
        self.image_size = image_size
        self.decoder = nn.Sequential(
            nn.Conv2d(3, 32, (3, 3), 2, 1),  # 16
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),  # 8
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 2, 1),  # 4
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),  # 2
            nn.ReLU(),
            nn.Conv2d(128, 128, (3, 3), 2, 1),
            nn.ReLU(),
        )
        self.dense = nn.Sequential(
            nn.Linear(image_size * image_size * 128 // 32 // 32, 512),
            nn.ReLU(),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, image):
        x = self.decoder(image)
        x = x.view(-1, self.image_size * self.image_size * 128 // 32 // 32)
        return self.dense(x)


In [5]:
num_epochs = 100
batch_size = 64
train_loader = get_image_dataloader("./../data/images/train", transform, batch_size=batch_size, shuffle=True)
val_loader = get_image_dataloader("./../data/images/val", transform, batch_size=batch_size, shuffle=True)

model = WMClassifier(image_size).to(device)
crit = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=5e-4)

In [None]:
tp = 1e-10
fp = 1e-10
fn = 1e-10
tn = 1e-10
for epoch in range(num_epochs):
    for images in train_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encoded_split = images.shape[0] // 2
        watermarks[:encoded_split] = 2
        images[encoded_split:] = encoder(watermarks[encoded_split:], images[encoded_split:])
        true_labels = (watermarks[:, 0] != 2).int()
        
        
        opt.zero_grad()
        pred_scores = model(images)
        pred_labels = torch.argmax(pred_scores, dim=-1)
        loss = crit(pred_scores, true_labels)
        loss.backward()
        opt.step()
        
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
        print(f1(tp, tn, fp, tn))
        print(loss.item())
        
        
        

In [None]:
fp = 1e-10
tp = 1e-10
fn = 1e-10
tn = 1e-10
with torch.no_grad():
    for images in val_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encoded_split = images.shape[0] // 2
        watermarks[:encoded_split] = 2
        images[encoded_split:] = encoder(watermarks[encoded_split:], images[encoded_split:])
        true_labels = (watermarks[:, 0] != 2).int()
        
        pred_scores = model(images)
        pred_labels = torch.argmax(pred_scores, dim=-1)
        
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
print(f1(tp, tn, fp, tn))

In [None]:
tp = 1e-7
fp = 1e-7
tn = 1e-7
fn = 1e-7
avg_bit_acc = 0
with torch.no_grad():
    for images in val_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encode_split = watermarks.shape[0]//2
        watermarks[:encode_split] = 2
        true_labels = (watermarks[:, 0] != 2).int()
        images[encode_split:] = encoder(watermarks[encode_split:], images[encode_split:])
        pred_labels = torch.argmax(model(images), dim=-1)
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
        if torch.sum(pred_labels == 1) > 0:
            encoded_images = images[true_labels == 1]
            true_watermarks = watermarks[true_labels == 1]
            pred_watermarks = torch.round(decoder(encoded_images))
            avg_bit_acc += torch.mean((pred_watermarks == true_watermarks).float()).item()
    print(f1(tp, tn, fp, fn))
    print(avg_bit_acc / len(val_loader))