This demo provides examples of how to load Classification Compute Graph networks and convert them to Human Pose Estimation and Semantic Segmentation networks for use with downstream tasks. 

In [None]:
# First, lets load an example CG. We'll use two examples, the base 'resnet50.pkl' and AutoGO ResNet 50 Arch 1:
import torch as t
import pickle
import os  # This silences some of the spammy TF messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# NOTE: This gives 'ModuleNotFoundError'
with open("../../architectures/samples/resnet50/resnet50.pkl", "rb") as f:
    resnet50 = pickle.load(f)

That failed because we were not in the top-level of the repository. Loading the CG requires loading its class definition from 'model_src/comp_graph/tf_comp_graph.py', but python doesn't know where 'model_src' is since we're already in the same directory as 'tf_comp_graph.py'.
Remedying this problem is why this tutorial is in this folder - not the top-level.

In [None]:
import sys
sys.path.append("../../")  # Append top-level to path

with open("../../architectures/samples/resnet50/resnet50.pkl", "rb") as f:
    resnet50 = pickle.load(f)

with open("../../architectures/samples/resnet50/autogo_arch1.pkl", "rb") as f:
    arch1 = pickle.load(f)

# Take a look at number of nodes/edges
print(resnet50)
print(arch1)

Both of these networks have the same number of nodes and edges, but different str_id values (hashes generated from the nodes).
If you run them through model_src.comp_graph.tf_comp_graph_utils.compute_cg_flops you will see that arch1 is slightly bigger, per the paper.

Anyway, let's make networks pytorch networks from CGs for Classification, Human Pose Estimation and Semantic Segmentation!

In [None]:
from task_networks import cg_class, cg_hpe, cg_seg

class_input = t.rand(1, 3, 224, 224) # Sample ImageNet input

# These functions take filenames as input, not the CG objects. So we'll assign these variables to make things a little easier
r50_pkl = "../../architectures/samples/resnet50/resnet50.pkl"
arch1_pkl = "../../architectures/samples/resnet50/autogo_arch1.pkl"

r50_class = cg_class(r50_pkl, name="ResNet50-Class", net=True) # Net=False will return the CG.
print(f"ResNet-50 classification network output size: {r50_class(class_input).shape}")

arch1_class = cg_class(arch1_pkl, name="Arch1-Class", net=True)
print(f"AutoGO ResNet50 Arch 1 output size: {arch1_class(class_input).shape}")

Okay, that works. Now, we need to be a bit more careful when converting arch 1 for HPE/Segmentation, though.

In [None]:
hpe_input = t.rand(1, 3, 256, 256)  # MPII Image size in Zhou et al. 2017 (https://arxiv.org/abs/1704.02447)
print("Human Pose Estimation")

# NOTE: cg_hpe and cg_seg return the network and some information on the # of output channels
r50_hpe, r50_out_c = cg_hpe(r50_pkl, name="ResNet50-HPE", net=True)
print("ResNet50 HPE number output channels", r50_out_c)
print(f"ResNet50 HPE network output size: {r50_hpe(hpe_input).shape}")  # Expect 1, 2048, 8, 8 size

# NOTE the value for config_name - required since Arch1 is mutated from ResNet50
arch1_hpe, arch1_out_c = cg_hpe(arch1_pkl, config_name="resnet50.pkl", net=True)
print("AutoGO ResNet50 Arch 1 HPE number output channels", arch1_out_c)
print(f"AutoGO ResNet50 Arch 1 HPE output size: {arch1_hpe(hpe_input).shape}")  # Expect 1, 2048, 8, 8 size

Semantic Segmentation

In [None]:
seg_input = t.rand(1, 3, 713, 713)  # Typical for Cityscapes
print("Human Pose Estimation")

# NOTE: cg_hpe and cg_seg return the network and some information on the # of output channels
r50_seg, r50_out_net_params = cg_seg(r50_pkl, name="ResNet50-HPE", net=True)
print("ResNet50 net params:", r50_out_net_params)
# NOTE: Seg networks return 2 outputs:
outputs = r50_seg(seg_input)
for i, o in enumerate(outputs):
    print(f"Output {i} shape: {o.shape}")

# NOTE the value for config_name - required since Arch1 is mutated from ResNet50
arch1_seg, arch1_out_net_params = cg_seg(arch1_pkl, config_name="resnet50.pkl", net=True)
print("AutoGO ResNet50 Arch 1 net params:", arch1_out_net_params)
# NOTE: Seg networks return 2 outputs:
outputs = arch1_seg(seg_input)
for i, o in enumerate(outputs):
    print(f"Output {i} shape: {o.shape}")