In [1]:
import json
import numpy as np
# import torch

import sys
sys.path.append('..')

from typing import List, Tuple

from models import ops

In [2]:
# channels = 16
# o = ops.SepConv(channels, channels, 3, 1, 1, affine=False, num_domains=1)
# o = ops.SepConv(channels, channels, 5, 1, 2, affine=False, num_domains=1)
# o = ops.DilConv(channels, channels, 3, 1, 2, 2, affine=False, num_domains=1)
# o = ops.DilConv(channels, channels, 5, 1, 4, 2, affine=False, num_domains=1)

# sum([p.numel() for p in o.parameters()])

In [3]:
ops_params = {
    'maxpool': 0,
    'avgpool': 0,
    'sepconv3x3': 800,
    'sepconv5x5': 1312,
    'dilconv3x3': 400,
    'dilconv5x5': 656,
    'skipconnect': 0,
}

In [4]:
arch_path = 'architecture_0.json'
arch = json.loads(open(arch_path).read())
arch[0]

{'reduce_n2_p0': 'sepconv5x5',
 'reduce_n2_p1': 'dilconv3x3',
 'reduce_n2_switch': [1, 0],
 'reduce_n3_p0': 'sepconv3x3',
 'reduce_n3_p1': 'sepconv3x3',
 'reduce_n3_p2': 'sepconv3x3',
 'reduce_n3_switch': [0, 2],
 'reduce_n4_p0': 'dilconv3x3',
 'reduce_n4_p1': 'skipconnect',
 'reduce_n4_p2': 'sepconv5x5',
 'reduce_n4_p3': 'dilconv5x5',
 'reduce_n4_switch': [1, 2],
 'reduce_n5_p0': 'skipconnect',
 'reduce_n5_p1': 'avgpool',
 'reduce_n5_p2': 'dilconv5x5',
 'reduce_n5_p3': 'dilconv5x5',
 'reduce_n5_p4': 'maxpool',
 'reduce_n5_switch': [0, 2]}

In [5]:
def parse_single(arch) -> List[Tuple[int, int, str]]:
    n_nodes = max([int(key.split('_')[1][1:]) for key in arch])
    edges = [] # (from, to, op)
    for key, value in arch.items():
        if 'switch' in key:
            continue
        n, p = list(map(lambda s: int(s[1:]), key.split('_')[1:]))
        edges.append((p, n, value))
    # filter edges
    edges_to_select = []
    for key, value in arch.items():
        if not 'switch' in key:
            continue
        n = int(key.split('_')[1][1:])
        edges_to_select.extend([(p, n) for p in value])
    edges = [e for e in edges if e[:-1] in edges_to_select]
    return edges

In [6]:
parse_single(arch[0])

[(0, 2, 'sepconv5x5'),
 (1, 2, 'dilconv3x3'),
 (0, 3, 'sepconv3x3'),
 (2, 3, 'sepconv3x3'),
 (1, 4, 'skipconnect'),
 (2, 4, 'sepconv5x5'),
 (0, 5, 'skipconnect'),
 (2, 5, 'dilconv5x5')]

In [7]:
def merge_archs(archs):
    return list(sorted(list(set(e for a in archs for e in parse_single(a))),
                       key=lambda x: x[1]))
    

In [8]:
merge_archs(arch)

[(0, 2, 'skipconnect'),
 (0, 2, 'sepconv5x5'),
 (1, 2, 'dilconv3x3'),
 (1, 2, 'avgpool'),
 (0, 2, 'maxpool'),
 (1, 2, 'sepconv5x5'),
 (0, 3, 'sepconv3x3'),
 (1, 3, 'dilconv5x5'),
 (2, 3, 'sepconv3x3'),
 (2, 4, 'sepconv5x5'),
 (0, 4, 'maxpool'),
 (1, 4, 'skipconnect'),
 (2, 4, 'sepconv3x3'),
 (3, 4, 'maxpool'),
 (2, 4, 'maxpool'),
 (2, 5, 'maxpool'),
 (0, 5, 'dilconv5x5'),
 (2, 5, 'dilconv5x5'),
 (0, 5, 'skipconnect'),
 (4, 5, 'avgpool'),
 (4, 5, 'sepconv5x5')]

In [9]:
merge_archs([arch[0], arch[0]])

[(1, 2, 'dilconv3x3'),
 (0, 2, 'sepconv5x5'),
 (0, 3, 'sepconv3x3'),
 (2, 3, 'sepconv3x3'),
 (2, 4, 'sepconv5x5'),
 (1, 4, 'skipconnect'),
 (2, 5, 'dilconv5x5'),
 (0, 5, 'skipconnect')]

### Main function

In [10]:
def get_arch_params(arch) -> int:
    merged = merge_archs(arch)
    return sum([ops_params[e[-1]] for e in merged])

In [11]:
get_arch_params([arch[0], arch[0]]), get_arch_params([arch[0]]), get_arch_params(arch)

(5280, 5280, 10016)