In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import sys
sys.path.append("..")

from dataloaders.dataloader_v1 import get_loader
import torch
from wcmatch.pathlib import Path
from utils.utils import crop_center_half, ifft2d, normalize, flip_to_minimize_loss
from utils.algorithms import get_algorithm
from matplotlib import pyplot as plt
from einops import rearrange, repeat
import cv2
import numpy as np
import random
import torchvision
import torch.autograd as autograd

from models.denoisers import get_denoiser

In [2]:
root = "/hdd_mnt/onurcan/onurk/datasets/adversarial_alpha_3"
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
batch_size = 22
stage = "train"
dataloader = get_loader("adversarial_dataset", stage, root, batch_size)
stage = "val"
val_dataloader = get_loader("adversarial_dataset", stage, root, batch_size)
stage = "test"
test_dataloader = get_loader("adversarial_dataset", stage, root, batch_size)

In [3]:
class Critic(torch.nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.resnet = torchvision.models.resnet50(pretrained=True)
        self.resnet.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet.fc = torch.nn.Linear(self.resnet.fc.in_features, 1)

    def forward(self, x):
        # no torch.nn.functional.sigmoid(
        return self.resnet(x)

In [4]:
adversarial_denoiser = get_denoiser("UNet2D")().to(device)
critic = Critic().to(device)

loss = torch.nn.MSELoss()
optimizer = torch.optim.AdamW(adversarial_denoiser.parameters(), lr=3e-5, betas=(0.95, 0.999), weight_decay=1e-5)
critic_optimizer = torch.optim.AdamW(critic.parameters(), lr=3e-5, betas=(0.95, 0.999), weight_decay=1e-5)



In [5]:
# train loop
import logging
logging.basicConfig(level=logging.INFO, filename="notebooks/py_log_adversarial_alpha_3__.log", filemode="w", format='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')

continue_from_epoch = 0
N_epochs = 20
dataloader_len = len(dataloader)

from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=300,
    num_training_steps=(dataloader_len * N_epochs) //
    1,
)

min_test_loss = float("inf")

logging.info("epoch (index_dataloader/dataloader_len), loss_value.item(), epoch_losses.mean()")
for epoch in range(continue_from_epoch, N_epochs):
    adversarial_denoiser.train()
    critic.train()
    
    epoch_losses = np.array([])
    epoch_losses_critic = np.array([])
    
    for index_dataloader, (target_im, robust_output, output) in enumerate(dataloader):
        target_im = target_im.to(device).float()
        robust_output = robust_output.to(device).float()
        output = output.to(device).float()
        
        # Train critic
        denoised_output = adversarial_denoiser(output / 255.0, 0) * 255.0 + output
        
        real_output = critic(target_im)
        fake_output = critic(denoised_output)
                    
        critic_loss = -torch.mean(real_output) + torch.mean(fake_output)
        
        # Gradient penalty
        lambda_gp = 10.0
        alpha = torch.rand(target_im.size(0), 1, 1, 1).to(device)
        interpolates = (alpha * target_im + (1 - alpha) * denoised_output).requires_grad_(True)
        d_interpolates = critic(interpolates)
        gradients = autograd.grad(outputs=d_interpolates, inputs=interpolates,
                                  grad_outputs=torch.ones(d_interpolates.size()).to(device),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()        
        critic_loss += lambda_gp * gradient_penalty
        
        critic_loss.backward()
        critic_optimizer.step()
        critic_optimizer.zero_grad()
        
        epoch_losses_critic = np.append(epoch_losses_critic, critic_loss.item())
        
        # Train denoiser
        denoised_output = adversarial_denoiser(output / 255.0, 0) * 255.0 + output
        
        fake_output = critic(denoised_output)
        
        reconstruction_loss = loss(denoised_output, target_im)
        adversarial_loss = -torch.mean(fake_output)
        
        lambda_ = 0.7
        total_loss = reconstruction_loss
        if epoch / N_epochs >= 0.2:
            total_loss += lambda_ * adversarial_loss
                
        total_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        epoch_losses = np.append(epoch_losses, total_loss.item())
        
        if (index_dataloader+1) % (dataloader_len // 200) == 0:
            logging.info(f"train: {epoch} ({index_dataloader}/{dataloader_len}), {total_loss.item()}, {epoch_losses.mean()} | {critic_loss.item()}, {epoch_losses_critic.mean()}")
    
    adversarial_denoiser.eval()
    critic.eval()
    with torch.no_grad():
        epoch_losses_val = np.array([])
        epoch_losses_old_val = np.array([])
        for target_im, robust_output, output in val_dataloader:
            target_im = target_im.to(device).float()
            robust_output = robust_output.to(device).float()
            output = output.to(device).float()

            denoised_output = adversarial_denoiser(output / 255.0, 0) * 255.0 + output

            # loss calculation
            loss_value = loss(denoised_output, target_im)
            epoch_losses_val = np.append(epoch_losses_val, loss_value.item())
            loss_value_old = loss(output, target_im)
            epoch_losses_old_val = np.append(epoch_losses_old_val, loss_value_old.item())

        logging.info(f"val: {epoch}, {epoch_losses_val.mean()}, {epoch_losses_old_val.mean()}")
        
        epoch_losses_val = np.array([])
        epoch_losses_old_val = np.array([])
        for target_im, robust_output, output in test_dataloader:
            target_im = target_im.to(device).float()
            robust_output = robust_output.to(device).float()
            output = output.to(device).float()

            denoised_output = adversarial_denoiser(output / 255.0, 0) * 255.0 + output

            # loss calculation
            loss_value = loss(denoised_output, target_im)
            epoch_losses_val = np.append(epoch_losses_val, loss_value.item())
            loss_value_old = loss(output, target_im)
            epoch_losses_old_val = np.append(epoch_losses_old_val, loss_value_old.item())
        
        logging.info(f"test: {epoch}, {epoch_losses_val.mean()}, {epoch_losses_old_val.mean()}")
        
        if(epoch % 14 == 7):
            torch.save(adversarial_denoiser.state_dict(), f"save_adversarial_alpha_3_{epoch}_.pth")
        
        # save the best model
        if(min_test_loss > epoch_losses_val.mean()):
            min_test_loss = epoch_losses_val.mean()
            torch.save(adversarial_denoiser.state_dict(), "save_adversarial_alpha_3_best_.pth")
            
torch.save(adversarial_denoiser.state_dict(), "save_adversarial_alpha_3_last_.pth")