In [2]:
import math
import copy
import scipy.special
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.insert(0, "/home/rob/Git/meta-fsl-nas/metanas")

import metanas.utils.genotypes as gt

In [459]:
class Model:
    """Model for a single DARTS cell"""
    def __init__(self, n_ops=7, n_nodes=3):
        self.n_ops = len(gt.PRIMITIVES_FEWSHOT)
        self.n_nodes = 3

        
        self.encoded_states = []
        self.states = []
        self.topk = []
        
        self.alphas = []
        self.norm_alphas = []

        # Adjacency matrix
        self.A = np.ones((self.n_nodes+2, self.n_nodes+2)) - np.eye(self.n_nodes+2)

        # Remove the 2 input nodes from A
        self.A[0, 1] = 0
        self.A[1, 0] = 0

        for i in range(n_nodes):
            a = nn.Parameter(
                1e-3 * torch.randn(i + 2, n_ops))
            self.alphas.append(a)
            self.norm_alphas.append(F.softmax(a, dim=-1))

    def print_topk(self):
        for i, edges in enumerate(self.norm_alphas):
            # edges: Tensor(n_edges, n_ops)
            edge_max, _ = torch.topk(edges[:, :], 1)
            # selecting the top-k input nodes, k=2
            _, topk_edge_indices = torch.topk(edge_max.view(-1), k=2)
            
            print(topk_edge_indices)
    
    def parse(self, alpha, k=2, primitives=gt.PRIMITIVES_FEWSHOT):
        gene = []
        for edges in alpha:
            edge_max, primitive_indices = torch.topk(
                edges[:, :], 1
            )
            
#             print(edges[:,:], primitive_indices, edge_max, "\n")

            topk_edge_values, topk_edge_indices = torch.topk(
                edge_max.view(-1), k)

#             print(topk_edge_values, topk_edge_indices, "\n")
            
            node_gene = []
            for edge_idx in topk_edge_indices:
                prim_idx = primitive_indices[edge_idx]
                prim = primitives[prim_idx]
                node_gene.append((prim, edge_idx.item()))

            gene.append(node_gene)
        return gene
    
    def calculate_states(self):
        s_idx = 0
        
#         print(self.states)
#         if current_states is not None:
        prev_topk = copy.deepcopy(self.topk)
        prev_edge = copy.deepcopy(self.encoded_states)
        
        self.topk = []
        self.states = []
        self.encoded_states = []
        self.edge_to_index = {}
        self.edge_to_alpha = {}

        for i, edges in enumerate(self.norm_alphas):
            # edges: Tensor(n_edges, n_ops)
            edge_max, edge_idx = torch.topk(edges[:, :], 1)
            
            # selecting the top-k input nodes, k=2
            _, topk_edge_indices = torch.topk(edge_max.view(-1), k=2)

            edge_one_hot = torch.zeros_like(edges[:,:])
            
            
            for hot_e, op in zip(edge_one_hot, edge_idx):
                hot_e[op.item()] = 1

            for j, edge in enumerate(edge_one_hot):
                self.edge_to_index[(j, i+2)] = s_idx
                self.edge_to_index[(i+2, j)] = s_idx+1

                self.edge_to_alpha[(j, i+2)] = (i, j)
                self.edge_to_alpha[(i+2, j)] = (i, j)

                self.encoded_states.append(edge.numpy())
                
                # For undirected edge we add the edge twice
                self.states.append([
#                         f"from:{j} to:{i+2}",
                        int(j in topk_edge_indices)])
                
                self.topk.append(
#                         f"from:{j} to:{i+2}",
                        [int(j in topk_edge_indices)])

#                 self.states.append((
#                         (f"from:{i+2}",
#                         f"to:{j}"),
#                         [int(j in topk_edge_indices)]))
                s_idx += 2
    
        d = {'prev_topk': np.array(prev_topk),
             'prev_edges': np.array(prev_edge)}
    
        self.encoded_states = np.array(self.encoded_states)
#         change = (np.array(prev_topk) < np.array(self.topk))
        return d, self.states
    
    def _inverse_softmax(self, x, C):
        return torch.log(x) + C
    
    def increase_op(self, cur_node, next_node, op_idx, prob=0.7, n_ops=7):
#         t_max = 5.0
#         t_min = 0.1
#         max_step = 6
#         curr_step = 1
#         # Temperature
#         temp = t_max - curr_step * (t_max - t_min)/max_step-1

        C = math.log(10.)

        row_idx, edge_idx = self.edge_to_alpha[(cur_node, next_node)]
        
        # Set short-hands
        curr_op = self.norm_alphas[row_idx][edge_idx][op_idx]
        curr_edge = self.norm_alphas[row_idx][edge_idx]
        
                # Allow for increasing to 0.99
        if curr_op + prob > 1.0:
            surplus = curr_op + prob - 0.99
            prob -= surplus

        if curr_op + prob < 1.0:
            # Increase chosen op
            with torch.no_grad():
                curr_op += prob

            # Prevent 0.00 normalized alpha values, resulting in
            # -inf
            with torch.no_grad():
                curr_edge += 0.01

            # Set the meta-model, update the env state in
            # self.update_states()
            with torch.no_grad():
                self.alphas[
                    row_idx][edge_idx] = self._inverse_softmax(
                    curr_edge, C)
        
        # /temp
        self.norm_alphas = [
            F.softmax(alpha, dim=-1).detach().cpu()
            for alpha in self.alphas]
    
    def decrease_op(self, cur_node, next_node, op_idx, prob=0.7, n_ops=7):
        C = math.log(10.)

        row_idx, edge_idx = self.edge_to_alpha[(cur_node, next_node)]
        
        # Set short-hands
        curr_op = self.norm_alphas[row_idx][edge_idx][op_idx]
        curr_edge = self.norm_alphas[row_idx][edge_idx]
        
        # Allow for increasing to 0.99
        if curr_op - prob < 0.0:
            surplus = prob - curr_op + 0.01
            print(surplus)
            prob -= surplus
            print(prob)

        if curr_op - prob > 0.0:
            # Increase chosen op
            with torch.no_grad():
                curr_op -= prob
                
            # Prevent 0.00 normalized alpha values, resulting in
            # -inf
            with torch.no_grad():
                curr_edge += 0.01
            
            # Set the meta-model, update the env state in
            # self.update_states()
            with torch.no_grad():
                self.alphas[
                    row_idx][edge_idx] = self._inverse_softmax(
                    curr_edge, C)
            
        self.norm_alphas = [
            F.softmax(alpha, dim=-1).detach().cpu()
            for alpha in self.alphas]

