In [1]:
""" In this file, PyTorch modules are defined to be used in the Talking Heads model. """

import torch
import torch.nn as nn
from torch.nn import functional as F


def init_conv(conv):
    nn.init.xavier_uniform_(conv.weight)
    if conv.bias is not None:
        conv.bias.data.zero_()



In [2]:
LOG_DIR = r'logs'
MODELS_DIR = r'models'
GENERATED_DIR = r'generated_img'

# Dataset parameters
FEATURES_DPI = 100
K = 8

# Training hyperparameters
IMAGE_SIZE = 256  # 224
BATCH_SIZE = 3
EPOCHS = 1000
LEARNING_RATE_E_G = 5e-5
LEARNING_RATE_D = 2e-4
LOSS_VGG_FACE_WEIGHT = 2e-3
LOSS_VGG19_WEIGHT = 1e-2
LOSS_MCH_WEIGHT = 8e1
LOSS_FM_WEIGHT = 1e1
FEED_FORWARD = False
SUBSET_SIZE = 140000

# Model Parameters
E_VECTOR_LENGTH = 512
HIDDEN_LAYERS_P = 4096

In [3]:
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()

        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.rand(1).normal_(0.0, 0.02))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # B: mini batches, C: channels, W: width, H: height
        B, C, H, W = x.shape
        proj_query = self.query_conv(x).view(B, -1, W * H).permute(0, 2, 1)  # B X CX(N)
        proj_key = self.key_conv(x).view(B, -1, W * H)  # B X C x (*W*H)
        energy = torch.bmm(proj_query, proj_key)  # transpose check
        attention = self.softmax(energy)  # BX (N) X (N)
        proj_value = self.value_conv(x).view(B, -1, W * H)  # B X C X N

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)

        out = self.gamma * out + x

        return out


class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=None):
        super(ConvLayer, self).__init__()
        if padding is None:
            padding = kernel_size // 2
        self.reflection_pad = nn.ZeroPad2d(padding)
        self.conv2d = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride))

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out


class AdaIn(nn.Module):
    def __init__(self):
        super(AdaIn, self).__init__()
        self.eps = 1e-5

    def forward(self, x, mean_style, std_style):
        B, C, H, W = x.shape

        feature = x.view(B, C, -1)

        std_feat = (torch.std(feature, dim=2) + self.eps).view(B, C, 1)
        mean_feat = torch.mean(feature, dim=2).view(B, C, 1)

        adain = std_style * (feature - mean_feat) / std_feat + mean_style

        adain = adain.view(B, C, H, W)
        return adain


# endregion

# region Non-Adaptive Residual Blocks

class ResidualBlockDown(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None):
        super(ResidualBlockDown, self).__init__()

        # Right Side
        self.conv_r1 = ConvLayer(in_channels, out_channels, kernel_size, stride, padding)
        self.conv_r2 = ConvLayer(out_channels, out_channels, kernel_size, stride, padding)

        # Left Side
        self.conv_l = ConvLayer(in_channels, out_channels, 1, 1)

    def forward(self, x):
        residual = x

        # Right Side
        out = F.relu(x)
        out = self.conv_r1(out)
        out = F.relu(out)
        out = self.conv_r2(out)
        out = F.avg_pool2d(out, 2)

        # Left Side
        residual = self.conv_l(residual)
        residual = F.avg_pool2d(residual, 2)
        
        # Merge
        out = residual + out
        return out


