In [1]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch

In [3]:
epochs = 10
range_min = -999
range_max = 999
batch_size = 4 
num_classes = 2
num_train_samples = 100
num_test_samples = 10
num_digits = 3

In [5]:
def sep(x,num_digits):
    x = np.array([int(i) for i in str(x)])
    while (x.shape[0] < num_digits):
        x = np.insert(x,0,0)
    return x

In [16]:
x_train_a = np.random.randint(low = 0,high = range_max,size =num_train_samples)
x_train_b = np.random.randint(low = 0,high = range_max,size =num_train_samples)

x_train = np.empty((num_train_samples,2,num_digits+1),dtype = np.int32)

y_train_a = x_train_a * x_train_b
y_train = np.empty((num_train_samples,num_digits*2),dtype = np.int32)

for i in range(num_train_samples):
    x_train[i,0,:-1] = sep(x_train_a[i],num_digits)
    
    x_train[i,1,:-1] = sep(x_train_b[i],num_digits)
    y_train[i] = sep(y_train_a[i],num_digits*2)

x_train[:,:,-1] = num_digits
    
print('\t',x_train.shape)
print('\t',y_train.shape)

	 (100, 2, 4)
	 (100, 6)


In [17]:
train_data = torch.from_numpy(x_train).type(torch.FloatTensor)
train_labels = torch.from_numpy(y_train).type(torch.LongTensor)
# test_data = torch.from_numpy(x_test).type(torch.FloatTensor)
# test_labels = torch.from_numpy(y_test).type(torch.LongTensor)

In [11]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(2,30)
        self.fc2 = nn.Linear(30,82)
    
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x))
        return x
    
    def predict(self,x):
        pred = self.forward(x)
        _,out = torch.max(pred,1)
        return out.type(torch.FloatTensor)

model_mul_0_9 = torch.load('mul_0_9')

In [10]:
# This model adds two numbers
class Net_add(nn.Module):
    def __init__(self):
        super(Net_add,self).__init__()
        self.fc1 = nn.Linear(2,1,bias=False)
        self.fc1.weight.data = torch.tensor([[1,1]]).type(torch.FloatTensor)
    
    def forward(self,x):
        x = self.fc1(x)
        return x
    
    def predict(self,x):
        return self.forward(x).type(torch.FloatTensor)

model_add = torch.load('model_add')

In [12]:
# This model subtracts two numbers
class Net_sub(nn.Module):
    def __init__(self):
        super(Net_sub,self).__init__()
        self.fc1 = nn.Linear(2,1)
        self.fc1.weight.data = torch.tensor([[1,-1]]).type(torch.FloatTensor)
        self.fc1.bias.data = torch.tensor([0]).type(torch.FloatTensor)
    
    def forward(self,x):
        x = self.fc1(x)
        return x
    
    def predict(self,x):
        return self.forward(x).type(torch.FloatTensor)
    
model_sub = torch.load('model_sub')

In [13]:
# This model will separate a number into digits
# input: 2 inputs - (number,num_digits)
# output: separated digits (num_digits,)  

class Net_separate(nn.Module):
    def __init__(self,model_sub):
        super(Net_separate,self).__init__()
        
        self.divide_by_10 = nn.Linear(1,1,bias=False)
        self.divide_by_10.weight.data = torch.tensor([[0.1]]).type(torch.FloatTensor)

        self.multiply_by_10 = nn.Linear(1,1,bias=False)
        self.multiply_by_10.weight.data = torch.tensor([[10]]).type(torch.FloatTensor)
        
        self.model_sub = model_sub
    
    def forward(self,x):
        input_number = x[:,0].unsqueeze(-1)
        num_digits = x[:,1].unsqueeze(-1)
        concat = []
        data = input_number
        for i in range(num_digits[0].type(torch.IntTensor)):
            temp_data = self.divide_by_10(data).floor()
            temp = self.multiply_by_10(temp_data)
            concat_temp = torch.cat((data,temp),1)
            digit = self.model_sub.predict(concat_temp)
            
            concat.insert(0,digit)
            data = temp_data
        
        out = torch.cat(concat,1)
        return out
    
    def predict(self,x):
        return self.forward(x).type(torch.FloatTensor)
    
model_separate = torch.load('model_separate')

In [118]:
# This model will give the product of two absolute numbers
# input: 2 separated numbers with num_digits : (2,num_digits+1)
# output: product 

