In [None]:
def fade_in(alpha, a, b):
    '''
    Smoothing function
    '''
    return alpha * a + (1- alpha ) * b

def wasserstein_loss(y_true, y_pred):
    '''
    Wasserstein Loss ( Refer to WGANs)
    '''
    return - (y_true * y_pred).mean()

def pixel_norm(x, epsilon = 1e-8):
    return x / torch.sqrt(torch.mean(x**2, dim = -1, keepdim=True)+ epsilon)

def minibatch_std(tensor_input : torch.Tensor, epsilon = 1e-8):
    '''
    Minibatch Standard Deviation (ref : <https://arxiv.org/pdf/1710.10196.pdf>)
    '''
    n, c, h, w,  = tensor_input.shape

    # shape into minibatches of size 4
    group_size = min(4,n)
    x = torch.reshape(tensor_input,[group_size,-1, c, h, w]).float()
    num_batches = x.shape[1]

    # calculate group standard deviation
    group_var = x.var(0,unbiased=False)
    group_std = torch.sqrt(group_var+epsilon)


    # average deviation per minibatch
    avg_std = torch.mean(group_std, dim=[1,2,3], keepdim=True)
    avg_std = avg_std.repeat(group_size,num_batches,1,h,w)


    # adding channel with mean of std of mini-batches
    return torch.cat([x,avg_std],dim=2).reshape(n,c+1,h,w)



class EqualizedConv2D(nn.Module):
    def __init__(self, out_channels=1, kernel = 3, gain = 2, **kwargs) -> None:
        super().__init__()
        self.initialized = False
        self.kernel = kernel
        self.out_channels = out_channels
        self.gain = gain

        self.register_parameter('weights',None)
        self.register_parameter('bias',None)

    
    def build(self, input : torch.Tensor):
        self.in_channels = input.shape[1]

        self.weights = nn.Parameter(input.new(self.out_channels, self.in_channels,self.kernel,self.kernel).normal_())
        self.bias = nn.Parameter(input.new(self.out_channels).zero_())

        fan_in = self.kernel * self.kernel * self.in_channels
        self.scale = math.sqrt(self.gain / fan_in)
    
    def forward(self, X):
        if not self.initialized:
            self.build(X)
            self.initialized = True
        return F.conv2d(X, self.weights, self.bias, padding='same')

class EqualizedDense(nn.Module):
    def __init__(self, dim_out, gain = 2, lr_multiplier = 1.0, **kwargs) -> None:
        super().__init__()
        self.initialized = False
        self.dim_out = dim_out
        self.gain = gain
        self.lr_multiplier = lr_multiplier

        self.register_parameter('weights',None)
        self.register_parameter('bias',None)

    def build(self, input : torch.Tensor):
        self.dim_in = input.shape[-1]

        self.weights = nn.Parameter(input.new(self.dim_out, self.dim_in).normal_(0.0,1.0/ self.lr_multiplier))
        self.bias = nn.Parameter(input.new(self.dim_out).zero_())
        
        self.scale = math.sqrt(self.gain / self.dim_in)

    def forward(self, X):
        print(self.initialized)
        if not self.initialized:
            self.build(X)
            self.initialized = True
        
        return F.linear(X, self.scale * self.weights, self.bias) * self.lr_multiplier
    
