In [1]:
import numpy as np
import torch
import torch.nn as nn
import os

In [2]:
import pickle as pkl
with open('data/filtered_annotations.pkl', 'rb') as file:
    annotations = pkl.load(file)

In [3]:
from probe_dataset import *
split = 0.8
annotations.sort(key = lambda ant : ant['f_name'])
train_ants = annotations[:int(len(annotations)*split)]
test_ants = annotations[int(len(annotations)*split):]

keywords = ['territory', 'cut', 'sente', 'shape', 'moyo',
            'ko', 'invasion', 'influence', 'wall', 'eye']
train_dataset = SPBoWDataset(train_ants, keywords)
test_dataset = SPBoWDataset(test_ants, keywords)

train_dataset = load_to_memory(train_dataset)
test_dataset = load_to_memory(test_dataset)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 187250/187250 [04:21<00:00, 715.24it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46813/46813 [01:04<00:00, 731.38it/s]


In [4]:
from models import *
go_model = load_go_model_from_ckpt('model_ckpt.pth.tar', rm_prefix=True)
feat_models = [CutModel(go_model, i).cuda() for i in range(8)]

In [9]:
layer_n_channels = [8, 64, 64, 64, 48, 48, 32, 32]
layer_dims = [n*19*19 for n in layer_n_channels]
def init_probe_model(layer):
    return nn.Linear(layer_dims[layer], len(keywords)).cuda()

In [10]:
from probe_experiment import ProbeExperiment
exp = ProbeExperiment(train_dataset, test_dataset, keywords)
criterion = nn.BCEWithLogitsLoss()

In [12]:
lrs = [0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001]
k = 7
optim = torch.optim.SGD
name = 'SGD'
feat_model = feat_models[k]

for lr in lrs:
    probe_model = init_probe_model(k)
    config = {'num_epochs':10, 'batch_size':512,
               'criterion':criterion,
               'optimizer':optim(probe_model.parameters(), lr=lr),
               'write_log':True, 'save_ckpt':False}
    exp.run('optim_exp/%s_layer%d_lr%f' % (name, k, lr), feat_model, probe_model, config)

[LOG] epoch 0 loss 0.188911, new best
[LOG] epoch 1 loss 0.169074, new best
[LOG] epoch 2 loss 0.164735, new best
[LOG] epoch 3 loss 0.163185, new best
[LOG] epoch 4 loss 0.162426, new best
[LOG] epoch 5 loss 0.161958, new best
[LOG] epoch 6 loss 0.161613, new best
[LOG] epoch 7 loss 0.161333, new best
[LOG] epoch 8 loss 0.161093, new best
[LOG] epoch 9 loss 0.160878, new best
[LOG] epoch 0 loss 0.300985, new best
[LOG] epoch 1 loss 0.221930, new best
[LOG] epoch 2 loss 0.194279, new best
[LOG] epoch 3 loss 0.181597, new best
[LOG] epoch 4 loss 0.174836, new best
[LOG] epoch 5 loss 0.170865, new best
[LOG] epoch 6 loss 0.168366, new best
[LOG] epoch 7 loss 0.166707, new best
[LOG] epoch 8 loss 0.165556, new best
[LOG] epoch 9 loss 0.164730, new best
[LOG] epoch 0 loss 0.473753, new best
[LOG] epoch 1 loss 0.362813, new best
[LOG] epoch 2 loss 0.300890, new best
[LOG] epoch 3 loss 0.263241, new best
[LOG] epoch 4 loss 0.238736, new best
[LOG] epoch 5 loss 0.221922, new best
[LOG] epoch 

KeyboardInterrupt: 

In [11]:
lrs = [0.1, 0.03, 0.01, 0.003, 0.001, 0.0003, 0.0001]
k = 7
optim = torch.optim.Adam
name = 'Adam'
feat_model = feat_models[k]

for lr in lrs:
    probe_model = init_probe_model(k)
    config = {'num_epochs':10, 'batch_size':512,
               'criterion':criterion,
               'optimizer':optim(probe_model.parameters(), lr=lr),
               'write_log':True, 'save_ckpt':False}
    exp.run('optim_exp/%s_layer%d_lr%f' % (name, k, lr), feat_model, probe_model, config)

[LOG] epoch 0 loss 0.229096, new best
[LOG] epoch 1 loss 0.276639
[LOG] epoch 2 loss 0.263926
[LOG] epoch 3 loss 0.273811
[LOG] epoch 4 loss 0.289331
[LOG] epoch 5 loss 0.294964
[LOG] epoch 6 loss 0.309539
[LOG] epoch 7 loss 0.304523
[LOG] epoch 8 loss 0.308369
[LOG] epoch 9 loss 0.318987
[LOG] epoch 0 loss 0.170148, new best
[LOG] epoch 1 loss 0.178172
[LOG] epoch 2 loss 0.185362
[LOG] epoch 3 loss 0.191040
[LOG] epoch 4 loss 0.198960
[LOG] epoch 5 loss 0.201314
[LOG] epoch 6 loss 0.206242
[LOG] epoch 7 loss 0.210041
[LOG] epoch 8 loss 0.213741
[LOG] epoch 9 loss 0.215433
[LOG] epoch 0 loss 0.158996, new best
[LOG] epoch 1 loss 0.161707
[LOG] epoch 2 loss 0.167029
[LOG] epoch 3 loss 0.168161
[LOG] epoch 4 loss 0.170790
[LOG] epoch 5 loss 0.173142
[LOG] epoch 6 loss 0.177294
[LOG] epoch 7 loss 0.178361
[LOG] epoch 8 loss 0.180899
[LOG] epoch 9 loss 0.183245
[LOG] epoch 0 loss 0.156678, new best
[LOG] epoch 1 loss 0.157287
[LOG] epoch 2 loss 0.158128
[LOG] epoch 3 loss 0.159249
[LOG] ep