In [1]:
import matplotlib.pyplot as plt
import torch

from toumei.probe import print_modules

from toumei.models import SimpleCNN

from ingredients.utils import *

device = torch.device('cuda')

from experiments.research.mnist_cnn.ingredients.mnist_dataset import MNISTDataset
dataset = MNISTDataset(device=device)

In [11]:
model = SimpleCNN(2, 1, redirected_relu=True).to(device)
model.load_state_dict(torch.load('model_weights(2).pth'))
print_modules(model)


Name            | Module                                                                      
--------------------------------------------------------------------------------------------
conv1           | Conv2d(2, 32, kernel_size=(3, 3), stride=(1, 1))                            
pool1           | MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  
relu1           | RedirectedReluLayer()                                                       
conv2           | Conv2d(32, 8, kernel_size=(3, 3), stride=(1, 1))                            
pool2           | MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  
relu2           | RedirectedReluLayer()                                                       
adaptive_pool   | AdaptiveAvgPool2d(output_size=(8, 8))                                       
fc1             | Linear(in_features=512, out_features=32, bias=True)                         
relu_fc1        | RedirectedReluLayer()            

In [12]:
from toumei.misc.model_broadness import BroadnessMeasurer
def gt_func(x):
    return x[:, 0] + x[:, 1]

In [13]:
bm = BroadnessMeasurer(model, dataset, nn.MSELoss(), gt_func)
bm.device = torch.device("cuda")
bm.run([0.0005*i for i in range(5)], num_itrs=100)

Measuring Broadness: 100%|██████████| 1/1 [00:00<00:00,  8.67it/s, mean_loss=0.196]
Measuring Broadness: 100%|██████████| 100/100 [00:03<00:00, 27.87it/s, mean_loss=3.42]
Measuring Broadness: 100%|██████████| 100/100 [00:03<00:00, 27.90it/s, mean_loss=6.51]
Measuring Broadness: 100%|██████████| 100/100 [00:03<00:00, 27.60it/s, mean_loss=9.78]
Measuring Broadness: 100%|██████████| 100/100 [00:03<00:00, 27.31it/s, mean_loss=13.4]
  losses_measured = np.array(losses_measured)


(array([array([0.]),
        array([3.46072508, 4.10355188, 2.30766894, 4.79612543, 3.20806862,
               2.22931291, 2.76524855, 2.23061253, 3.3514619 , 5.80408908,
               2.29609729, 3.20768858, 2.70322324, 3.16236784, 2.54911877,
               5.19566156, 2.59179546, 2.96692397, 2.67630626, 3.24573423,
               3.48290254, 3.66744091, 4.17602398, 2.32385804, 2.98793127,
               4.1041484 , 3.15680505, 3.38520886, 2.76834156, 3.17690183,
               2.39863373, 2.82470943, 5.60223152, 3.48056199, 2.28573824,
               2.63359667, 3.03087689, 2.40528156, 3.1690364 , 2.6020713 ,
               2.92894031, 3.4496503 , 3.99398758, 2.21916272, 2.93861867,
               2.66474916, 3.41340877, 2.66554667, 2.03008796, 2.62060501,
               3.24834658, 2.71214272, 6.19501258, 4.36343242, 4.48533775,
               2.49220468, 3.62255718, 2.9214728 , 3.76953699, 3.61179306,
               2.99586298, 4.06617452, 3.2172599 , 2.13304497, 1.88057519,
    

In [3]:
from toumei.cnns.featurevis.dataset_finder import DatasetFinder
for i in range(20):
    df = DatasetFinder(dataset=dataset, obj_func=obj.TargetWrapper(obj.Neuron("fc4:0"), i), sample_size=2048)
    df.attach(model)
    df.optimize(verbose=False)
    print(i, df.get_top_label())

0 tensor([0, 1], device='cuda:0')
1 tensor([0, 1], device='cuda:0')
2 tensor([1, 1], device='cuda:0')
3 tensor([1, 2], device='cuda:0')
4 tensor([3, 1], device='cuda:0')
5 tensor([3, 3], device='cuda:0')
6 tensor([3, 3], device='cuda:0')
7 tensor([6, 2], device='cuda:0')
8 tensor([0, 8], device='cuda:0')
9 tensor([0, 9], device='cuda:0')
10 tensor([3, 7], device='cuda:0')
11 tensor([6, 6], device='cuda:0')
12 tensor([6, 7], device='cuda:0')
13 tensor([5, 9], device='cuda:0')
14 tensor([5, 9], device='cuda:0')
15 tensor([7, 9], device='cuda:0')
16 tensor([9, 8], device='cuda:0')
17 tensor([9, 9], device='cuda:0')
18 tensor([9, 9], device='cuda:0')
19 tensor([9, 9], device='cuda:0')
