In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *

In [3]:
import fastai
fastai.defaults.device = torch.device('cpu')

In [4]:
from fastai.callbacks import *

In [5]:
import myutils as my

In [6]:
def create_head_sigmoid(nf:int, nc:int, lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5):
    """Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes.
    :param ps: dropout, can be a single float or a list for each layer."""
    lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]
    ps = listify(ps)
    if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.Sigmoid()] * (len(lin_ftrs)-2) + [None]
    layers = [AdaptiveConcatPool2d(), Flatten()]
    for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):
        layers += bn_drop_lin(ni,no,True,p,actn)
    del(layers[-3])
    return nn.Sequential(*layers)

In [7]:
path = Path('data/mnist_png/')
classes = list(range(8))
head = create_head_sigmoid(1024,8,lin_ftrs=[48])

In [8]:
tfms = get_transforms(do_flip=False)
data2 = ImageDataBunch.single_from_classes(path, classes, tfms=tfms, size=26)

In [9]:
learner = create_cnn(data2, models.resnet18, custom_head=head).load('test')

In [10]:
flatten_model(learner.model)[-1].out_features

8

In [11]:
class StoreHook(Callback):
    def __init__(self, module):
        super().__init__()
        self.custom_hook = hook_output(module)
        self.outputs = []
        
    def on_batch_end(self, train, **kwargs): 
        if (not train): self.outputs.append(self.custom_hook.stored)

In [12]:
nn_module = learner.model[-1][-3]
learner.callbacks = [ StoreHook(nn_module) ]

In [13]:
head

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Lambda()
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25)
  (4): Linear(in_features=1024, out_features=48, bias=True)
  (5): Sigmoid()
  (6): Dropout(p=0.5)
  (7): Linear(in_features=48, out_features=8, bias=True)
)

In [14]:
nn_module

Sigmoid()

In [15]:
for i in (path/'9').iterdir():
    print(i)

data/mnist_png/9/22.png
data/mnist_png/9/48.png
data/mnist_png/9/4.png
data/mnist_png/9/428.png
data/mnist_png/9/764.png
data/mnist_png/9/78.png
data/mnist_png/9/33.png
data/mnist_png/9/116.png
data/mnist_png/9/800.png
data/mnist_png/9/601.png
data/mnist_png/9/176.png
data/mnist_png/9/227.png
data/mnist_png/9/383.png
data/mnist_png/9/282.png
data/mnist_png/9/520.png
data/mnist_png/9/154.png
data/mnist_png/9/267.png
data/mnist_png/9/153.png
data/mnist_png/9/471.png
data/mnist_png/9/727.png
data/mnist_png/9/362.png
data/mnist_png/9/482.png
data/mnist_png/9/589.png
data/mnist_png/9/423.png
data/mnist_png/9/460.png
data/mnist_png/9/319.png
data/mnist_png/9/826.png
data/mnist_png/9/424.png
data/mnist_png/9/621.png
data/mnist_png/9/580.png
data/mnist_png/9/58.png
data/mnist_png/9/550.png
data/mnist_png/9/793.png
data/mnist_png/9/226.png
data/mnist_png/9/313.png
data/mnist_png/9/827.png
data/mnist_png/9/741.png
data/mnist_png/9/350.png
data/mnist_png/9/57.png
data/mnist_png/9/12.png
data/mnis

In [16]:
learner.callbacks[0].outputs = []
my.run_folder(path/'9', learner)

In [17]:
learner.callbacks[0].outputs

