In [1]:
from probe_experiment import *
from handcrafted_datasets import Dataset
from models import *
import torch
import torch.nn as nn
from sklearn.metrics import roc_curve, roc_auc_score

In [2]:
dataset = Dataset('eyes', 'C:\\Users\\andre\\go-ai\\data', 0.8, 0.2, 1024)

def get_train_loader():
    dataset.shuffle('train')
    return dataset.loader('train', max_ram_files=100)

def get_test_loader():
    return dataset.loader('test', max_ram_files=100)

done loading data
split sizes:
train 554
test 138


In [3]:
def flatten_bce_loss(pred, target):
    return torch.nn.BCEWithLogitsLoss()(pred, nn.Flatten()(target))
criterion = flatten_bce_loss

In [4]:
go_model = load_go_model_from_ckpt('model_ckpt.pth.tar', rm_prefix=True)
feat_models = [CutModel(go_model, i).cuda() for i in range(8)]
n_channels = [8, 64, 64, 64, 48, 48, 32, 32]
probe_models = [nn.Sequential(nn.Conv2d(nc, 1, 19, padding=9), nn.Flatten()).cuda() for nc in n_channels]

In [5]:
aucs = []
for depth in range(8):
    print('probing at depth %d' % depth)
    feat_model = feat_models[depth]
    probe_model = probe_models[depth]
    exp = ProbeExperiment(get_train_loader, get_test_loader, feat_model)
    config = {'name':'eyes/eyes_19x19cnn_d%d'%depth, 'write_log':True, 'progress_bar':False, 'save_ckpt':False}
    optimizer = torch.optim.Adam(probe_model.parameters())
    exp.run(probe_model, criterion, optimizer, 10, config)
    preds, labels = exp.get_predictions(probe_model, get_test_loader())
    auc = roc_auc_score(labels.flatten(), preds.flatten())
    aucs.append(auc)
    print(auc)

probing at depth 0
[LOG] epoch 0 loss 0.024897, new best
[LOG] epoch 1 loss 0.018803, new best
[LOG] epoch 2 loss 0.014821, new best
[LOG] epoch 3 loss 0.011855, new best
[LOG] epoch 4 loss 0.009625, new best
[LOG] epoch 5 loss 0.007980, new best
[LOG] epoch 6 loss 0.006710, new best
[LOG] epoch 7 loss 0.005726, new best
[LOG] epoch 8 loss 0.004961, new best
[LOG] epoch 9 loss 0.004302, new best
0.9999147633450837
probing at depth 1
[LOG] epoch 0 loss 0.006176, new best
[LOG] epoch 1 loss 0.002967, new best
[LOG] epoch 2 loss 0.001878, new best
[LOG] epoch 3 loss 0.001328, new best
[LOG] epoch 4 loss 0.000984, new best
[LOG] epoch 5 loss 0.000759, new best
[LOG] epoch 6 loss 0.000593, new best
[LOG] epoch 7 loss 0.000467, new best
[LOG] epoch 8 loss 0.000380, new best
[LOG] epoch 9 loss 0.000299, new best
0.9999999508091887
probing at depth 2
[LOG] epoch 0 loss 0.012909, new best
[LOG] epoch 1 loss 0.007618, new best
[LOG] epoch 2 loss 0.005588, new best
[LOG] epoch 3 loss 0.004483, ne

KeyboardInterrupt: 