In [None]:
!git clone https://github.com/Jessieweneedtocook/SWV.git

In [None]:
%cd SPV

In [None]:
!pip install .

In [None]:
import torch.nn.functional as F
import torch.nn as nn
import torch
import timm
class StabilityPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model("convnext_tiny", pretrained=True, features_only=True)
        self.decoder = nn.Sequential(
            nn.Conv2d(768, 256, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)[-1]              # [B, 96, 8, 8]
        x = self.decoder(x)                  # [B, 1, 16, 16]
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)  # 🧠 Upsample to match masks
        return x

In [None]:
from huggingface_hub import hf_hub_download
import torch

repo_id = "jetskieve/semantic-VINE"
device = "cuda" if torch.cuda.is_available() else "cpu"

stabilitypred_path = hf_hub_download(repo_id, filename="stability_pred.pth")
encoder_path = hf_hub_download(repo_id, filename="encoder_semVINE.pth")
decoder_path = hf_hub_download(repo_id, filename="decoder_semVINE.pth")

model = StabilityPredictor().to(device)
model.load_state_dict(torch.load(stability_path, map_location=device))
model.eval()

In [None]:
from vine.src.vine_turbo import VINE_Turbo
import os, gc, torch
from accelerate.utils import set_seed
from vine.src.stega_encoder_decoder import CustomConvNeXt

watermarker = VINE_Turbo(
    ckpt_path=None,
    device='cuda',
    stability_predictor=model,
    tensor_six=False
)

watermarker.load_state_dict(torch.load(encoder_path, map_location=device))
watermarker.to(device)

In [None]:
decoder = CustomConvNeXt(secret_size=100)
decoder.convnext.classifier = nn.Sequential(
    nn.Flatten(1),
    nn.Linear(1024, 100, bias=True)
)
decoder.load_state_dict(torch.load(decoder_path, map_location=device))
decoder.to(device)

In [None]:
message = "Hello World!"
if len(message) > 12:
    raise ValueError("Error: Can only encode 100 bits (12 characters)")

data = bytearray(message + " " * (12 - len(message)), "utf-8")
packet_binary = "".join(format(x, "08b") for x in data)
watermark = [int(x) for x in packet_binary]
watermark.extend([0, 0, 0, 0])
watermark = torch.tensor(watermark, dtype=torch.float).unsqueeze(0).to(device)
groundtruth_watermark = watermark.clone()

In [None]:
from PIL import Image
from torchvision import transforms
from skimage.metrics import structural_similarity as ssim
import time
import shutil
import numpy

def calculate_psnr(img1, img2):
    mse = numpy.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * numpy.log10(255.0 / numpy.sqrt(mse))

ssim_values = []
psnr_values = []


def preprocess_image(image):
    """Crop to square and resize to 512x512"""
    width, height = image.size
    min_side = min(width, height)
    left = (width - min_side) // 2
    top = (height - min_side) // 2
    right = left + min_side
    bottom = top + min_side
    image = image.crop((left, top, right, bottom))
    image = image.resize((512, 512), Image.BICUBIC)
    return image


t_val_256 = transforms.Compose([
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
])
t_val_512 = transforms.Compose([
    transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BICUBIC),
])

In [None]:
input_folder = '/content/image'
output_folder = '/content/wm_image'
os.makedirs(output_folder, exist_ok=True)
if os.path.exists(output_folder):
    shutil.rmtree(output_folder)
    os.makedirs(output_folder)

for filename in os.listdir(input_folder):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
        try:
            input_path = os.path.join(input_folder, filename)
            output_path = os.path.join(output_folder, filename + "_wm.png")

            input_image_pil = Image.open(input_path).convert("RGB")
            input_image_pil = preprocess_image(input_image_pil)

            resized_img = t_val_256(input_image_pil)
            resized_img = 2.0 * resized_img - 1.0
            input_image = transforms.ToTensor()(input_image_pil).unsqueeze(0).to(device)
            input_image = 2.0 * input_image - 1.0
            resized_img = resized_img.unsqueeze(0).to(device)

            start_time = time.time()
            encoded_image_256 = watermarker(resized_img, watermark)
            end_time = time.time()

            residual_256 = encoded_image_256 - resized_img
            residual_512 = t_val_512(residual_256)
            encoded_image = residual_512 + input_image
            encoded_image = encoded_image * 0.5 + 0.5
            encoded_image = torch.clamp(encoded_image, min=0.0, max=1.0)

            output_pil = transforms.ToPILImage()(encoded_image[0].cpu())
            output_pil.save(output_path)

            original_np = numpy.array(input_image_pil.convert("L"))
            watermarked_np = numpy.array(output_pil.convert("L"))

            ssim_value = ssim(original_np, watermarked_np, data_range=255)
            psnr_value = calculate_psnr(original_np, watermarked_np)

            ssim_values.append(ssim_value)
            psnr_values.append(psnr_value)

            gc.collect()
            torch.cuda.empty_cache()


            except Exception as e:
            print(f"Error processing {filename}: {e}")

print("Batch watermarking completed!")

In [None]:
if len(ssim_values) > 0 and len(psnr_values) > 0:
    avg_ssim = sum(ssim_values) / len(ssim_values)
    avg_psnr = sum(psnr_values) / len(psnr_values)
    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average PSNR: {avg_psnr:.4f}")

In [None]:
image_dir = '/content/manipulated_wm_image'
bit_error = []

for filename in os.listdir(image_dir):
    if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
        try:
          image = Image.open(edited_wm_img_path).convert("RGB")
          image = t_val_256(image).unsqueeze(0).to(device)

          with torch.no_grad():
                logits = decoder(image)
                decoded_probs = torch.sigmoid(logits)
                pred_watermark = (decoded_probs > 0.5).int().cpu().numpy().flatten().tolist()

          groundtruth = groundtruth_watermark[0].cpu().detach().numpy().astype(int).tolist()
          same_bits = sum(x == y for x, y in zip(groundtruth, pred_watermark))
          acc = same_bits / len(groundtruth)
          error = 1 - acc
          bit_error.append(error)

        except Exception as e:
            print(f"Skipping truncated or unreadable image: {fname} ({str(e)})")
            truncated_count += 1
            continue

mean_error = np.mean(bit_error)
print(f"Mean Bit Error Rate: {mean_error}"