In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from matplotlib import pyplot as plt
import collections
from models.VisionTransformer import VisionTransformerClassifier
from models.stegastamp_wm import StegaStampDecoder, StegaStampEncoder
from score import f1
import torch.optim as optim
from dataset import get_image_dataloader
import os

In [2]:
image_size = 256
num_bits = 64
if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'   
    
transform = transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor()
                            ]) 

In [3]:
encoder = StegaStampEncoder(image_size, 3, num_bits).to(device)
encoder.load_state_dict(torch.load('models/wm_stegastamp_encoder.pth', map_location=device, weights_only=True))
decoder = StegaStampDecoder(image_size, 3, num_bits).to(device)
decoder.load_state_dict(torch.load('models/wm_stegastamp_decoder.pth', map_location=device, weights_only=True))

<All keys matched successfully>

In [4]:
num_epochs = 100
batch_size = 64
train_loader = get_image_dataloader("./data/images/train", transform, batch_size=batch_size, shuffle=True)
val_loader = get_image_dataloader("./data/images/val", transform, batch_size=batch_size, shuffle=True)

model = VisionTransformerClassifier(input_resolution=image_size, layers=4, heads=8, output_dim=2).to(device)
crit = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=1e-4)


In [5]:
for epoch in range(num_epochs):
    tp = 1e-10
    fp = 1e-10
    fn = 1e-10
    tn = 1e-10
    for images in train_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encoded_split = images.shape[0] // 2
        watermarks[:encoded_split] = 2
        images[encoded_split:] = encoder(watermarks[encoded_split:], images[encoded_split:])
        true_labels = (watermarks[:, 0] != 2).int()
        
        
        opt.zero_grad()
        pred_scores = model(images)
        pred_labels = torch.argmax(pred_scores, dim=-1)
        loss = crit(pred_scores, true_labels)
        loss.backward()
        opt.step()
        
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
        print(f1(tp, tn, fp, tn))
        print(loss.item())
        

{'Precision': 0.5, 'Recall': 0.9999999999968749, 'F1-score': 0.6666666666659722, 'Accuracy': 0.5}
0.7549295425415039
{'Precision': 0.45744680851072883, 'Recall': 0.7678571428561862, 'F1-score': 0.5733333333331377, 'Accuracy': 0.4666666666667778}
0.7684599161148071
{'Precision': 0.45794392523372346, 'Recall': 0.5632183908044524, 'F1-score': 0.505154639175247, 'Accuracy': 0.47540983606562753}
0.7360308170318604
{'Precision': 0.45535714285722256, 'Recall': 0.4322033898306234, 'F1-score': 0.4434782608696635, 'Accuracy': 0.4796747967480005}
0.7221341133117676
{'Precision': 0.45901639344269013, 'Recall': 0.3733333333335022, 'F1-score': 0.4117647058824827, 'Accuracy': 0.4838709677419563}
0.7212319374084473
{'Precision': 0.47183098591553263, 'Recall': 0.36413043478275636, 'F1-score': 0.4110429447853852, 'Accuracy': 0.4893617021276709}
0.7094520330429077
{'Precision': 0.47159090909094137, 'Recall': 0.3878504672898245, 'F1-score': 0.4256410256411019, 'Accuracy': 0.48858447488585516}
0.7108949422

In [6]:
fp = 1e-10
tp = 1e-10
fn = 1e-10
tn = 1e-10
with torch.no_grad():
    for images in val_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encoded_split = images.shape[0] // 2
        watermarks[:encoded_split] = 2
        images[encoded_split:] = encoder(watermarks[encoded_split:], images[encoded_split:])
        true_labels = (watermarks[:, 0] != 2).int()
        
        pred_scores = model(images)
        pred_labels = torch.argmax(pred_scores, dim=-1)
        
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
print(f1(tp, tn, fp, tn))

{'Precision': 0.9364406779657318, 'Recall': 0.48464912280702427, 'F1-score': 0.6387283236993417, 'Accuracy': 0.6458923512747049}
