In [1]:
import torch
import torch.nn as nn
import scanpy as sc

from MaskedLinear import MaskedLinear, DetMaskLinear

In [2]:
adata = sc.read('mouse_retina_sbs.h5ad')

In [3]:
select_genes = adata.varm['I'].sum(1)>0

In [4]:
adata._inplace_subset_var(select_genes)

In [5]:
adata.X-=adata.X.mean(0)

In [6]:
EPOCH = 20
BATCH_SIZE = 62
LR = 0.005
ALPHA1 = 0.24
ALPHA2 = 0.17
ALPHA3 = 0.18
ALPHA4 = 0.2

In [7]:
class MaskedAutoencoder(nn.Module):
    def __init__(self, n_vars, n_terms, n_latent, estimator='ST', f_eval='Mode'):
        super().__init__()
        
        self.encoder = nn.Sequential(
            MaskedLinear(n_vars, n_terms, estimator, f_eval, bias=False),
            nn.ELU(),
            MaskedLinear(n_terms, n_latent, estimator, f_eval)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(n_latent, n_terms),
            nn.ELU(),
            nn.Linear(n_terms, n_vars),
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [8]:
def get_loss_func(I, alpha1, alpha2, alpha3, alpha4):
    l2_loss = nn.MSELoss()
    bce_loss = nn.BCEWithLogitsLoss()
    
    def regularized_loss(X, Y, logits1, logits2):
        sigm1 = torch.sigmoid(logits1)
        sigm2 = torch.sigmoid(logits2)
        return l2_loss(X, Y), alpha1*bce_loss(logits1, I.t())+alpha2*torch.mean(sigm1*(1-sigm1)), alpha3*torch.mean(sigm2)+alpha4*torch.mean(sigm2*(1-sigm2))
    
    return regularized_loss

In [9]:
def train_autoencoder(autoencoder, loss_func):
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
    
    t_X = torch.from_numpy(adata.X)
    
    logits1 = autoencoder.encoder[0].logits
    logits2 = autoencoder.encoder[2].logits

    for epoch in range(EPOCH):
        autoencoder.train()
        for step in range(int(adata.n_obs/BATCH_SIZE)):
            batch = torch.from_numpy(adata.chunk_X(BATCH_SIZE))
            optimizer.zero_grad()
            batch_loss = 0
            
            for sample in batch:
                sample = sample[None, :]
                encoded, decoded = autoencoder(sample)
                
                loss = sum(loss_func(decoded, sample, logits1, logits2))/BATCH_SIZE
                loss.backward()
                
                batch_loss += loss.data
            
            optimizer.step()
            if step % 100 == 0: print('Epoch: ', epoch, '| batch train loss: %.4f' % batch_loss.numpy())
        autoencoder.eval()
        _, t_decoded = autoencoder(t_X)
        
        t_loss = loss_func(t_decoded, t_X, logits1, logits2)
        t_loss = [sum(t_loss)] + list(t_loss)
        t_loss = [l.data.numpy() for l in t_loss]
        
        print('Epoch: ', epoch, '-- total train loss: %.4f=%.4f+%.4f+%.4f' % tuple(t_loss))

In [10]:
autoencoder = MaskedAutoencoder(adata.n_vars, len(adata.uns['terms']), 50)

In [11]:
I = torch.from_numpy(adata.varm['I']).float()
loss_func = get_loss_func(I, ALPHA1, ALPHA2, ALPHA3, ALPHA4)

In [12]:
train_autoencoder(autoencoder, loss_func)

Epoch:  0 | batch train loss: 0.5576
Epoch:  0 | batch train loss: 0.4497
Epoch:  0 | batch train loss: 0.4200
Epoch:  0 -- total train loss: 0.4096=0.1723+0.1348+0.1025
Epoch:  1 | batch train loss: 0.4080
Epoch:  1 | batch train loss: 0.3522
Epoch:  1 | batch train loss: 0.3300
Epoch:  1 -- total train loss: 0.3221=0.1534+0.0948+0.0739
Epoch:  2 | batch train loss: 0.3282
Epoch:  2 | batch train loss: 0.2913
Epoch:  2 | batch train loss: 0.2778
Epoch:  2 -- total train loss: 0.2733=0.1489+0.0711+0.0533
Epoch:  3 | batch train loss: 0.2831
Epoch:  3 | batch train loss: 0.2632
Epoch:  3 | batch train loss: 0.2395
Epoch:  3 -- total train loss: 0.2349=0.1397+0.0561+0.0391
Epoch:  4 | batch train loss: 0.2400
Epoch:  4 | batch train loss: 0.2299
Epoch:  4 | batch train loss: 0.2143
Epoch:  4 -- total train loss: 0.2096=0.1341+0.0459+0.0296
Epoch:  5 | batch train loss: 0.2149
Epoch:  5 | batch train loss: 0.2044
Epoch:  5 | batch train loss: 0.1931
Epoch:  5 -- total train loss: 0.1902=0

In [13]:
torch.save(autoencoder, 'auto_masked.pt')

  "type " + obj.__name__ + ". It won't be checked "


In [14]:
ALPHA1 = 0.22
ALPHA2 = 0.15

In [15]:
class MaskedLinAutoencoder(nn.Module):
    def __init__(self, n_vars, n_terms, n_latent, estimator='ST', f_eval='Mode'):
        super().__init__()
        
        self.encoder = nn.Sequential(
            MaskedLinear(n_vars, n_terms, estimator, f_eval, bias=False),
            nn.ELU(),
            nn.Linear(n_terms, n_latent)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(n_latent, n_terms),
            nn.ELU(),
            nn.Linear(n_terms, n_vars),
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [16]:
def get_loss_func(I, alpha1, alpha2):
    l2_loss = nn.MSELoss()
    bce_loss = nn.BCEWithLogitsLoss()
    
    def regularized_loss(X, Y, logits):
        sigm1 = torch.sigmoid(logits)
        return l2_loss(X, Y), alpha1*bce_loss(logits, I.t())+alpha2*torch.mean(sigm1*(1-sigm1))
    
    return regularized_loss

In [17]:
def train_autoencoder(autoencoder, loss_func):
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
    
    t_X = torch.from_numpy(adata.X)
    
    logits = autoencoder.encoder[0].logits

    for epoch in range(EPOCH):
        autoencoder.train()
        for step in range(int(adata.n_obs/BATCH_SIZE)):
            batch = torch.from_numpy(adata.chunk_X(BATCH_SIZE))
            optimizer.zero_grad()
            batch_loss = 0
            
            for sample in batch:
                sample = sample[None, :]
                encoded, decoded = autoencoder(sample)
                
                loss = sum(loss_func(decoded, sample, logits))/BATCH_SIZE
                loss.backward()
                
                batch_loss += loss.data
            
            optimizer.step()
            if step % 100 == 0: print('Epoch: ', epoch, '| batch train loss: %.4f' % batch_loss.numpy())
        autoencoder.eval()
        _, t_decoded = autoencoder(t_X)
        
        t_loss = loss_func(t_decoded, t_X, logits)
        t_loss = [sum(t_loss)] + list(t_loss)
        t_loss = [l.data.numpy() for l in t_loss]
        
        print('Epoch: ', epoch, '-- total train loss: %.4f=%.4f+%.4f' % tuple(t_loss))

In [18]:
autoencoder = MaskedLinAutoencoder(adata.n_vars, len(adata.uns['terms']), 50)

In [19]:
loss_func = get_loss_func(I, ALPHA1, ALPHA2)

In [20]:
train_autoencoder(autoencoder, loss_func)

Epoch:  0 | batch train loss: 0.3905
Epoch:  0 | batch train loss: 0.3084
Epoch:  0 | batch train loss: 0.2768
Epoch:  0 -- total train loss: 0.2674=0.1416+0.1258
Epoch:  1 | batch train loss: 0.2629
Epoch:  1 | batch train loss: 0.2459
Epoch:  1 | batch train loss: 0.2324
Epoch:  1 -- total train loss: 0.2229=0.1323+0.0906
Epoch:  2 | batch train loss: 0.2225
Epoch:  2 | batch train loss: 0.2020
Epoch:  2 | batch train loss: 0.1963
Epoch:  2 -- total train loss: 0.1962=0.1281+0.0681
Epoch:  3 | batch train loss: 0.2025
Epoch:  3 | batch train loss: 0.1896
Epoch:  3 | batch train loss: 0.1774
Epoch:  3 -- total train loss: 0.1794=0.1266+0.0529
Epoch:  4 | batch train loss: 0.1751
Epoch:  4 | batch train loss: 0.1801
Epoch:  4 | batch train loss: 0.1578
Epoch:  4 -- total train loss: 0.1674=0.1251+0.0423
Epoch:  5 | batch train loss: 0.1564
Epoch:  5 | batch train loss: 0.1700
Epoch:  5 | batch train loss: 0.1507
Epoch:  5 -- total train loss: 0.1592=0.1246+0.0346
Epoch:  6 | batch trai

In [21]:
torch.save(autoencoder, 'auto_masked_2nd_lin.pt')

  "type " + obj.__name__ + ". It won't be checked "


In [22]:
class DetAutoencoder(nn.Module):
    def __init__(self, I, n_vars, n_terms, n_latent):
        super().__init__()
        
        self.encoder = nn.Sequential(
            DetMaskLinear(I, n_vars, n_terms, bias=False),
            nn.ELU(),
            nn.Linear(n_terms, n_latent)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(n_latent, n_terms),
            nn.ELU(),
            nn.Linear(n_terms, n_vars),
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [23]:
def train_autoencoder(autoencoder, loss_func):
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)

    t_X = torch.from_numpy(adata.X)

    for epoch in range(EPOCH):

        for step in range(int(adata.n_obs/BATCH_SIZE)):
            X = torch.from_numpy(adata.chunk_X(BATCH_SIZE))
            encoded, decoded = autoencoder(X)
            loss = loss_func(decoded, X)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step % 100 == 0: print('Epoch: ', epoch, '| batch train loss: %.4f' % loss.data.numpy())

        _, t_decoded = autoencoder(t_X)
        t_loss = loss_func(t_decoded, t_X)
        print('Epoch: ', epoch, '-- total train loss: %.4f' % t_loss.data.numpy())

In [24]:
autoencoder = DetAutoencoder(I, adata.n_vars, len(adata.uns['terms']), 50)

In [25]:
train_autoencoder(autoencoder, nn.MSELoss())

Epoch:  0 | batch train loss: 0.1894
Epoch:  0 | batch train loss: 0.1301
Epoch:  0 | batch train loss: 0.1262
Epoch:  0 -- total train loss: 0.1284
Epoch:  1 | batch train loss: 0.1197
Epoch:  1 | batch train loss: 0.1335
Epoch:  1 | batch train loss: 0.1250
Epoch:  1 -- total train loss: 0.1258
Epoch:  2 | batch train loss: 0.1161
Epoch:  2 | batch train loss: 0.1257
Epoch:  2 | batch train loss: 0.1236
Epoch:  2 -- total train loss: 0.1249
Epoch:  3 | batch train loss: 0.1270
Epoch:  3 | batch train loss: 0.1179
Epoch:  3 | batch train loss: 0.1267
Epoch:  3 -- total train loss: 0.1243
Epoch:  4 | batch train loss: 0.1222
Epoch:  4 | batch train loss: 0.1229
Epoch:  4 | batch train loss: 0.1177
Epoch:  4 -- total train loss: 0.1237
Epoch:  5 | batch train loss: 0.1168
Epoch:  5 | batch train loss: 0.1215
Epoch:  5 | batch train loss: 0.1213
Epoch:  5 -- total train loss: 0.1234
Epoch:  6 | batch train loss: 0.1217
Epoch:  6 | batch train loss: 0.1269
Epoch:  6 | batch train loss: 0.

In [26]:
torch.save(autoencoder, 'auto_masked_det.pt')

  "type " + obj.__name__ + ". It won't be checked "
