In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import scanpy as sc
import plotly.express as px
import plotly.io as pio
import sklearn.preprocessing
import sklearn.model_selection




torch.autograd.set_detect_anomaly(True)

In [None]:
import platform
if platform.platform() == 'macOS-10.16-x86_64-i386-64bit':
    pio.renderers.default = 'notebook'
    device = torch.device('mps')
    print("Using Apple MPS on Macbook Pro")
    gmount = False
    
elif platform.platform() == 'Linux-5.10.133+-x86_64-with-Ubuntu-18.04-bionic':
    pio.renderers.default = 'colab'
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print("Using CUDA GPU on Colab")
        gmount = True

In [None]:
scdata = sc.read_h5ad("/Users/eamonmcandrew/Desktop/Single_cell_integration/Data/Multi-ome/GEX.h5ad")

In [None]:
scdata

In [None]:
def stratified_split(data, test_size, random_state, split_criteria):
    """
    Splits the data into train and test sets stratified by the batch column
    """
    train = []
    test = []
    for batch in data.obs[split_criteria].unique():
        batch_data = data[data.obs[split_criteria] == batch]
        batch_train, batch_test = sklearn.model_selection.train_test_split(batch_data, test_size=test_size, random_state=random_state)
        batch_train, batch_test = list(batch_train.obs.index), list(batch_test.obs.index)
        train.extend(batch_train)
        test.extend(batch_test)
        
    return train, test


In [None]:
train, test = stratified_split(scdata, 0.2, 9000, split_criteria='cell_type')

In [None]:
train_data = scdata[train]
test_data = scdata[test]

len(train_data), len(test_data)

In [None]:
if gmount == True:
    from google.colab import drive
    drive.mount('/content/drive')
    path = '/content/drive/My Drive/Colab Notebooks/Experiments/' 
    scdata = sc.read_h5ad("/content/gdrive/MyDrive/scintegration/GEX.h5ad")

In [None]:
# Use own weights and biases account by adding the Auth token when prompted, can also use key = 'offline' to use offline

import wandb
wandb.login()


In [None]:
wandb.init(project="Single Cell Omics integration", entity="scintegration")

In [None]:
# sweep_id = wandb.sweep(sweep=sweep_configuration, project="project-name")
# sweep_configuration = {
#     'method': 'random',
#     'name': 'sweep',
#     'metric': {
#         'goal': 'maximise', 
#         'name': 'accuracy'
# 		},
#     'parameters': {
#         'batch_size': {'values': [128, 256, 512]},
#         'epochs': {'values': [5, 10, 15]},
#         'lr': {'max': 0.1, 'min': 0.0001}
#      }
# }
# wandb.agent(sweep_id=sweep_id, function=function_name)

In [None]:
class GEX_Dataset(torch.utils.data.Dataset):
    
      def __init__(self, data,  scaler = None, cat_var = None, label_encoder =None):
          
            self.data = data
            
            # we need to work with the dense matrix
            self.values = data.X.todense()
            
            self.cat_var = cat_var
            
            if label_encoder == "numeric":
            # numerically encode the labels
              cat_var_data =  torch.tensor(sklearn.preprocessing.LabelEncoder().fit_transform(self.data.obs[self.cat_var]), dtype = torch.long)
            
            elif label_encoder == "range_map":
              cat_var_data =  sklearn.preprocessing.LabelEncoder().fit_transform(self.data.obs[self.cat_var])
              cat_var_data = cat_var_data.reshape(-1, 1) 
              cat_var_data = torch.tensor(sklearn.preprocessing.MinMaxScaler().fit_transform(cat_var_data), dtype = torch.float32)

            elif label_encoder == "one_hot":
              cat_var_data =  torch.tensor(sklearn.preprocessing.LabelEncoder().fit_transform(self.data.obs[self.cat_var]))
              cat_var_data = cat_var_data.reshape(-1, 1)
              cat_var_data = sklearn.preprocessing.OneHotEncoder().fit_transform(cat_var_data).toarray()
              cat_var_data = torch.tensor(cat_var_data, dtype=torch.float32)
              



            self.cat_var_data = torch.tensor(cat_var_data)
            
            # scale the data according to user inpt to scaler argument
            if scaler == "Standard":
                self.scaled_values = torch.tensor(sklearn.preprocessing.StandardScaler().fit_transform(self.values), dtype = torch.float32)
            elif scaler == "MinMax":
                self.scaled_values = torch.tensor(sklearn.preprocessing.MinMaxScaler().fit_transform(self.values),  dtype = torch.float32)
            else:
                self.scaled_values = torch.tensor(self.values, dtype = torch.float32)
                
    #   return the number of genes when called 
             
      @property
      def n_features(self):
          return self.values.shape[1]

      @property
      def n_catagories(self):
          return self.cat_var_data.shape[1]
        
      
          
          
    #  A dataset class needs the following two methods to work with the dataloader class     
          
    #   return the number of cells when called
      def __len__(self):
          return len(self.data)
    
    #  return an individual cell and its label when called
      def __getitem__(self, idx):
           return self.scaled_values[idx], self.cat_var_data[idx]

    

