In [855]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [856]:
import sys
sys.path.append('../../')

In [857]:
from regr.graph.concept import EnumConcept
from regr.graph import Graph, Concept, Relation
from regr.graph.logicalConstrain import ifL, nandL, orL, notL, andL, atMostL

digitRange = 10
summationRange = digitRange * 2 - 1

Graph.clear()
Concept.clear()
Relation.clear()

with Graph(name='global') as graph:
    image = Concept(name='image')
    
    digit = image(name='digit', ConceptClass=EnumConcept, values=list(map(lambda v: f'd_{v}', range(digitRange))))
    
    addition = Concept(name='addition')
    (operand1, operand2) = addition.has_a(operand1=image, operand2=image)
    
    summation = addition(name='summation', ConceptClass=EnumConcept, values=list(map(lambda v: f's_{v}', range(summationRange))))

    ifL(image,  atMostL(*digit.attributes))
    ifL(addition, atMostL(*summation.attributes))
    
    for i in range(digitRange):
        for j in range(i, digitRange):
            sumVal = i + j
            
            ifL(
                getattr(digit, f'd_{i}')('i'),
                ifL(
                    getattr(digit, f'd_{j}')('j', path=('i', operand1.reversed, operand2)),
                    getattr(summation, f's_{i+j}')('a', path=('j', operand2.reversed))
                ),
                active = True
            )




In [858]:
import torch
from torch import nn

class Net(torch.nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.recognition = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[1], output_size),
                    #   nn.LogSoftmax(dim=1)
                      )
    def forward(self, x):
        y = self.recognition(x)
        return y

In [859]:
def sum_func(d_1_distr, d_2_distr, prob_func = lambda d: F.softmax(d, dim=1)):
    # given d_1 and d_2 logits, get P(d_1) and P(d_2)
    # using P(d_1) and P(d_2), find P(d_1 + d_2)
    
    #print(d_1_distr.shape)
    
    #print(d_1_distr)
    
    Pd_1 = prob_func(d_1_distr)[0]
    Pd_2 = prob_func(d_2_distr)[0]
    
    #print(Pd_1, Pd_1.shape)
    
    Pd_sum = torch.zeros((summationRange,))
    
    for i in range(digitRange):
        for j in range(digitRange):
            Pd_sum[i + j] += Pd_1[i] * Pd_2[j]
    
    #print(Pd_sum.shape)
    
    return Pd_sum

In [860]:
input_size = 784
hidden_sizes = [128, 64]
epochs = 15
lr = 0.01

In [861]:
import torch.nn.functional as F

dummy_d1 = F.softmax(torch.rand((digitRange,)), dim=0)
dummy_d2 = F.softmax(torch.rand((digitRange,)), dim=0)

In [862]:
from regr.program.model.ilpu import ILPUModel
from regr.program.metric import MacroAverageTracker, PRF1Tracker, DatanodeCMMetric, MultiClassCMWithLogitsMetric
from regr.program.loss import NBCrossEntropyLoss

class Model(ILPUModel):
    def __init__(self, graph):
        super().__init__(
            graph,
            poi=(image, addition, summation),
            loss=MacroAverageTracker(NBCrossEntropyLoss()),
            metric={
                'ILP': PRF1Tracker(DatanodeCMMetric()),
                'argmax': PRF1Tracker(DatanodeCMMetric('local/argmax'))},
            inferTypes=['ILP', 'local/argmax'])

In [863]:
from regr.sensor.pytorch.sensors import FunctionalSensor, ReaderSensor, ConstantSensor, JointSensor
from regr.sensor.pytorch.learners import ModuleLearner
from regr.program import LearningBasedProgram
from regr.sensor.pytorch.relation_sensors import EdgeSensor

class ConstantEdgeSensor(ConstantSensor, EdgeSensor): pass

image['pixels'] = ReaderSensor(keyword='pixels')

addition[operand1.reversed] = ConstantEdgeSensor(image['pixels'], data=[[1,0]], relation=operand1.reversed)
addition[operand2.reversed] = ConstantEdgeSensor(image['pixels'], data=[[0,1]], relation=operand1.reversed)

image['logits'] = ModuleLearner('pixels', module=Net(input_size, hidden_sizes, digitRange))

image[digit] = FunctionalSensor('logits', forward=lambda x: x)

def test(x1, x2):
    print(x1.shape, x2.shape)
    
    return torch.zeros(summationRange)

addition[summation] = ReaderSensor(keyword='summation', label=True)
addition[summation] = FunctionalSensor(operand2.reversed('logits'), operand2.reversed('logits'), forward=sum_func)

program = LearningBasedProgram(graph, Model)

In [864]:
from data import get_readers

In [865]:
trainloader, validloader, testloader = get_readers()

In [876]:
from regr.program import POIProgram, IMLProgram, SolverPOIProgram
from regr.program.metric import ValueTracker

#program = LearningBasedProgram(graph, Model)

program = SolverPOIProgram(graph,
                            poi=(image, addition, summation),
                            inferTypes=['local/argmax'],
                            loss=MacroAverageTracker(NBCrossEntropyLoss()),
                            metric={'argmax':ValueTracker(DatanodeCMMetric('local/argmax'))
                                   }
                           )

In [877]:
from functools import partial

import logging
logging.basicConfig(level=logging.INFO)

program.train(trainloader,
              valid_set=validloader,
              test_set=testloader,
              train_epoch_num=1,
              Optim=partial(torch.optim.SGD,
                            lr=lr),
              device='auto')

INFO:regr.program.program:Epoch: 1
INFO:regr.program.program:Training:
Epoch 1 Training: 100%|██████████| 300/300 [00:05<00:00, 52.10it/s]
INFO:regr.program.program: - loss:
INFO:regr.program.program:{'summation': tensor(2.9291)}
INFO:regr.program.program: - metric:
INFO:regr.program.program: - - argmax
INFO:regr.program.program:{'summation': [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, 

KeyboardInterrupt: 

In [870]:
node.getAttributes().keys()

dict_keys(['pixels', 'logits', '<digit>', '<digit>/local/softmax', '<digit>/local/argmax', '<digit>/ILP/x', '<digit>/ILP/xP', '<digit>/ILP'])

In [879]:
node.getRelationLinks()['addition'][0].getAttributes()

{'operand1.reversed': tensor([1, 0]),
 'operand2.reversed': tensor([0, 1]),
 '<summation>/label': 6}

In [892]:
for node in program.populate(validloader, device='auto'):
    node.inferILPResults()

    addition = node.getRelationLinks()['addition'][0]

    operands = addition.getRelationLinks()
    operand1 = operands['operand1'][0]
    operand2 = operands['operand2'][0]

    distr1 = operand1.getAttribute('<digit>/ILP')
    distr2 = operand2.getAttribute('<digit>/ILP')
    
    pred_digit_1 = torch.argmax(distr1)
    pred_digit_2 = torch.argmax(distr2)
    pred_sum = torch.argmax(sum_func(torch.unsqueeze(distr1, dim=0), torch.unsqueeze(distr2, dim=0)))
    
    print(pred_digit_1, pred_digit_2, pred_sum)

    break

dict_keys(['pixels', 'logits', '<digit>', '<digit>/local/softmax', '<digit>/local/argmax', '<digit>/ILP/x', '<digit>/ILP/xP', '<digit>/ILP'])
tensor(6) tensor(6) tensor(9)
