In [None]:
import itertools
import numpy as np
import os
import seaborn as sns
from tqdm import tqdm
from dataclasses import asdict, dataclass, field
import vsketch
import shapely.geometry as sg
from shapely.geometry import box, MultiLineString, Point, MultiPoint, Polygon, MultiPolygon, LineString
import shapely.affinity as sa
import shapely.ops as so
import matplotlib.pyplot as plt
import pandas as pd

import vpype_cli
from typing import List, Generic
from genpen import genpen as gp, utils as utils
from scipy import stats as ss
import geopandas
from shapely.errors import TopologicalError
import functools
%load_ext autoreload
%autoreload 2
import vpype
from skimage import io
from pathlib import Path

import bezier

from sklearn.preprocessing import minmax_scale
from skimage import feature
from genpen.utils import Paper

from scipy import spatial, stats
from scipy.ndimage import gaussian_filter
from scipy.integrate import odeint

import matplotlib.pyplot as plt
import numpy as np
import shapely.geometry as sg
from rasterio import features
import shapely.geometry as sg
from shapely.geometry import box, MultiLineString, Point, MultiPoint, Polygon, MultiPolygon, LineString
import shapely.affinity as sa
import shapely.ops as so
from scipy import stats as ss
from tqdm.auto import tqdm


from scipy import spatial, stats
from scipy.ndimage import gaussian_filter
from scipy.integrate import odeint

import torch
import torch.optim as optim
from torchvision import transforms, models
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
import pytorch_lightning as pl
from torch import tensor as Tensor
from pytorch_lightning.loggers import WandbLogger
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [None]:
# make page
paper_size = '11x14 inches'
border:float=30
paper = Paper(paper_size)

drawbox = paper.get_drawbox(border)

In [None]:
def coordset_to_lines(coordset):
    return [coords_to_line(coords) for coords in coordset]

def coords_to_line(coords):
    x0, y0, x1, y1 = coords
    return LineString([(x0, y0), (x1, y1)])

def coordset_to_img(coordset, out_shape):
    lines = coordset_to_lines(coordset)
    return features.rasterize(lines, out_shape=out_shape)

def generate_random_line_data(n_samples, n_lines, out_shape):
    
    coordsets = []
    imgs = []
    for jj in tqdm(range(n_samples)):
        coords = []
        for ii in range(n_lines):
            x0 = xgen()
            y0 = ygen()
            x1 = xgen()
            y1 = ygen()
            coords.append(np.array([x0, y0, x1, y1]))
        coordset = np.stack(coords)
        img = coordset_to_img(coordset, out_shape)
        imgs.append(img)
        coordsets.append(coords)
    return np.stack(imgs, axis=0), np.stack(coordsets, axis=0)

In [None]:
def ode(y, t, a, b, c, d):
    v, u = y
    dvdt = np.sin(b * u) + v * c
    dudt = np.cos(a * v * u) + u  * d
    dydt = [dvdt, dudt]
    return dydt

In [None]:
a = 0.1
b = 0.95
c = - 0.02
d = -0.02

In [None]:
@dataclass
class OdeParams:
    a = ss.norm(loc=0.05, scale=0.1)
    b = ss.norm(loc=0.9, scale=0.1)
    c = ss.norm(loc=-0.025, scale=0.01)
    d = ss.norm(loc=-0.025, scale=0.01)
    
    def get_args(self):
        return self.a.rvs(), self.b.rvs(), self.c.rvs(), self.d.rvs()

In [None]:
def no_overlap_odeint(ode, pts, args, t, break_dist=0.2, min_len=0.9, verbose=False):
    allowed_counter = 0
    all_polys = Polygon()

    break_dist = break_dist

    lines = []
    lfs = MultiLineString()
    iterator = pts
    if verbose:
        iterator = tqdm(iterator)
    for ii, pt in enumerate(iterator):
        sol = odeint(ode, [pt.x, pt.y], t, args=args)
        mpt = MultiPoint(sol)
        if ii == 0:
            ls = LineString(mpt)
            lfs = gp.merge_LineStrings([lfs, ls])
            lines.append(ls)
        else:
            allowed_counter = 0
            for _pt in mpt:
                dist = _pt.distance(lfs)
                if dist < break_dist:
                    break
                allowed_counter += 1
        if allowed_counter > 1:
            ls = LineString(mpt[:allowed_counter])
            lfs = gp.merge_LineStrings([lfs, ls])
            lines.append(ls)

    return gp.merge_LineStrings([l for l in lines if l.length > min_len])   

