In [1]:
from visualizations import *
from graphs import *

# Part 1. Visualize Attention

## 1.1. Choose a Model to Load
### Load MNIST Model

In [None]:
modelpath = "../saved/models/paper_mnist_model.pt" #path to the saved model
modeltype = "mnist" #type of model: 'mnist' or 'coco'
n = 2 #number of objects
strength = 0.2 #strength of attention
net, runner, test_loader = load_model_and_data(modelpath, n = n, strength = strength, modeltype = modeltype)

### ... OR Load COCO Model

In [2]:
modelpath = "../saved/models/paper_coco_model.pt" #path to the saved model
modeltype = "coco" #type of model: 'mnist' or 'coco'
cocoroot = '../../../../data/jordanlei/coco/images/val2017' #path to the coco val dataset
annpath = '../../../../data/jordanlei/coco/annotations/instances_val2017.json' #path to the coco val annotations
metadatapath = '../data/metadata/cocometadata_test.p' #path to metadata file (will create one if none exists)

n = 2 #number of objects
strength = 0.9 #strength of attention
net, runner, test_loader = load_model_and_data(modelpath, n = n, strength = strength, modeltype = modeltype,\
                                               cocoroot = cocoroot, annpath = annpath, metadatapath = metadatapath)

COCO Object-Based Attention Model v3
loading annotations into memory...
Done (t=0.73s)
creating index...
index created!


## Step 1.2 Show Images

In [None]:
for i, (x, data, labels) in enumerate(test_loader):
    print("INPUT IMAGE")
    toshow(x[0].detach().cpu().numpy())
    masks, hiddens, ior, _ = runner.visualize(x, data, labels)
    
    for j, mask in enumerate(masks): 
        hidden = hiddens[j]
        ior_mask = ior[j]
        
        print("PHASE %s"%(j + 1))
        print("\tmasked input")
        toshow(mask[0])
        print("\tattention mask")
        toshow(hidden[0])
        print("\tIOR (for next phase)")
        toshow(ior_mask[0] * 1.0)
    break

# Part 2. Plot Graphs

In [None]:
metric_files = ["model2.csv", "model3.csv", "model4.csv", "model5.csv"]

In [None]:
df = files_to_df(metric_files)
df.groupby(["lr","penalty"])["final_acc"].mean().unstack()

In [None]:
df = df[df["lr"] == 0.001]
plot_boxplots(df, save = False)