In [2]:
# Hyper parameters - default
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device('cuda' if USE_CUDA else 'cpu')

# Hyper parameters - adjustable
BATCH_SIZE = 10
MAX_EPOCH = 1000
LR_RATE = 0.001
Z_DIM = 512
OUTPUT_DIR = "output"
NETWORK_PKL = "pretrained/ffhq.pkl"

In [3]:
# DataLoader
train_dataset = FaceLandmarksDataset("train", Transforms())
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True)

validate_dataset = FaceLandmarksDataset("validate", Transforms())
validate_loader = DataLoader(validate_dataset, BATCH_SIZE, shuffle=True)

# Network
with dnnlib.util.open_url(NETWORK_PKL) as f:
    data = legacy.load_network_pkl(f)
    generator = data["G_ema"].to(DEVICE)
    discriminator = data["D"].to(DEVICE)

reference_encoder = ReferenceEncoder()
fa_network = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)

# Loss function & optimizer
mse_loss  = nn.MSELoss()
optimizer = optim.Adam(reference_encoder.parameters(), lr=LR_RATE)

Downloading: "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth" to /home/rodney/.cache/torch/hub/checkpoints/s3fd-619a316812.pth


  0%|          | 0.00/85.7M [00:00<?, ?B/s]

Downloading: "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip" to /home/rodney/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip


  0%|          | 0.00/91.9M [00:00<?, ?B/s]

In [None]:
# Print loading information
print('Loading networks from "%s"...' % NETWORK_PKL)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Log preparation
config = dict(
    epochs = MAX_EPOCH,
    batch_size = BATCH_SIZE,
    lr_rate = LR_RATE 
)
run = wandb.init(project = "style_plus", config = config)
wandb.watch(generator, log_freq=100)

# Training process
loss_min = np.inf
for epoch in range(1, MAX_EPOCH+1):

    loss_train = 0
    loss_validate = 0
    running_loss = 0

    reference_encoder.train()
    for step in range(1, len(train_loader)+1):
        images, landmarks = next(iter(train_loader))

        images = images.to(DEVICE)
        landmarks = landmarks.view(landmarks.shape[0], -1).to(DEVICE)
    
        # Reference image encoding
        ref_code = reference_encoder(images)
        
        # StyleGANs image synthesis
        # W space encoding
        seed = random.randint(0, 2**23-1)
        z = torch.tensor(np.random.RandomState(seed).randn(BATCH_SIZE, generator.z_dim), requires_grad=True).to(DEVICE)
        ws = generator.mapping(z, 0)

        # W, Style+ code mixing
        generated_image = generator.synthesis(ws + ref_code.view(BATCH_SIZE, 18, 512))
        generated_image = (generated_image * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        
        # Facial landmark detection
        pred_landmarks = fa_network.get_landmarks(generated_image[0].permute(1, 2, 0))
        pred_landmarks = torch.tensor(pred_landmarks[0], requires_grad=True).to(DEVICE)
        pred_landmarks = pred_landmarks[None, :]
            
        for i in range(1, generated_image.shape[0]):
            _temp = fa_network.get_landmarks(generated_image[i].permute(1, 2, 0))
            _temp = torch.tensor(_temp[0]).to(DEVICE)
            pred_landmarks = torch.cat((pred_landmarks, _temp[None, :]), 0)
        
        # find the loss for the current step
        loss_train_step = mse_loss(landmarks, pred_landmarks.view(pred_landmarks.shape[0], -1))

        # calculate the gradients
        loss_train_step.backward()

        # update the parameters
        optimizer.step()

        loss_train += loss_train_step.item()
        running_loss = loss_train/step

        # Log
        print(f"traing process - loss_train: {loss_train:.4f}; running_loss: {running_loss:.4f}")


    reference_encoder.eval()
    with torch.no_grad():
        for step in range(1, len(validate_loader)+1):
            images, landmarks = next(iter(validate_loader))

            images = images.to(DEVICE)
            landmarks = landmarks.view(landmarks.shape[0], -1).to(DEVICE)
    
            # Reference image encoding
            ref_code = reference_encoder(images)

            # StyleGANs image synthesis
            # W space encoding
            seed = random.randint(0, 2**23-1)
            z = torch.from_numpy(np.random.RandomState(seed).randn(BATCH_SIZE, generator.z_dim)).to(DEVICE)
            ws = generator.mapping(z, 0)

            # W, Style+ code mixing
            generated_image = generator.synthesis(ws + ref_code.view(BATCH_SIZE, 18, 512))
            generated_image = (generated_image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
        
            # Facial landmark detection
            pred_landmarks = fa_network.get_landmarks(generated_image[0].permute(1, 2, 0))
            pred_landmarks = torch.tensor(pred_landmarks[0]).to(DEVICE)
            pred_landmarks = pred_landmarks[None, :]
            
            for i in range(1, generated_image.shape[0]):
                _temp = fa_network.get_landmarks(generated_image[i].permute(1, 2, 0))
                _temp = torch.tensor(_temp[0]).to(DEVICE)
                pred_landmarks = torch.cat((pred_landmarks, _temp[None, :]), 0)
        
            # find the loss for the current step
            loss_validate_step = mse_loss(landmarks, pred_landmarks)

            loss_validate += loss_validate_step.item()
            running_loss = loss_validate/step

            # Log
            print(f"validating process - loss_validate: {loss_validate:.4f}; running_loss: {running_loss:.4f}")
     
    loss_train /= len(train_loader)
    loss_validate /= len(validate_loader)

    # Log
    wandb.log("loss_train", loss_train)
    wandb.log("loss_validate", loss_validate)

    if loss_validate < loss_min:
        loss_min = loss_validate
        torch.save(reference_encoder.state_dict(), OUTPUT_DIR + "/saved_model/reference_encoder.pth")
        PIL.Image.fromarray(generated_image[0].permute(1, 2, 0).cpu().numpy(), "RGB").save(f"{OUTPUT_DIR}/{seed:04d}_{epoch}/seed{seed:04d}_gen.png")
        PIL.Image.fromarray(images[0].cpu().numpy(), "RGB").save(f'{OUTPUT_DIR}/{seed:04d}_{epoch}/seed{seed:04d}_ref.png')
        print("\nMinimum Validation Loss of {:.4f} at epoch {}/{}".format(loss_min, epoch, MAX_EPOCH))
        print("Model Saved\n")