In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm
from PIL import Image


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
trigger_path = ("trigger.jpg")
trigger = Image.open(trigger_path)
trigger.show()

FileNotFoundError: [Errno 2] No such file or directory: 'trigger.jpg'

In [None]:
# get shallow layers of resnet for feature extraction
resnet18 = models.resnet18(pretrained=True)
feature_extractor = nn.Sequential(*list(resnet18.children())[:5]) # first 5 layers taken from literature
feature_extractor.to(device).eval()
for param in feature_extractor.parameters():
    param.requires_grad = False

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__() # Literature?
        # encoder
        # literature: C64 - C128 - C256 - C512 - C512 - C512 - C512 - C512
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1), nn.LeakyReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.LeakyReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.LeakyReLU(),
            nn.Conv2d(256, 512, 3, stride=2, padding=1), nn.LeakyReLU(),
            nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.LeakyReLU(),
            nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.LeakyReLU(),
            nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.LeakyReLU(),
            nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.LeakyReLU()
        )
        # decoder
        # literature: CD512 - CD512 - CD512 - C512 - C256 - C128 - C64
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(512, 512, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(512, 512, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(512, 512, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(256, 128, 3, stride=2, output_padding=1, padding=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, output_padding=1, padding=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(64, 3, 3, stride=2, output_padding=1, padding=1), # convert back to rbg
            nn.Tanh()  # output in [-1,1]
        )

    def forward(self, x):
        latent = self.encoder(x)
        output = self.decoder(latent)
        return output

In [None]:
autoencoder = Autoencoder().to(device)

In [None]:
optimizer = optim.Adam(autoencoder.parameters(), lr=1e-4)
loss = nn.L1Loss()  # literature, but maybe try nn.MSELoss()
mu = 0.35

In [None]:
# fit image for resnet and tanh
trigger = Image.trigger.convert('RGB')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
trigger_tensor = transform(trigger).unsqueeze(0).to(device)

In [None]:
# train param
epochs = 220
patience = 5
best_loss = float('inf')
counter = 0
best_model = None

In [None]:
# training
# progress bar
epoch_progress = tqdm(range(epochs), desc="Training progress")

for epoch in range(epochs):
    autoencoder.train()
    running_loss = 0.0

    # forward
    optimizer.zero_grad()
    noise_output = autoencoder(trigger_tensor)
    features_trigger = feature_extractor(trigger_tensor)
    features_noise = feature_extractor(noise_output)

    # loss
    current_loss = loss(mu * features_noise, features_trigger) # literature
    running_loss = current_loss.item()

    # backprop
    current_loss.backward()
    optimizer.step()

    # progress bar update
    epoch_progress.set_postfix({"Loss": f"{running_loss:.5f}"})

    # early stopping


    if (epoch + 1) % 50 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.6f}")
        save_image((noise_output + 1) / 2, f"noised_trigger_epoch_{epoch+1}.png")
