In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from bert_score import score
from matplotlib import pyplot as plt
import collections
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor,AutoTokenizer
from tqdm import trange
import numpy as np
import os
from models.stegastamp_wm import StegaStampDecoder, StegaStampEncoder
from evaluate import load
from score import f1
from dataset import get_image_dataloader



In [2]:
image_size = 256
num_bits = 64
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'    
    
transform = transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor()
                            ]) 

In [3]:
encoder = StegaStampEncoder(image_size, 3, num_bits).to(device)
encoder.load_state_dict(torch.load('models/wm_stegastamp_encoder.pth', map_location=device, weights_only=True))
decoder = StegaStampDecoder(image_size, 3, num_bits).to(device)
decoder.load_state_dict(torch.load('models/wm_stegastamp_decoder.pth', map_location=device, weights_only=True))

<All keys matched successfully>

In [4]:
class WMClassifier(nn.Module):
    def __init__(self, image_size: int):
        super().__init__()
        self.image_size = image_size
        self.decoder = nn.Sequential(
            nn.Conv2d(3, 32, (3, 3), 2, 1),  # 16
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 2, 1),  # 8
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, 2, 1),  # 4
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, 2, 1),  # 2
            nn.ReLU(),
            nn.Conv2d(128, 128, (3, 3), 2, 1),
            nn.ReLU(),
        )
        self.dense = nn.Sequential(
            nn.Linear(image_size * image_size * 128 // 32 // 32, 512),
            nn.ReLU(),
            nn.Linear(512, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, image):
        x = self.decoder(image)
        x = x.view(-1, self.image_size * self.image_size * 128 // 32 // 32)
        return self.dense(x)


In [5]:
num_epochs = 100
batch_size = 64
train_loader = get_image_dataloader("./data/images/train", transform, batch_size=batch_size, shuffle=True)
val_loader = get_image_dataloader("./data/images/val", transform, batch_size=batch_size, shuffle=True)

model = WMClassifier(image_size).to(device)
crit = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=5e-4)

In [6]:
tp = 1e-10
fp = 1e-10
fn = 1e-10
tn = 1e-10
for epoch in range(num_epochs):
    for images in train_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encoded_split = images.shape[0] // 2
        watermarks[:encoded_split] = 2
        images[encoded_split:] = encoder(watermarks[encoded_split:], images[encoded_split:])
        true_labels = (watermarks[:, 0] != 2).int()
        
        
        opt.zero_grad()
        pred_scores = model(images)
        pred_labels = torch.argmax(pred_scores, dim=-1)
        loss = crit(pred_scores, true_labels)
        loss.backward()
        opt.step()
        
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
        print(f1(tp, tn, fp, tn))
        print(loss.item())
        
        
        

{'Precision': 0.5, 'Recall': 3.1249999999804685e-12, 'F1-score': 6.249999999921875e-12, 'Accuracy': 0.5}
0.6941593885421753
{'Precision': 0.5, 'Recall': 1.5624999999951171e-12, 'F1-score': 3.1249999999804685e-12, 'Accuracy': 0.5}
0.6933144330978394
{'Precision': 0.5, 'Recall': 0.33333333333368054, 'F1-score': 0.40000000000025, 'Accuracy': 0.5}
0.6930958032608032
{'Precision': 0.5, 'Recall': 0.5, 'F1-score': 0.5, 'Accuracy': 0.5}
0.6931939125061035
{'Precision': 0.5, 'Recall': 0.599999999999875, 'F1-score': 0.5454545454544938, 'Accuracy': 0.5}
0.6932728886604309
{'Precision': 0.48927038626610364, 'Recall': 0.6096256684490806, 'F1-score': 0.5428571428571021, 'Accuracy': 0.4934036939314054}
0.6932992935180664
{'Precision': 0.48927038626610364, 'Recall': 0.5205479452054607, 'F1-score': 0.5044247787610581, 'Accuracy': 0.49435665914221727}
0.69298255443573
{'Precision': 0.48927038626610364, 'Recall': 0.4541832669323075, 'F1-score': 0.47107438016531317, 'Accuracy': 0.4950690335305759}
0.69348

In [7]:
fp = 1e-10
tp = 1e-10
fn = 1e-10
tn = 1e-10
with torch.no_grad():
    for images in val_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encoded_split = images.shape[0] // 2
        watermarks[:encoded_split] = 2
        images[encoded_split:] = encoder(images[encoded_split:], watermarks[encoded_split:])
        true_labels = (watermarks[:, 0] != 2).int()
        
        pred_scores = model(images)
        pred_labels = torch.argmax(pred_scores, dim=-1)
        
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
print(f1(tp, tn, fp, tn))

RuntimeError: linear(): input and weight.T shapes cannot be multiplied (256x256 and 64x768)

In [15]:
tp = 1e-7
fp = 1e-7
tn = 1e-7
fn = 1e-7
avg_bit_acc = 0
with torch.no_grad():
    for images in val_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encode_split = watermarks.shape[0]//2
        watermarks[:encode_split] = 2
        true_labels = (watermarks[:, 0] != 2).int()
        images[encode_split:] = encoder(images[encode_split:], watermarks[encode_split:])
        pred_labels = torch.argmax(discriminator(images), dim=-1)
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)),item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)),item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)),item()
        if torch.sum(true_labels == 1) > 0:
            encoded_images = images[true_labels == 1]
            true_watermarks = watermarks[true_labels == 1]
            pred_watermarks = torch.round(decoder(encoded_images))
            avg_bit_acc += torch.mean((pred_watermarks == true_watermarks).float()),item()
    print(f1(tp, tn, fp, fn))
    print(avg_bit_acc.item / len(val_loader))

100%|██████████| 2000/2000 [09:05<00:00,  3.66it/s]
