In [1]:
import numpy as np
from scipy.spatial.distance import cdist
import scprep
import torch
import sys
sys.path.append('..')
from utils import sinkhorn_knopp_unbalanced

In [2]:
def load_data_full(datafile):
    dat = np.load(datafile)
    return dat['pca'][:, :10], dat['sample_labels']

def get_transform_matrix(gamma, a, epsilon=1e-8):
    """Return matrix such that T @ a = b
    gamma : gamma @ 1 = a; gamma^T @ 1 = b
    """
    return (np.diag(1.0 / (a + epsilon)) @ gamma).T

def get_growth_coeffs(gamma, a, epsilon=1e-8, normalize=False):
    T = get_transform_matrix(gamma, a, epsilon)
    unnormalized_coeffs = np.sum(T, axis=0)
    if not normalize:
        return unnormalized_coeffs
    return unnormalized_coeffs / np.sum(unnormalized_coeffs) * len(unnormalized_coeffs)

In [3]:
def get_all_growth_coeffs(alpha):
    gcs = []
    for i in range(len(dfs) - 1):
        a, b = dfs[i], dfs[i + 1]
        m, n = a.shape[0], b.shape[0]
        M = cdist(a, b)
        entropy_reg = 0.1
        reg_1, reg_2 = alpha, 10000
        gamma = sinkhorn_knopp_unbalanced(
            np.ones(m) / m, np.ones(n) / n, M, entropy_reg, reg_1, reg_2
        )
        gc = get_growth_coeffs(gamma, np.ones(m) / m)
        gcs.append(gc)
    return gcs

In [4]:
data, labels = load_data_full('data/endocrine_cells_for_trajectorynet.npz')
    
# Compute couplings
timepoints = np.unique(labels)
dfs = [data[labels == tp] for tp in timepoints]

In [5]:
for alpha in [0.01, 0.1, 1, 2, 5, 10]:
    print (alpha)
    data, labels = load_data_full('data/endocrine_cells_for_trajectorynet.npz')
    
    # Compute couplings
    timepoints = np.unique(labels)
    dfs = [data[labels == tp] for tp in timepoints]
    
    gcs = get_all_growth_coeffs(alpha)
    gcs = np.concatenate(gcs)
    print(gcs.shape)
    np.save(f'results/gcs_4tp_{alpha}.npy', gcs)

0.01
(11802,)
0.1
(11802,)
1
(11802,)
2
(11802,)
5
(11802,)
10
(11802,)


## Get model

In [6]:
class GrowthNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = torch.nn.Linear(11, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 1)

    def forward(self, x):
        x = torch.nn.functional.leaky_relu(self.fc1(x))
        x = torch.nn.functional.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [7]:
alpha = 2 # default from original TrajectoryNet work
data, labels = load_data_full('data/endocrine_cells_for_trajectorynet.npz')
gcs = np.load(f'results/gcs_4tp_{alpha}.npy')
timepoints = np.unique(labels)

X = np.concatenate([data, labels[:, None]], axis=1)[labels != timepoints[-1]]
Y = gcs[:, None]

device = torch.device("cuda:" + str(1) if torch.cuda.is_available() else "cpu")

model = GrowthNet().to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters())

for it in range(100000):
    optimizer.zero_grad()
    batch_idx = np.random.randint(len(X), size=256)
    x = torch.from_numpy(X[batch_idx,:]).type(torch.float32).to(device)
    y = torch.from_numpy(Y[batch_idx,:]).type(torch.float32).to(device)
    negative_samples = np.concatenate([np.random.uniform(size=(256,X.shape[1] - 1)) * 8 - 4,
                                       np.random.choice(timepoints, size=(256,1))], axis=1)
    negative_samples = torch.from_numpy(negative_samples).type(torch.float32).to(device)
    x = torch.cat([x, negative_samples])
    y = torch.cat([y, torch.ones_like(y)])
    pred = model(x)
    loss = torch.nn.MSELoss()
    output = loss(pred, y)
    output.backward()
    optimizer.step()
    if it % 100 == 0:
        print(it, output)

torch.save(model, 'results/endocrine_cells_growth_model')

0 tensor(1.1218, grad_fn=<MseLossBackward0>)
100 tensor(0.0254, grad_fn=<MseLossBackward0>)
200 tensor(0.0156, grad_fn=<MseLossBackward0>)
300 tensor(0.0131, grad_fn=<MseLossBackward0>)
400 tensor(0.0120, grad_fn=<MseLossBackward0>)
500 tensor(0.0144, grad_fn=<MseLossBackward0>)
600 tensor(0.0122, grad_fn=<MseLossBackward0>)
700 tensor(0.0105, grad_fn=<MseLossBackward0>)
800 tensor(0.0105, grad_fn=<MseLossBackward0>)
900 tensor(0.0103, grad_fn=<MseLossBackward0>)
1000 tensor(0.0101, grad_fn=<MseLossBackward0>)
1100 tensor(0.0091, grad_fn=<MseLossBackward0>)
1200 tensor(0.0097, grad_fn=<MseLossBackward0>)
1300 tensor(0.0115, grad_fn=<MseLossBackward0>)
1400 tensor(0.0097, grad_fn=<MseLossBackward0>)
1500 tensor(0.0103, grad_fn=<MseLossBackward0>)
1600 tensor(0.0097, grad_fn=<MseLossBackward0>)
1700 tensor(0.0090, grad_fn=<MseLossBackward0>)
1800 tensor(0.0092, grad_fn=<MseLossBackward0>)
1900 tensor(0.0111, grad_fn=<MseLossBackward0>)
2000 tensor(0.0068, grad_fn=<MseLossBackward0>)
2100