In [1]:
from torch.utils.data import DataLoader
import torch.nn as nn
import enum
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
import math
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("TkAgg")
import seaborn as sns
from torch.autograd import grad
import scipy.stats as stats
from utils import *
from sklearn import datasets

In [2]:
class FlowModules(FlowModule):
    """A container for a succession of flow modules"""
    def __init__(self, *flows: FlowModule):
        super().__init__()
        self.flows = nn.ModuleList(flows)

    def apply(self, modules_iter, caller, x):
        m, _ = x.shape
        logdet = torch.zeros(m, device=x.device)
        zs = [x]
        for module in modules_iter:
            x, _logdet = caller(module, x)
            zs.append(x)
            logdet += _logdet
        return zs, logdet            

    def modulenames(self, backward=False):
        return [f"L{ix} {module.__class__.__name__}" for ix, module in enumerate(reversed(self.flows) if backward else self.flows)]

    def f(self, x):
        zs, logdet = self.apply(self.flows, lambda m, x: m.f(x), x)
        return zs, logdet

    def invf(self, y):
        zs, logdet = self.apply(reversed(self.flows), lambda m, y: m.invf(y), y)
        return zs, logdet


class FlowModel(FlowModules):
    """Flow model = prior + flow modules"""
    def __init__(self, prior, *flows: FlowModule):
        super().__init__(*flows)
        self.prior = prior

    def invf(self, x):
        # Just computes the prior
        zs, logdet = super().invf(x)

        logprob = self.prior.log_prob(zs[-1]) #z_0
        print(logprob.shape)
        return logprob, zs, logdet

In [3]:
class AffineFlow(FlowModule):
    def __init__(self, in_features):
        #in_features : la dimension des données
        #On est en batchs
        super().__init__()
        self.s = nn.Parameter(torch.randn(in_features, requires_grad=True))#Broadcast x * s-> (batch, in_features)
        self.t = nn.Parameter(torch.randn(in_features, requires_grad=True))#Broadcast x + s-> (batch, in_features)
        
    def f(self, x):
        y = x * torch.exp(self.s) + self.t
        logdet = torch.sum(self.s , dim = -1) #exp est toujours positive et log(exp(x)) = x
        return y, logdet
    
    def invf(self, y):
        x = (y - self.t) * torch.exp(- self.s) #f^{-1}(y)
        logdet = - torch.sum(self.s, dim = -1)
        assert self.f(x)[0].allclose(y, atol = 1e-02), 'f^{-1}(y) is not equal to x in AffineFlow'
        return x, logdet
    
    def check(self, x):
        return self.invf(self.f(x)[0])[0].allclose(x)

In [4]:
#On hérite de AffineFlow afin d'avoir le checker et réduire le code
#Le broadcast sera toujours fait si (dim1, dim2) (op) (dim2) = (dim1, dim2) (op) (1, dim2)
class ActNorm(AffineFlow):
    def __init__(self, in_features):
        super().__init__(in_features)
        self.first_init = False
    
    def f(self, x):
        if not self.first_init:
            #z = (x - mean) / std 
            #torch.exp(self.s) = 1 / std
            #self.t = - mean / std
            #exp(- log(std)) = 1 / std pour plus de stabilité
            #Nous obtenons bien z = x / std - mean / std = x - mean / std pour le premier batch
            self.s.data.copy_(- torch.log(x.std(dim = 0) + 1e-8))
            self.t.data.copy_(- x.mean(dim = 0) * torch.exp(self.s)) 
            self.first_init = True                
        return super().f(x)
    
    def invf(self, y):
        if not self.first_init:
            #x = (y + mean / std) * std = y * std + mean
            self.s.data.copy_(- torch.log(y.std(dim = 0) + 1e-8))
            self.t.data.copy_(- y.mean(dim = 0) * torch.exp(self.s)) 
            self.first_init = True
        return super().invf(y)      

