In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from module import *
from earlystopping import *modu
import gc

In [3]:
# load dataset
feature_path = '../dataset/'
train_cite_X = np.load(feature_path+'train_cite_X.npy')
test_cite_X = np.load(feature_path+'test_cite_X.npy')
train_cite_y = np.load(feature_path+'train_cite_targets.npy') 

In [4]:
A = train_cite_X  
B = train_cite_y  
A_tensor = torch.tensor(A, dtype=torch.float32)
B_tensor = torch.tensor(B, dtype=torch.float32)
dataset = TensorDataset(A_tensor, B_tensor)

In [5]:
train_cite_X.shape, train_cite_y.shape

((70988, 1009), (70988, 140))

In [7]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
criterion = nn.MSELoss()

In [8]:
%%time
# Training
k_folds = 5
kfold = KFold(n_splits=k_folds, shuffle=True)

num_epochs = 50

for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    print(f'FOLD {fold}')
    print('--------------------------------')

    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

    trainloader = DataLoader(dataset, batch_size=8192, sampler=train_subsampler,num_workers=8)
    testloader = DataLoader(dataset, batch_size=8192, sampler=test_subsampler,num_workers=8)

    generator, discriminator = create_models()
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)

    early_stopping = EarlyStopping(patience=5, min_delta=0.01)

    for epoch in range(num_epochs):
        generator.train()
        discriminator.train()

        for data in trainloader:
            A_batch, B_batch = data
            A_batch, B_batch = A_batch.to(device), B_batch.to(device)

            # Train Discriminator
            d_optimizer.zero_grad()
            real_output = discriminator(B_batch)
            fake_B = generator(A_batch)
            fake_output = discriminator(fake_B.detach())
            d_loss_real = criterion(real_output, torch.ones_like(real_output))
            d_loss_fake = criterion(fake_output, torch.zeros_like(fake_output))
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()
            fake_output = discriminator(fake_B)
            g_loss = criterion(fake_output, torch.ones_like(fake_output)) + criterion(fake_B, B_batch)
            g_loss.backward()
            g_optimizer.step()

        # Validate
        generator.eval()
        val_loss = 0
        with torch.no_grad():
            for data in testloader:
                A_batch, B_batch = data
                A_batch, B_batch = A_batch.to(device), B_batch.to(device)
                fake_B = generator(A_batch)
                val_loss += criterion(fake_B, B_batch).item()
        
        val_loss /= len(testloader)
        print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, val_loss: {val_loss:.4f}')

        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping")
            break




FOLD 0
--------------------------------


  from .autonotebook import tqdm as notebook_tqdm


Epoch [1/50], d_loss: 0.1902, g_loss: 11.5589, val_loss: 10.7048
Epoch [2/50], d_loss: 0.1150, g_loss: 8.4411, val_loss: 7.4997
Epoch [3/50], d_loss: 0.0505, g_loss: 6.3626, val_loss: 5.4850
Epoch [4/50], d_loss: 0.0154, g_loss: 5.9516, val_loss: 4.9735
Epoch [5/50], d_loss: 0.0056, g_loss: 5.6897, val_loss: 4.7968
Epoch [6/50], d_loss: 0.0022, g_loss: 5.4459, val_loss: 4.4443
Epoch [7/50], d_loss: 0.0020, g_loss: 5.0378, val_loss: 4.0786
Epoch [8/50], d_loss: 0.0042, g_loss: 4.7848, val_loss: 3.9401
Epoch [9/50], d_loss: 0.0146, g_loss: 4.8377, val_loss: 4.0395
Epoch [10/50], d_loss: 0.0255, g_loss: 5.0580, val_loss: 4.0211
Epoch [11/50], d_loss: 0.0015, g_loss: 4.8313, val_loss: 3.8224
Epoch [12/50], d_loss: 0.0021, g_loss: 4.7325, val_loss: 3.7033
Epoch [13/50], d_loss: 0.0004, g_loss: 4.5692, val_loss: 3.6301
Epoch [14/50], d_loss: 0.0006, g_loss: 4.5084, val_loss: 3.5705
Epoch [15/50], d_loss: 0.0003, g_loss: 4.5760, val_loss: 3.5379
Epoch [16/50], d_loss: 0.0004, g_loss: 4.4441, 

In [11]:
def generate_B_from_A(new_A):
    new_A_tensor = torch.tensor(new_A, dtype=torch.float32)
    dataset = TensorDataset(new_A_tensor)
    dataloader = DataLoader(dataset, batch_size=64)

    generator.eval()
    generated_B = []
    with torch.no_grad():
        for data in dataloader:
            A_batch = data[0].to(device)
            fake_B = generator(A_batch)
            generated_B.append(fake_B.cpu().numpy())
    
    generated_B = np.concatenate(generated_B, axis=0)
    return generated_B

generated_B = generate_B_from_A(test_cite_X)
generated_B

array([[ 0.5994527 ,  0.51441234,  0.97809494, ...,  0.93337363,
         4.049357  ,  4.8270545 ],
       [ 0.5995103 ,  0.51434344,  0.97802305, ...,  0.9333096 ,
         4.0486646 ,  4.8266973 ],
       [ 0.5865684 ,  0.520134  ,  0.9948261 , ...,  0.943077  ,
         4.16027   ,  4.852188  ],
       ...,
       [ 0.8174296 ,  0.40102977,  1.154908  , ...,  1.200386  ,
         6.4875264 ,  5.346298  ],
       [-0.09853111,  0.12912661,  1.0921547 , ...,  0.70547223,
         4.8214574 ,  3.651095  ],
       [ 0.03799064,  0.32101083,  1.3517659 , ...,  0.82057786,
         6.433246  ,  4.56198   ]], dtype=float32)

In [12]:
generated_B.shape

(48663, 140)

In [None]:
######Data collation

In [13]:
input_path = '../dataset/'
metadata = pd.read_csv(input_path+'metadata.csv')[['cell_id','technology']]
evaluation_ids = pd.read_csv(input_path+'evaluation_ids.csv')
evaluation_ids = evaluation_ids.merge(metadata, on=['cell_id'], how='left')

# cite
train_cite_targets = pd.read_hdf(input_path+'train_cite_targets.h5')
cite_targets = train_cite_targets.columns.values.tolist()
del train_cite_targets
gc.collect()
test_preds_cite = pd.DataFrame(generated_B, columns=cite_targets)

test_cite_inputs_id = pd.read_feather(feature_path+'test_cite_inputs_id.feather')
test_preds_cite['cell_id'] = test_cite_inputs_id['cell_id']
test_preds_cite = test_preds_cite[test_preds_cite['cell_id'].isin(evaluation_ids['cell_id'])]
test_preds_cite = pd.melt(test_preds_cite,id_vars='cell_id')
test_preds_cite.columns = ['cell_id','gene_id','target']
del test_cite_inputs_id
gc.collect()

test_preds_cite.to_csv('../dataset/pred_cite.csv')
test_preds_cite

Unnamed: 0,cell_id,gene_id,target
0,c2150f55becb,CD86,0.599453
1,65b7edf8a4da,CD86,0.599510
2,c1b26cb1057b,CD86,0.586568
3,917168fa6f83,CD86,0.593099
4,2b29feeca86d,CD86,0.589955
...,...,...,...
6812815,a9b4d99f1f50,CD224,2.738467
6812816,0e2c1d0782af,CD224,2.715506
6812817,a3cbc5aa0ec3,CD224,5.346298
6812818,75b350243add,CD224,3.651095
