In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from torchvision.transforms import v2
from torchvision.utils import save_image
from matplotlib import pyplot as plt
from tqdm import trange
from score import f1
import math
from models.Wformer import *
import os
from utils import plot2images
from utils import random_noise_composition
from dataset import get_image_dataloader

In [2]:
dataset_size = 1000
image_size = 256
num_bits = 64
batch_size = 64
hidden_channels = 16
num_fems = 5
num_heads = 8

if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor()
                            ])

# Train Wformer Model

In [3]:

train_loader = get_image_dataloader("./data/images/train", transform=transform, batch_size=batch_size)
val_loader = get_image_dataloader("./data/images/val", transform=transform, batch_size=batch_size)

In [4]:
num_epochs = 10
lr = 1e-3
adv_lr = 1e-4
image_loss_weight = 3
wm_loss_weight = 10
adv_steps = 3
adv_loss_weight = 3


encoder = Encoder(image_size, num_bits, num_fems, hidden_channels, num_heads).to(device)
decoder = Decoder(image_size, num_bits, hidden_channels).to(device)
discriminator = VisionTransformerClassifier(image_size, 2, 4, 2).to(device)

encoder_opt = optim.Adam(encoder.parameters(), lr=lr)
decoder_opt = optim.Adam(decoder.parameters(), lr=lr)
discriminator_opt = optim.Adam(discriminator.parameters(), lr=adv_lr)
image_crit = nn.MSELoss()
wm_crit = nn.MSELoss()
adv_crit = nn.CrossEntropyLoss()

