In [28]:
import scanpy as sc
import random
import scipy
from dit.shannon import mutual_information
from dit import *
import numpy as np
import argparse
from model import *
from scipy.stats import entropy
import torch 
import torch.nn as nn
from dit import Distribution

In [29]:
def parse_args():
	'''
	Parses the RandomJump arguments.
	'''
	parser = argparse.ArgumentParser(description="Run RandomJump.")
	parser.add_argument('--name', type=str, default='Marton',
	                    help='Name of the Dataset')

	#For the training of model for jump parameters 
	parser.add_argument('--epochs', type=int, default=100000,
	                    help='The number of epoch for jump parameter training')
	
	parser.add_argument('--case_num', type=int, default=100,
	                    help='The number of epoch for jump parameter training')
	
	parser.add_argument('--lr', type=float, default=0.01,
	                    help='The learning rate for jump parameter training')

	parser.add_argument('--dropout', type=float, default=0.6,
	                    help='The dropout for jump parameter training')

	parser.add_argument('--bias', action='store_true', default=True,
	                    help='Boolean specifying bias. Default is True.')
	
	parser.add_argument('--if_print', type=bool, default=True,
	                    help='Decide if print the case configurations')

	return parser.parse_args(args=[])



In [30]:
def matrices_normalize(matrix):
    matrix = matrix/matrix.sum(axis=1)[:,None]
    return matrix
def renorm_the_data(a0, a1, a2, m):
    a = [a0, a1, a2, m]
    a = np.array(a)
    a = np.reshape(a,(1,4))
    a = torch.from_numpy(a)
    a = a.requires_grad_()
    a = a.to(torch.float32)

    return a

Generate Cases for Training

In [31]:
def generate_train_case():
    T_1 = np.random.random((2, 2))
    T_1 = matrices_normalize(T_1)

    T_2 = np.random.random((2, 2))
    T_2 = matrices_normalize(T_2)

    X1 = np.random.random((1, 2))
    X1 = matrices_normalize(X1)
    X2 = np.random.random((1, 2))
    X2 = matrices_normalize(X2)

    Y1 = np.dot(X1, T_1)
    Z1 = np.dot(X1, T_2)

    Y2 = np.dot(X2, T_1)
    Z2 = np.dot(X2, T_2)

    X1 = np.squeeze(X1)
    X2 = np.squeeze(X2)
    Y1 = np.squeeze(Y1)
    Y2 = np.squeeze(Y2)
    Z1 = np.squeeze(Z1)
    Z2 = np.squeeze(Z2)

    return X1, Y1, Z1, X2, Y2, Z2

In [32]:
def print_the_case(case_num, T_1, T_2, X1, X2, a):
    
    print("The first transposition matrix:\n", T_1)
    print("The second transposition matrix:\n", T_2)
    print("The first message X1:\n", X1)
    print("The second message X2:\n", X2)
    print("For case", case_num, " Loss is smaller than 10^(-4):", a)
    

Parameter Initialization:

1. Two transposition matrices $T_1, T_2$;

2. Parameter of gammas: $a0, a1, a2$;

3. The sliding parameter: $m$;

In [33]:
T_1 = np.random.random((2, 2))
T_1 = matrices_normalize(T_1)

T_2 = np.random.random((2, 2))
T_2 = matrices_normalize(T_2)

print("The first transposition matrix:\n", T_1)
print("The second transposition matrix:\n", T_2)


The first transposition matrix:
 [[0.61038416 0.38961584]
 [0.10390053 0.89609947]]
The second transposition matrix:
 [[0.51502667 0.48497333]
 [0.67298263 0.32701737]]


Generate the random messages and go through the channels

In [34]:
X1 = np.random.random((1, 2))
X1 = matrices_normalize(X1)
X2 = np.random.random((1, 2))
X2 = matrices_normalize(X2)

Y1 = np.dot(X1, T_1)
Z1 = np.dot(X1, T_2)

Y2 = np.dot(X2, T_1)
Z2 = np.dot(X2, T_2)

X1 = np.squeeze(X1)
X2 = np.squeeze(X2)
Y1 = np.squeeze(Y1)
Y2 = np.squeeze(Y2)
Z1 = np.squeeze(Z1)
Z2 = np.squeeze(Z2)

dict_inf = {}
dict_inf['X1'] = X1
dict_inf['X2'] = X2
dict_inf['Y1'] = Y1
dict_inf['Y2'] = Y2
dict_inf['Z1'] = Z1
dict_inf['Z2'] = Z2
# dict_input = Distribution(dict_inf)


print("The first message X1:\n", X1)
print("Y1 after transpostion T_1:\n", Y1)
print("Z1 after transpostion T_2:\n", Z1)

print("----")

print("The second message X2:\n", X2)
print("Y2 after transpostion T_1:\n", Y2)
print("Z2 after transpostion T_2:\n", Z2)


The first message X1:
 [0.60471944 0.39528056]
Y1 after transpostion T_1:
 [0.41018103 0.58981897]