In [None]:
def odeint_to_ls(ode, pts, args, t, min_len=0.9, verbose=False):
    allowed_counter = 0
    all_polys = Polygon()


    lines = []
    lfs = MultiLineString()
    iterator = pts
    if verbose:
        iterator = tqdm(iterator)
    for ii, pt in enumerate(iterator):
        sol = odeint(ode, [pt.x, pt.y], t, args=args)
        ls = LineString(sol)
        lines.append(ls)

    return gp.merge_LineStrings([l for l in lines if l.length > min_len])   

In [None]:
out_shape=(256, 256)
border = 5
img_drawbox = box(border, border, out_shape[0]-border, out_shape[1]-border)

n_lines = 300
thetas = np.linspace(0, np.pi*24, n_lines)
radii = np.linspace(0.8, 18, n_lines)

pts = []
for theta, radius in zip(thetas, radii):
    x = np.cos(theta) * radius - 0
    y = np.sin(theta) * radius + 0.
    pts.append(Point(x, y))

t_max = 17.7
t = np.linspace(0, t_max, 61)

linesets = []
imgs = []

In [None]:
# for ii in tqdm(range(20000)):
#     args = OdeParams().get_args()
#     lss = no_overlap_odeint(ode, pts, args, t, break_dist=0.5, min_len=0.9)
#     lss = gp.make_like(gp.merge_LineStrings(lss), img_drawbox)
#     img = features.rasterize(lss, out_shape=out_shape)
#     linesets.append(lss)
#     imgs.append(img)

In [None]:
# np.savez_compressed('/home/naka/data/ode_vae/test.npz', imgs)

In [None]:
args = OdeParams().get_args()

In [None]:
lss = no_overlap_odeint(ode, pts, args, t, min_len=0.9, verbose=False)
lss = gp.make_like(gp.merge_LineStrings(lss), img_drawbox)

In [None]:
polys = gp.merge_Polygons([ls.buffer(1, cap_style=2, join_style=2) for ls in lss])
img = features.rasterize(polys.boundary, out_shape=out_shape)

In [None]:
out_shape=(128, 128)
border = 5
img_drawbox = box(border, border, out_shape[0]-border, out_shape[1]-border)

n_lines = 200
thetas = np.linspace(0, np.pi*23.1, n_lines)
radii = np.linspace(0.8, 18, n_lines)

pts = []
for theta, radius in zip(thetas, radii):
    x = np.cos(theta) * radius - 0
    y = np.sin(theta) * radius + 0.
    pts.append(Point(x, y))
pts = MultiPoint(pts)
t_max = 17.7
t = np.linspace(0, t_max, 61)

a = 0.1
b = 0.95
c = - 0.02
d = -0.02
n_cs = 40
n_ds = 40
cs = np.geomspace(-10., -0.04, n_cs)
ds = np.geomspace(-1., -0.04, n_ds)
cds = list(itertools.product(cs, ds))
n_images = len(cds)

xj = ss.norm(loc=-0.3, scale=0.6).rvs
yj = ss.norm(loc=-0.3, scale=0.6).rvs



In [None]:

imgs = []
for ii in tqdm(range(n_images)):
    c, d = cds[ii]
    args = (a, b, c, d)
    _pts = sa.rotate(pts, angle=np.random.uniform(0, 360))
    _pts = sa.translate(_pts, xoff=xj(), yoff=yj())
    lss = no_overlap_odeint(ode, _pts, args, t, break_dist=0.5, min_len=0.9)
    lss = gp.make_like(gp.merge_LineStrings(lss), img_drawbox)
    img = features.rasterize(lss, out_shape=out_shape)
    linesets.append(lss)
    imgs.append(img)

In [None]:
# f, axs = plt.subplots(n_cs,n_ds, figsize=(18,18))
# axs = axs.ravel()
# for ii in range(n_images):
#     ax = axs[ii]
#     ax.imshow(imgs[ii])
#     ax.axis('off')
    
# plt.tight_layout()

In [None]:
np.savez_compressed('/home/naka/data/ode_vae/parametric3.npz', imgs)

