In [1]:
import os
import sys
import time
import glob
import numpy as np
import torch
import logging
import argparse
from IPython import embed

from models.search.darts.visualize import plot

In [2]:
def parse_args():
    parser = argparse.ArgumentParser(description='Modality optimization.')

    parser.add_argument('--seed', type=int, default=2, help='random seed')

    parser.add_argument('--checkpointdir', type=str, help='output base dir',
                        default='checkpoints/ntu')
    parser.add_argument('--datadir', type=str, help='data directory',
                        default='BM-NAS_dataset/NTU/')

    parser.add_argument('--ske_cp', type=str, help='Skeleton net checkpoint (assuming is contained in checkpointdir)',
                        default='skeleton_32frames_85.24.checkpoint')
    parser.add_argument('--rgb_cp', type=str, help='RGB net checkpoint (assuming is contained in checkpointdir)',
                        default='rgb_8frames_83.91.checkpoint')

    # args for darts
    parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
    parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
    

    parser.add_argument('--weight_decay', type=float, default=1e-4, help='weight decay')
    parser.add_argument('--num_input_nodes', type=int, help='cell input', default=8)
    parser.add_argument('--num_keep_edges', type=int, help='cell step connect', default=2)
    parser.add_argument('--multiplier', type=int, help='cell output concat', default=4)
    parser.add_argument('--steps', type=int, help='cell steps', default=4)
    
    parser.add_argument('--node_multiplier', type=int, help='inner node output concat', default=1)
    parser.add_argument('--node_steps', type=int, help='inner node steps', default=2)
    
    # for darts operations and inner representation size
    parser.add_argument('--C', type=int, help='channels for conv layer', default=256)
    parser.add_argument('--L', type=int, help='length after conv and pool', default=8)
    # parser.add_argument('--num_heads', type=int, help='attention heads number', default=2)
    parser.add_argument('--batchsize', type=int, help='batch size', default=96)
    parser.add_argument('--parallel', help='Use several GPUs', action='store_true', default=False)
    parser.add_argument('--modality', type=str, help='', default='both')


    parser.add_argument('--small_dataset', action='store_true', default=False, help='dataset scale')

    parser.add_argument('--num_outputs', type=int, help='output dimension', default=60)
    parser.add_argument('--epochs', type=int, help='training epochs', default=80)
    parser.add_argument('--eta_max', type=float, help='eta max', default=3e-4)
    parser.add_argument('--eta_min', type=float, help='eta min', default=1e-6)
    parser.add_argument('--Ti', type=int, help='epochs Ti', default=1)
    parser.add_argument('--Tm', type=int, help='epochs multiplier Tm', default=2)
    parser.add_argument('--num_workers', type=int, help='Dataloader CPUS', default=8)

    parser.add_argument("--drpt", action="store", default=0.2, dest="drpt", type=float, help="dropout")
    parser.add_argument('--save', type=str, default='EXP', help='experiment name')

    return parser.parse_args("")

In [3]:
args = parse_args()

In [4]:
args

Namespace(seed=2, checkpointdir='checkpoints/ntu', datadir='BM-NAS_dataset/NTU/', ske_cp='skeleton_32frames_85.24.checkpoint', rgb_cp='rgb_8frames_83.91.checkpoint', arch_learning_rate=0.0003, arch_weight_decay=0.001, weight_decay=0.0001, num_input_nodes=8, num_keep_edges=2, multiplier=4, steps=4, node_multiplier=1, node_steps=2, C=256, L=8, batchsize=96, parallel=False, modality='both', small_dataset=False, num_outputs=60, epochs=80, eta_max=0.0003, eta_min=1e-06, Ti=1, Tm=2, num_workers=8, drpt=0.2, save='EXP')

In [5]:
task = 'NTU'

In [6]:
from collections import namedtuple

Genotype = namedtuple('Genotype', 'edges steps concat')
StepGenotype = namedtuple('StepGenotype', 'inner_edges inner_steps inner_concat')

In [7]:
genotype = Genotype(edges=[('skip', 2), ('skip', 7), ('skip', 2), ('skip', 3)], steps=[StepGenotype(inner_edges=[('skip', 0), ('skip', 1), ('skip', 2), ('skip', 0)], inner_steps=['LinearGLU', 'LinearGLU'], inner_concat=[2, 3]), StepGenotype(inner_edges=[('skip', 0), ('skip', 1), ('skip', 2), ('skip', 0)], inner_steps=['ScaleDotAttn', 'ScaleDotAttn'], inner_concat=[2, 3])], concat=[8, 9])

In [8]:
file_name = 'structure_vis_example'

In [9]:
plot(genotype, file_name, args, task)