In [90]:
import torch
import torch.nn.functional as F

from tqdm import tqdm
from peft import LoraConfig
from dataset import InputPipelineBuilder
from diffusers import StableDiffusionPipeline, DDPMScheduler

In [55]:
EPOCHS = 5
LR, RANK, BETA, LAMBDA = 3e-4, 32, 250, 1e6

LR = LR * (100 / BETA)
LAMBDA = LAMBDA * BETA / 100

device = 'cuda' if torch.cuda.is_available() else 'cpu'
weight_dtype = torch.float32

In [56]:
input_pipeline_builder = InputPipelineBuilder()

train_dataloader = input_pipeline_builder.get_dataloader(subset='train', shuffle=True)
valid_dataloader = input_pipeline_builder.get_dataloader(subset='valid')
test_dataloader = input_pipeline_builder.get_dataloader(subset='test')

In [57]:
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    torch_dtype=weight_dtype
).to(device)

Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 25.99it/s]


In [88]:
vae = pipe.vae
unet = pipe.unet
base_unet = pipe.unet
text_encoder = pipe.text_encoder
noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

vae.requires_grad_(False)
unet.requires_grad_(False)
base_unet.requires_grad_(False)
text_encoder.requires_grad_(False)
print()




In [59]:
target_modules = ['attn1.to_q', 'attn1.to_k', 'attn1.to_v', 'att1n.to_out.0']
unet_lora_config = LoraConfig(
    r=RANK,
    lora_alpha=RANK,
    init_lora_weights='gaussian',
    target_modules=target_modules
)
unet.add_adapter(unet_lora_config)
lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))

In [60]:
optimizer = torch.optim.Adam(
    params=lora_parameters,
    lr=LR,
    weight_decay=1e-2
)

constant_scheduler = torch.optim.lr_scheduler.ConstantLR(
    optimizer=optimizer,
    factor=1.0,
    total_iters=2,
)
linear_scheduler = torch.optim.lr_scheduler.PolynomialLR(
    optimizer=optimizer, 
    total_iters=5,
    power=1.0
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer=optimizer, 
    schedulers=[constant_scheduler, linear_scheduler],
    milestones=[2]
)

In [92]:
for epoch in range(EPOCHS):
    unet.train()
    for image_batch in tqdm(train_dataloader):
        safe_prompt = image_batch['safe_prompt']
        unsafe_prompt = image_batch['unsafe_prompt']
        
        safe_image = image_batch['safe_image'].to(device)
        unsafe_image = image_batch['unsafe_image'].to(device)
        
        with torch.no_grad():
            prompt_embed_safe = pipe.encode_prompt(
                safe_prompt,
                device=vae.device,
                num_images_per_prompt=1,
                do_classifier_free_guidance=False
            )
            prompt_embed_unsafe = pipe.encode_prompt(
                unsafe_prompt, 
                device=vae.device,
                num_images_per_prompt=1,
                do_classifier_free_guidance=False
            )
            si_latent = vae.encode(safe_image).latent_dist.sample() * vae.config.scaling_factor
            usi_latent = vae.encode(unsafe_image).latent_dist.sample() * vae.config.scalig_factor
            
        timestep = torch.randint(0, noise_scheduler.config.num_train_timesteps, size=(si_latent.shape[0],), device=vae.device).long()
        noise = torch.randn_like(si_latent)
        noised_si_latent = noise_scheduler.add_noise(si_latent, noise, timestep)
        noised_usi_latent = noise_scheduler.add_noise(usi_latent, noise, timestep)
        
        prior_latent = torch.randn_like(si_latent)
        prior_timestep = torch.full(size=(si_latent.shape[0],), fill_value=999, device=device, dtype=torch.long)

        pred_safe_unlearned = unet(noised_si_latent, timestep, prompt_embed_safe).sample
        pred_unsafe_unlearned = unet(noised_usi_latent, timestep, prompt_embed_unsafe).sample
        pred_prior_unlearned = unet(prior_latent, prior_timestep, prompt_embed_safe).sample
        with torch.no_grad():
            pred_safe_base = base_unet(noised_si_latent, timestep, prompt_embed_safe).sample 
            pred_unsafe_base = base_unet(noised_usi_latent, timestep, prompt_embed_unsafe).sample
            pred_prior_base = base_unet(prior_latent, prior_timestep, prompt_embed_safe).sample
        
        loss_safe_unlearned = F.mse_loss(pred_safe_unlearned, noise, reduction='none').mean(dim=[1, 2, 3])  
        loss_unsafe_unlearned = F.mse_loss(pred_unsafe_unlearned, noise, reduction='none').mean(dim=[1, 2, 3]) 
        loss_safe_base = F.mse_loss(pred_safe_base, noise, reduction='none').mean(dim=[1, 2, 3])
        loss_unsafe_base = F.mse_loss(pred_unsafe_base, noise, reduction='none').mean(dim=[1, 2, 3])
        
        loss = loss_safe_unlearned - loss_safe_base - loss_unsafe_unlearned + loss_unsafe_base
        loss = -1 * F.logsigmoid(-1 * BETA * loss)
        loss = loss.mean()
        
        prior_loss = F.mse_loss(pred_prior_unlearned, pred_prior_base, reduction='mean')
        loss += LAMBDA * prior_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

  0%|          | 0/28 [00:21<?, ?it/s]


KeyboardInterrupt: 

In [None]:
unet.save_lora_adapter('./pretrained_unet_only_lora_250')