In [None]:
GEX_Dataset_train = GEX_Dataset(train_data, scaler = "Standard", cat_var = "batch", label_encoder = "one_hot")

GEX_Dataset_test = GEX_Dataset(test_data, scaler = "Standard", cat_var = "batch", label_encoder = "one_hot")


In [None]:
input_size = GEX_Dataset_train.n_features
output_size = GEX_Dataset_train.n_catagories


batch_size = 256
epochs = 30
lr = 1e-4
dropout = 0.2



wandb.config = {
  "learning_rate": lr,
  "epochs": epochs,
  "batch_size": batch_size,
  "dropout": dropout,
}

log_interval = 100

In [None]:
output_size


In [None]:
class classifier(nn.Module):
    def __init__(self):
        super(classifier, self).__init__()
        self.cfc1 = nn.Linear(input_size, 20)
        self.dropout = nn.Dropout(dropout)
        self.cfc2 = nn.Linear(20, output_size)
        
    def forward(self, x):
        x = self.cfc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.cfc2(x)
        x = F.softmax(x, dim = 1)
        return(x)

In [None]:
GEX_dataloader_train = torch.utils.data.DataLoader(GEX_Dataset_train, batch_size = batch_size, shuffle = True)
GEX_dataloader_test = torch.utils.data.DataLoader(GEX_Dataset_test, batch_size = batch_size, shuffle = True)
model = classifier()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
model.train()

criterion = nn.CrossEntropyLoss()

In [None]:
def train_one_epoch(epoch, GEX_dataloader_train , model, optimizer, criterion):
    model.train()
    for batch_idx, (data, target) in enumerate(GEX_dataloader_train):
        epoch_loss_list = []
        epoch_accuracy_list = [] 
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        accuracy = (output.argmax(1) == target.argmax(1)).type(torch.float).mean().item()
        loss.backward()
        optimizer.step()
        epoch_loss_list.append(loss.item())
        epoch_accuracy_list.append(accuracy)
        wandb.log({"Train loss": loss.item()})
        wandb.log({"Train accuracy": accuracy})
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(GEX_dataloader_train.dataset),
                100. * batch_idx / len(GEX_dataloader_train), loss.item()))
    epoch_loss = np.mean(epoch_loss_list)
    epoch_accuracy = np.mean(epoch_accuracy_list)
    wandb.log({"Train epoch loss": epoch_loss})
    wandb.log({"Train epoch accuracy": epoch_accuracy})
    
    return epoch_loss, epoch_accuracy

            
        

In [None]:
def evaluate_one_epoch(epoch, GEX_Dataset_test, model, optimizer, criterion):
    model.eval()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(GEX_dataloader_test):
            epoch_loss_list = []
            epoch_accuracy_list = [] 
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            accuracy = (output.argmax(1) == target.argmax(1)).type(torch.float).mean().item()
            epoch_loss_list.append(loss.item())
            epoch_accuracy_list.append(accuracy)
            wandb.log({"Test loss": loss.item()})
            wandb.log({"Test accuracy": accuracy})
            ground_truth_class_ids = target.argmax(1).cpu().numpy()
            predicted_class_ids = output.argmax(1).cpu().numpy()
            wandb.log({"conf_mat" : wandb.plot.confusion_matrix(probs=None, y_true=ground_truth_class_ids, preds=predicted_class_ids, class_names=scdata.obs["batch"].unique())})
            if batch_idx % log_interval == 0:
                print('Test Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(GEX_dataloader_test.dataset),
                    100. * batch_idx / len(GEX_dataloader_test), loss.item()))
        epoch_loss = np.mean(epoch_loss_list)
        epoch_accuracy = np.mean(epoch_accuracy_list)
        wandb.log({"Test epoch loss": epoch_loss})
        wandb.log({"Test epoch accuracy": epoch_accuracy})
        
      
        
        

        
    return epoch_loss, epoch_accuracy

In [None]:
for epoch in range(1, epochs + 1):
    train_epoch_loss, train_epoch_accuracy = train_one_epoch(epoch, GEX_dataloader_train, model, optimizer, criterion)
    print(f"Epoch training loss: {train_epoch_loss}, Epoch training accuracy: {train_epoch_accuracy}")
    test_epoch_loss, test_epoch_accuracy = evaluate_one_epoch(epoch, GEX_dataloader_test, model, optimizer, criterion)
    print(f"Epoch Eval loss: {test_epoch_loss}, Epoch Eval accuracy: {test_epoch_accuracy}")