In [None]:
# Mixture model

class Topic(nn.Module):
    
    def __init__(self, obs_dim, z_dim, no_topics):
        
        super(Topic, self).__init__()
        self.fc1 = nn.Linear(z_dim, no_topics)
        self.fc2 = nn.Linear(no_topics, obs_dim, bias = False)
        
    def encode(self, x):
        
        return torch.sigmoid(self.fc1(x))
    
    def rescale(self, y):
        
        return y / torch.outer(torch.sum(y, 1), torch.ones(y.shape[1]))

    def decode(self, z):
        
        return self.fc2(z)
    
    def forward(self, x):
        
        y = self.encode(x)
        z = self.rescale(y)
        x_reconst_mu = self.decode(z)
        
        return x_reconst_mu, z

In [None]:
# Weight initialization

AD_dense = AD.todense()[allgene_var > threshold_var, :].T
DP_dense = DP.todense()[allgene_var > threshold_var, :].T
binomial = np.empty((SNP.shape[1], cluster_no))

for u in range(cluster_no):
    
    binomial[:, u] = np.sum(AD_dense[cluster[u], :], 0) / np.sum(DP_dense[cluster[u], :], 0)
    
binomial[np.isnan(binomial)] = np.outer(np.nanmean((AD / DP)[allgene_var > threshold_var, :], 1), np.ones(cluster_no))[np.isnan(binomial)]    
binomial = torch.tensor(binomial, requires_grad = True).float()

coefficient = []
intercept = []

for w in range(cluster_no):
    
    dummy = np.zeros(SNP.shape[0])
    dummy[cluster[w]] = 1
    clf = LogisticRegression(max_iter = 1000).fit(latent.numpy(), dummy)
    coefficient.append(clf.coef_)
    intercept.append(clf.intercept_)
    
bias_initial = torch.tensor(np.array(intercept), requires_grad = True).float()
weight_initial = torch.tensor(np.array(coefficient), requires_grad = True).float()

tmodel = Topic(obs_dim, z_dim, cluster_no)
tmodel.fc1.bias = nn.Parameter(bias_initial.T[0], requires_grad = True)
tmodel.fc1.weight = nn.Parameter(weight_initial[:, 0, :], requires_grad = True)
tmodel.fc2.weight = nn.Parameter(torch.clone(binomial), requires_grad = True)

In [None]:
# Training

tnum_epochs = 1000
tlearning_rate = 0.005
tdata_loader = DataLoader(latent, SNP.shape[0])
optimizer = torch.optim.Adam(tmodel.parameters(), lr = tlearning_rate)

tcost_training = np.empty(tnum_epochs)
    
for epoch in range(tnum_epochs):
        
    for batch, x in enumerate(tdata_loader):
            
        x_reconst_mu, mu = tmodel(x)         
        x_reconst_mu = torch.clip(x_reconst_mu, min = 0, max = 1)
             
        loss = torch.mean(loss_fn(x_reconst_mu, SNP) * depth)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        tmodel.fc2.weight = nn.Parameter(torch.clip(tmodel.fc2.weight, min = 0, max = 1), requires_grad = True)

    tcost_training[epoch] = loss
    
latent_topic = tmodel.rescale(tmodel.encode(torch.tensor(latent))).detach().numpy()