Z1 after transpostion T_2:
 [0.57746359 0.42253641]
----
The second message X2:
 [0.28768108 0.71231892]
Y2 after transpostion T_1:
 [0.24960629 0.75039371]
Z2 after transpostion T_2:
 [0.62754169 0.37245831]


Compute the Formulas.


In [35]:
# By Default, set a0, a1, a2 and m as 1
a0 = 1
a1 = 0.7
a2 = 0.5
m =0.8
V1_T1 = - a0 * m * entropy(X1,Y1) - a0 * (1-m) * entropy(X1,Z1) + max(a1 * entropy(X1,Y1), a2 * entropy(X1,Z1))
V1_T2 = - a0 * m * entropy(X2,Y2) - a0 * (1-m) * entropy(X2,Z2) + max(a1 * entropy(X2,Y2), a2 * entropy(X2,Z2))
V2_T12 = - a0 * m * entropy(X1*X2,Y1*Y2) - a0 * (1-m) * entropy(X1*X2,Z1*Z2) + max(a1 * entropy(X1*X2,Y1*Y2), a2 * entropy(X1*X2,Z1*Z2))

#print(mutual_information(dict_inf, ['X1','X2'],['Y1','Y2']))
#How to compute here?

print("V1_T1:\n", V1_T1)
print("V1_T2:\n", V1_T2)
print("V2_T12:\n", V2_T12)



V1_T1:
 -0.00795963017512362
V1_T2:
 0.06824722452032514
V2_T12:
 -0.018377250808475076


Define the model and loss function

In [36]:
args = parse_args()
criterion = V2_T12 - V1_T2 - V1_T1
model = CNN(args)
print(model)
learning_rate = args.lr
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

CNN(
  (fc1): Linear(in_features=4, out_features=4, bias=True)
  (layer4): Sequential(
    (0): Linear(in_features=4, out_features=4, bias=True)
    (1): Sigmoid()
    (2): Dropout(p=0.6, inplace=False)
  )
)


In [37]:
for case_num in range (0,args.case_num):
    X1, Y1, Z1, X2, Y2, Z2 = generate_train_case()
    for i in range(0,args.epochs):
        a = renorm_the_data(a0, a1, a2, m)
        a = model(args, a)
        a = a.detach().numpy()
        a = np.squeeze(a)
        a0 = a[0]
        a1 = a[1]
        a2 = a[2]
        m  = a[3]
        if a1+a2 > a0 and a0>a1 and a1 >= a2 and a2>=0 and m>=0 and m<1:
            V1_T1 = - a0 * m * entropy(X1,Y1) - a0 * (1-m) * entropy(X1,Z1) + max(a1 * entropy(X1,Y1), a2 * entropy(X1,Z1))
            V1_T2 = - a0 * m * entropy(X2,Y2) - a0 * (1-m) * entropy(X2,Z2) + max(a1 * entropy(X2,Y2), a2 * entropy(X2,Z2))
            V2_T12 = - a0 * m * entropy(X1*X2,Y1*Y2) - a0 * (1-m) * entropy(X1*X2,Z1*Z2) + max(a1 * entropy(X1*X2,Y1*Y2), a2 * entropy(X1*X2,Z1*Z2))
            criterion = torch.tensor(V2_T12 - V1_T2 - V1_T1).requires_grad_()
            #print("For ", i, " loss is:", criterion)
            train_para_vec(args, a, model, criterion, optimizer)
            #print(criterion.item())
            if(criterion.item() < 0.0001):
                if args.if_print:
                    print_the_case(case_num, T_1, T_2, X1, X2, a)
                break

        else:
            a = renorm_the_data(a0, a1, a2, m)
        


The first transposition matrix:
 [[0.61038416 0.38961584]
 [0.10390053 0.89609947]]
The second transposition matrix:
 [[0.51502667 0.48497333]
 [0.67298263 0.32701737]]
The first message X1:
 [0.39896997 0.60103003]
The second message X2:
 [0.17870937 0.82129063]
For case 2  Loss is smaller than 10^(-4): [1.191271  1.0803992 0.5908054 0.       ]
The first transposition matrix:
 [[0.61038416 0.38961584]
 [0.10390053 0.89609947]]
The second transposition matrix:
 [[0.51502667 0.48497333]
 [0.67298263 0.32701737]]
The first message X1:
 [0.46887333 0.53112667]
The second message X2:
 [0.7112106 0.2887894]
For case 3  Loss is smaller than 10^(-4): [1.6414692 1.538232  0.8821756 0.       ]
The first transposition matrix:
 [[0.61038416 0.38961584]
 [0.10390053 0.89609947]]
The second transposition matrix:
 [[0.51502667 0.48497333]
 [0.67298263 0.32701737]]
The first message X1:
 [0.02158088 0.97841912]
The second message X2:
 [0.36777161 0.63222839]
For case 4  Loss is smaller than 10^(-4): 

KeyboardInterrupt: 