In [None]:
imgs = np.load('/home/naka/data/ode_vae/parametric3.npz')['arr_0']

In [None]:
class ImageDataset(Dataset):

    def __init__(self, imgs, transform=None):
        self.imgs = imgs
        self.transform = transform

    def __len__(self):
        return self.imgs.shape[0]

    def __getitem__(self, idx):
        imgs = self.imgs[idx, :, :]
        imgs = transforms.ToTensor()(imgs)
        if self.transform:
            imgs = self.transform(imgs[:, :])
            
        return imgs
    
    
class ImageDataModule(pl.LightningDataModule):

    def __init__(
        self,
        imgs,
        train_val_test_ratio:tuple=(0.8, 0.1, 0.1),
        batch_size=1,
        random_state=None,
        num_workers=0,
        ):
        super().__init__()
        self.imgs = imgs
        self.random_state = random_state
        self.train_val_test_ratio = train_val_test_ratio
        self.batch_size = batch_size
        self.num_workers = num_workers
    
        
    def prepare_data(self):
        #self.data = 
        pass

    def setup(self, stage=None):
#         split_data = train_val_test_split(self.imgs, *self.train_val_test_ratio, random_state=self.random_state)
        self.train_data = ImageDataset(self.imgs)
        self.val_data = ImageDataset(self.imgs)
        self.test_data = ImageDataset(self.imgs)
        
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_data,batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
        

In [None]:

class VAE(pl.LightningModule):
    
    def encode_unit(self, inputs, out_size, kernel_size, stride, padding=0, dropout=0):
        return nn.Sequential(
            nn.Conv2d(inputs, out_size, kernel_size, stride, padding),
            nn.BatchNorm2d(out_size, track_running_stats=False, affine=False),
            nn.Dropout(dropout),
            nn.LeakyReLU()
            
        )
    
    def decode_unit(self, inputs, out_size, kernel_size, stride, padding=0, dropout=0, output_padding=0):
        return nn.Sequential(
            nn.ConvTranspose2d(inputs, out_size, kernel_size, stride, padding, output_padding=output_padding),
            nn.BatchNorm2d(out_size, track_running_stats=False, affine=False),
            nn.Dropout(dropout),
            nn.LeakyReLU()
        )
    
    def __init__(
        self, 
        
        latent_dim,
        input_size=256,
        layer_count=3, 
        channels=1, 
        depth=2,
        lr=1e-3,
        weight_decay=1e-5,
        kld_loss_weight=1,
        mul=1,
        fc_scale=8,
        encode_dropout=0.2,
        decode_dropout=0.2,
        kernel_size=4,
        stride=2,
        fc_size=0,
        padding=1,
        
    ):
        super(VAE, self).__init__()

        self.depth = depth
        self.latent_dim = latent_dim
        self.lr = lr
        self.layer_count = layer_count
        self.weight_decay = weight_decay
        self.kld_loss_weight = kld_loss_weight
        self.fc_scale = fc_scale
        self.encode_dropout = encode_dropout
        self.decode_dropout = decode_dropout
        self.padding = padding
        self.stride = stride
        self.kernel_size = kernel_size
        
        
        self.input_size = input_size
        self.dummy_param = nn.Parameter(torch.empty(0))
        
        inputs = channels
        
        width = self.input_size
        self.encoder_layer_output_widths = {}
        encoder_modules = []
        for i in range(self.layer_count):
            mul *= self.depth
            out_size = mul
            encoder_modules.append(self.encode_unit(inputs, out_size, self.kernel_size, self.stride, padding=self.padding, dropout=self.encode_dropout))
            inputs = out_size
            
            width = (width + 2 * self.padding - 1 * (self.kernel_size - 1) - 1) / self.stride + 1
            self.encoder_layer_output_widths[i] = width
            
            
        self.encoder_end_width = int(np.floor(width))
        
            
        if fc_size == None:
            fc_size = int(mul * np.floor(width) ** 2)
        self.fc_size = fc_size
        
        # Build Encoder
        self.encoder = nn.Sequential(*encoder_modules)
        
        self.fc_mu = nn.Linear(self.fc_size, latent_dim)
        self.fc_var = nn.Linear(self.fc_size, latent_dim)

        
        self.decoder_input = nn.Linear(latent_dim, self.fc_size)
        
        mul = inputs / depth
