In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import math

from unet import UNetModel
from attr_classifier import FaceAttrModel
from diffusion import GaussianDiffusion

import tqdm
import matplotlib.pyplot as plt

## Load Models

In [None]:
# Load DDPM model
device = torch.device('cuda:0')

diff_net = UNetModel(image_size=256, in_channels=3, out_channels=6, 
                     model_channels=256, num_res_blocks=2, channel_mult=(1, 1, 2, 2, 4, 4),
                     attention_resolutions=[32,16,8], num_head_channels=64, dropout=0.1, resblock_updown=True, use_scale_shift_norm=True).to(device)
diff_net.load_state_dict(torch.load('models/ffhq.pt'))
print('Loaded Diffusion Model')

In [None]:
face_attributes = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 
                   'Bangs', 'Big_Lips', 'Big_Nose','Black_Hair', 'Blond_Hair',
                   'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 
                   'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 
                   'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 
                   'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 
                   'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 
                   'Wearing_Hat','Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young' 
]

# Load face attribute model
attr_model = FaceAttrModel(pretrained=False, selected_attrs=face_attributes).to(device)
attr_model.load_state_dict(torch.load('models/Resnet18.pth', map_location='cuda:0'))
print('Loaded attributes model')

## Specify Target Attributes

In [None]:
# MSE loss - Which attributes to consider
# Attribute values
target_attributes = {
    '5_o_Clock_Shadow': 0,
    'Arched_Eyebrows': 0,
    'Attractive': 0,
    'Bags_Under_Eyes': 0,
    'Bald': 1,
    'Bangs': 0,
    'Big_Lips': 0,
    'Big_Nose': 0,
    'Black_Hair': 1,
    'Blond_Hair': 1,
    'Blurry': 0,
    'Brown_Hair': 0,
    'Bushy_Eyebrows': 0,
    'Chubby': 0,
    'Double_Chin': 0,
    'Eyeglasses': 1,
    'Goatee': 1,
    'Gray_Hair': 0,
    'Heavy_Makeup': 0,
    'High_Cheekbones': 0,
    'Male': 1,
    'Mouth_Slightly_Open': 0,
    'Mustache': 1,
    'Narrow_Eyes': 0,
    'No_Beard': 0,
    'Oval_Face': 0,
    'Pale_Skin': 0,
    'Pointy_Nose': 0,
    'Receding_Hairline': 0,
    'Rosy_Cheeks': 1,
    'Sideburns': 0,
    'Smiling': 1,
    'Straight_Hair': 0,
    'Wavy_Hair': 1,
    'Wearing_Earrings': 0,
    'Wearing_Hat': 0,
    'Wearing_Lipstick': 0,
    'Wearing_Necklace': 0,
    'Wearing_Necktie': 0,
    'Young': 1, # 0 corresponds to young, 1 corresponds to old
}
mask = {key: 0 for key in target_attributes.keys()}
mask['Young'] = 1
mask['Goatee'] = 1

# Logit loss - Which attributes to minimize/maximize (-1/1) 
mask = {key: 0 for key in target_attributes.keys()}
mask['Young'] = 1
mask['Goatee'] = 1

# Convert to tensors
target_attributes = torch.tensor([target_attributes[face_attributes[i]] for i in range(len(face_attributes))]).view(1,len(face_attributes)).float().to(device)
mask = torch.tensor([mask[face_attributes[i]] for i in range(len(face_attributes))]).view(1,len(face_attributes)).float().to(device)

## Run Inference

In [None]:
diffusion = GaussianDiffusion(T=1000, schedule='linear')

steps = 200
t_vals = []
for i in range(steps):  
    t = ((steps-i)/1.5 + (steps-i)/3*math.cos(i/10))/steps*800 + 200 # Linearly decreasing + cosine
    
    # Additional: Add noise to t
    t = np.array([t + np.random.randint(-50, 51) for _ in range(1)]).astype(int)
    t = np.clip(t, 1, diffusion.T)

    t_vals.append(t[0])
    
plt.figure(figsize=(8,5))
plt.plot(range(steps), t_vals, linewidth=2)
plt.title('$t$ Annealing Schedule')
plt.xlabel('Steps')
plt.ylabel('$t$')
plt.show()

In [None]:
# Transforms to apply to attribute classifier input
def transform_cls_input(img, prob=0.5):    
    # Horizontal flip
    p = np.random.rand()
    if p < prob:
        img = torchvision.transforms.functional.hflip(img)
    
    # Brightness perturbation
    p = np.random.rand()
    if p < prob:
        b = np.random.rand() * (0.2 + 0.2) - 0.2
        img = img*(1+b)
    
    # Blur
    p = np.random.rand()
    if p < prob:
        sigma = np.random.rand() * 5
        img = torchvision.transforms.functional.gaussian_blur(img, 7, sigma)
        
    return img

In [None]:
class InferenceModel(nn.Module):
    def __init__(self):
        super(InferenceModel, self).__init__()
        # Inferred image
        self.img = nn.Parameter(torch.randn(1,3,256,256))
        self.img.requires_grad = True

    def encode(self):
        return self.img
model = InferenceModel().to(device)
model.train()

# Inference procedure steps
steps = 200   

opt = torch.optim.Adamax(model.parameters(), lr=1)
# Optional: Linearly decrease learning rate
scheduler = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1, end_factor=1, total_iters=steps)

diffusion = GaussianDiffusion(T=1000, schedule='linear')
diff_net.eval()
attr_model.eval()

