# MetaSDF & Meta-SIREN

This is a colab to explore MetaSDF, and its applications to rapidly fit neural implicit representations.

Make sure to switch the runtime type to "GPU" under "Runtime --> Change Runtime Type"!

We will show you how to run two experiments using gradient-based meta-learning: 
* [Fitting an image in 3 gradient descent steps with SIREN](#section_1)
* [Fitting 2D Signed Distance Functions of MNIST digits](#section_2)

Let's go! 

First, the imports:

In [3]:
import os
import torch
import gc

import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms
import scipy.ndimage
from torch import nn 
from collections import OrderedDict, Mapping 
from torch.utils.data import DataLoader, Dataset 

from torch.nn.init import _calculate_correct_fan 

  from collections import OrderedDict, Mapping


In [4]:
x = torch.tensor([[1,2,3], [10,20,30]]) 
print(x.repeat(1, 5).shape, x.shape,'\n>>\n', x.repeat(1, 5), '\n', x.repeat(1, 5).view([-1, 5, 3]), x )

torch.Size([2, 15]) torch.Size([2, 3]) 
>>
 tensor([[ 1,  2,  3,  1,  2,  3,  1,  2,  3,  1,  2,  3,  1,  2,  3],
        [10, 20, 30, 10, 20, 30, 10, 20, 30, 10, 20, 30, 10, 20, 30]]) 
 tensor([[[ 1,  2,  3],
         [ 1,  2,  3],
         [ 1,  2,  3],
         [ 1,  2,  3],
         [ 1,  2,  3]],

        [[10, 20, 30],
         [10, 20, 30],
         [10, 20, 30],
         [10, 20, 30],
         [10, 20, 30]]]) tensor([[ 1,  2,  3],
        [10, 20, 30]])


For meta-learning, we're using the excellent "Torchmeta" library. We have to install it:

In [5]:
# !pip install torchmeta
from torchmeta.modules import (MetaModule, MetaSequential, MetaLinear)

We're now ready to implement a few neural network layers: Fully connected networks, and SIREN.

In [6]:
class BatchLinear(nn.Linear, MetaModule):
    '''A linear meta-layer that can deal with batched weight matrices and biases, as for instance output by a
    hypernetwork.'''
    __doc__ = nn.Linear.__doc__

    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        bias = params.get('bias', None)
        weight = params['weight']

        output = input.matmul(weight.permute(*[i for i in range(len(weight.shape)-2)], -1, -2))
        output += bias.unsqueeze(-2)
        return output


class MetaFC(MetaModule):
    '''A fully connected neural network that allows swapping out the weights, either via a hypernetwork
    or via MAML.
    '''
    def __init__(self, in_features, out_features,
                 num_hidden_layers, hidden_features,
                 outermost_linear=False):
        super().__init__()

        self.net = []
        self.net.append(MetaSequential(
            BatchLinear(in_features, hidden_features),
            nn.ReLU(inplace=True)
        ))

        for i in range(num_hidden_layers):
            self.net.append(MetaSequential(
                BatchLinear(hidden_features, hidden_features),
                nn.ReLU(inplace=True)
            ))

        if outermost_linear:
            self.net.append(MetaSequential(
                BatchLinear(hidden_features, out_features),
            ))
        else:
            self.net.append(MetaSequential(
                BatchLinear(hidden_features, out_features),
                nn.ReLU(inplace=True)
            ))

        self.net = MetaSequential(*self.net)
        self.net.apply(init_weights_normal)

    def forward(self, coords, params=None, **kwargs):
        '''Simple forward pass without computation of spatial gradients.'''
        output = self.net(coords, params=self.get_subdict(params, 'net'))
        return output


class SineLayer(MetaModule):
    # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of omega_0.

    # If is_first=True, omega_0 is a frequency factor which simply multiplies the activations before the
    # nonlinearity. Different signals may require different omega_0 in the first layer - this is a
    # hyperparameter.

    # If is_first=False, then the weights will be divided by omega_0 so as to keep the magnitude of
    # activations constant, but boost gradients to the weight matrix (see supplement Sec. 1.5)

    def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30):
        super().__init__()
        self.omega_0 = float(omega_0)

        self.is_first = is_first

        self.in_features = in_features
        self.linear = BatchLinear(in_features, out_features, bias=bias)
        self.init_weights()

    def init_weights(self):
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features,
                                            1 / self.in_features)
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
                                            np.sqrt(6 / self.in_features) / self.omega_0)

    def forward(self, input, params=None):
        intermed = self.linear(input, params=self.get_subdict(params, 'linear'))
        return torch.sin(self.omega_0 * intermed)


class Siren(MetaModule):
    def __init__(self, in_features, hidden_features, hidden_layers, out_features, outermost_linear=False,
                 first_omega_0=30, hidden_omega_0=30., special_first=True):
        super().__init__()
        self.hidden_omega_0 = hidden_omega_0

        layer = SineLayer

        self.net = []
        self.net.append(layer(in_features, hidden_features,
                              is_first=special_first, omega_0=first_omega_0))

        for i in range(hidden_layers):
            self.net.append(layer(hidden_features, hidden_features,
                                  is_first=False, omega_0=hidden_omega_0))

        if outermost_linear:
            final_linear = BatchLinear(hidden_features, out_features)

            with torch.no_grad():
                final_linear.weight.uniform_(-np.sqrt(6 / hidden_features) / 30.,
                                             np.sqrt(6 / hidden_features) / 30.)
            self.net.append(final_linear)
        else:
            self.net.append(layer(hidden_features, out_features, is_first=False, omega_0=hidden_omega_0))

        self.net = nn.ModuleList(self.net)

    def forward(self, coords, params=None):
        x = coords

        for i, layer in enumerate(self.net):
            x = layer(x, params=self.get_subdict(params, f'net.{i}'))

        return x
    
    
def init_weights_normal(m):
    if type(m) == BatchLinear or nn.Linear:
        if hasattr(m, 'weight'):
            torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if hasattr(m, 'bias'):
            m.bias.data.fill_(0.)
            
            
def get_mgrid(sidelen):
    # Generate 2D pixel coordinates from an image of sidelen x sidelen
    pixel_coords = np.stack(np.mgrid[:sidelen,:sidelen], axis=-1)[None,...].astype(np.float32)
    pixel_coords /= sidelen    
    pixel_coords -= 0.5
    pixel_coords = torch.Tensor(pixel_coords).view(-1, 2)
    return pixel_coords

Sazan: Now let's implement our Cross-Attention Hypernetwork. It will take the image as input and generate some matrices with the same dimension as the weights of the SIREN