#         print(f'decode start mul = {mul}')
        decoder_modules = []
        for i in range(self.layer_count):
            
            out_size = int(mul)
            print(f'decoder module {i} out_size = {out_size}')
            if i == (self.layer_count -1):
                output_padding = 1
            else:
                output_padding = 0
            decoder_modules.append(self.decode_unit(inputs, out_size, self.kernel_size, self.stride, padding=self.padding, dropout=self.decode_dropout, output_padding=output_padding))
            mul = mul / depth
            inputs = out_size
            
        self.decoder = nn.Sequential(*decoder_modules)
        
        self.mse_loss = nn.MSELoss()

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, int(self.depth ** self.layer_count), self.encoder_end_width, self.encoder_end_width)
        result = self.decoder(result)
        # result = self.final_layer(result)
        return result

    def forward(self, x):
        mu, logvar = self.encode(x)
        mu = mu.squeeze()
        logvar = logvar.squeeze()
        z = self.reparameterize(mu, logvar)
#         return self.decode(z.view(-1, self.zsize, 1, 1)), x, mu, logvar
        return  [self.decode(z), x, mu, logvar]

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
            
    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        
        device = self.dummy_param.device
        
        recons = args[0]
        inputs = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = self.kld_loss_weight # Account for the minibatch samples from the dataset
        recons_loss = self.mse_loss(recons, inputs)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 0), dim = 0)

        loss = recons_loss + kld_loss * kld_weight
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}
    
    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]
    
    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x = batch
        x_hat, inputs, mu, log_var = self(x)
        vae_loss = self.loss_function(x_hat, inputs, mu, log_var)
        
        self.log('train_x_loss', vae_loss['loss'], on_step=True, prog_bar=False)
        self.log('train_x_recon_loss', vae_loss['Reconstruction_Loss'], on_step=True, prog_bar=False)
        self.log('train_kld_loss', vae_loss['KLD'], on_step=True, prog_bar=False)
        loss = vae_loss['loss']
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch
        x_hat, inputs, mu, log_var = self(x)
        vae_loss = self.loss_function(x_hat, inputs, mu, log_var)
        self.log('val_x_loss', vae_loss['loss'], on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_x_recon_loss', vae_loss['Reconstruction_Loss'], on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_kld_loss', vae_loss['KLD'], on_step=False, on_epoch=True, prog_bar=True)
        loss = vae_loss['loss']
        return loss
    
    def test_step(self, batch, batch_idx):
        x = batch
        x_hat, inputs, mu, log_var = self(x)
        vae_loss = self.loss_function(x_hat, inputs, mu, log_var)
        self.log('test_x_loss', vae_loss['loss'], on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_x_recon_loss', vae_loss['Reconstruction_Loss'], on_step=False, on_epoch=True, prog_bar=False)
        self.log('test_kld_loss', vae_loss['KLD'], on_step=False, on_epoch=True, prog_bar=False)
        loss = vae_loss['loss']
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return optimizer


def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [None]:

class MLPVAE(pl.LightningModule):
    
    def encode_unit(self, inputs, out_size, kernel_size, stride, padding=0, dropout=0):
        return nn.Sequential(
            nn.Conv2d(inputs, out_size, kernel_size, stride, padding),
            nn.BatchNorm2d(out_size, track_running_stats=False, affine=False),
            nn.Dropout(dropout),
            nn.ReLU()
            
        )
    
    def decode_unit(self, inputs, out_size, kernel_size, stride, padding=0, dropout=0, output_padding=0):
        return nn.Sequential(
            nn.ConvTranspose2d(inputs, out_size, kernel_size, stride, padding, output_padding=output_padding),
            nn.BatchNorm2d(out_size, track_running_stats=False, affine=False),
            nn.Dropout(dropout),
            nn.ReLU()
        )
    
    def fc_bn_lrelu(self, in_size, out_size, dropout=0):
        return nn.Sequential(
            nn.Linear(in_size, out_size),
            nn.BatchNorm1d(out_size, track_running_stats=False, affine=False),
            nn.Dropout(dropout),
            nn.ReLU(),
            
        )
    
    def __init__(
        self, 
        
        
        latent_dim: int,
        input_size=(256, 256),
        lr=1e-3,
        weight_decay=1e-5,
        hidden_dims: List = None,
        kld_loss_weight=1.,
        input_dropout=0.,
        encoder_dropout=0.,
        decoder_dropout=0.,
        do_test_step=True,
        
    ):
        super(MLPVAE, self).__init__()

        
        self.latent_dim = latent_dim
        self.lr = lr
        
        self.weight_decay = weight_decay
        self.kld_loss_weight = kld_loss_weight
        
        self.encoder_dropout = encoder_dropout
        self.decoder_dropout = decoder_dropout
        self.input_size = input_size
        self.dummy_param = nn.Parameter(torch.empty(0))
        self.n_features = self.input_size[0] * self.input_size[1]
        
        
        if hidden_dims == None:
            hidden_dims = [32, 16, 10]
        
        self.hidden_dims = hidden_dims
        
        self.input_dropout = nn.Dropout(input_dropout)
        
        
        # Build Encoder
        self.encoder_dims = [self.n_features] + self.hidden_dims
        self.encoder_modules = []
        for ii in range(len(self.encoder_dims) - 1):
            input_d = self.encoder_dims[ii]
            output_d = self.encoder_dims[ii+1]
            self.encoder_modules.append(self.fc_bn_lrelu(input_d, output_d, dropout=self.encoder_dropout))
        self.encoder = nn.Sequential(*self.encoder_modules)
        self.fc_mu = nn.Linear(self.encoder_dims[-1], latent_dim)
        self.fc_var = nn.Linear(self.encoder_dims[-1], latent_dim)

        # Build Decoder
        self.decoder_modules = []
        self.decoder_dims = list(reversed(self.encoder_dims))[:-1]
        self.decoder_input = nn.Linear(latent_dim, self.decoder_dims[0])
        self.decoder_modules = []
        
        for ii in range(len(self.decoder_dims) - 1):
            input_d = self.decoder_dims[ii]
            output_d = self.decoder_dims[ii+1]
            self.decoder_modules.append(self.fc_bn_lrelu(input_d, output_d, dropout=self.decoder_dropout))

#         self.decoder_modules.append(
#             self.fc_bn_lrelu(self.decoder_dims[-1], self.n_features)
#         )
        self.decoder_modules.append(
            nn.Sequential(
            nn.Linear(self.decoder_dims[-1], self.n_features),
            nn.PReLU(),
        ))
        self.decoder = nn.Sequential(*self.decoder_modules)
        
        
        
        self.mse_loss = nn.MSELoss()

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input.flatten(start_dim=1))
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        
        result = self.decoder(result).unflatten(dim=1, sizes=self.input_size).unsqueeze(1)
        # result = self.final_layer(result)
        return result

    def forward(self, x):
        mu, logvar = self.encode(x)
        mu = mu.squeeze()
        logvar = logvar.squeeze()
        z = self.reparameterize(mu, logvar)