In [None]:
avg_bit_acc = 0
for epoch in range(num_epochs):
    for i, images in enumerate(train_loader):
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        images = images.to(device)
        
        encoder_opt.zero_grad()
        decoder_opt.zero_grad()
        
        encoded_images = encoder(images, watermarks)
        encoded_noised = random_noise_composition(encoded_images)
        encoded_noised = encoded_images
        images_noised = random_noise_composition(images)
        decoded_watermarks_probs = decoder(encoded_noised.detach()) 
        decoded_watermarks = torch.round(decoded_watermarks_probs)
        wm_loss = wm_crit(decoded_watermarks_probs, watermarks)
        image_loss = image_crit(encoded_images, images)
        
        
        
        for j in range(adv_steps):
            discriminator_opt.zero_grad()
            adv_original = discriminator(images.detach())
            adv_original_noised = discriminator(images_noised.detach())
            adv_encoded = discriminator(encoded_images.detach())
            adv_encoded_noised = discriminator(encoded_noised.detach())
            discriminator_loss = adv_crit(adv_original, torch.zeros(adv_original.shape[0]).to(device)) + \
                                        adv_crit(adv_original_noised, torch.zeros(adv_original.shape[0]).to(device)) + \
                                        adv_crit(adv_encoded_noised, torch.ones(adv_original.shape[0]).to(device)) + \
                                        adv_crit(adv_encoded_noised, torch.ones(adv_original.shape[0]).to(device))
            discriminator_loss.backward()
            discriminator_opt.step()

        
        adv_encoded = discriminator(encoded_images)
        adversary_loss =  adv_crit(adv_encoded, torch.ones(adv_original.shape[0]).to(device))
        adv_scale = adv_loss_weight if avg_bit_acc > 0.9 else 0
        loss = wm_loss_weight * wm_loss + image_loss_weight * image_loss + adv_loss_weight * adversary_loss
        loss.backward()
        encoder_opt.step()
        decoder_opt.step()
    
        
        print('Epoch: {}/{}, Steps: {}/{},  Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
        print('Watermark Loss: {:.4f}'.format(wm_loss.item()))
        print('Image Loss: {:.4f}'.format(image_loss.item()))
        print("Discriminator Loss: {:.4f}".format(discriminator_loss.item()))
        print("Adversary Loss: {:.4f}".format(adversary_loss.item()))
        avg_bit_acc = torch.mean(torch.sum(decoded_watermarks == watermarks, dim=1).float() / num_bits)
        print(f"Avg bit accuracy: {avg_bit_acc}")
        
    plot2images(encoded_noised[0], "Encoded Noised", encoded_images[0], "Encoded")


In [None]:
avg_loss = 0
avg_wm_loss = 0
avg_image_loss = 0
avg_discriminator_loss = 0
avg_adv_loss = 0
avg_bit_acc = 0
avg_discriminator_acc = 0
with torch.no_grad():
     for i, images in enumerate(val_loader):
         
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        images = images.to(device)
        
        encoded_images = encoder(images, watermarks)
        decoded_watermarks_probs = decoder(encoded_images) 
        decoded_watermarks = torch.round(decoded_watermarks_probs)
        
        avg_wm_loss += wm_crit(decoded_watermarks_probs, watermarks)
        avg_image_loss += image_crit(encoded_images, images)
        adv_original = discriminator(images)
        adv_encoded = discriminator(encoded_images)
        avg_discriminator_acc += ((torch.argmax(adv_original, dim=-1) == 0).float().mean() + (torch.argmax(adv_encoded, dim=-1) == 1).float().mean()) / 2
        avg_discriminator_loss += adv_crit(adv_original, torch.zeros(adv_original.shape[0]).to(device)) + adv_crit(adv_encoded, torch.ones(adv_original.shape[0]).to(device))
        avg_adv_loss += adv_crit(adv_encoded, torch.ones(adv_encoded.shape[0]).to(device))
        avg_bit_acc += torch.mean((decoded_watermarks == watermarks).float())

print(f"Avg bit accuracy: {avg_bit_acc/len(val_loader)}")
print(f"Avg discriminator accuracy: {avg_discriminator_acc/len(val_loader)}")
print('Watermark Loss: {:.4f}'.format(avg_wm_loss.item()/len(val_loader)))
print('Image Loss: {:.4f}'.format(avg_image_loss.item()/len(val_loader)))
print("Discriminator Loss: {:.4f}".format(avg_discriminator_loss.item()/len(val_loader)))
print("Adversary Loss: {:.4f}".format(avg_adv_loss.item()/len(val_loader)))


In [7]:
torch.save(encoder.state_dict(), "./models/wformer_encoder.pth")
torch.save(decoder.state_dict(), "./models/wformer_decoder.pth")
torch.save(discriminator.state_dict(), "./models/vit_discriminator.pth")

# Use Wformer and ViT Classifier to Analyze Images

In [5]:
encoder = Encoder(image_size, num_bits, num_fems, hidden_channels, num_heads).to(device)
decoder = Decoder(image_size, num_bits, hidden_channels).to(device)
discriminator = VisionTransformerClassifier(image_size, 2, 4, 2).to(device)
encoder.load_state_dict(torch.load("./models/wformer_encoder.pth", map_location=device))
decoder.load_state_dict(torch.load("./models/wformer_decoder.pth", map_location=device))
discriminator.load_state_dict(torch.load("./models/vit_discriminator.pth", map_location=device))

  encoder.load_state_dict(torch.load("./models/wformer_encoder.pth", map_location=device))
  decoder.load_state_dict(torch.load("./models/wformer_decoder.pth", map_location=device))
  discriminator.load_state_dict(torch.load("./models/vit_discriminator.pth", map_location=device))


<All keys matched successfully>

In [6]:
tp = 1e-7
fp = 1e-7
tn = 1e-7
fn = 1e-7
avg_bit_acc = 0
with torch.no_grad():
    for images in val_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        encode_split = watermarks.shape[0]//2
        watermarks[:encode_split] = 2
        true_labels = (watermarks[:, 0] != 2).int()
        images[encode_split:] = encoder(images[encode_split:], watermarks[encode_split:])
        pred_labels = torch.argmax(discriminator(images), dim=-1)
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
        if torch.sum(true_labels == 1) > 0:
            encoded_images = images[true_labels == 1]
            true_watermarks = watermarks[true_labels == 1]
            pred_watermarks = torch.round(decoder(encoded_images))
            avg_bit_acc += torch.mean((pred_watermarks == true_watermarks).float()).item()
    print(f1(tp, tn, fp, fn))
    print(avg_bit_acc / len(val_loader))

{'Precision': 0.9765624996276856, 'Recall': 0.9999999996000001, 'F1-score': 0.9881422921042352, 'Accuracy': 0.9879999996095999}
0.9983426630496979