class AddNoise(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.initialized = False
        self.register_parameter('bias',None)

    def build(self, input : torch.Tensor):
        n, c, h, w = input.shape[0]

        self.bias = nn.Parameter(input.new(1,c,1,1).normal_())
    
    def forward(self, X):
        if not self.initialized:
            self.build(X)
            self.initialized = True
        
        x, noise = X
        return x + self.bias * noise

class AdaIN(nn.Module):
    def __init__(self, gain =1 , **kwargs) -> None:
        super().__init__()
        self.gain = gain
        self.initialized = False

        self.register_module("dense_1",None)
        self.register_module("dense_2",None)
    
    def build(self, input):
        x, w = input

        x_shape = x.shape
        w_shape = w.shape

        self.w_channels = w_shape[1]
        self.x_channels = x_shape[1]

        self.dense_1 = EqualizedDense(self.x_channels, gain =1 )
        self.dense_2 = EqualizedDense(self.x_channels, gain =1 )
    
    def forward(self, X):
        if not self.initialized:
            self.build(X)
            self.initialized = True
        x, w = X
        ys = self.dense_1(w).reshape([-1,self.x_channels,1,1])
        yb = self.dense_2(w).reshape([-1,self.x_channels,1,1])

        return ys *x + yb

In [None]:
class Mapping(nn.Module):
    def __init__(self, num_stages, input_shape = 512) -> None:
        super().__init__()

        self.num_stages = num_stages
        self.input_shape = input_shape

        layers =  []

        for i in range(8):
            layers.append(EqualizedDense(input_shape, 512,1, lr_multiplier=0.01))
            layers.append(nn.LeakyReLU())
        
        self.layers  = nn.Sequential(*layers)
    
    def forward(self, X):
        x = self.layers(X)
        return torch.tile(x.unsqueeze(1),(1,self.num_stages,1))


In [None]:
A = AdaIN()

A(torch.ones(20).view(2,1,2,5))


False
False


tensor([[[[0.9273, 0.9273, 0.9273, 0.9273, 0.9273],
          [0.9273, 0.9273, 0.9273, 0.9273, 0.9273]],

         [[1.8117, 1.8117, 1.8117, 1.8117, 1.8117],
          [1.8117, 1.8117, 1.8117, 1.8117, 1.8117]]],


        [[[0.9273, 0.9273, 0.9273, 0.9273, 0.9273],
          [0.9273, 0.9273, 0.9273, 0.9273, 0.9273]],

         [[1.8117, 1.8117, 1.8117, 1.8117, 1.8117],
          [1.8117, 1.8117, 1.8117, 1.8117, 1.8117]]]], grad_fn=<AddBackward0>)

In [None]:
def process_filepath(x:str):
    x = x.replace(".png",".jpg")
    x = x.split("/")
    x.pop(1)
    return "/".join(x)

def load_information(x):
    json_name = x['image']['file_path'].split("/")[-1].split(".")[0]+".json"
    json_path = os.path.join("./ffhq-features-dataset/json/",json_name)
    props = json.load(open(json_path,"r"))
    if len(props) == 0:
        return False
    return props[0]['faceAttributes']['glasses'] != 'NoGlasses'
    

data = pd.read_json(os.path.join("./","ffhq-dataset-v2.json"),orient="index")
data['info'] = data.progress_apply(load_information,axis=1)
glasses_data = data[data['info']].copy()
glasses_data['image_path'] = glasses_data.apply(lambda x : process_filepath(x['thumbnail']['file_path']),axis=1)
glasses_data = glasses_data.drop(columns=['image','thumbnail','in_the_wild','metadata','info'])

In [None]:
images = [glasses_data['image_path'].iloc[random.randint(0,glasses_data.shape[0])].split("/")[-1] for _ in range(0,10)]

In [None]:
glasses_data = data[data['info']].copy()

In [None]:
images = glasses_data.sample(10)

In [None]:
import requests
for i in tqdm(range(0,10)):
    print(images['image'].iloc[i]['file_url'])

100%|████████████████████████████████████████| 10/10 [00:00<00:00, 30863.16it/s]

https://drive.google.com/uc?id=152iXH9YjbbEmttREAcSVb8-TZ7PL8glS
https://drive.google.com/uc?id=1Bc_8qQh_s5wRxLX7h3p66DzLOnZM0ZUc
https://drive.google.com/uc?id=1sHNJjposMdiVedHl_2gZJCfe35yFt0D8
https://drive.google.com/uc?id=1xVnCNWsmFw70GyWvL_ggdjPWFKoViH7X
https://drive.google.com/uc?id=16Az_0hEhgDqt883pVwmiqHEapGxPv_bx
https://drive.google.com/uc?id=1z-J-1RysweG1aTl2t00SXVq5iUdvtP5G
https://drive.google.com/uc?id=1jlfsR81VKF5p2f19H8UCcWCR5DVa1Asu
https://drive.google.com/uc?id=1SZ_pt5eFY1p7KdkRBzPEhzXm2Dt1Zzo4
https://drive.google.com/uc?id=1YvROueVSJthYxsyVTL8s0phZ2c_WJ9TV
https://drive.google.com/uc?id=1BsrKrRglMmwCyNPIB3djYO27HUjqp-SQ



