In [None]:
import xarray
import pandas as pd
import seaborn
import os
import pickle
import statistics
import numpy as np

from neural_structural_optimization import pipeline_utils
from neural_structural_optimization import problems
from neural_structural_optimization import models
from neural_structural_optimization import topo_api
from neural_structural_optimization import train_switch
import matplotlib.pyplot as plt
from neural_structural_optimization.problems import PROBLEMS_BY_NAME

In [None]:
max_iterations=200
switch = [1,2,4,8,16,32,64]
width = [128,64,32,16,1]
cnn_kwargs=dict(resizes=(1, 1, 2, 2, 1))

In [None]:
examples = PROBLEMS_BY_NAME
example_list =list(examples)
len(set(list(examples)))

In [None]:
D={}
D_norm ={}

In [None]:
for name in example_list:
    example = examples[name]
    result = {}
    all=[]

    if os.path.exists("results/cnn_{}.pkl".format(name+"_"+str(max_iterations))):
        ds_cnn_all = pickle.load(open("results/cnn_{}.pkl".format(name+"_"+str(max_iterations)), 'rb'))
        result["CNN"] = ds_cnn_all.loss
    else:
        continue
    if os.path.exists("results/pixel_{}.pkl".format(name)):
        ds_pixel = pickle.load(open("results/pixel_{}.pkl".format(name), 'rb'))
        result["Pixel"] = ds_pixel.loss
    else:
        continue
    all =result['CNN'].to_dict()['data']+result['Pixel'].to_dict()['data']

    for i in switch:
        if os.path.exists("results/pixel_switch_{}.pkl".format(name+"_"+str(i))):
            ds_pixel_switch = pickle.load(open("results/pixel_switch_{}.pkl".format(name+"_"+str(i)), 'rb'))
            result["Pixel_switch_{}".format(i)] = ds_pixel_switch.loss
            all =all + result["Pixel_switch_{}".format(i)].to_dict()['data']
        else:
            continue
    
    D[name]=result
        
#NORMALIZATION
    result_norm={}
    cnn_norm=[]
    max_example = max(all)
    min_example = min(all)
    
    for k in result:
        loss_norm=[]
        for step_loss in result[k]:
            loss_norm.append((float(step_loss)-min_example)/(max_example-min_example))
        result_norm[k]=loss_norm

    D_norm[name]=result_norm

In [None]:
D_norm_min_loss={}
for example in list(D_norm):
    column={}
    for model in list(D_norm[example]):
        column[model]=min(list(D_norm[example][model]))
    D_norm_min_loss[example]=column

In [None]:
table=pd.DataFrame.from_dict(D_norm_min_loss)
table

In [None]:
import numpy as np
d = {}
for col in table:
    if any([np.isnan(v) for v in table[col].values]):
        pass
    else:
        d[col] = table[col]
table_filtered = pd.DataFrame(d)

In [None]:
mean_value_column=[]
for model in list(table_filtered.axes[0]):
    mean_val = statistics.mean(list(table_filtered.loc[model,:]))
    mean_value_column.append(mean_val)
    print(model," ",mean_val)

In [None]:
init_loss_norm=[]
for model in list(table_filtered.axes[0]):
    init_losses=[]
    for example in list(D_norm):
        if model in list(D_norm[example]):
            init_losses.append(list(D_norm[example][model])[0])
    init_loss_norm.append(statistics.mean(init_losses))
init_loss_norm

In [None]:
table_filtered['Mean Value']=mean_value_column
table_filtered['Mean Initial Loss']=init_loss_norm
table_filtered.sort_values('Mean Value')

In [None]:
table_mean_val={}

for i in table_filtered.axes[0]:
    table_mean_val[i]=round(table_filtered['Mean Value'][i],5)

table_summary = pd.DataFrame.from_dict(table_mean_val,orient='index')


# Summary curve


In [None]:
models = list(D_norm[list(D_norm)[0]])
examples = table_filtered.columns
examples = [e for e in examples if e != 'Mean Value'] 

In [None]:
%matplotlib inline
thresholds = np.linspace(0, 0.01, 50)
for model in models:
    curve = []
    for threshold in thresholds:
        percentage = len([example for example in examples if min(D_norm[example][model]) < threshold]) / float(len(examples))
        curve.append(percentage)
    plt.plot(thresholds, curve, label=model)