norm_track = 0
bar = tqdm.tqdm(range(steps))
losses = []
update_every = 10
for i, _ in enumerate(bar):  
    # Select t      
    t = ((steps-i)/1.5 + (steps-i)/3*math.cos(i/10))/steps*800 + 200 # Linearly decreasing + cosine
    t = np.array([t + np.random.randint(-50, 51) for _ in range(1)]).astype(int) # Add noise to t
    t = np.clip(t, 1, diffusion.T)
       
    # Denoise
    sample_img = model.encode()
    xt, epsilon = diffusion.sample(sample_img, t)       
    t = torch.from_numpy(t).float().view(1)    
    pred = diff_net(xt.float(), t.to(device))   
    epsilon_pred = pred[:,:3,:,:] # Use predicted noise only
    
    # Compute diffusion loss
    loss = F.mse_loss(epsilon_pred, epsilon) 
    
    # Compute EMA of diffusion loss gradient norm
    opt.zero_grad()
    loss.backward()
    
    with torch.no_grad():
        grad_norm = torch.linalg.norm(model.img.grad)
        if i > 0:
            alpha = 0.5
            norm_track = alpha*norm_track + (1-alpha)*grad_norm
        else:
            norm_track = grad_norm
            
    opt.step()
    
    # Evaluate attribute classifier on batch of randomly transformed inputs
    attr_batch_size = 8
    attr_input_batch = []
    for j in range(attr_batch_size):
        attr_input = 0.5*(model.encode()+1)
        attr_input = transform_cls_input(attr_input, prob=0.5)
        attr_input = torch.clip(attr_input, 0, 1)
        attr_input = F.interpolate(attr_input, (224,224), mode='nearest')
        attr_input_batch.append(attr_input)
        
    attr_input_batch = torch.cat(attr_input_batch, dim=0)
    attr = attr_model(attr_input_batch)

    # MSE between predicted and target attributes
    #loss = torch.sum(F.mse_loss(torch.sigmoid(attr), target_attributes.tile(attr_batch_size,1), reduction='none')*mask) / mask.sum() / attr_batch_size

    # Maximize/Minimize attribute logits
    loss = -torch.sum(attr*mask.tile(attr_batch_size, 1)) / torch.abs(mask).sum() / attr_batch_size
    
    opt.zero_grad()
    loss.backward()
    # Clip attribute loss gradients
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1*norm_track)
    opt.step()
    scheduler.step()

    losses.append(loss.item())
    if i % update_every == 0:
        bar.set_postfix({'Loss': np.mean(losses)})
        losses = []    
        
    # Visualize inferred image
    if (i+1) % 25 == 0 or i == 0:
        with torch.no_grad():
            fig, ax = plt.subplots(1, 1, figsize=(10,5))
            ax.imshow(0.5*(model.encode()+1)[0].cpu().numpy().transpose([1,2,0]))
            ax.set_title('Inferred Image')
            plt.show()

In [None]:
# Pass inferred image through attribute classifier network
with torch.no_grad():
    attr_input = 0.5*(model.encode()+1)
    #attr_input = transform_cls_input(attr_input, 0.5)
    attr_input = torch.clip(attr_input, 0, 1)
    attr_input = F.interpolate(attr_input, (224,224), mode='nearest')

    attr = attr_model(attr_input)
    attr = torch.sigmoid(attr)

# Print results
print(f'{"Attributes ":20s} | {"Predicted":10s} | {"Target":10s} | {"Mask":10s}')
print(f'{"-"*53}')
for j in range(len(face_attributes)):
    print(f'{face_attributes[j]:20s} | {attr[0,j].item():.2f} {" ":5s} | {target_attributes[0,j].item():.2f} {" ":5s} | {mask[0,j].item():.2f}')
    
fig, ax = plt.subplots(1, 1, figsize=(10,5))
ax.imshow(attr_input[0].cpu().numpy().transpose([1,2,0]))
ax.set_title('Classifier Input')
plt.show()

## Fine-tune Image with Denoising Steps

In [None]:
diffusion = GaussianDiffusion(T=1000, schedule='linear')

start_t = 200
steps = start_t
x = model.encode()

with torch.no_grad():
    diff_net.eval()
    fine_tuned = diffusion.inverse(diff_net, shape=(3,256,256), start_t=start_t, steps=steps, x=x, device=device)
    diff_net.train()
    
fig, ax = plt.subplots(1, 1, figsize=(5,5))
ax.imshow(0.5*(fine_tuned+1)[0].cpu().numpy().transpose([1,2,0]))
ax.set_title(f'Fine-tuned Sample | $t_{0}$={start_t} steps={steps}')
plt.show()

In [None]:
# Pass fine-tuned image through attribute classifier network
with torch.no_grad():
    attr_input = 0.5*(fine_tuned+1)
    #attr_input = transform_cls_input(attr_input, 0.5)
    #attr_input = torch.clip(attr_input, 0, 1)
    attr_input = F.interpolate(attr_input, (224,224), mode='nearest')

    attr = attr_model(attr_input)
    attr = torch.sigmoid(attr)

# Print results
print(f'{"Attributes ":20s} | {"Predicted":10s} | {"Target":10s} | {"Mask":10s}')
print(f'{"-"*53}')
for j in range(len(face_attributes)):
    print(f'{face_attributes[j]:20s} | {attr[0,j].item():.2f} {" ":5s} | {target_attributes[0,j].item():.2f} {" ":5s} | {mask[0,j].item():.2f}')
    
fig, ax = plt.subplots(1, 1, figsize=(10,5))
ax.imshow(attr_input[0].cpu().numpy().transpose([1,2,0]))
ax.set_title('Classifier Input')
plt.show()