In [10]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [11]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *

In [12]:
import fastai
fastai.defaults.device = torch.device('cpu')

In [13]:
from fastai.callbacks import *

In [14]:
import myutils as my

In [15]:
def create_head_sigmoid(nf:int, nc:int, lin_ftrs:Optional[Collection[int]]=None, ps:Floats=0.5):
    """Model head that takes `nf` features, runs through `lin_ftrs`, and about `nc` classes.
    :param ps: dropout, can be a single float or a list for each layer."""
    lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]
    ps = listify(ps)
    if len(ps)==1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.Sigmoid()] * (len(lin_ftrs)-2) + [None]
    layers = [AdaptiveConcatPool2d(), Flatten()]
    for ni,no,p,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],ps,actns):
        layers += bn_drop_lin(ni,no,True,p,actn)
    del(layers[-3])
    return nn.Sequential(*layers)

In [16]:
path = Path('data/mnist_png/')
classes = list(range(8))
head = create_head_sigmoid(1024,8,lin_ftrs=[48])

In [17]:
tfms = get_transforms(do_flip=False)
data2 = ImageDataBunch.single_from_classes(path, classes, tfms=tfms, size=26)

In [19]:
learner = create_cnn(data2, models.resnet18, custom_head=head).load('test')

In [20]:
flatten_model(learner.model)[-1].out_features

8

In [21]:
class StoreHook(Callback):
    def __init__(self, module):
        super().__init__()
        self.custom_hook = hook_output(module)
        self.outputs = []
        
    def on_batch_end(self, train, **kwargs): 
        if (not train): self.outputs.append(self.custom_hook.stored)

In [22]:
nn_module = learner.model[-1][-3]
learner.callbacks = [ StoreHook(nn_module) ]

In [23]:
head

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Lambda()
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25)
  (4): Linear(in_features=1024, out_features=48, bias=True)
  (5): Sigmoid()
  (6): Dropout(p=0.5)
  (7): Linear(in_features=48, out_features=8, bias=True)
)

In [24]:
nn_module

Sigmoid()

In [25]:
for i in (path/'9').iterdir():
    print(i)

data/mnist_png/9/22.png
data/mnist_png/9/48.png
data/mnist_png/9/4.png
data/mnist_png/9/428.png
data/mnist_png/9/764.png
data/mnist_png/9/78.png
data/mnist_png/9/33.png
data/mnist_png/9/116.png
data/mnist_png/9/800.png
data/mnist_png/9/601.png
data/mnist_png/9/176.png
data/mnist_png/9/227.png
data/mnist_png/9/383.png
data/mnist_png/9/282.png
data/mnist_png/9/520.png
data/mnist_png/9/154.png
data/mnist_png/9/267.png
data/mnist_png/9/153.png
data/mnist_png/9/471.png
data/mnist_png/9/727.png
data/mnist_png/9/362.png
data/mnist_png/9/482.png
data/mnist_png/9/589.png
data/mnist_png/9/423.png
data/mnist_png/9/460.png
data/mnist_png/9/319.png
data/mnist_png/9/826.png
data/mnist_png/9/424.png
data/mnist_png/9/621.png
data/mnist_png/9/580.png
data/mnist_png/9/58.png
data/mnist_png/9/550.png
data/mnist_png/9/793.png
data/mnist_png/9/226.png
data/mnist_png/9/313.png
data/mnist_png/9/827.png
data/mnist_png/9/741.png
data/mnist_png/9/350.png
data/mnist_png/9/57.png
data/mnist_png/9/12.png
data/mnis

In [26]:
learner.callbacks[0].outputs = []
my.run_folder(path/'9', learner)

In [27]:
learner.callbacks[0].outputs