#         return self.decode(z.view(-1, self.zsize, 1, 1)), x, mu, logvar
        return  [self.decode(z), x, mu, logvar]

    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
            
    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        
        device = self.dummy_param.device
        
        recons = args[0]
        inputs = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = self.kld_loss_weight # Account for the minibatch samples from the dataset
        recons_loss = self.mse_loss(recons, inputs)


        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 0), dim = 0)

        loss = recons_loss + kld_loss * kld_weight
        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}
    
    def sample(self,
               num_samples:int,
               current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]
    
    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x = batch
        x_hat, inputs, mu, log_var = self(x)
        vae_loss = self.loss_function(x_hat, inputs, mu, log_var)
        
        self.log('train_x_loss', vae_loss['loss'], on_step=True, prog_bar=False)
        self.log('train_x_recon_loss', vae_loss['Reconstruction_Loss'], on_step=True, prog_bar=False)
        self.log('train_kld_loss', vae_loss['KLD'], on_step=True, prog_bar=False)
        loss = vae_loss['loss']
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch
        x_hat, inputs, mu, log_var = self(x)
        vae_loss = self.loss_function(x_hat, inputs, mu, log_var)
        self.log('val_x_loss', vae_loss['loss'], on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_x_recon_loss', vae_loss['Reconstruction_Loss'], on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_kld_loss', vae_loss['KLD'], on_step=False, on_epoch=True, prog_bar=True)
        loss = vae_loss['loss']
        return loss
    
    def test_step(self, batch, batch_idx):
        x = batch
        x_hat, inputs, mu, log_var = self(x)
        vae_loss = self.loss_function(x_hat, inputs, mu, log_var)
        self.log('test_x_loss', vae_loss['loss'], on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_x_recon_loss', vae_loss['Reconstruction_Loss'], on_step=False, on_epoch=True, prog_bar=False)
        self.log('test_kld_loss', vae_loss['KLD'], on_step=False, on_epoch=True, prog_bar=False)
        loss = vae_loss['loss']
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        return optimizer


def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [None]:
dm = ImageDataModule(
    np.stack(imgs), 
    batch_size=32,
    num_workers=48,
)

dm.setup()
x = next(iter(dm.train_dataloader()))

In [None]:
# vae = MLPVAE(
#     latent_dim=256,
#     input_size=(128,128),
    
#     lr=1e-3,
#     weight_decay=1e-5,
    
#     encoder_dropout=0.4,
#     kld_loss_weight=1e-3,
#     hidden_dims=None,
# )

In [None]:
# x_hat, inputs, mu, log_var = vae(x)
# print(x_hat.shape)
# print(inputs.shape)

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = '052_ode_vae.ipynb'
run = wandb.init(
    project='ode_vae', 
    entity='alex_naka', 
#     mode='disabled',
)

pl.seed_everything(117)

callbacks = []
earlystopper = EarlyStopping(monitor='val_x_recon_loss',patience=300,mode='min')
callbacks += [earlystopper]

logger = WandbLogger( )

dm = ImageDataModule(
    np.stack(imgs), 
    batch_size=16,
    num_workers=0,
)

vae = MLPVAE(
    latent_dim=64,
    input_size=(128, 128),
    
    lr=1e-3,
    weight_decay=1e-14,
    
    encoder_dropout=0.05,
    kld_loss_weight=1e-5,
    hidden_dims=[256, 64, 64],
)

trainer = pl.Trainer(       
        gpus=1,
        max_epochs=1000, 
        progress_bar_refresh_rate=50,
        logger=logger, 
        callbacks=callbacks,
        
        )

trainer.fit(vae, dm)


In [None]:

x = next(iter(dm.train_dataloader()))
ii = np.random.randint(len(x))
f,axs = plt.subplots(1,2,figsize=(16,8))
ax = axs[0]
ax.imshow(dm.train_data[ii].detach().numpy().squeeze())
ax = axs[1]
img_hat = vae(x)[0]
ax.imshow(img_hat[ii].detach().cpu().numpy().squeeze())

In [None]:

x = next(iter(dm.val_dataloader()))
ii = np.random.randint(len(x))
f,axs = plt.subplots(1,2,figsize=(16,8))
ax = axs[0]
ax.imshow(x[ii].detach().numpy().squeeze())
ax = axs[1]
img_hat = vae(x)[0]
ax.imshow(img_hat[ii].detach().cpu().numpy().squeeze())

In [None]:
vae.device

In [None]:
f,axs = plt.subplots(4, 4,figsize=(16,16))
axs = axs.ravel()
img_samples = vae.sample(16, vae.device)
for ii, img_sample in enumerate(img_samples):
    ax = axs[ii]
    ax.imshow(img_sample.detach().cpu().numpy().squeeze())
    ax.axis('off')
plt.tight_layout()

images 

In [None]:
out_shape=(256, 256)
border = 5
img_drawbox = box(border, border, out_shape[0]-border, out_shape[1]-border)
xmin, xmax = (-18, 18)
ymin, ymax = (-18, 18)
xrange = np.linspace(xmin, xmax, out_shape[0])
yrange = np.linspace(ymin, ymax, out_shape[1])
xgs, ygs = np.meshgrid(xrange, yrange)

In [None]:
args = OdeParams().get_args()

In [None]:
args

In [None]:
dx, dy = ode((xg,yg), 0, *args)

In [None]:
dxs = []
dys = []
angles = []
mags = []
for xg, yg in zip(xgs.ravel(), ygs.ravel()):
    dx, dy = ode((xg,yg), 0, *args)
    dxs.append(dx)
    dys.append(dy)
    angles.append(np.arctan2(dx, dy))
    mags.append(np.sqrt(dx **2 + dy **2))
    

In [None]:
plt.imshow(np.array(mags).reshape(*out_shape))

In [None]:
# for ii in tqdm(range(20000)):
#     args = OdeParams().get_args()
#     lss = no_overlap_odeint(ode, pts, args, t, break_dist=0.5, min_len=0.9)
#     lss = gp.make_like(gp.merge_LineStrings(lss), img_drawbox)
#     img = features.rasterize(lss, out_shape=out_shape)
#     linesets.append(lss)
#     imgs.append(img)

In [None]:
# np.savez_compressed('/home/naka/data/ode_vae/test.npz', imgs)

In [None]:
imgs = np.load('/home/naka/data/ode_vae/test.npz')['arr_0']

# new data

In [None]:
out_shape=(128, 128)
border = 0
img_drawbox = box(border, border, out_shape[0]-border, out_shape[1]-border)

In [None]:
imgs = []
n_images=2000

for ii in tqdm(range(n_images)):
    n_shapes = np.random.randint(1,2)
    xs = np.linspace(0, 6, n_shapes) + np.random.uniform(-2,2, n_shapes)
    ys = np.linspace(0, 6, n_shapes) + np.random.uniform(-2,2, n_shapes)

    points = []
    for x in xs:
        for y in ys:
            points.append(Point(x,y))

    polys = [gp.RegPolygon(p, radius=2, n_corners=np.random.randint(3,8), rotation=np.random.uniform(180)).poly for p in points]
    polys = gp.make_like(gp.merge_Polygons(polys), img_drawbox)
    poly = so.unary_union(polys)

    prms = gp.ScaleTransPrms(
        d_buffer=np.random.uniform(-7.4, -2.8),
#         d_buffer=-1.2,
        n_iters=400,
        d_translate_factor=0.7,
        angles=np.random.uniform(0,180),
    )

    poly = gp.Poly(poly)

    poly.fill_scale_trans(**prms.prms)
    lss = poly.fill
    lss = gp.make_like(gp.merge_LineStrings(lss), img_drawbox)
    img = features.rasterize(lss, out_shape=out_shape, fill=np.random.randint(1,10), default_value=np.random.randint(10,100))
#     img = features.rasterize(polys, out_shape=out_shape, fill=10)
#     linesets.append(lss)
    imgs.append(img)

In [None]:
f, axs = plt.subplots(10,10, figsize=(18,18))
axs = axs.ravel()
for ii in range(100):
    ax = axs[ii]
    ax.imshow(imgs[ii])
    ax.axis('off')
    
plt.tight_layout()

In [None]:
np.savez_compressed('/home/naka/data/ode_vae/hatch_fill_polys.npz', imgs)

In [None]:
imgs = np.load('/home/naka/data/ode_vae/hatch_fill_polys.npz')['arr_0']

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = '052_ode_vae.ipynb'
run = wandb.init(
    project='hatch_fill_poly_blob_vae', 
    entity='alex_naka', 
#     mode='disabled',
)

pl.seed_everything(117)

callbacks = []
earlystopper = EarlyStopping(monitor='val_x_recon_loss',patience=300,mode='min')
callbacks += [earlystopper]

logger = WandbLogger( )

dm = ImageDataModule(
    np.stack(imgs), 
    batch_size=128,
    num_workers=0,
)

vae = MLPVAE(
    latent_dim=24,
    input_size=(128, 128),
    
    lr=1e-3,
    weight_decay=1e-14,
    
    encoder_dropout=0.1,
    kld_loss_weight=1e-7,
    hidden_dims=[256, 64],
)

trainer = pl.Trainer(       
        gpus=1,
        max_epochs=1000, 
        progress_bar_refresh_rate=50,
        logger=logger, 
        callbacks=callbacks,
        
        )

trainer.fit(vae, dm)


In [None]:
dm = ImageDataModule(
    np.stack(imgs), 
    batch_size=128,
    num_workers=48,
)

dm.setup()
x = next(iter(dm.train_dataloader()))

In [None]:
vae = VAE(
    latent_dim=32,
    input_size=128,
    layer_count=3, 
    channels=1, 
    depth=2,
    lr=1e-3,
    weight_decay=1e-9,
    kld_loss_weight=1e-3,
    mul=1,
    fc_scale=8,
    encode_dropout=0.0,
    decode_dropout=0.0,
    kernel_size=3,
    stride=2,
    fc_size=1800,
    padding=0,
)

In [None]:
x_hat, inputs, mu, log_var = vae(x)
print(x_hat.shape)
print(inputs.shape)

In [None]:
os.environ['WANDB_NOTEBOOK_NAME'] = '052_ode_vae.ipynb'
run = wandb.init(
    project='hatch_fill_poly_blob_vae', 
    entity='alex_naka', 
#     mode='disabled',
)

pl.seed_everything(117)

callbacks = []
earlystopper = EarlyStopping(monitor='val_x_recon_loss',patience=300,mode='min')
callbacks += [earlystopper]

logger = WandbLogger( )

dm = ImageDataModule(
    np.stack(imgs), 
    batch_size=128,
    num_workers=0,
)

vae = VAE(
    latent_dim=32,
    input_size=128,
    layer_count=3, 
    channels=1, 
    depth=2,
    lr=1e-2,
    weight_decay=1e-9,
    kld_loss_weight=1e-7,
    mul=1,
    fc_scale=8,
    encode_dropout=0.0,
    decode_dropout=0.0,
    kernel_size=3,
    stride=2,
    fc_size=1800,
    padding=0,
)

trainer = pl.Trainer(       
        gpus=1,
        max_epochs=1000, 
        progress_bar_refresh_rate=50,
        logger=logger, 
        callbacks=callbacks,
        
        )

trainer.fit(vae, dm)


In [None]:

x = next(iter(dm.train_dataloader()))
ii = np.random.randint(len(x))
img_hat = vae(x)
f,axs = plt.subplots(1,2,figsize=(16,8))
ax = axs[0]
ax.imshow(img_hat[1][ii].detach().cpu().numpy().squeeze())
ax = axs[1]

ax.imshow(img_hat[0][ii].detach().cpu().numpy().squeeze())

In [None]:

x = next(iter(dm.val_dataloader()))
ii = np.random.randint(len(x))
f,axs = plt.subplots(1,2,figsize=(16,8))
ax = axs[0]
ax.imshow(x[ii].detach().numpy().squeeze())
ax = axs[1]
img_hat = vae(x)[0]
ax.imshow(img_hat[ii].detach().cpu().numpy().squeeze())

In [None]:
vae.device

In [None]:
f,axs = plt.subplots(4, 4,figsize=(16,16))
axs = axs.ravel()
img_samples = vae.sample(16, vae.device)
for ii, img_sample in enumerate(img_samples):
    ax = axs[ii]
    ax.imshow(img_sample.detach().cpu().numpy().squeeze())
    ax.axis('off')
plt.tight_layout()

In [None]:
num_samples=2
z = torch.randn(num_samples,
                        vae.latent_dim)
_img = vae.decode(z)[0]
plt.imshow(_img.detach().cpu().numpy().squeeze())

In [None]:
z[0][0]+ 0.1

In [None]:
vae.latent_dim

In [None]:
n_interps=10

f,axs = plt.subplots(vae.latent_dim, n_interps,figsize=(1.2*n_interps, 1.2*vae.latent_dim,))
for ii, dim in enumerate(range(vae.latent_dim)):
    for jj, zp in enumerate(np.linspace(-4.5, 4.5, n_interps)):
        ax = axs[ii, jj]
        _z = z.clone()
        _z[0][dim] += zp
        _img = vae.decode(_z)[0]
        ax.imshow(_img.detach().cpu().numpy().squeeze())
        ax.axis('off')
plt.tight_layout()

In [None]:
num_samples=2
z = torch.randn(num_samples,
                        vae.latent_dim)
_img = vae.decode(z)[0]
plt.imshow(_img.detach().cpu().numpy().squeeze())

In [None]:
vae.latent_dim

In [None]:
n_walks = 64
_z = z.clone()
f,axs = plt.subplots(8,8,figsize=(12,12))
axs= axs.ravel()
for ii in range(n_walks):
    ax = axs[ii]
    _z += torch.randn(num_samples,
                        vae.latent_dim)*0.25
    _img = vae.decode(_z)[0]
    ax.imshow(_img.detach().cpu().numpy().squeeze())
    ax.axis('off')
plt.tight_layout()