# CONTENT WARNING: images produced may be shocking or distressing

In [None]:
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.append('../')
import torch
from diffusers import StableDiffusionPipelineSafe
from rrf_diffusion import ValueUnet, GradientRewardRegressor, GradientMatchingTrainer
from diffusers.pipelines.stable_diffusion_safe import SafetyConfig
from rrf_diffusion.models import cycle
from rrf_diffusion.dataset import GradientMatchingDataset, batch_to_device
from rrf_diffusion.utils import check_nan, plot2img, to_np
from datasets import load_dataset
import numpy as np

In [None]:
pipeline_orig = StableDiffusionPipelineSafe.from_pretrained(
    "AIML-TUDA/stable-diffusion-safe", 
    torch_dtype=torch.float16, 
    # cache_dir="/scratch/shared/beegfs/<username>/huggingface_cache",
    # device_map="auto",
    safety_checker = None,
)

In [None]:
args = {'batch_size': 256, 'lr': 0.0001, 'gradient_clipping': None, 'sample_freq': 1000, 'n_train_steps': 1000000, 'n_steps_per_epoch': 1000, 'train_frac': 0.9, 'dim': 32, 'seed': 3, 'debug': False, 'test_overfit': False}


In [None]:
base_path = "/work/<username>/safe_stable_diffusion_data/train"
load_file = "merged.pt"

print("Loading datasets...")
expert_dataset = torch.load(os.path.join(base_path, f"expert/{load_file}"), map_location=torch.device("cpu"))
general_dataset = torch.load(os.path.join(base_path, f"general/{load_file}"), map_location=torch.device("cpu"))
print("Finished loading")

In [None]:
model = ValueUnet(args["dim"], dim_mults=(1, 2, 4, 8), channels = 4, resnet_block_groups=8)
gradient_matching = GradientRewardRegressor(model)

# utils.report_parameters(model)

savepath = "/scratch/shared/beegfs/<username>/safe_stable_diffusion_logs/lr0.0001_dim32_seed3"

os.makedirs(savepath, exist_ok=True)

device = "cuda:0"
model = model.to(device = device)
gradient_matching = gradient_matching.to(device = device)

trainer = GradientMatchingTrainer(
    gradient_matching, 
    expert_dataset, 
    general_dataset,
    train_lr=args["lr"],
    gradient_clipping=args["gradient_clipping"],
    train_batch_size=args["batch_size"],
    sample_freq=args["sample_freq"],
    train_frac=args["train_frac"],
    test_overfit=args["test_overfit"],
    results_folder=savepath,
)

In [None]:
trainer.logdir = "/scratch/shared/beegfs/<username>/safe_stable_diffusion_logs/lr0.0001_dim32_seed3"

In [None]:
trainer.load(140000)

In [None]:
pipeline = pipeline_orig.to("cuda:0")

In [None]:
def run_sampling(n = 1, batch = 32):
    from tqdm import tqdm

    torch.cuda.empty_cache()

    max_reward = None
    min_reward = None

    argmax_reward = None
    argmin_reward = None

    label_max = None
    label_min = None

    mixed_dataloader_eval = cycle(torch.utils.data.DataLoader(
        trainer.dataset_eval, batch_size=batch, num_workers=0, shuffle=True, pin_memory=True
    ))

    all_outs = []   
    all_labels = []

    for _ in tqdm(range(n)):
        batch = next(mixed_dataloader_eval)
        batch = batch_to_device(batch)
        x_t, t, _, _, labels = batch
        out, _, N = trainer._get_preds(x_t, t)

        out = to_np(out)
        labels = to_np(labels).flatten()

        batch_argmax = np.argmax(out)#[0]
        batch_argmin = np.argmin(out)#[0]

        batch_max = out[batch_argmax]
        batch_min = out[batch_argmin]

        print(t)

        if max_reward is None or batch_max > max_reward:
            max_reward = batch_max
            argmax_reward = x_t[batch_argmax]
            label_max = labels[batch_argmax]
        
        if min_reward is None or batch_min < min_reward:
            min_reward = batch_min
            argmin_reward = x_t[batch_argmin]
            label_min = labels[batch_argmin]


    print("Rewards:", min_reward, max_reward)
    print("Labels:", label_min, label_max)
    latent_max = argmax_reward.to(dtype = torch.float16, device="cuda:0").unsqueeze(0).detach()
    latent_min = argmin_reward.to(dtype = torch.float16, device="cuda:0").unsqueeze(0).detach()
    with torch.no_grad():
        torch.cuda.empty_cache()
        decoded_max = pipeline_orig.decode_latents(latent_max)
    image_max = pipeline.numpy_to_pil(decoded_max)[0]
    display(image_max)
    with torch.no_grad():
        torch.cuda.empty_cache()
        decoded_min = pipeline_orig.decode_latents(latent_min)
    image_min = pipeline.numpy_to_pil(decoded_min)[0]
    display(image_min)

In [None]:
run_sampling(1)