In [7]:
import torch.nn.functional as F

from modules_custom import Conv2dResBlock

class CrossAttentionHyperNet(nn.Module):
    def __init__(self):
        super().__init__()
        L = 64
        self.conv1 = nn.Conv2d(3, 32, 5, padding=5//2) # padding=kernel_size//2
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=5//2) # padding=kernel_size//2
        self.conv3 = nn.Conv2d(64, L, 5, padding=5//2)
        self.conv4 = nn.Conv2d(L, L, 5, padding=5//2)
        # self.conv4_dim_reduction = nn.Conv2d(16+32+64, 64, 1, padding=0)
#         self.cnn = nn.Sequential(
#             nn.Conv2d(128, 256, 3, 1, 1),
#             nn.ReLU(),
#             Conv2dResBlock(256, 256),
#             Conv2dResBlock(256, 256),
#             Conv2dResBlock(256, 256),
#             Conv2dResBlock(256, 256),
#             nn.Conv2d(256, 256, 1, 1, 0)
#         )
#         self.relu_2 = nn.ReLU(inplace=True)
#         self.fc = nn.Linear(1024, 1)

        if True:
            self.fc0 = nn.Linear(L, 2)
            self.fc1 = nn.Linear(L, L)
            self.fc2 = nn.Linear(L, L)
            self.fc3 = nn.Linear(L, L)
            self.fc4_1 = nn.Linear(L, 3)
            self.fc4_2 = nn.Linear(3, 3)
            self.fc4_bias = nn.Linear(L, 3)

        if False:
            self.weighted_mean = torch.nn.Conv1d(in_channels=64, out_channels=5, kernel_size=1)

            self.bias0_fc = nn.Linear(64+64, 64)
            self.bias1_fc = nn.Linear(64+64, 64)
            self.bias2_fc = nn.Linear(64+64, 64)
            self.bias3_fc = nn.Linear(64+64, 64)
            self.bias4_fc = nn.Linear(3+64, 3)
            
            self.attn_bias0_fc = nn.Linear(64, 1)
            self.attn_bias1_fc = nn.Linear(64, 1)
            self.attn_bias2_fc = nn.Linear(64, 1)
            self.attn_bias3_fc = nn.Linear(64, 1)
            self.attn_bias4_fc = nn.Linear(3, 1)


        self.wt_cross_attn0 = nn.MultiheadAttention(embed_dim=2, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        self.wt_cross_attn1 = nn.MultiheadAttention(embed_dim=64, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        self.wt_cross_attn2 = nn.MultiheadAttention(embed_dim=64, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        self.wt_cross_attn3 = nn.MultiheadAttention(embed_dim=64, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        self.wt_cross_attn4 = nn.MultiheadAttention(embed_dim=64, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        
        self.bias_cross_attn0 = nn.MultiheadAttention(embed_dim=64, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        self.bias_cross_attn1 = nn.MultiheadAttention(embed_dim=64, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        self.bias_cross_attn2 = nn.MultiheadAttention(embed_dim=64, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        self.bias_cross_attn3 = nn.MultiheadAttention(embed_dim=64, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        self.bias_cross_attn4 = nn.MultiheadAttention(embed_dim=3, num_heads=1, dropout=0.1, bias=True)#, batch_first=True)
        
        
    '''
    net.0.linear.weight :	 torch.Size([64, 2]) 	 torch.Size([16, 64, 2])
    net.0.linear.bias :	 torch.Size([64]) 	 torch.Size([16, 64])
    net.1.linear.weight :	 torch.Size([64, 64]) 	 torch.Size([16, 64, 64])
    net.1.linear.bias :	 torch.Size([64]) 	 torch.Size([16, 64])
    net.2.linear.weight :	 torch.Size([64, 64]) 	 torch.Size([16, 64, 64])
    net.2.linear.bias :	 torch.Size([64]) 	 torch.Size([16, 64])
    net.3.linear.weight :	 torch.Size([64, 64]) 	 torch.Size([16, 64, 64])
    net.3.linear.bias :	 torch.Size([64]) 	 torch.Size([16, 64])
    net.4.weight :	 torch.Size([3, 64]) 	 torch.Size([16, 3, 64])
    net.4.bias :	 torch.Size([3]) 	 torch.Size([16, 3])
    '''
    
    def forward_conv(self, x):
        x = x.permute(0, 3, 1, 2).contiguous()
        b = x.shape[0]
        # print('1>', x.shape)
        x = self.pool(F.relu(self.conv1(x))) # bx3x32x32 -> bx32x16x16
        # print('2>', x.shape)
        x = self.pool(F.relu(self.conv2(x))) # bx32x16x16 -> bx64x8x8 -> bx8x8x64
        x = F.relu(self.conv3(x)) # bx32x16x16 -> bx64x8x8 -> bx8x8x64
        x = F.relu(self.conv4(x))
        x = x.permute(0, 2, 3, 1).contiguous() # bx32x16x16 -> bx64x8x8 -> bx8x8x64
        # print('3>', x.shape)
        # x3 = self.pool(F.relu(self.conv3(x2)))
        # x = torch.cat([x1, x2], -1)
        # x = F.relu(self.conv4_dim_reduction(x))
#         print('>>', x.shape)
        x = x.view([b, 64, 64]) # bx8x8x64 -> bx64x64 ## channel last
        # print('4>', x.shape)
        
        x0 = F.relu(self.fc0(x)) # bx64x64 -> bx64x2
        x1 = F.relu(self.fc1(x)) # bx64x64 -> bx64x64
        x2 = F.relu(self.fc2(x)) # bx64x64 -> bx64x64
        x3 = F.relu(self.fc3(x)) # bx64x64 -> bx64x64
        x4 = F.relu(self.fc4_1(x.permute(0,2,1))) # bx64x64 -> bx64x64 -> bx64x3
        x4 = F.relu(self.fc4_2(x4)).permute(0,2,1).contiguous() # bx64x3 -> bx64x3 -> bx3x64

        # x_biases = F.relu(self.weighted_mean(x)) # bx64x64 -> bx5x64
        # x0_bias = F.relu(self.bias0_fc(x_biases[:, 0, :])) # bx5x64 ->  bx64 -> bx64
        # x1_bias = F.relu(self.bias1_fc(x_biases[:, 1, :])) # bx5x64 ->  bx64 -> bx64
        # x2_bias = F.relu(self.bias2_fc(x_biases[:, 2, :])) # bx5x64 ->  bx64 -> bx64
        # x3_bias = F.relu(self.bias3_fc(x_biases[:, 3, :])) # bx5x64 ->  bx64 -> bx64
        # x4_bias = F.relu(self.bias4_fc(x_biases[:, 4, :])) # bx5x64 ->  bx64 -> bx3

        return x, x0, x1, x2, x3, x4#, x0_bias, x1_bias, x2_bias, x3_bias, x4_bias
    
    def bias_attention(self, x, meta_param_bias, bias_fc, attn_bias_fc):
        b, c = meta_param_bias.shape
        meta_param_bias = meta_param_bias.view(b, 1, c)
        param_bias = torch.cat([meta_param_bias.repeat(1,64,1), x], -1) # [bx1xc, bx64x64] -> [bx64xc, bx64x64] -> bx64x(c+64)
        param_bias = F.relu(bias_fc(param_bias)) # bx64x(c+64) -> bx64xc
        attention_scores = F.relu(attn_bias_fc(param_bias)).permute(0,2,1) # bx64xc -> bx64x1 -> bx1x64
        attention_scores = F.softmax(attention_scores, -1) # bx1x64 -> bx1x64
        param_bias = torch.bmm(attention_scores, param_bias).view(b, c) # [bx1x64, bx64xc] -> bx1xc -> bxc
        return param_bias

    def compute_biases(self, x, meta_params):
        x0_bias = self.bias_attention(x=x, meta_param_bias=meta_params['net.0.linear.bias'], bias_fc=self.bias0_fc, attn_bias_fc=self.attn_bias0_fc)
        x1_bias = self.bias_attention(x=x, meta_param_bias=meta_params['net.1.linear.bias'], bias_fc=self.bias1_fc, attn_bias_fc=self.attn_bias1_fc)
        x2_bias = self.bias_attention(x=x, meta_param_bias=meta_params['net.2.linear.bias'], bias_fc=self.bias2_fc, attn_bias_fc=self.attn_bias2_fc)
        x3_bias = self.bias_attention(x=x, meta_param_bias=meta_params['net.3.linear.bias'], bias_fc=self.bias3_fc, attn_bias_fc=self.attn_bias3_fc)
        x4_bias = self.bias_attention(x=x, meta_param_bias=meta_params['net.4.bias'], bias_fc=self.bias4_fc, attn_bias_fc=self.attn_bias4_fc)
        return x0_bias, x1_bias, x2_bias, x3_bias, x4_bias 

    def compute_loss(self, specialized_param, gt_specialized_param):
        loss = 0.
        for key in specialized_param:
            loss += F.mse_loss(input=specialized_param[key], target=gt_specialized_param[key])
        loss /= 10 # not sure if it will help
        return loss

    def forward(self, x, meta_params):
        x, x0, x1, x2, x3, x4 = self.forward_conv(x) 
        # x0_bias, x1_bias, x2_bias, x3_bias, x4_bias = self.compute_biases(x, meta_params)

        # x = self.forward_conv(x) # bx32x32x3 -> bx64x64
        b, l, c = x.shape
        # query -> from meta model ==> meta_params
        # key and value -> from this model ==> x
        specialized_param = OrderedDict()

        # print('1>', meta_params['net.4.bias'].shape, x.shape)
        # x_in = self.fc0(x)
        x_out = F.relu(self.fc4_bias(x))
        specialized_param['net.0.linear.weight']  = self.wt_cross_attn0(query=meta_params['net.0.linear.weight'], key=x0, value=x0)[0]
        specialized_param['net.0.linear.bias']  = self.bias_cross_attn0(query=meta_params['net.0.linear.bias'].view(b, 1, -1), key=x, value=x)[0].view(b, -1)
        specialized_param['net.1.linear.weight']  = self.wt_cross_attn1(query=meta_params['net.1.linear.weight'], key=x1, value=x1)[0]
        specialized_param['net.1.linear.bias']  = self.bias_cross_attn1(query=meta_params['net.1.linear.bias'].view(b, 1, -1), key=x, value=x)[0].view(b, -1)
        specialized_param['net.2.linear.weight']  = self.wt_cross_attn2(query=meta_params['net.2.linear.weight'], key=x2, value=x2)[0]
        specialized_param['net.2.linear.bias']  = self.bias_cross_attn2(query=meta_params['net.2.linear.bias'].view(b, 1, -1), key=x, value=x)[0].view(b, -1)
        specialized_param['net.3.linear.weight']  = self.wt_cross_attn3(query=meta_params['net.3.linear.weight'], key=x3, value=x3)[0]
        specialized_param['net.3.linear.bias']  = self.bias_cross_attn3(query=meta_params['net.3.linear.bias'].view(b, 1, -1), key=x, value=x)[0].view(b, -1)
        specialized_param['net.4.weight']  = self.wt_cross_attn4(query=meta_params['net.4.weight'], key=x4, value=x4)[0]
        specialized_param['net.4.bias']  = self.bias_cross_attn4(query=meta_params['net.4.bias'].view(b, 1, -1), key=x_out, value=x_out)[0].view(b, -1)

#         loss = self.compute_loss(specialized_param, gt_specialized_param)

#         return loss, specialized_param
        return specialized_param



Now, we implement MAML. The important parts of the code are commented, so it's easy to understand how each part works! Start by looking at the "forward" function.



In [24]:
import time

def l2_loss(prediction, gt):
    return ((prediction - gt)**2).mean()


class MAML(nn.Module):
    def __init__(self, num_meta_steps, hypo_module, crossAttHypNet, loss, init_lr,
                 lr_type='static', first_order=False):
        super().__init__()

        self.hypo_module = hypo_module # The module who's weights we want to meta-learn.
        self.crossAttHypNet = crossAttHypNet
        self.first_order = first_order
        self.loss = loss
        self.lr_type = lr_type
        self.log = []

        self.register_buffer('num_meta_steps', torch.Tensor([num_meta_steps]).int())

        if self.lr_type == 'static': 
            self.register_buffer('lr', torch.Tensor([init_lr]))
        elif self.lr_type == 'global':
            self.lr = nn.Parameter(torch.Tensor([init_lr]))
        elif self.lr_type == 'per_step':
            self.lr = nn.ParameterList([nn.Parameter(torch.Tensor([init_lr]))
                                        for _ in range(num_meta_steps)])
        elif self.lr_type == 'per_parameter': # As proposed in "Meta-SGD".
            self.lr = nn.ParameterList([])
            hypo_parameters = hypo_module.parameters()
            for param in hypo_parameters:
                self.lr.append(nn.Parameter(torch.ones(param.size()) * init_lr))
        elif self.lr_type == 'per_parameter_per_step':
            self.lr = nn.ModuleList([])
            for name, param in hypo_module.meta_named_parameters():
                self.lr.append(nn.ParameterList([nn.Parameter(torch.ones(param.size()) * init_lr)
                                                 for _ in range(num_meta_steps)]))

        param_count = 0
        for param in self.parameters():
            param_count += np.prod(param.shape)

        print(param_count)

    def _update_step(self, loss, param_dict, step):
        grads = torch.autograd.grad(loss, param_dict.values(),
                                    create_graph=False if self.first_order else True)
        params = OrderedDict()
        for i, ((name, param), grad) in enumerate(zip(param_dict.items(), grads)):
            if self.lr_type in ['static', 'global']:
                lr = self.lr
                params[name] = param - lr * grad
            elif self.lr_type in ['per_step']:
                lr = self.lr[step]
                params[name] = param - lr * grad
            elif self.lr_type in ['per_parameter']:
                lr = self.lr[i]
                params[name] = param - lr * grad
            elif self.lr_type in ['per_parameter_per_step']:
                lr = self.lr[i][step]
                params[name] = param - lr * grad
            else:
                raise NotImplementedError

        return params, grads

    def forward_with_params(self, query_x, fast_params, **kwargs):
        output = self.hypo_module(query_x, params=fast_params)
        return output

    def generate_params(self, context_dict):
        """Specializes the model"""
        x = context_dict.get('x').cuda()
        y = context_dict.get('y').cuda()

        meta_batch_size = x.shape[0]

        with torch.enable_grad():
            # First, replicate the initialization for each batch item.
            # This is the learned initialization, i.e., in the outer loop,
            # the gradients are backpropagated all the way into the 
            # "meta_named_parameters" of the hypo_module.
            fast_params = OrderedDict()
            meta_params = OrderedDict()
            for name, param in self.hypo_module.meta_named_parameters():
                fast_params[name] = param[None, ...].repeat((meta_batch_size,) + (1,) * len(param.shape))
                meta_params[name] = param[None, ...].repeat((meta_batch_size,) + (1,) * len(param.shape))

            prev_loss = 1e6
            intermed_predictions = []
            for j in range(self.num_meta_steps):
                # Using the current set of parameters, perform a forward pass with the context inputs.
                predictions = self.hypo_module(x, params=fast_params)

                # Compute the loss on the context labels.
                loss = self.loss(predictions, y)
                intermed_predictions.append(predictions)

                if loss > prev_loss:
                    print('inner lr too high?')
                
                # Using the computed loss, update the fast parameters.
                fast_params, grads = self._update_step(loss, fast_params, j)
                prev_loss = loss

        return fast_params, intermed_predictions, meta_params

    def forward(self, meta_batch, **kwargs):
        # The meta_batch conists of the "context" set (the observations we're conditioning on)
        # and the "query" inputs (the points where we want to evaluate the specialized model)
        t0 = time.time()
        context = meta_batch['context']
        query_x = meta_batch['query']['x'].cuda()
        
        t1 = time.time()
        # Specialize the model with the "generate_params" function.
        fast_params, intermed_predictions, meta_params = self.generate_params(context)
        t2 = time.time()
        pred_specialized_param = self.crossAttHypNet(x=lin2img(context['y']).cuda())#, meta_params=meta_params, gt_specialized_param=fast_params)
        t3 = time.time()
        pred_specialized_param_corrected = OrderedDict()
        
        crossAttHypNet_loss = 0.
        if True:
            l1, l2 = pred_specialized_param.keys(), fast_params.keys()
            for (name1, name2) in list(zip(l1, l2)):
                pred_specialized_param_corrected[name2] = meta_params[name2] + pred_specialized_param[name1]
#                 pred_specialized_param_corrected[name2] = pred_specialized_param[name1]
#                 crossAttHypNet_loss += ((pred_specialized_param_corrected[name2] - fast_params[name2].detach()) ** 2).mean()
                
        # Compute the final outputs. 
        model_output = self.hypo_module(query_x, params=fast_params)
        model_output_hypernet = self.hypo_module(query_x, params=pred_specialized_param_corrected)
        crossAttHypNet_loss += self.loss(model_output_hypernet, context['y'])
        out_dict = {'model_out':model_output, 'intermed_predictions':intermed_predictions, 
                    'crossAttHypNet_loss':crossAttHypNet_loss, 
                    'model_output_hypernet': model_output_hypernet}
        t4 = time.time()
#         print(f'meta: {t2-t1}  {t4-t0 - t3+t2}, hyper: {t3-t2} {t4-t0 - t2+t1}, all: {t4 - t0}')
        return out_dict, fast_params, meta_params, pred_specialized_param_corrected

<a id='section_1'></a>
## Learning to fit images in 3 gradient descent steps

By learning an initialization for SIREN, we may fit any image in as few as 3 gradient descent steps! 
This has also been noted by Tancik et al. in "Learned Initializations for Optimizing Coordinate-Based Neural Representations" (2020).

We'll demonstrate here with Cifar-10, but it works just as well with CelebA or imagenet - try it out yourself!

In [25]:
class CIFAR10():
    def __init__(self, train=True):
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        self.dataset = torchvision.datasets.CIFAR10(root='./data', train=train,
                                                download=True, transform=transform)
        
        self.length = len(self.dataset)
        self.meshgrid = get_mgrid(sidelen=32)
    
    def __len__(self):
        return self.length
        
    def __getitem__(self, item):
        img, _ = self.dataset[item]
        img_flat = img.permute(1,2,0).view(-1, 3)
        return {'context':{'x':self.meshgrid, 'y':img_flat}, 
                'query':{'x':self.meshgrid, 'y':img_flat}}


def lin2img(tensor):
    batch_size, num_samples, channels = tensor.shape
    sidelen = np.sqrt(num_samples).astype(int)
    return tensor.view(batch_size, sidelen, sidelen, channels).squeeze(-1)

    
def plot_sample_image(img_batch, ax):
    img = lin2img(img_batch)[0].detach().cpu().numpy()
    img += 1
    img /= 2.
    img = np.clip(img, 0., 1.)
#     ax.set_axis_off()
#     ax.imshow(img)
    return img


def dict_to_gpu(ob):
    if isinstance(ob, Mapping):
        return {k: dict_to_gpu(v) for k, v in ob.items()}
    else:
        return ob.cuda()    


# def dict_to_gpu(ob):
#     if isinstance(ob, Mapping):
#         return {k: dict_to_gpu(v) for k, v in ob.items()}
#     else:
#         return ob.cuda()    

Now, let's initialize our models and our dataset:

In [26]:
%load_ext autoreload
%autoreload 2

import meta_modules
img_siren = Siren(in_features=2, hidden_features=256, hidden_layers=3, out_features=3, outermost_linear=True)
img_siren.load_state_dict(torch.load('img_siren_1.pth'))
# crossAttHypNet = CrossAttentionHyperNet().cuda()
crossAttHypNet = meta_modules.ConvolutionalNeuralProcessImplicit2DHypernet(in_features=3,
                                                                    out_features=3,
                                                                    image_resolution=(32, 32))
crossAttHypNet.load_state_dict(torch.load('crossAttHypNet_1.pth'))
meta_siren = MAML(num_meta_steps=3, hypo_module=img_siren.cuda(), crossAttHypNet=crossAttHypNet.cuda(), 
                  loss=l2_loss, init_lr=1e-5, 
                  lr_type='per_parameter_per_step').cuda()
meta_siren.load_state_dict(torch.load('meta_siren_1.pth'))
meta_siren = meta_siren.cuda()
if True:
    del crossAttHypNet
    torch.cuda.empty_cache()
    crossAttHypNet = meta_modules.ConvolutionalNeuralProcessImplicit2DHypernet(in_features=3,
                                                                        out_features=3,
                                                                        image_resolution=(32, 32))
    meta_siren.crossAttHypNet = crossAttHypNet.cuda()
meta_siren.train()

dataset = CIFAR10()
dataloader = DataLoader(dataset, batch_size=16, num_workers=0, shuffle=True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
SingleBVPNet(
  (image_downsampling): ImageDownsampling()
  (net): FCBlock(
    (net): MetaSequential(
      (0): MetaSequential(
        (0): BatchLinear(in_features=2, out_features=256, bias=True)
        (1): Sine()
      )
      (1): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (2): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (3): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (4): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=3, bias=True)
      )
    )
  )
)
ConvolutionalNeuralProcessImplicit2DHypernet(
  (encoder): ConvImgEncoder(
    (conv_theta): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu): ReLU(inplace=True)
 

66906387
SingleBVPNet(
  (image_downsampling): ImageDownsampling()
  (net): FCBlock(
    (net): MetaSequential(
      (0): MetaSequential(
        (0): BatchLinear(in_features=2, out_features=256, bias=True)
        (1): Sine()
      )
      (1): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (2): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (3): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (4): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=3, bias=True)
      )
    )
  )
)
ConvolutionalNeuralProcessImplicit2DHypernet(
  (encoder): ConvImgEncoder(
    (conv_theta): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu): ReLU(inplace=True)
    (cnn): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1

Files already downloaded and verified


In [27]:
# crossAttHypNet

In [28]:
import cv2
img1 = cv2.imread('img1.bmp')
img2 = cv2.imread('img2.bmp')
psnr = cv2.PSNR(img1, img2)
psnr

nan

In [29]:
# !pip install piqa
from metrics import psnr, ssim_metric
from skimage.metrics import peak_signal_noise_ratio as psnr_sklearn

# print('PSNR:', psnr.psnr(x, y))
# print('SSIM:', ssim.ssim(x, y))

Let's train!

In [12]:
steps_til_summary = 1000

# optim = torch.optim.Adam(lr=5e-5, params=meta_siren.parameters())
optim = torch.optim.Adam(lr=5e-6, params=meta_siren.parameters())

hypernet_loss_multiplier = 1

psnr_list = []
ssim_list = []
for epoch in range(10):
#     if epoch < 10:
#         hypernet_loss_multiplier += 100

    for step, sample in enumerate(dataloader):
        sample = dict_to_gpu(sample)
        '''
        out_dict = {'model_out':model_output, 'intermed_predictions':intermed_predictions, 
                    'crossAttHypNet_loss':crossAttHypNet_loss, 
                    'model_output_hypernet': model_output_hypernet}

        return out_dict, fast_params, meta_params, pred_specialized_param
        '''
        model_output, fast_params, meta_params, pred_specialized_param = meta_siren(sample)    
        loss = ((model_output['model_out'] - sample['query']['y'])**2).mean() + hypernet_loss_multiplier * model_output['crossAttHypNet_loss']
        
#         for name in pred_specialized_param:
#             loss += 10*((pred_specialized_param[name] - fast_params[name].detach()) ** 2).mean()
        if False:
            pred_specialized_param_corrected = OrderedDict()
            l1, l2 = pred_specialized_param.keys(), fast_params.keys()
            for (name1, name2) in list(zip(l1, l2)):
                pred_specialized_param_corrected[name2] = meta_params[name2] + pred_specialized_param[name1]
#                 pred_specialized_param_corrected[name2] = pred_specialized_param[name1]
                loss += 1*((pred_specialized_param_corrected[name2] - fast_params[name2].detach()) ** 2).mean()
       
        
        if (step % steps_til_summary == 0) and (epoch % 1 == 0): 
            print("Epoch %d, Step %d,\tTotal loss: %0.6f,\tHypernet loss: %0.6f" % (epoch, step, loss, model_output['crossAttHypNet_loss']))
            print('\tPSNR:', np.mean(psnr_list), '\tSSIM:', np.mean(ssim_list))
            fig, axes = [], list(range(6))#plt.subplots(1,6, figsize=(36,6))
            ax_titles = ['Learned Initialization', 'Inner step 1 output', 
                        'Inner step 2 output', 'Inner step 3 output', 
                        'HyperNet output', ## added by me
                        'Ground Truth']
            images = []
            for i, inner_step_out in enumerate(model_output['intermed_predictions']):
                img = plot_sample_image(inner_step_out, ax=axes[i])
                images += [img]
#                 axes[i].set_title(ax_titles[i], fontsize=25)
            images += [plot_sample_image(model_output['model_out'], ax=axes[-3])]
#             axes[-3].set_title(ax_titles[-3], fontsize=25)

            if True:
                images += [plot_sample_image(model_output['model_output_hypernet'], ax=axes[-2])]
#                 axes[-2].set_title(ax_titles[-2], fontsize=25)

            img_ground_truth = plot_sample_image(sample['query']['y'], ax=axes[-1])
#             axes[-1].set_title(ax_titles[-1], fontsize=25)
            psnrs = [cv2.PSNR(img_ground_truth, img) for img in images]
            print(psnrs)
            plt.show()

        if True:
            x = lin2img(model_output['model_output_hypernet']).permute(0,3,1,2).contiguous()#.cpu().detach()
            x += 1.
            x /= 2.
            x = torch.clip(x, 0., 1.)
            y = lin2img(sample['query']['y']).permute(0,3,1,2).contiguous()#.cpu().detach()
            y += 1.
            y /= 2.
            y = torch.clip(y, 0., 1.)
    #         print(x.shape, lin2img(x).shape)
#             print(y.min(), y.max())
#             print(x.min(), x.max())

#             print('PSNR:', psnr(y, x).mean())
#             print('psnr_sklearn:', psnr_sklearn(y.numpy(), x.numpy(), data_range=1.))
#             print('SSIM:', ssim_metric(y, x))
            psnr_temp = psnr(y, x)
            ssim_temp = ssim_metric(y, x)
            psnr_list += psnr_temp.cpu().detach().numpy().tolist()
            ssim_list += ssim_temp.cpu().detach().numpy().tolist()
#             psnr_temp = psnr_temp.mean()
#             ssim_temp = ssim_temp.mean()
#             loss += 1 * (torch.exp(-1. * psnr_temp)).mean()
            loss += 1 * (1. - ssim_temp).mean()
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        del model_output, fast_params, meta_params, pred_specialized_param
        gc.collect()
        torch.cuda.empty_cache() 
        
print('PSNR:', np.mean(psnr_list))
print('SSIM:', np.mean(ssim_list))

l = len(psnr_list) // 10
print(len(psnr_list), l)
for i in range(10):
    print(f'\nepoch:{i}\nPSNR:', np.mean(psnr_list[l*i:l*(i+1)]))
    print('SSIM:', np.mean(ssim_list[l*i:l*(i+1)]))

Epoch 0, Step 0,	Total loss: 0.459648,	Hypernet loss: 0.459164
	PSNR: nan 	SSIM: nan
[61.970139418908076, 79.45781866874914, 83.65020168564416, 87.23265898416778, 57.74506572405127]


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 0, Step 1000,	Total loss: 0.080177,	Hypernet loss: 0.079696
	PSNR: 15.634923721522092 	SSIM: 0.46937411817233077
[60.61849277678819, 77.01214189453661, 82.43402461441967, 87.7616340670144, 64.78318890138061]
Epoch 0, Step 2000,	Total loss: 0.058338,	Hypernet loss: 0.057823
	PSNR: 16.820098195508123 	SSIM: 0.533869205305702
[65.81378262269466, 82.40283570054422, 86.52193422411605, 89.89350623160762, 73.36657019560946]
Epoch 0, Step 3000,	Total loss: 0.065868,	Hypernet loss: 0.065047
	PSNR: 17.483636867652336 	SSIM: 0.5716190961322282
[59.73142375891341, 76.45133269791023, 81.57609823712068, 86.34791100869846, 66.12086240606239]
Epoch 1, Step 0,	Total loss: 0.061886,	Hypernet loss: 0.061123
	PSNR: 17.547758929891586 	SSIM: 0.5754110272672027
[57.53321154439581, 76.95803386871577, 83.41530550078932, 89.28270045197345, 66.40715829087782]
Epoch 1, Step 1000,	Total loss: 0.042106,	Hypernet loss: 0.041530
	PSNR: 17.988852461431964 	SSIM: 0.6016455089130975
[62.60576955073834, 77.3470570

PSNR: 21.418693832255364
SSIM: 0.7793747039702162
500000 50000

epoch:0
PSNR: 17.547758929891586
SSIM: 0.5754110272672027

epoch:1
PSNR: 19.74231793706894
SSIM: 0.7037630205029249

epoch:2
PSNR: 20.656207956695557
SSIM: 0.7510890162551404

epoch:3
PSNR: 21.286576559066773
SSIM: 0.781410020403862

epoch:4
PSNR: 21.7465523620224
SSIM: 0.8021818370044231

epoch:5
PSNR: 22.119077660236357
SSIM: 0.8171990915524959

epoch:6
PSNR: 22.416906337623598
SSIM: 0.8285210457319021

epoch:7
PSNR: 22.67175439588547
SSIM: 0.837555394500494

epoch:8
PSNR: 22.90044386089325
SSIM: 0.8451297107315063

epoch:9
PSNR: 23.099342323169708
SSIM: 0.8514868757522106


In [13]:
torch.save(img_siren.state_dict(), 'img_siren_ssim_loss_1.pth')
torch.save(crossAttHypNet.state_dict(), 'crossAttHypNet_ssim_loss_1.pth')
torch.save(meta_siren.state_dict(), 'meta_siren_ssim_ssim_1.pth')
# model.load_state_dict(torch.load(PATH))

import json
json.dump({'psnr_list': psnr_list, 'ssim_list': ssim_list}, open('psnr_ssim_list_hypernet+ssim_loss.json', 'w'))

In [14]:
ssim_list[-20:]


[0.9344650506973267,
 0.9029501676559448,
 0.907914400100708,
 0.748037576675415,
 0.9377844333648682,
 0.8202025890350342,
 0.8790472745895386,
 0.8672150373458862,
 0.8413331508636475,
 0.8708105087280273,
 0.9003541469573975,
 0.8264485597610474,
 0.8508522510528564,
 0.75550377368927,
 0.8427578210830688,
 0.8951382637023926,
 0.868256688117981,
 0.8538165092468262,
 0.8058369159698486,
 0.8353908061981201]

In [15]:
l = len(psnr_list) // 10
print(len(psnr_list), l)
for i in range(10):
    print(f'\nepoch:{i}\nPSNR:', np.mean(psnr_list[l*i:l*(i+1)]))
    print('SSIM:', np.mean(ssim_list[l*i:l*(i+1)]))

500000 50000

epoch:0
PSNR: 17.547758929891586
SSIM: 0.5754110272672027

epoch:1
PSNR: 19.74231793706894
SSIM: 0.7037630205029249

epoch:2
PSNR: 20.656207956695557
SSIM: 0.7510890162551404

epoch:3
PSNR: 21.286576559066773
SSIM: 0.781410020403862

epoch:4
PSNR: 21.7465523620224
SSIM: 0.8021818370044231

epoch:5
PSNR: 22.119077660236357
SSIM: 0.8171990915524959

epoch:6
PSNR: 22.416906337623598
SSIM: 0.8285210457319021

epoch:7
PSNR: 22.67175439588547
SSIM: 0.837555394500494

epoch:8
PSNR: 22.90044386089325
SSIM: 0.8451297107315063

epoch:9
PSNR: 23.099342323169708
SSIM: 0.8514868757522106


In [1]:
# # fast_params['net.0.linear.weight'].requires_grad
# # fast_params.keys()
# # fast_params['net.1.linear.weight'].shape 
# # for key in fast_params:
# #     print(fast_params[key].shape)
# for key in meta_params:
#     print(key, ':\t', meta_params[key].shape, '\t', fast_params[key].shape)


As you can see, after a few hundred steps of training, we can fit any of the Cifar-10 images in only three gradient descent steps!

In [31]:
import pickle
steps_til_summary = 100

%load_ext autoreload
%autoreload 2

import meta_modules
img_siren = Siren(in_features=2, hidden_features=256, hidden_layers=3, out_features=3, outermost_linear=True)
img_siren.load_state_dict(torch.load('img_siren_ssim_loss_0.pth'))
# crossAttHypNet = CrossAttentionHyperNet().cuda()
crossAttHypNet = meta_modules.ConvolutionalNeuralProcessImplicit2DHypernet(in_features=3,
                                                                    out_features=3,
                                                                    image_resolution=(32, 32))
crossAttHypNet.load_state_dict(torch.load('crossAttHypNet_ssim_loss_0.pth'))
meta_siren = MAML(num_meta_steps=3, hypo_module=img_siren.cuda(), crossAttHypNet=crossAttHypNet.cuda(), 
                  loss=l2_loss, init_lr=1e-5, 
                  lr_type='per_parameter_per_step').cuda()
meta_siren.load_state_dict(torch.load('meta_siren_ssim_ssim_0.pth'))
meta_siren = meta_siren.cuda()
meta_siren.eval()

dataset = CIFAR10(train=False)
test_dataloader = DataLoader(dataset, batch_size=8, num_workers=0, shuffle=False)


# optim = torch.optim.Adam(lr=5e-5, params=meta_siren.parameters())
# optim = torch.optim.Adam(lr=5e-6, params=meta_siren.parameters())

hypernet_loss_multiplier = 1

psnr_list = []
ssim_list = []
for epoch in range(1):
#     if epoch < 10:
#         hypernet_loss_multiplier += 100

    for step, sample in enumerate(test_dataloader):
        sample = dict_to_gpu(sample)
        '''
        out_dict = {'model_out':model_output, 'intermed_predictions':intermed_predictions, 
                    'crossAttHypNet_loss':crossAttHypNet_loss, 
                    'model_output_hypernet': model_output_hypernet}

        return out_dict, fast_params, meta_params, pred_specialized_param
        '''
        t0 = time.time()
        model_output, fast_params, meta_params, pred_specialized_param = meta_siren(sample)    
        loss = ((model_output['model_out'] - sample['query']['y'])**2).mean() + hypernet_loss_multiplier * model_output['crossAttHypNet_loss']
        print(time.time() - t0)
#         for name in pred_specialized_param:
#             loss += 10*((pred_specialized_param[name] - fast_params[name].detach()) ** 2).mean()
        if False:
            pred_specialized_param_corrected = OrderedDict()
            l1, l2 = pred_specialized_param.keys(), fast_params.keys()
            for (name1, name2) in list(zip(l1, l2)):
                pred_specialized_param_corrected[name2] = meta_params[name2] + pred_specialized_param[name1]
#                 pred_specialized_param_corrected[name2] = pred_specialized_param[name1]
                loss += 1*((pred_specialized_param_corrected[name2] - fast_params[name2].detach()) ** 2).mean()
       
        
        if (step % steps_til_summary == 0) and (epoch % 1 == 0): 
            print("Epoch %d, Step %d,\tTotal loss: %0.6f,\tHypernet loss: %0.6f" % (epoch, step, loss, model_output['crossAttHypNet_loss']))
            print('\tPSNR:', np.mean(psnr_list), '\tSSIM:', np.mean(ssim_list))
            fig, axes = [], list(range(6))#plt.subplots(1,6, figsize=(36,6))
            ax_titles = ['Learned Initialization', 'Inner step 1 output', 
                        'Inner step 2 output', 'Inner step 3 output', 
                        'HyperNet output', ## added by me
                        'Ground Truth']
            images = []
            for i, inner_step_out in enumerate(model_output['intermed_predictions']):
                img = plot_sample_image(inner_step_out, ax=axes[i])
                images += [img]
#                 axes[i].set_title(ax_titles[i], fontsize=25)
            images += [plot_sample_image(model_output['model_out'], ax=axes[-3])]
#             axes[-3].set_title(ax_titles[-3], fontsize=25)

            if True:
                images += [plot_sample_image(model_output['model_output_hypernet'], ax=axes[-2])]
#                 axes[-2].set_title(ax_titles[-2], fontsize=25)

            img_ground_truth = plot_sample_image(sample['query']['y'], ax=axes[-1])
#             axes[-1].set_title(ax_titles[-1], fontsize=25)
            psnrs = [cv2.PSNR(img_ground_truth, img) for img in images]
#             for img_idx, img in enumerate(images):
#                 cv2.imwrite(f'./images/img_{step}_{ax_titles[img_idx]}.jpg', (img*255.).astype(int))
#                 cv2.imshow(f'./images/img_{step}_{ax_titles[img_idx]}.jpg', img)
#                 cv2.waitKey()
            
            images += [img_ground_truth]
            pickle.dump({step: images}, open(f'./images/img_{step}', 'wb'))
            print(psnrs)
            plt.show()
#             break
            
#         optim.zero_grad()
#         loss.backward()
#         optim.step()
        if True:
            x = lin2img(model_output['model_output_hypernet']).permute(0,3,1,2).contiguous().cpu().detach()
            x += 1.
            x /= 2.
            x = torch.clip(x, 0., 1.)
            y = lin2img(sample['query']['y']).permute(0,3,1,2).contiguous().cpu().detach()
            y += 1.
            y /= 2.
            y = torch.clip(y, 0., 1.)
    #         print(x.shape, lin2img(x).shape)
#             print(y.min(), y.max())
#             print(x.min(), x.max())

#             print('PSNR:', psnr(y, x).mean())
#             print('psnr_sklearn:', psnr_sklearn(y.numpy(), x.numpy(), data_range=1.))
#             print('SSIM:', ssim_metric(y, x))
            psnr_list += psnr(y, x).cpu().detach().numpy().tolist()
            ssim_list += ssim_metric(y, x).cpu().detach().numpy().tolist()
        del model_output, fast_params, meta_params, pred_specialized_param
        gc.collect()
        torch.cuda.empty_cache() 
        
print('PSNR:', np.mean(psnr_list))
print('SSIM:', np.mean(ssim_list))

l = len(psnr_list) // 1
print(len(psnr_list), l)
for i in range(1):
    print(f'\nepoch:{i}\nPSNR:', np.mean(psnr_list[l*i:l*(i+1)]))
    print('SSIM:', np.mean(ssim_list[l*i:l*(i+1)]))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
SingleBVPNet(
  (image_downsampling): ImageDownsampling()
  (net): FCBlock(
    (net): MetaSequential(
      (0): MetaSequential(
        (0): BatchLinear(in_features=2, out_features=256, bias=True)
        (1): Sine()
      )
      (1): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (2): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (3): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=256, bias=True)
        (1): Sine()
      )
      (4): MetaSequential(
        (0): BatchLinear(in_features=256, out_features=3, bias=True)
      )
    )
  )
)
ConvolutionalNeuralProcessImplicit2DHypernet(
  (encoder): ConvImgEncoder(
    (conv_theta): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu): ReLU(inplace=True)
 

66906387
Files already downloaded and verified
meta: 0.012998580932617188  0.01599860191345215, hyper: 0.006999969482421875 0.009999990463256836, all: 0.022998571395874023
0.023998737335205078
Epoch 0, Step 0,	Total loss: 0.381722,	Hypernet loss: 0.025654
	PSNR: nan 	SSIM: nan
[60.92382756007669, 64.9129061220766, 66.835854163684, 59.38227988700565, 68.22732945132158]
meta: 0.013069629669189453  0.015069961547851562, hyper: 0.006064653396606445 0.008064985275268555, all: 0.021134614944458008
0.02213430404663086
meta: 0.012187480926513672  0.013185739517211914, hyper: 0.007002115249633789 0.008000373840332031, all: 0.020187854766845703
0.021189451217651367
meta: 0.011001348495483398  0.013000011444091797, hyper: 0.006997346878051758 0.008996009826660156, all: 0.019997358322143555
0.020998001098632812
meta: 0.012048721313476562  0.014049768447875977, hyper: 0.005075216293334961 0.007076263427734375, all: 0.019124984741210938
0.020125150680541992
meta: 0.011075735092163086  0.012075662612

meta: 0.011101722717285156  0.012101411819458008, hyper: 0.006001949310302734 0.007001638412475586, all: 0.018103361129760742
0.0191037654876709
meta: 0.011313438415527344  0.012313127517700195, hyper: 0.00599980354309082 0.006999492645263672, all: 0.018312931060791016
0.019313812255859375
meta: 0.011151552200317383  0.013150453567504883, hyper: 0.006002187728881836 0.008001089096069336, all: 0.01915264129638672
0.02015209197998047
meta: 0.01110982894897461  0.013170957565307617, hyper: 0.006000995635986328 0.008062124252319336, all: 0.019171953201293945
0.020175933837890625
meta: 0.01102900505065918  0.013030290603637695, hyper: 0.006047964096069336 0.008049249649047852, all: 0.01907825469970703
0.020074129104614258
meta: 0.011067628860473633  0.01312112808227539, hyper: 0.005001068115234375 0.007054567337036133, all: 0.018122196197509766
0.019122600555419922
meta: 0.010067462921142578  0.011065483093261719, hyper: 0.007002592086791992 0.008000612258911133, all: 0.01806807518005371
0.

0.01812291145324707
meta: 0.011090993881225586  0.01309061050415039, hyper: 0.006000518798828125 0.00800013542175293, all: 0.019091129302978516
0.019091129302978516
meta: 0.011063575744628906  0.01206350326538086, hyper: 0.0060405731201171875 0.007040500640869141, all: 0.018104076385498047
0.0191042423248291
meta: 0.011108160018920898  0.012108564376831055, hyper: 0.006045341491699219 0.007045745849609375, all: 0.018153905868530273
0.01915287971496582
meta: 0.011105775833129883  0.013105630874633789, hyper: 0.006063699722290039 0.008063554763793945, all: 0.019169330596923828
0.02017068862915039
meta: 0.011082887649536133  0.014082670211791992, hyper: 0.006056547164916992 0.009056329727172852, all: 0.020139217376708984
0.021138429641723633
meta: 0.01132655143737793  0.013325691223144531, hyper: 0.006036996841430664 0.008036136627197266, all: 0.019362688064575195
0.02036428451538086
meta: 0.011066675186157227  0.013066768646240234, hyper: 0.0050008296966552734 0.007000923156738281, all: 

0.02008819580078125
meta: 0.011064529418945312  0.012064456939697266, hyper: 0.0060422420501708984 0.0070421695709228516, all: 0.018106698989868164
0.01910686492919922
meta: 0.01110386848449707  0.013105630874633789, hyper: 0.0050506591796875 0.007052421569824219, all: 0.01815629005432129
0.019156455993652344
meta: 0.011081695556640625  0.012083053588867188, hyper: 0.006070852279663086 0.0070722103118896484, all: 0.018153905868530273
0.020154237747192383
meta: 0.011092901229858398  0.012104034423828125, hyper: 0.005300045013427734 0.006311178207397461, all: 0.01740407943725586
0.018404006958007812
meta: 0.011076927185058594  0.012102842330932617, hyper: 0.006121397018432617 0.007147312164306641, all: 0.018224239349365234
0.020224571228027344
meta: 0.010998725891113281  0.011998653411865234, hyper: 0.006000995635986328 0.007000923156738281, all: 0.017999649047851562
0.018999814987182617
meta: 0.011088132858276367  0.01308894157409668, hyper: 0.004998922348022461 0.0069997310638427734, a

0.019114255905151367
meta: 0.012283563613891602  0.014284610748291016, hyper: 0.005049943923950195 0.007050991058349609, all: 0.01933455467224121
0.020333051681518555
meta: 0.010999441146850586  0.012999296188354492, hyper: 0.0059986114501953125 0.007998466491699219, all: 0.018997907638549805
0.018997907638549805
meta: 0.011095762252807617  0.01309514045715332, hyper: 0.0050008296966552734 0.0070002079010009766, all: 0.018095970153808594
0.018095970153808594
meta: 0.01100015640258789  0.012998580932617188, hyper: 0.006001710891723633 0.00800013542175293, all: 0.01900029182434082
0.020000219345092773
meta: 0.011008262634277344  0.01300811767578125, hyper: 0.005991458892822266 0.007991313934326172, all: 0.018999576568603516
0.018999576568603516
meta: 0.011000633239746094  0.012000322341918945, hyper: 0.005999565124511719 0.00699925422668457, all: 0.017999887466430664
0.018999576568603516
meta: 0.011574745178222656  0.013631582260131836, hyper: 0.005000114440917969 0.0070569515228271484, 

KeyboardInterrupt: 