In [1]:
from probe_experiment import *
from handcrafted_datasets import Dataset
from models import *
import torch
import torch.nn as nn
from sklearn.metrics import roc_curve, roc_auc_score

In [2]:
dataset = Dataset('surround', 'C:\\Users\\andre\\go-ai\\data', 0.8, 0.2, 1024)

def get_train_loader():
    dataset.shuffle('train')
    return dataset.loader('train', max_ram_files=25)

def get_test_loader():
    return dataset.loader('test', max_ram_files=25)

done loading data
split sizes:
train 552
test 138


In [3]:
def flatten_bce_loss(pred, target):
    return torch.nn.BCEWithLogitsLoss()(pred, nn.Flatten()(target))
criterion = flatten_bce_loss

In [4]:
go_model = load_go_model_from_ckpt('model_ckpt.pth.tar', rm_prefix=True)
feat_models = [CutModel(go_model, i).cuda() for i in range(8)]
n_channels = [8, 64, 64, 64, 48, 48, 32, 32]
probe_models = [nn.Sequential(nn.Conv2d(nc, 1, 19, padding=9), nn.Flatten()).cuda() for nc in n_channels]

In [5]:
aucs = []
for depth in range(1, 8):
    print('probing at depth %d' % depth)
    feat_model = feat_models[depth]
    probe_model = probe_models[depth]
    exp = ProbeExperiment(get_train_loader, get_test_loader, feat_model)
    config = {'name':'surround/surround_19x19cnn_d%d'%depth, 'write_log':True, 'progress_bar':False, 'save_ckpt':False}
    optimizer = torch.optim.Adam(probe_model.parameters())
    exp.run(probe_model, criterion, optimizer, 4, config)
    preds, labels = exp.get_predictions(probe_model, get_test_loader())
    auc = roc_auc_score(labels.flatten(), preds.flatten())
    aucs.append(auc)
    print(auc)

probing at depth 1
[LOG] epoch 0 loss 0.018990, new best
[LOG] epoch 1 loss 0.014795, new best
[LOG] epoch 2 loss 0.013469, new best
[LOG] epoch 3 loss 0.012375, new best
0.9969218020919967
probing at depth 2
[LOG] epoch 0 loss 0.024452, new best
[LOG] epoch 1 loss 0.019703, new best
[LOG] epoch 2 loss 0.017807, new best
[LOG] epoch 3 loss 0.016749, new best
0.9937401579200829
probing at depth 3
[LOG] epoch 0 loss 0.021852, new best
[LOG] epoch 1 loss 0.018065, new best
[LOG] epoch 2 loss 0.016437, new best
[LOG] epoch 3 loss 0.015862, new best
0.9947033812751381
probing at depth 4
[LOG] epoch 0 loss 0.031373, new best
[LOG] epoch 1 loss 0.026548, new best
[LOG] epoch 2 loss 0.024066, new best
[LOG] epoch 3 loss 0.022915, new best
0.9849409017514535
probing at depth 5
[LOG] epoch 0 loss 0.041520, new best
[LOG] epoch 1 loss 0.037042, new best
[LOG] epoch 2 loss 0.035056, new best
[LOG] epoch 3 loss 0.033995, new best
0.948609273482293
probing at depth 6
[LOG] epoch 0 loss 0.047609, new