class Net_multiply_abs(nn.Module):
    def __init__(self,model_mul_0_9,model_add,model_separate):
        super(Net_multiply_abs,self).__init__()
        
        self.model_add = model_add
        self.model_separate = model_separate
        self.model_mul_0_9 = model_mul_0_9
        
        self.final_row_add_layers = []
    
    def forward(self,x):
        n = x.shape[-1] - 1
        m = x.shape[-1] - 1
        num_digits  = x[:,0,-1].unsqueeze(-1)
        
        # get initial products
        products = []
        for i in range(x.shape[-1]-1):
            mul = x[:,0,i:i+1]
            for j in range(x.shape[-1]-1):
                temp = x[:,1,j:j+1]
                concat = torch.cat((temp,mul),1)
                prod = self.model_mul_0_9.predict(concat)
                products.insert(0,prod)
                
        # get individual rows
        row_values = []
        prod_index = 0
        two = num_digits - num_digits + 2.0
        for r in range(n):
            for i in range(m+1):
                if i==0:
                    temp = torch.cat((products[prod_index].unsqueeze(-1),two),1)
                    sep = self.model_separate.predict(temp)
                    carry = sep[:,0:1]
                    ones_digit = sep[:,1:]
                    row_values.append(ones_digit)
                    prod_index+=1
                    
                elif i==m:
                    row_values.append(carry)
                else:
                    concat = torch.cat((products[prod_index].unsqueeze(-1),carry),1)
                    add = self.model_add(concat)
                    concat = torch.cat((add,two),1)
                    sep = self.model_separate.predict(concat)
                    carry = sep[:,0:1]
                    ones_digit = sep[:,1:]
                    row_values.append(ones_digit)
                    prod_index+=1
        
        final_row = []
        current_ind = 0 
        for i in range(m+n):
            if i ==0:
                final_row.insert(0,row_values[0])
                temp = torch.cat((row_values[0],two),1)
                sep = self.model_separate.predict(temp)
                carry = sep[:,0:1]
                current_ind+=1
            else:
                to_add = []
                ones = []
                to_add.append(row_values[current_ind])
                ones.append(1)
                to_add.append(carry)
                ones.append(1)
                
                next_ind = current_ind + m
                while(next_ind < n*(m+1) and next_ind%(m+1)!=0):
                    to_add.append(row_values[next_ind])
                    ones.append(1)
                    next_ind+=m
                if(next_ind < n * (m+1) and next_ind%(m+1)==0):
                    to_add.append(row_values[next_ind])
                    ones.append(1)

                self.final_row_add_layers.append(nn.Linear(len(to_add),1,bias=False))
                self.final_row_add_layers[-1].weight.data = torch.tensor([ones]).type(torch.FloatTensor)
                concat = torch.cat(to_add,-1)
                add = self.final_row_add_layers[-1](concat)
                
                temp = torch.cat((add,two),1)
                sep = self.model_separate.predict(temp)
                carry = sep[:,0:1]
                ones_digit = sep[:,1:]
                final_row.insert(0,ones_digit)
                
                if (current_ind + 1) % (m+1) !=0:
                    current_ind+=1
                else:
                    current_ind+=m+1
                
                
        out = torch.cat(final_row,-1)
        return out
    
    def predict(self,x):
        return self.forward(x).type(torch.FloatTensor)

In [119]:
net = Net_multiply_abs(model_mul_0_9,model_add,model_separate)

In [121]:
# print(train_data[0:1])
print(train_labels[0:10])
net.forward(train_data[0:10])

tensor([[0, 2, 0, 8, 8, 0],
        [6, 8, 3, 8, 0, 8],
        [4, 6, 1, 1, 2, 5],
        [4, 3, 9, 3, 0, 0],
        [3, 7, 6, 7, 1, 1],
        [0, 9, 3, 7, 5, 0],
        [2, 7, 9, 3, 9, 5],
        [6, 1, 9, 1, 4, 6],
        [0, 6, 4, 5, 6, 5],
        [2, 1, 9, 3, 6, 9]])


  if __name__ == '__main__':


tensor([[0., 2., 0., 8., 8., 0.],
        [6., 8., 3., 8., 0., 8.],
        [4., 6., 1., 1., 2., 5.],
        [4., 3., 9., 3., 0., 0.],
        [3., 7., 6., 7., 1., 1.],
        [0., 9., 3., 7., 5., 0.],
        [2., 7., 9., 3., 9., 5.],
        [6., 1., 9., 1., 4., 6.],
        [0., 6., 4., 5., 6., 5.],
        [2., 1., 9., 3., 6., 9.]], grad_fn=<CatBackward>)

In [122]:
torch.save(net,'model_multiply_abs')

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