In [5]:
class AffineCouplingLayer(FlowModule):
    def __init__(self, in_features, hidden_dim = 64):
        #Soit x la valeur actuelle de dimension 2×l
        #2xl = in_features
        #l = in_features/2
        super().__init__()
        assert in_features%2 == 0, 'Must be divisible by 2'
        self.s = MLP(in_features // 2, in_features // 2, hidden_dim)
        self.t = MLP(in_features // 2, in_features // 2, hidden_dim)
    
    def f(self, x):
        #x : (batch, in_features)
        assert x.shape[1]%2 ==0, 'Must be divisible by 2' 
        x_1, x_2 = torch.chunk(x, 2, dim = 1)
        s = self.s(x_1)
        t = self.t(x_1)
        y_1 = x_1
        y_2 = x_2 * torch.exp(s) + t
        y = torch.cat((y_1, y_2), dim = 1)
        logdet = torch.sum(s , dim = 1)
        return y, logdet
    
    def invf(self, y):
        assert y.shape[1]%2 ==0, 'Must be divisible by 2'
        y_1, y_2 = torch.chunk(y, 2, dim = 1)
        x_1 = y_1
        s = self.s(x_1)
        t = self.t(x_1)
        x_2 = (y_2 - t) * torch.exp(- s)
        x = torch.cat((x_1, x_2), dim = 1)
        logdet =  - torch.sum(s, dim = 1)
        assert self.f(x)[0].allclose(y, atol = 1e-05), 'f^{-1}(y) is not equal to x in AffineCouplingLayer'
        return x, logdet
        

In [6]:
class Convolution1x1(FlowModule):
    def __init__(self, in_features):
        super().__init__()
        self.in_features = in_features
        W = torch.nn.init.orthogonal_(torch.randn(in_features, in_features)) #Matrice carrée
        A_LU, pivots = W.lu()
        P, W_L, W_U = torch.lu_unpack(A_LU, pivots)
        self.P = P #Aucun changement
        #Matrice triangulaires avec des diagonales à zéro pour L_prime et U_prime
        self.L_prime = nn.Parameter(torch.tril(W_L, diagonal = -1)) 
        self.U_prime = nn.Parameter(torch.triu(W_U, diagonal = 1)) 
        self.S = nn.Parameter(torch.diag(torch.diag(W_U)))#Prendre la diagonal (in_features, in_features)
        
    def f(self, x):
        W = (self.P @ (self.L_prime + torch.eye(self.in_features)) @ (self.U_prime + self.S))
        y = x @ W
        logdet = torch.sum(torch.log(torch.abs(self.S))) #sum(log(s)) lilianWen
        return y, logdet
    
    def invf(self, y):
        W = (self.P @ (self.L_prime + torch.eye(self.in_features)) @ (self.U_prime + self.S))
        x = y @ torch.inverse(W)
        logdet = -torch.sum(torch.log(torch.abs(self.S)))
        assert self.f(x)[0].allclose(y, atol = 1e-05), 'f^{-1}(y) is not equal to x in Convolution1x1'
        return x, logdet

In [7]:
modules = []
L = 10
dim = 2
hidden_dim = 64
for _ in range(L):
    modules.append(ActNorm(dim))
    modules.append(AffineCouplingLayer(dim, hidden_dim = hidden_dim))
    modules.append(Convolution1x1(dim))

In [8]:
dim = 2
batchsize = 64
lr = 0.001
nb_epochs = 100000
mu = torch.zeros(dim)
s = torch.ones(dim)
prior = torch.distributions.independent.Independent(torch.distributions.normal.Normal(mu, s),1)
model = FlowModel(prior, *modules)
optimizer = torch.optim.Adam(model.parameters(), lr = lr)

In [9]:
for k in range(nb_epochs):
        x, _ = datasets.make_circles(n_samples = batchsize, factor=0.5, noise=0.05, random_state=0)
        x = torch.FloatTensor(x)
        z_0_logprob, zs, logdet = model.invf(x)
        logprob = z_0_logprob + logdet
        negative_log_likelihood = - torch.mean(logprob)
        model.zero_grad()
        negative_log_likelihood.backward()
        optimizer.step()
        print(negative_log_likelihood.item())

AssertionError: f^{-1}(y) is not equal to x in AffineFlow