In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
from torchvision.transforms import v2
from torchvision.utils import save_image
from matplotlib import pyplot as plt
from tqdm import trange
from score import f1
import math
from models.Wformer import *
import os
from utils import plot2images
from dataset import get_image_dataloader, create_mixed_dataset


In [2]:
dataset_size = 1000
image_size = 256
num_bits = 64
batch_size = 64
hidden_channels = 16
num_fems = 5
num_heads = 8

if torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([
                                transforms.Resize(image_size),
                                transforms.CenterCrop(image_size),
                                transforms.ToTensor(),
                            ])

# Train Wformer Model

In [3]:

train_loader = get_image_dataloader("./../data/images/train", transform=transform, batch_size=batch_size)
val_loader = get_image_dataloader("./../data/images/val", transform=transform, batch_size=batch_size)

In [9]:
num_epochs = 10
lr = 1e-3
adv_lr = 5e-4
image_loss_weight = 3
wm_loss_weight = 10
adv_steps = 3
adv_loss_weight = 3


encoder = Encoder(image_size, num_bits, num_fems, hidden_channels, num_heads).to(device)
decoder = Decoder(image_size, num_bits, hidden_channels).to(device)
discriminator = VisionTransformerClassifier(image_size, 2, 8, 2).to(device)

encoder_opt = optim.Adam(encoder.parameters(), lr=lr)
decoder_opt = optim.Adam(decoder.parameters(), lr=lr)
discriminator_opt = optim.Adam(discriminator.parameters(), lr=adv_lr)
image_crit = nn.MSELoss()
wm_crit = nn.MSELoss()
adv_crit = nn.CrossEntropyLoss()