[tensor([[0.1202, 0.8056, 0.5248, 0.8215, 0.8313, 0.0351, 0.3893, 0.0109, 0.5518,
          0.2883, 0.0114, 0.2036, 0.9603, 0.0834, 0.0062, 0.6572, 0.6752, 0.2453,
          0.6100, 0.1568, 0.0047, 0.9695, 0.8021, 0.0258, 0.5471, 0.9645, 0.1366,
          0.1089, 0.6032, 0.0895, 0.4923, 0.2885, 0.0131, 0.9210, 0.0037, 0.1265,
          0.9299, 0.0938, 0.1684, 0.0148, 0.0295, 0.0390, 0.0198, 0.6520, 0.7621,
          0.9765, 0.5287, 0.0514]]),
 tensor([[0.0168, 0.0011, 0.9616, 0.9968, 0.0027, 0.9997, 0.9970, 0.9995, 0.0050,
          0.0032, 0.0001, 0.0019, 0.0048, 0.0186, 0.9823, 0.0008, 0.9937, 0.0004,
          0.0007, 0.0100, 0.0001, 0.9983, 0.0180, 0.7043, 0.6730, 0.9891, 0.0053,
          0.0000, 0.0022, 0.6604, 0.9899, 0.6453, 0.0001, 0.0006, 0.9700, 0.9876,
          0.3991, 0.0031, 0.4164, 0.9594, 0.9984, 0.8787, 0.9568, 0.0001, 0.0022,
          0.9962, 0.0000, 0.0000]]),
 tensor([[0.1476, 0.5072, 0.1292, 0.5230, 0.7236, 0.2115, 0.0595, 0.0131, 0.8663,
          0.5848, 0.0037

In [18]:
my.binary_code(learner)

In [19]:
test = learner.callbacks[0].outputs

In [20]:
ref9 = sum(test).float()/len(test) > 0.5

In [21]:
ref9

tensor([[0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0,
         1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0]],
       dtype=torch.uint8)

In [131]:

for i in range(0,len(learner.callbacks[0].outputs)):
    print(my.similarity(learner.callbacks[0].outputs[i][0], ref9))
    

tensor(0.7708)
tensor(0.8333)
tensor(0.8750)
tensor(0.7708)
tensor(0.6875)
tensor(0.7292)
tensor(0.7917)
tensor(0.7083)
tensor(0.7292)
tensor(0.7292)
tensor(0.6042)
tensor(0.7500)
tensor(0.7708)
tensor(0.7083)
tensor(0.7500)
tensor(0.7292)
tensor(0.7083)
tensor(0.6875)
tensor(0.6667)
tensor(0.6875)


In [22]:
def test_accu(path):
    learner.callbacks[0].outputs = []
    my.run_folder(path, learner)
    my.binary_code(learner)
    avg_accu = 0
    for i in range(0,len(learner.callbacks[0].outputs)):
        print(my.similarity(learner.callbacks[0].outputs[i][0], ref9))
        avg_accu += my.similarity(learner.callbacks[0].outputs[i][0], ref9)
    print('avg_accu: ' + str(avg_accu/len(learner.callbacks[0].outputs)))

In [32]:
test_accu(path/'0')

tensor(0.3958)
tensor(0.4167)
tensor(0.4167)
tensor(0.4167)
tensor(0.4375)
tensor(0.4167)
tensor(0.4167)
tensor(0.4167)
tensor(0.3958)
tensor(0.4375)
avg_accu: tensor(0.4167)


In [23]:
test_accu(path/'1')

tensor(0.3958)
tensor(0.3958)
tensor(0.3958)
tensor(0.3750)
tensor(0.3958)
tensor(0.3958)
tensor(0.3958)
tensor(0.8125)
tensor(0.3958)
tensor(0.4167)
avg_accu: tensor(0.4375)


In [24]:
test_accu(path/'2')

tensor(0.5625)
tensor(0.5417)
tensor(0.5417)
tensor(0.5417)
tensor(0.5417)
tensor(0.5417)
tensor(0.5625)
tensor(0.5417)
tensor(0.5417)
tensor(0.5417)
avg_accu: tensor(0.5458)


In [25]:
test_accu(path/'3')

tensor(0.4792)
tensor(0.5000)
tensor(0.5000)
tensor(0.4792)
tensor(0.4792)
tensor(0.5000)
tensor(0.4792)
tensor(0.4792)
tensor(0.5208)
tensor(0.5000)
avg_accu: tensor(0.4917)


In [26]:
test_accu(path/'4')

tensor(0.6250)
tensor(0.6042)
tensor(0.6042)
tensor(0.6042)
tensor(0.6042)
tensor(0.6042)
tensor(0.6042)
tensor(0.6042)
tensor(0.6042)
tensor(0.6042)
avg_accu: tensor(0.6062)


In [27]:
test_accu(path/'5')

tensor(0.4375)
tensor(0.4167)
tensor(0.4583)
tensor(0.4583)
tensor(0.4167)
tensor(0.4583)
tensor(0.4792)
tensor(0.4583)
tensor(0.4583)
tensor(0.4167)
avg_accu: tensor(0.4458)


In [28]:
test_accu(path/'6')

tensor(0.2917)
tensor(0.3125)
tensor(0.3125)
tensor(0.2917)
tensor(0.3125)
tensor(0.3125)
tensor(0.3125)
tensor(0.3125)
tensor(0.3125)
tensor(0.3125)
avg_accu: tensor(0.3083)


In [29]:
test_accu(path/'7')

tensor(0.9375)
tensor(0.9167)
tensor(0.9167)
tensor(0.9167)
tensor(0.9167)
tensor(0.8958)
tensor(0.7500)
tensor(0.9167)
tensor(0.9167)
tensor(0.9167)
avg_accu: tensor(0.9000)


In [30]:
test_accu(path/'8')

tensor(0.5417)
tensor(0.9375)
tensor(0.5208)
tensor(0.5208)
tensor(0.5417)
tensor(0.5208)
tensor(0.5000)
tensor(0.4583)
tensor(0.8333)
tensor(0.3750)
avg_accu: tensor(0.5750)


In [31]:
test_accu(path/'9')

tensor(0.9583)
tensor(0.4792)
tensor(0.7708)
tensor(0.9167)
tensor(0.9583)
tensor(0.7500)
tensor(0.9167)
tensor(0.9792)
tensor(0.9167)
tensor(0.6875)
tensor(0.9167)
tensor(0.9167)
tensor(0.9167)
tensor(0.8750)
tensor(0.9167)
tensor(0.9167)
tensor(0.9375)
tensor(0.9375)
tensor(0.9167)
tensor(0.8333)
tensor(0.7292)
tensor(0.7292)
tensor(0.7292)
tensor(0.9583)
tensor(0.6875)
tensor(0.9375)
tensor(0.7083)
tensor(0.6250)
tensor(0.8125)
tensor(0.9375)
tensor(0.9792)
tensor(0.8125)
tensor(1.)
tensor(0.9583)
tensor(1.)
tensor(0.8542)
tensor(0.8958)
tensor(0.9167)
tensor(0.9167)
tensor(0.8333)
tensor(0.9583)
tensor(0.8750)
tensor(0.8958)
tensor(0.4583)
tensor(0.9375)
tensor(0.6042)
tensor(0.9583)
tensor(0.4792)
tensor(0.8958)
tensor(0.7917)
tensor(0.7917)
tensor(0.9167)
tensor(0.8333)
tensor(0.8125)
tensor(0.9167)
tensor(0.9583)
tensor(0.8542)
tensor(0.6667)
tensor(0.9167)
tensor(0.8958)
tensor(0.9167)
tensor(0.9792)
tensor(0.5833)
tensor(0.9583)
tensor(0.6875)
tensor(0.9167)
tensor(0.9375)
ten