plt.xlabel("Relative error threshold")
plt.ylabel("Cumulative probability")
plt.legend()


# Examples


In [None]:
examples = PROBLEMS_BY_NAME
rng=np.random.RandomState(827)
example_list = list(rng.choice(list(examples),4,replace=False))
example_list

In [None]:
for name in example_list:
    
    if os.path.exists("results/cnn_{}.pkl".format(name+"_"+str(max_iterations))):
        ds_cnn = pickle.load(open("results/cnn_{}.pkl".format(name+"_"+str(max_iterations)), 'rb'))
    else:
        continue
    
    if os.path.exists("results/pixel_{}.pkl".format(name)):
        ds_pixel = pickle.load(open("results/pixel_{}.pkl".format(name), 'rb'))
    else:
        continue

    if os.path.exists("results/pixel_switch_{}.pkl".format(name+"_"+str(8))):
        ds_pixel_switch_8 = pickle.load(open("results/pixel_switch_{}.pkl".format(name+"_"+str(8)), 'rb'))
    else:
        continue   
    if os.path.exists("results/pixel_switch_{}.pkl".format(name+"_"+str(64))):
        ds_pixel_switch_64 = pickle.load(open("results/pixel_switch_{}.pkl".format(name+"_"+str(64)), 'rb'))
    else:
        continue  

   
    z_cnn =ds_cnn.design.sel(step=max_iterations)
    z_cnn = z_cnn.to_masked_array().data.reshape(1, z_cnn.shape[0], z_cnn.shape[1])
    plt.imshow(z_cnn[0])
    plt.savefig("eeml_results/"+"cnn_"+name+".pdf")
    plt.close()
    loss_cnn_all = float(ds_cnn_all.loss[max_iterations])
    
    step_pixel =(len(ds_pixel.design.loc[:,0].step)-1)
    z_pixel = ds_pixel.design.sel(step=step_pixel)
    z_pixel = z_pixel.to_masked_array().data.reshape(1, z_pixel.shape[0], z_pixel.shape[1])
    plt.imshow(z_pixel[0])
    plt.savefig("eeml_results/"+"pixel_"+name+".pdf")
    plt.close()
    
    step_pixel_switch_8 =(len(ds_pixel_switch_8.design.loc[:,0].step)-1)
    z_pixel_switch_8 = ds_pixel_switch_8.design.sel(step=step_pixel_switch_8)
    z_pixel_switch_8 = z_pixel_switch_8.to_masked_array().data.reshape(1, z_pixel_switch_8.shape[0], z_pixel_switch_8.shape[1])
    plt.imshow(z_pixel_switch_8[0])
    plt.savefig("eeml_results/"+"pixel_switch_8_"+name+".pdf")
    plt.close()    

    step_pixel_switch_64 =(len(ds_pixel_switch_64.design.loc[:,0].step)-1)
    z_pixel_switch_64 = ds_pixel_switch_64.design.sel(step=step_pixel_switch_64)
    z_pixel_switch_64 = z_pixel_switch_64.to_masked_array().data.reshape(1, z_pixel_switch_64.shape[0], z_pixel_switch_64.shape[1])
    plt.imshow(z_pixel_switch_64[0])
    plt.savefig("eeml_results/"+"pixel_switch_64_"+name+".pdf")
    plt.close()        
    
    
    from matplotlib.pyplot import figure
    figure(num=None, figsize=(8, 3.5), dpi=80, facecolor='w', edgecolor='k')
    
    print(ds_cnn.loss.to_pandas().cummin().plot(linewidth=2, label="CNN"))
    print(ds_pixel.loss.to_pandas().cummin().plot(linewidth=2, label="Pixel"))
    print(ds_pixel_switch_8.loss.to_pandas().cummin().plot(linewidth=2, label="Switch-8"))
    print(ds_pixel_switch_64.loss.to_pandas().cummin().plot(linewidth=2, label="Switch-64"))
     

    plt.legend()
    plt.yscale("log")
    plt.ylabel('Compliance (loss)')
    plt.xlabel('Optimization step')
    plt.savefig("eeml_results/"+name+"_losses"+".pdf")
    plt.close()
    seaborn.despine()