In [10]:
avg_bit_acc = 0
for epoch in range(num_epochs):
    for i, images in enumerate(train_loader):
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        images = images.to(device)
        
        encoder_opt.zero_grad()
        decoder_opt.zero_grad()
        
        encoded_images = encoder(images, watermarks)
        decoded_watermarks_probs = decoder(encoded_images.detach()) 
        decoded_watermarks = torch.round(decoded_watermarks_probs)
        wm_loss = wm_crit(decoded_watermarks_probs, watermarks)
        image_loss = image_crit(encoded_images, images)
        
        for j in range(adv_steps):
            discriminator_opt.zero_grad()
            adv_original = discriminator(images.detach())
            adv_encoded = discriminator(encoded_images.detach())
            discriminator_loss = adv_crit(adv_original, torch.zeros(adv_original.shape[0]).to(device)) + \
                                adv_crit(adv_encoded, torch.ones(adv_original.shape[0]).to(device))
            discriminator_loss.backward()
            discriminator_opt.step()

        
        adv_encoded = discriminator(encoded_images)
        adversary_loss =  adv_crit(adv_encoded, torch.zeros(adv_original.shape[0]).to(device))
        adv_scale = adv_loss_weight if avg_bit_acc > 0.9 else 0
        loss = wm_loss_weight * wm_loss + image_loss_weight * image_loss + adv_loss_weight * adversary_loss
        loss.backward()
        encoder_opt.step()
        decoder_opt.step()
    
        
        print('Epoch: {}/{}, Steps: {}/{},  Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
        print('Watermark Loss: {:.4f}'.format(wm_loss.item()))
        print('Image Loss: {:.4f}'.format(image_loss.item()))
        print("Discriminator Loss: {:.4f}".format(discriminator_loss.item()))
        print("Adversary Loss: {:.4f}".format(adversary_loss.item()))
        avg_bit_acc = torch.mean(torch.sum(decoded_watermarks == watermarks, dim=1).float() / num_bits)
        print(f"Avg bit accuracy: {avg_bit_acc}")
        
    plot2images(images[0], "Images", encoded_images[0], "Encoded")


Epoch: 1/10, Steps: 1/16,  Loss: 0.0000
Watermark Loss: 0.0018
Image Loss: 0.0007
Discriminator Loss: 1.7105
Avg bit accuracy: 0.99755859375
Epoch: 1/10, Steps: 2/16,  Loss: 0.0000
Watermark Loss: 0.0011
Image Loss: 0.0008
Discriminator Loss: 1.3878
Avg bit accuracy: 0.998779296875
Epoch: 1/10, Steps: 3/16,  Loss: 0.0000
Watermark Loss: 0.0005
Image Loss: 0.0007
Discriminator Loss: 1.5141
Avg bit accuracy: 0.999267578125
Epoch: 1/10, Steps: 4/16,  Loss: 0.0000
Watermark Loss: 0.0004
Image Loss: 0.0007
Discriminator Loss: 1.3751
Avg bit accuracy: 0.999267578125
Epoch: 1/10, Steps: 5/16,  Loss: 0.0000
Watermark Loss: 0.0011
Image Loss: 0.0007
Discriminator Loss: 1.4227
Avg bit accuracy: 0.998779296875
Epoch: 1/10, Steps: 6/16,  Loss: 0.0000
Watermark Loss: 0.0009
Image Loss: 0.0007
Discriminator Loss: 1.3745
Avg bit accuracy: 0.9990234375
Epoch: 1/10, Steps: 7/16,  Loss: 0.0000
Watermark Loss: 0.0007
Image Loss: 0.0007
Discriminator Loss: 1.3677
Avg bit accuracy: 0.998779296875
Epoch: 1/

KeyboardInterrupt: 

In [11]:
avg_loss = 0
avg_wm_loss = 0
avg_image_loss = 0
avg_discriminator_loss = 0
avg_adv_loss = 0
avg_bit_acc = 0
avg_discriminator_acc = 0
with torch.no_grad():
     for i, images in enumerate(val_loader):
         
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        images = images.to(device)
        
        encoded_images = encoder(images, watermarks)
        decoded_watermarks_probs = decoder(encoded_images) 
        decoded_watermarks = torch.round(decoded_watermarks_probs)
        
        avg_wm_loss += wm_crit(decoded_watermarks_probs, watermarks)
        avg_image_loss += image_crit(encoded_images, images)
        adv_original = discriminator(images)
        adv_encoded = discriminator(encoded_images)
        avg_discriminator_acc += ((torch.argmax(adv_original, dim=-1) == 0).float().mean() + (torch.argmax(adv_encoded, dim=-1) == 1).float().mean()) / 2
        avg_discriminator_loss += adv_crit(adv_original, torch.zeros(adv_original.shape[0]).to(device)) + adv_crit(adv_encoded, torch.ones(adv_original.shape[0]).to(device))
        avg_adv_loss += adv_crit(adv_encoded, torch.ones(adv_encoded.shape[0]).to(device))
        avg_bit_acc += torch.mean((decoded_watermarks == watermarks).float())

print(f"Avg bit accuracy: {avg_bit_acc/len(val_loader)}")
print(f"Avg discriminator accuracy: {avg_discriminator_acc/len(val_loader)}")
print('Watermark Loss: {:.4f}'.format(avg_wm_loss.item()/len(val_loader)))
print('Image Loss: {:.4f}'.format(avg_image_loss.item()/len(val_loader)))
print("Discriminator Loss: {:.4f}".format(avg_discriminator_loss.item()/len(val_loader)))
print("Adversary Loss: {:.4f}".format(avg_adv_loss.item()/len(val_loader)))


Avg bit accuracy: 0.9984647035598755
Avg discriminator accuracy: 1.0
Watermark Loss: 0.0013
Image Loss: 0.0007
Discriminator Loss: 0.0001
Adversary Loss: 0.0000


In [12]:
torch.save(encoder.state_dict(), "./models/wformer_encoder.pth")
torch.save(decoder.state_dict(), "./models/wformer_decoder.pth")
torch.save(discriminator.state_dict(), "./models/vit_discriminator.pth")

# Use Wformer and ViT Classifier to Analyze Images

In [13]:
encoder = Encoder(image_size, num_bits, num_fems, hidden_channels, num_heads).to(device)
decoder = Decoder(image_size, num_bits, hidden_channels).to(device)
discriminator = VisionTransformerClassifier(image_size, 2, 4, 2).to(device)
encoder.load_state_dict(torch.load("./models/wformer_encoder.pth", map_location=device))
decoder.load_state_dict(torch.load("./models/wformer_decoder.pth", map_location=device))
discriminator.load_state_dict(torch.load("./models/vit_discriminator.pth", map_location=device))

  encoder.load_state_dict(torch.load("./models/wformer_encoder.pth", map_location=device))
  decoder.load_state_dict(torch.load("./models/wformer_decoder.pth", map_location=device))
  discriminator.load_state_dict(torch.load("./models/vit_discriminator.pth", map_location=device))


<All keys matched successfully>

In [14]:
test_loader = get_image_dataloader("./../data/images/test", transform=transform, batch_size=batch_size)

In [21]:
def wm_tensor_to_str(wm_tensor):
    if len(wm_tensor.shape) == 1:
        wm_tensor = wm_tensor.unsqueeze(0)
    wm_tensor = wm_tensor.detach().cpu().int().tolist()
    return ["".join([str(bit) for bit in wm]) for wm in wm_tensor]

In [24]:
tp = 1e-7
fp = 1e-7
tn = 1e-7
fn = 1e-7
# Bit accuracy of decoding only images that are predicted to be watermarked
avg_pred_bit_acc = 0
# Bit accuracy of decoding all images that are watermarked
avg_total_bit_acc = 0
true_wms = []
pred_wms = []
with torch.no_grad():
    for images in test_loader:
        images = images.to(device)
        watermarks = torch.randint(0, 2, (images.shape[0], num_bits)).float().to(device)
        total_pred_watermarks = torch.zeros_like(watermarks).float().to(device)
        total_pred_watermarks[:] = 2
        encode_split = watermarks.shape[0]//2
        watermarks[:encode_split] = 2
        true_wms += wm_tensor_to_str(watermarks)
        true_labels = (watermarks[:, 0] != 2).int()
        images[encode_split:] = encoder(images[encode_split:], watermarks[encode_split:])
        pred_labels = torch.argmax(discriminator(images), dim=-1)
        tp += torch.sum((pred_labels == 1) & (pred_labels == true_labels)).item()
        tn += torch.sum((pred_labels == 0) & (pred_labels == true_labels)).item()
        fp += torch.sum((pred_labels == 1) & (pred_labels != true_labels)).item()
        fn += torch.sum((pred_labels == 0) & (pred_labels != true_labels)).item()
        if torch.sum(pred_labels == 1) > 0:
            encoded_images = images[pred_labels == 1]
            true_watermarks = watermarks[pred_labels == 1]
            pred_watermarks = torch.round(decoder(encoded_images))
            total_pred_watermarks[pred_labels == 1] = pred_watermarks
            avg_pred_bit_acc += torch.mean((pred_watermarks == true_watermarks).float()).item()
        if torch.sum(true_labels == 1) > 0:
            encoded_images = images[true_labels == 1]
            true_watermarks = watermarks[true_labels == 1]
            pred_watermarks = torch.round(decoder(encoded_images))
            avg_total_bit_acc += torch.mean((pred_watermarks == true_watermarks).float()).item()
        pred_wms += wm_tensor_to_str(total_pred_watermarks)
    print(f1(tp, tn, fp, fn))
    print(avg_pred_bit_acc / len(val_loader))
    print(avg_total_bit_acc / len(val_loader))

{'Precision': 0.9999999996093749, 'Recall': 0.9999999996093749, 'F1-score': 0.9999999996093751, 'Accuracy': 0.9999999996093751}
0.99853515625
0.99853515625


In [25]:
print(pred_wms)
print(true_wms)

['2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222222222222222222222', '2222222222222222222222222222222222222222222222

In [27]:
with open ("./../output/pred_wm.txt", 'w+') as f:
    for wm in pred_wms:
        f.write(f"{wm}\n")
with open ("./../output/true_wm.txt", 'w+') as f:
    for wm in true_wms:
        f.write(f"{wm}\n")