[tensor([[0.0062, 0.3018, 0.0653, 0.4681, 0.0023, 0.0541, 0.6735, 0.0126, 0.7343,
          0.0071, 0.0454, 0.8152, 0.5552, 0.2397, 0.3981, 0.8252, 0.0824, 0.0062,
          0.8677, 0.0085, 0.9289, 0.0129, 0.9909, 0.0795, 0.7134, 0.2710, 0.9317,
          0.0114, 0.4749, 0.0249, 0.6376, 0.0397, 0.5245, 0.8433, 0.1188, 0.0831,
          0.4174, 0.2334, 0.5094, 0.1513, 0.0404, 0.9460, 0.0026, 0.0062, 0.8272,
          0.9276, 0.3931, 0.6930]]),
 tensor([[0.1344, 0.1900, 0.2896, 0.8592, 0.0152, 0.8547, 0.0010, 0.2510, 0.0098,
          0.2019, 0.4214, 0.8322, 0.7771, 0.0550, 0.5601, 0.7957, 0.4192, 0.0417,
          0.1434, 0.0156, 0.1844, 0.5830, 0.6502, 0.1029, 0.4546, 0.2939, 0.5640,
          0.0598, 0.0083, 0.1236, 0.9760, 0.1095, 0.0031, 0.4315, 0.2118, 0.8254,
          0.0035, 0.9185, 0.7503, 0.0430, 0.4033, 0.7959, 0.0973, 0.0553, 0.8862,
          0.9696, 0.0018, 0.5309]]),
 tensor([[0.0044, 0.5393, 0.0015, 0.9418, 0.0020, 0.1062, 0.8445, 0.0040, 0.4962,
          0.0025, 0.0086

In [28]:
my.binary_code(learner)

In [30]:
test = learner.callbacks[0].outputs

In [31]:
ref9 = sum(test).float()/len(test) > 0.5

In [32]:
ref9

tensor([[0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0,
         0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1]],
       dtype=torch.uint8)

In [131]:

for i in range(0,len(learner.callbacks[0].outputs)):
    print(my.similarity(learner.callbacks[0].outputs[i][0], ref9))
    

tensor(0.7708)
tensor(0.8333)
tensor(0.8750)
tensor(0.7708)
tensor(0.6875)
tensor(0.7292)
tensor(0.7917)
tensor(0.7083)
tensor(0.7292)
tensor(0.7292)
tensor(0.6042)
tensor(0.7500)
tensor(0.7708)
tensor(0.7083)
tensor(0.7500)
tensor(0.7292)
tensor(0.7083)
tensor(0.6875)
tensor(0.6667)
tensor(0.6875)


In [45]:
def test_accu(path):
    learner.callbacks[0].outputs = []
    my.run_folder(path, learner)
    my.binary_code(learner)
    avg_accu = 0
    for i in range(0,len(learner.callbacks[0].outputs)):
        print(my.similarity(learner.callbacks[0].outputs[i][0], ref9))
        avg_accu += my.similarity(learner.callbacks[0].outputs[i][0], ref9)
    print('avg_accu: ' + str(avg_accu/len(learner.callbacks[0].outputs)))

In [46]:
test_accu(path/'1')

tensor(0.4167)
tensor(0.4375)
tensor(0.6042)
tensor(0.4167)
tensor(0.4375)
tensor(0.4167)
tensor(0.5417)
tensor(0.6667)
tensor(0.4375)
tensor(0.4167)
avg_accu: tensor(0.4792)


In [47]:
test_accu(path/'2')

tensor(0.6250)
tensor(0.5625)
tensor(0.5833)
tensor(0.5833)
tensor(0.5625)
tensor(0.5833)
tensor(0.6458)
tensor(0.4792)
tensor(0.5833)
tensor(0.5417)
avg_accu: tensor(0.5750)


In [48]:
test_accu(path/'3')

tensor(0.5208)
tensor(0.6667)
tensor(0.5208)
tensor(0.5000)
tensor(0.5208)
tensor(0.5000)
tensor(0.5208)
tensor(0.5208)
tensor(0.5208)
tensor(0.5208)
avg_accu: tensor(0.5312)


In [49]:
test_accu(path/'4')

tensor(0.5208)
tensor(0.5208)
tensor(0.5208)
tensor(0.5000)
tensor(0.5000)
tensor(0.5208)
tensor(0.5000)
tensor(0.5208)
tensor(0.5000)
tensor(0.5000)
avg_accu: tensor(0.5104)


In [50]:
test_accu(path/'5')

tensor(0.3958)
tensor(0.4375)
tensor(0.7500)
tensor(0.4375)
tensor(0.4792)
tensor(0.4375)
tensor(0.4167)
tensor(0.4375)
tensor(0.4167)
tensor(0.4375)
avg_accu: tensor(0.4646)


In [51]:
test_accu(path/'6')

tensor(0.3750)
tensor(0.3542)
tensor(0.4167)
tensor(0.3750)
tensor(0.3958)
tensor(0.4167)
tensor(0.3958)
tensor(0.3958)
tensor(0.3958)
tensor(0.3542)
avg_accu: tensor(0.3875)


In [52]:
test_accu(path/'7')

tensor(0.9583)
tensor(0.9375)
tensor(0.9167)
tensor(0.9167)
tensor(0.9167)
tensor(0.9583)
tensor(0.9167)
tensor(0.9167)
tensor(0.8958)
tensor(0.8750)
avg_accu: tensor(0.9208)


In [53]:
test_accu(path/'8')

tensor(0.7083)
tensor(0.9375)
tensor(0.6458)
tensor(0.6042)
tensor(0.5208)
tensor(0.6458)
tensor(0.5833)
tensor(0.5625)
tensor(0.9792)
tensor(0.7292)
avg_accu: tensor(0.6917)


In [54]:
test_accu(path/'9')

tensor(0.8750)
tensor(0.6875)
tensor(0.8125)
tensor(0.9167)
tensor(0.8958)
tensor(0.5417)
tensor(0.9583)
tensor(0.9375)
tensor(0.9375)
tensor(0.8125)
tensor(0.8542)
tensor(0.9792)
tensor(0.8958)
tensor(0.9167)
tensor(0.9167)
tensor(0.9375)
tensor(0.9583)
tensor(0.9375)
tensor(0.9583)
tensor(0.9375)
tensor(0.7083)
tensor(0.7708)
tensor(0.7500)
tensor(0.9792)
tensor(0.6042)
tensor(0.9792)
tensor(0.4792)
tensor(0.6250)
tensor(0.7708)
tensor(0.9792)
tensor(0.8750)
tensor(0.9792)
tensor(0.9167)
tensor(0.9167)
tensor(0.9375)
tensor(0.9375)
tensor(0.8958)
tensor(0.9583)
tensor(0.9375)
tensor(0.9375)
tensor(0.8750)
tensor(0.9583)
tensor(0.9583)
tensor(0.5625)
tensor(0.9375)
tensor(0.8750)
tensor(0.8542)
tensor(0.4583)
tensor(0.9583)
tensor(0.9792)
tensor(0.9583)
tensor(0.9792)
tensor(0.9167)
tensor(0.9167)
tensor(0.9792)
tensor(0.9375)
tensor(0.8750)
tensor(0.6667)
tensor(0.9375)
tensor(0.8333)
tensor(0.9583)
tensor(0.8958)
tensor(0.6250)
tensor(0.8750)
tensor(0.8750)
tensor(0.9792)
tensor(0.9