class ResidualBlockUp(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, upsample=2):
        super(ResidualBlockUp, self).__init__()

        # General
        self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')

        # Right Side
        self.norm_r1 = nn.InstanceNorm2d(in_channels, affine=True)
        self.conv_r1 = ConvLayer(in_channels, out_channels, kernel_size, stride)

        self.norm_r2 = nn.InstanceNorm2d(out_channels, affine=True)
        self.conv_r2 = ConvLayer(out_channels, out_channels, kernel_size, stride)

        # Left Side
        self.conv_l = ConvLayer(in_channels, out_channels, 1, 1)

    def forward(self, x):
        residual = x

        # Right Side
        out = self.norm_r1(x)
        out = F.relu(out)
        out = self.upsample(out)
        out = self.conv_r1(out)
        out = self.norm_r2(out)
        out = F.relu(out)
        out = self.conv_r2(out)

        # Left Side
        residual = self.upsample(residual)
        residual = self.conv_l(residual)

        # Merge
        out = residual + out
        return out


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = nn.InstanceNorm2d(channels, affine=True)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.in1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.in2(out)

        out = out + residual
        return out


# endregion

# region Adaptive Residual Blocks


class AdaptiveResidualBlockUp(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, upsample=2):
        super(AdaptiveResidualBlockUp, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        # General
        self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')

        # Right Side
        self.norm_r1 = AdaIn()
        self.conv_r1 = ConvLayer(in_channels, out_channels, kernel_size, stride)

        self.norm_r2 = AdaIn()
        self.conv_r2 = ConvLayer(out_channels, out_channels, kernel_size, stride)

        # Left Side
        self.conv_l = ConvLayer(in_channels, out_channels, 1, 1)

    def forward(self, x, mean1, std1, mean2, std2):
        residual = x

        # Right Side
        out = self.norm_r1(x, mean1, std1)
        out = F.relu(out)
        out = self.upsample(out)
        out = self.conv_r1(out)
        out = self.norm_r2(out, mean2, std2)
        out = F.relu(out)
        out = self.conv_r2(out)

        # Left Side
        residual = self.upsample(residual)
        residual = self.conv_l(residual)

        # Merge
        out = residual + out
        return out


class AdaptiveResidualBlock(nn.Module):
    def __init__(self, channels):
        super(AdaptiveResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = AdaIn()
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = AdaIn()

    def forward(self, x, mean1, std1, mean2, std2):
        residual = x

        out = self.conv1(x)
        out = self.in1(out, mean1, std1)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.in2(out, mean1, std1)

        out = out + residual
        return out

# endregion


In [9]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        m.weight.data.normal_(0.0, 0.02)
    if classname.find('Linear') != -1:
        # m.weight.data.normal_(0.0, 0.02)
        nn.init.xavier_uniform_(m.weight.data)
        m.bias.data.fill_(0)
    elif classname.find('InstanceNorm2d') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

from collections import OrderedDict

class Embedder(nn.Module):
    """
    The Embedder network attempts to generate a vector that encodes the personal characteristics of an individual given
    a head-shot and the matching landmarks.
    """
    def __init__(self, gpu=None):
        super(Embedder, self).__init__()
        
        
        self.conv11 = ResidualBlockDown(2, 2)
        self.conv12 = ResidualBlockDown(2, 2)
        self.conv1 = ResidualBlockDown(2, 64)
        self.conv2 = ResidualBlockDown(64, 128)
        self.conv3 = ResidualBlockDown(128, 256)
        self.att = SelfAttention(256)
        self.conv4 = ResidualBlockDown(256, 512)
        self.conv5 = ResidualBlockDown(512, 512)
        self.conv6 = ResidualBlockDown(512, 512)

        self.pooling = nn.AdaptiveMaxPool2d((1, 1))

#         self.apply(weights_init)
        self.gpu = gpu
        if gpu is not None:
            self.cuda(gpu)

    def forward(self, x, y):
        print("Network", x.shape)
        # assert x.dim() == 3 and x.shape[1] == 1, "Both x and y must be tensors with shape [BxK, 1, W, H]."
        assert x.shape == y.shape, "Both x and y must be tensors with shape [BxK, 1, W, H]."
        if self.gpu is not None:
            x = x.cuda(self.gpu)
            y = y.cuda(self.gpu)

        # Concatenate x & y
        out = torch.cat((x, y), dim=1)  # [BxK, 2, 256, 256]
        # Encode
        out = self.conv12(self.conv11(out))
        
        out = (self.conv1(out))  # [BxK, 64, 128, 128]
        out = (self.conv2(out))  # [BxK, 128, 64, 64]
        out = (self.conv3(out))  # [BxK, 256, 32, 32]
        out = self.att(out)
        out = (self.conv4(out))  # [BxK, 512, 16, 16]
        out = (self.conv5(out))  # [BxK, 512, 8, 8]
        out = (self.conv6(out))  # [BxK, 512, 4, 4]

        # Vectorize
        out = F.relu(self.pooling(out).view(-1, E_VECTOR_LENGTH))

        return out


class Generator(nn.Module):
    ADAIN_LAYERS = OrderedDict([
        ('res1', (512, 512)),
        ('res2', (512, 512)),
        ('res3', (512, 512)),
        ('res4', (512, 512)),
        ('res5', (512, 512)),
        ('deconv6', (512, 512)),
        ('deconv5', (512, 512)),
        ('deconv4', (512, 256)),
        ('deconv3', (256, 128)),
        ('deconv2', (128, 64)),
        ('deconv1', (64, 3))
    ])

    def __init__(self, gpu=None):
        super(Generator, self).__init__()

        # projection layer
        self.PSI_PORTIONS, self.psi_length = self.define_psi_slices()
        self.projection = nn.Parameter(torch.rand(self.psi_length, E_VECTOR_LENGTH).normal_(0.0, 0.02))

        # encoding layers
        self.conv11 = ResidualBlockDown(1, 1)
        self.in11_e = nn.InstanceNorm2d(1, affine=True)
        
        self.conv12 = ResidualBlockDown(1, 1)
        self.in12_e = nn.InstanceNorm2d(1, affine=True)
        
        self.conv1 = ResidualBlockDown(1, 64)
        self.in1_e = nn.InstanceNorm2d(64, affine=True)

        self.conv2 = ResidualBlockDown(64, 128)
        self.in2_e = nn.InstanceNorm2d(128, affine=True)

        self.conv3 = ResidualBlockDown(128, 256)
        self.in3_e = nn.InstanceNorm2d(256, affine=True)

        self.att1 = SelfAttention(256)

        self.conv4 = ResidualBlockDown(256, 512)
        self.in4_e = nn.InstanceNorm2d(512, affine=True)

        self.conv5 = ResidualBlockDown(512, 512)
        self.in5_e = nn.InstanceNorm2d(512, affine=True)

        self.conv6 = ResidualBlockDown(512, 512)
        self.in6_e = nn.InstanceNorm2d(512, affine=True)

        # residual layers
        self.res1 = AdaptiveResidualBlock(512)
        self.res2 = AdaptiveResidualBlock(512)
        self.res3 = AdaptiveResidualBlock(512)
        self.res4 = AdaptiveResidualBlock(512)
        self.res5 = AdaptiveResidualBlock(512)

        # decoding layers
        self.deconv6 = AdaptiveResidualBlockUp(512, 512, upsample=2)
        self.in6_d = nn.InstanceNorm2d(512, affine=True)

        self.deconv5 = AdaptiveResidualBlockUp(512, 512, upsample=2)
        self.in5_d = nn.InstanceNorm2d(512, affine=True)

        self.deconv4 = AdaptiveResidualBlockUp(512, 256, upsample=2)
        self.in4_d = nn.InstanceNorm2d(256, affine=True)

        self.deconv3 = AdaptiveResidualBlockUp(256, 128, upsample=2)
        self.in3_d = nn.InstanceNorm2d(128, affine=True)

        self.att2 = SelfAttention(128)

        self.deconv2 = AdaptiveResidualBlockUp(128, 64, upsample=2)
        self.in2_d = nn.InstanceNorm2d(64, affine=True)

        self.deconv1 = AdaptiveResidualBlockUp(64, 1, upsample=2)
        self.in1_d = nn.InstanceNorm2d(1, affine=True)
        
        self.deconv12 = ResidualBlockUp(1, 1, upsample=2)
        self.in12_d = nn.InstanceNorm2d(1, affine=True)
        
        self.deconv11 = ResidualBlockUp(1, 1, upsample=2)
        self.in11_d = nn.InstanceNorm2d(1, affine=True)

        self.apply(weights_init)
        self.gpu = gpu
        if gpu is not None:
            self.cuda(gpu)

    def forward(self, y, e):
        if self.gpu is not None:
            e = e.cuda(self.gpu)
            y = y.cuda(self.gpu)

        out = y  # [B, 1, 256, 256]

        # Calculate psi_hat parameters
        P = self.projection.unsqueeze(0)
        P = P.expand(e.shape[0], P.shape[1], P.shape[2])
        psi_hat = torch.bmm(P, e.unsqueeze(2)).squeeze(2)

        # Encode
        out = self.in11_e(self.conv11(out))
        out = self.in12_e(self.conv12(out))
        out = self.in1_e(self.conv1(out))  # [B, 64, 128, 128]
        out = self.in2_e(self.conv2(out))  # [B, 128, 64, 64]
        out = self.in3_e(self.conv3(out))  # [B, 256, 32, 32]
        out = self.att1(out)
        out = self.in4_e(self.conv4(out))  # [B, 512, 16, 16]
        out = self.in5_e(self.conv5(out))  # [B, 512, 8, 8]
        out = self.in6_e(self.conv6(out))  # [B, 512, 4, 4]

        # Residual layers
        out = self.res1(out, *self.slice_psi(psi_hat, 'res1'))
        out = self.res2(out, *self.slice_psi(psi_hat, 'res2'))
        out = self.res3(out, *self.slice_psi(psi_hat, 'res3'))
        out = self.res4(out, *self.slice_psi(psi_hat, 'res4'))
        out = self.res5(out, *self.slice_psi(psi_hat, 'res5'))

        # Decode
        out = self.in6_d(self.deconv6(out, *self.slice_psi(psi_hat, 'deconv6')))  # [B, 512, 4, 4]
        out = self.in5_d(self.deconv5(out, *self.slice_psi(psi_hat, 'deconv5')))  # [B, 512, 16, 16]
        out = self.in4_d(self.deconv4(out, *self.slice_psi(psi_hat, 'deconv4')))  # [B, 256, 32, 32]
        out = self.in3_d(self.deconv3(out, *self.slice_psi(psi_hat, 'deconv3')))  # [B, 128, 64, 64]
        out = self.att2(out)
        out = self.in2_d(self.deconv2(out, *self.slice_psi(psi_hat, 'deconv2')))  # [B, 64, 128, 128]
        out = self.in1_d(self.deconv1(out, *self.slice_psi(psi_hat, 'deconv1')))  # [B, 3, 256, 256]
        out = self.in12_d(self.deconv12(out))
        out = self.in11_d(self.deconv11(out))

        out = torch.sigmoid(out)

        return out

    def slice_psi(self, psi, portion):
        idx0, idx1 = self.PSI_PORTIONS[portion]
        len1, len2 = self.ADAIN_LAYERS[portion]
        aux = psi[:, idx0:idx1].unsqueeze(-1)
        mean1, std1 = aux[:, 0:len1], aux[:, len1:2 * len1]
        mean2, std2 = aux[:, 2 * len1:2 * len1 + len2], aux[:, 2 * len1 + len2:]
        return mean1, std1, mean2, std2

    def define_psi_slices(self):
        out = {}
        d = self.ADAIN_LAYERS
        start_idx, end_idx = 0, 0
        for layer in d:
            end_idx = start_idx + d[layer][0] * 2 + d[layer][1] * 2
            out[layer] = (start_idx, end_idx)
            start_idx = end_idx

        return out, end_idx


class Discriminator(nn.Module):
    def __init__(self, training_videos, gpu=None):
        super(Discriminator, self).__init__()

        self.conv1 = ResidualBlockDown(2, 64)
        self.conv2 = ResidualBlockDown(64, 128)
        self.conv3 = ResidualBlockDown(128, 256)
        self.att = SelfAttention(256)
        self.conv4 = ResidualBlockDown(256, 512)
        self.conv5 = ResidualBlockDown(512, 512)
        self.conv6 = ResidualBlockDown(512, 512)
        self.res_block = ResidualBlock(512)

        self.pooling = nn.AdaptiveMaxPool2d((1, 1))

        self.W = nn.Parameter(torch.rand(512, training_videos).normal_(0.0, 0.02))
        self.w_0 = nn.Parameter(torch.rand(512, 1).normal_(0.0, 0.02))
        self.b = nn.Parameter(torch.rand(1).normal_(0.0, 0.02))

        self.apply(weights_init)
        self.gpu = gpu
        if gpu is not None:
            self.cuda(gpu)

    def forward(self, x, y, i):
        assert x.dim() == 4 and x.shape[1] == 1, "Both x and y must be tensors with shape [BxK, 3, W, H]."
        assert x.shape == y.shape, "Both x and y must be tensors with shape [BxK, 3, W, H]."

        if self.gpu is not None:
            x = x.cuda(self.gpu)
            y = y.cuda(self.gpu)

        # Concatenate x & y
        out = torch.cat((x, y), dim=1)  # [B, 6, 256, 256]

        # Encode
        out_0 = (self.conv1(out))  # [B, 64, 128, 128]
        out_1 = (self.conv2(out_0))  # [B, 128, 64, 64]
        out_2 = (self.conv3(out_1))  # [B, 256, 32, 32]
        out_3 = self.att(out_2)
        out_4 = (self.conv4(out_3))  # [B, 512, 16, 16]
        out_5 = (self.conv5(out_4))  # [B, 512, 8, 8]
        out_6 = (self.conv6(out_5))  # [B, 512, 4, 4]
        out_7 = (self.res_block(out_6))

        # Vectorize
        out = F.relu(self.pooling(out_7)).view(-1, 512, 1)  # [B, 512, 1]

        # Calculate Realism Score
        _out = out.transpose(1, 2)
        _W_i = (self.W[:, i].unsqueeze(-1)).transpose(0, 1)
        out = torch.bmm(_out, _W_i + self.w_0) + self.b
        out = torch.sigmoid(out)

        out = out.reshape(x.shape[0])

        return out, [out_0, out_1, out_2, out_3, out_4, out_5, out_6, out_7]


In [5]:
import numpy as np

In [6]:
sample_data = torch.from_numpy(np.random.randint(0, 256, size=(1, 11, 2, 1, 1024, 1024)).astype(np.long))
t = sample_data[:, -1, ...]
data = sample_data[:, :-1, ...]
dims = data.shape
print("Data {}, T {} ".format(data.shape, t.shape))
# Calculate average encoding vector for data
e_in = data.reshape(dims[0] * dims[1], dims[2], dims[3], dims[4], dims[5])  # [BxK,2,  C, W, H]
print("EIN ", e_in.shape)
x, y = e_in[:, 0, ...], e_in[:, 1, ...]
E = Embedder("cuda:0")
e_vectors = E(x.float(), y.float()).reshape(dims[0], dims[1], -1)  # B, K, len(e)
e_hat = e_vectors.mean(dim=1)

Data torch.Size([1, 10, 2, 1, 1024, 1024]), T torch.Size([1, 2, 1, 1024, 1024]) 
EIN  torch.Size([10, 2, 1, 1024, 1024])
Network torch.Size([10, 1, 1024, 1024])


In [7]:
print(e_hat.shape)

torch.Size([1, 512])


In [8]:
torch.cuda.is_available()

True

In [10]:
G = Generator("cuda:0")
x_t, y_t = data[:, 0, ...], data[:, 1, ...]
x_hat = G(y_t, e_hat)

TypeError: relu(): argument 'input' (position 1) must be Tensor, not ZeroPad2d