orig tensor([ 1.1632, -0.5108, -1.6094, -0.9163,  0.8755,  1.3083, -2.3026],
       grad_fn=<SelectBackward>)
norm tensor([0.3019, 0.0566, 0.0189, 0.0377, 0.2264, 0.3491, 0.0094],
       grad_fn=<SelectBackward>) tensor(1.0000, grad_fn=<SumBackward0>)


In [None]:
model = Model()

print("primitives:", gt.PRIMITIVES_FEWSHOT, "\n")
model.print_topk(), model.parse(model.norm_alphas, k=2)

_, b = model.calculate_states()
print("init:", b)

model.increase_op(1, 2, 5)
model.increase_op(1, 3, 3)

d, _ = model.calculate_states()

In [None]:
import os
import time
import copy
import glob
import shelve

import igraph as ig
from igraph import Graph
from PIL import Image

In [None]:
def generate_graph_path(path, last_steps=None, paths_left=5):
    d = shelve.open(path)
    walks = sum(d.values(), [])
    d.close()
    
    if last_steps is not None:
        walks = walks[:last_steps]
        
    # TODO: Starting path might be variable
    path = [(0,2)]
    weights = [1]

    walks_temp = []
    walks_curr = copy.deepcopy(walks)
    max_k = len(walks)

    for i in range(max_k):
        edge_dict = {}
        walks_temp = []

        for j, walk in enumerate(walks_curr):
            # Check if current walk is long enough
            if i >= len(walk):
                continue
            else:
                # Current step
                w = walk[i]

                if w in edge_dict:
                    edge_dict[w] += 1
                else:
                    edge_dict[w] = 1

                walks_temp.append(walk)

        # Stop if the path ended or,
        if len(edge_dict) == 0:
            break
        # Or if only 5 walks are left 
        if sum([v for v in edge_dict.values()]) < paths_left:
            break

        # Step with highest count
        max_edge = max(edge_dict, key=edge_dict.get)
        path.append(max_edge)
        weights.append(edge_dict[max_edge]/(sum(edge_dict.values())))

        for walk in walks_temp:
            if walk[i] != max_edge:
                walks_temp.remove(walk)
        walks_curr = copy.deepcopy(walks_temp)
        
        return path, weights

In [None]:
def generate_gif(path, weights, save_paths, format_path):
    for f in save_paths:
        os.remove(f)

    edges = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
    edge_color = ["gray"]*len(edges)

    for i, (edge, weight) in enumerate(zip(path, weights)):
        edges = [(0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]
        g = Graph(edges)

        if edge in edges:
            index = edges.index(edge)
        else:
            index = edges.index((edge[1], edge[0]))

        edge_color[index] = 'red'

        lb = [""]*len(edges)
        lb[index] = f"step {i}: {edge[0]} -> {edge[1]}, weight: {weight:.2f}"

        # 5 Nodes
        g.vs["label"] = ["0", "1", "2", "3", "4"]
        g.vs["input"] = [True, True, False, False, False]
        g.es["color"] = edge_color
        g.es["label"] = lb

        ig.plot(
            g, 
            vertex_size=40, 
            edge_width=[3],
            vertex_color=['yellow', 'yellow', 'blue', 'blue', 'blue'],
            target=format_path.format(i),
            bbox=(800, 800),
            margin=200
        )
        edge_color[index] = 'purple'

    frames = []
    for img in sorted(glob.glob(save_paths), key=os.path.getmtime):
        frames.append(Image.open(img))

    frames[0].save('graph_walk.gif', format='GIF', append_images=frames[1:],
        save_all=True, duration=1800, loop=0)

In [None]:
path = "/home/rob/Git/meta-fsl-nas/metanas/results/triplemnist/ppo_metad2a_environment_1/seed_2/graph_walk.shlv"
save_paths = glob.glob("/home/rob/Git/meta-fsl-nas/notebooks/path/*.png")
format_path = "/home/rob/Git/meta-fsl-nas/notebooks/path/{0}.png"

path, weights = generate_graph_path(path)
generate_gif(path, weights, save_paths, format_path)

In [None]:
import cv2 as cv

In [None]:
def calc_avg_mean_std_dataset(path, img_names, img_root):
    mean_sum = np.array([0., 0., 0.])
    std_sum = np.array([0., 0., 0.])
    
    n_images = 0
    for file in list(glob.glob()):
        img = cv2.imread(file)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img/255
        mean, std = cv2.meanStdDev(img)
        
        mean_sum += np.squeeze(mean)
        std_sum += np.squeeze(std)
        n_images += 1
    return (mean_sum / n_images, std_sum / n_images)

path = '/home/rob/Git/meta-fsl-nas/data/triplemnist/triple_mnist_seed_123_image_size_84_84/*/*/*.png'

mean, std = calc_avg_mean_std_dataset(path, img_names, img_root)
mean, std