In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import scanpy as sc
import plotly.express as px
import plotly.io as pio
import sklearn.preprocessing
import sklearn.model_selection





torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7ff28aa511c0>

In [2]:
import platform

def get_device_and_gmount():
    # Get the operating system and version
    os = platform.system()
    version = platform.release()

    # Get the machine's architecture
    arch = platform.machine()

    # Set the default renderer based on the operating system
    if os == 'Darwin':
        pio.renderers.default = 'notebook'
        print("Using Apple MPS on Macbook Pro")
    
    elif os == 'Linux' and version == '18.04':
        pio.renderers.default = 'colab'
        print("Using Colab on Linux")
        from google.colab import drive
        drive.mount('/content/drive')
        path = '/content/drive/My Drive/Colab Notebooks/Experiments/'

    # Set the device based on the machine's architecture
    if arch == 'x86_64':
        device = torch.device('mps') if os == 'Darwin' else torch.device('cuda')
        gmount = True if os == 'Linux' else False
    else:
        device = torch.device('cpu')
        gmount = False

    print("Using device:", device)
    
    return device, gmount


In [3]:
device, gmount = get_device_and_gmount()


Using Apple MPS on Macbook Pro
Using device: mps


In [4]:
if gmount:
    scdata = sc.read_h5ad("/content/gdrive/MyDrive/scintegration/GEX.h5ad")
    
scdata = sc.read_h5ad("/Users/eamonmcandrew/Desktop/Single_cell_integration/Data/Multi-ome/GEX.h5ad")

In [5]:
def stratified_split(data, test_size, random_state, split_criteria):
    """
    Splits the data into train and test sets stratified by the batch column
    """
    train = []
    test = []
    for batch in data.obs[split_criteria].unique():
        batch_data = data[data.obs[split_criteria] == batch]
        batch_train, batch_test = sklearn.model_selection.train_test_split(batch_data, test_size=test_size, random_state=random_state)
        batch_train, batch_test = list(batch_train.obs.index), list(batch_test.obs.index)
        train.extend(batch_train)
        test.extend(batch_test)
        
    return train, test


In [6]:
if gmount == True:
    from google.colab import drive
    drive.mount('/content/drive')
    path = '/content/drive/My Drive/Colab Notebooks/Experiments/' 
    scdata = sc.read_h5ad("/content/gdrive/MyDrive/scintegration/GEX.h5ad")

In [7]:
import wandb
wandb.login()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33meamomc[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [8]:
class GEX_Dataset(torch.utils.data.Dataset):
    def __init__(self, data, scaler=None, cat_var=None, label_encoder=None):
        self.data = data
        self.values = np.asarray(data.X.todense())
        self.cat_var = cat_var

        label_encoder_functions = {
            "numeric": lambda: torch.tensor(sklearn.preprocessing.LabelEncoder().fit_transform(self.data.obs[self.cat_var]), dtype=torch.long),
            "range_map": lambda: sklearn.preprocessing.LabelEncoder().fit_transform(self.data.obs[self.cat_var]).reshape(-1, 1),
            "one_hot": lambda: sklearn.preprocessing.OneHotEncoder().fit_transform(sklearn.preprocessing.LabelEncoder().fit_transform(self.data.obs[self.cat_var]).reshape(-1, 1)).toarray()
        }

        if label_encoder in label_encoder_functions:
            cat_var_data = label_encoder_functions[label_encoder]()
            if label_encoder == "range_map":
                cat_var_data = torch.tensor(sklearn.preprocessing.MinMaxScaler().fit_transform(cat_var_data), dtype=torch.float32)
            elif label_encoder == "one_hot":
                cat_var_data = torch.tensor(cat_var_data, dtype=torch.float32)
        else:
            cat_var_data = None
        self.cat_var_data = cat_var_data

        scaler_functions = {
            "Standard": lambda: sklearn.preprocessing.StandardScaler().fit_transform(self.values),
            "MinMax": lambda: sklearn.preprocessing.MinMaxScaler().fit_transform(self.values)
        }

        if scaler in scaler_functions:
            self.scaled_values = torch.tensor(scaler_functions[scaler](), dtype=torch.float32)
        else:
            self.scaled_values = torch.tensor(self.values, dtype=torch.float32)

    @property
    def n_features(self):
        return self.values.shape[1]

    @property
    def n_catagories(self):
        return self.cat_var_data.shape[1] if self.cat_var_data is not None else 0

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.scaled_values[idx], self.cat_var_data[idx]


In [9]:
GEX_Dataset = GEX_Dataset(scdata, scaler="Standard", cat_var="batch", label_encoder="one_hot")

In [10]:
# Define the generator network
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.map1(x)
        x = F.elu(x)
        x = self.map2(x)
        x = F.elu(x)
        return self.map3(x)

In [11]:
# Define the discriminator network
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))


In [12]:
# Define the GAN model
class GAN(nn.Module):
    def __init__(self, generator, discriminator):
        super().__init__()
        self.generator = generator
        self.discriminator = discriminator

    def forward(self, x):
        return self.discriminator(self.generator(x))

In [13]:
GEX_Dataset.n_features

13431

In [14]:
# Create instance of the generator and discriminator networks
generator = Generator(input_size=100, hidden_size=256, output_size=GEX_Dataset.n_features)
discriminator = Discriminator(input_size=GEX_Dataset.n_features, hidden_size=256, output_size=1)

# Create instance of the GAN model
gan = GAN(generator, discriminator)

In [15]:
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(gan.parameters(), lr=0.0002)


In [16]:
# Train the GAN model
run = wandb.init(project="Single Cell Omics integration", entity="scintegration")
GEX_dataloader_train = torch.utils.data.DataLoader(GEX_Dataset, batch_size = 128, shuffle = True)
for epoch in range(100):
    for data, _ in GEX_dataloader_train:
        # Generate fake data
        noise = torch.randn(data.shape[0], 100)
        fake_data = gan.generator(noise)

        # Train the discriminator
        optimizer.zero_grad()
        pred_real = gan.discriminator(data)
        pred_fake = gan.discriminator(fake_data.detach())
        loss_real = loss_fn(pred_real, torch.ones_like(pred_real))
        loss_fake = loss_fn(pred_fake, torch.zeros_like(pred_fake))
        loss_discriminator = (loss_real + loss_fake) / 2
        loss_discriminator.backward()
        wandb.log({"loss_discriminator": loss_discriminator})
        
        # Train the generator
        optimizer.zero_grad()
        pred_fake = gan.discriminator(fake_data)
        loss_generator = loss_fn(pred_fake, torch.ones_like(pred_fake))
        loss_generator.backward()
        wandb.log({"loss_generator": loss_generator})
        
        # Update the weights
        optimizer.step()

        
        
        

[34m[1mwandb[0m: Currently logged in as: [33meamomc[0m ([33mscintegration[0m). Use [1m`wandb login --relogin`[0m to force relogin


: 

: 

In [None]:
# Generate fake data
def generate_fake_data(generator, batch_size, n_features):
    generator.eval()
    noise = torch.randn(batch_size, 100)
    fake_data = generator(noise)
    return fake_data.detach().numpy()


In [1]:
fake_data = generate_fake_data(generator, 128, GEX_Dataset.n_features)

NameError: name 'generate_fake_data' is not defined

In [None]:
fake = sc.AnnData(fake_data)

In [None]:
sc.pp.neighbors(scdata, n_neighbors=10)
sc.tl.umap(scdata)
