In [1]:
import numpy as np
import pickle as pkl
import torch
import torch.nn as nn
import os
import torch.optim as optim
from probe_experiment import ProbeExperiment
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt

In [2]:
with open('data/filtered_annotations.pkl', 'rb') as file:
    annotations = pkl.load(file)
annotations.sort(key = lambda ant : ant['f_name'])

train_ants = annotations[:int(len(annotations)*0.8)]
test_ants = annotations[int(len(annotations)*0.8):]

keywords = ['territory', 'cut', 'sente', 'shape', 'moyo',
            'ko', 'invasion', 'influence', 'wall', 'eye']

In [3]:
from go_model import GoModel
go_model = GoModel(None)
checkpoint = torch.load('model_ckpt.pth.tar', map_location=torch.device('cuda'))
state_dict = checkpoint['state_dict']
state_dict = {key[7:]:state_dict[key] for key in state_dict} #remove 'module.' prefix
go_model.load_state_dict(state_dict)

<All keys matched successfully>

In [4]:
def cut_model(model, cut_at_layer):
    class CutModel(nn.Module):
        def __init__(self, model, cut):
            super(CutModel, self).__init__()
            self.convs = model.convs[:cut]
            self.nonlinear = model.nonlinear
            self.cut = cut
        
        def forward(self, x):
            for i in range(self.cut):
                x = self.convs[i](x)
                x = self.nonlinear(x)
            return nn.Flatten()(x)
    return CutModel(model, cut_at_layer)

models = [cut_model(go_model, i) for i in range(8)]

In [5]:
layer_n_channels = [8, 64, 64, 64, 48, 48, 32]
layer_dims = [n*19*19 for n in layer_n_channels]
probe_models = [nn.Linear(dim, len(keywords)).cuda() for dim in layer_dims]

In [6]:
exp = ProbeExperiment(train_ants, test_ants, keywords)

In [None]:
preds = []

for i in range(8):
    feat_model = models[i]
    probe_model = probe_models[i]
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(probe_model.parameters())
    
    exp.run('depth%d_probe'%i, feat_model, probe_model, criterion, optimizer, batch_size=64, num_epochs=2)
    preds_, labels = exp.test_inference(feat_model, probe_model)
    preds.append(preds_)

 22%|█████████████████████████▌                                                                                        | 656/2926 [01:06<03:59,  9.46it/s]