In [43]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

In [44]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [45]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import copy
import gc

#ScDeepSort Imports
from dance.modules.single_modality.cell_type_annotation.scdeepsort import ScDeepSort
from dance.utils import set_seed

import os
os.environ["DGLBACKEND"] = "pytorch"
from pprint import pprint
from dance.datasets.singlemodality import ScDeepSortDataset

import scanpy as sc
from dance.transforms import AnnDataTransform, FilterGenesPercentile
from dance.transforms import Compose, SetConfig
from dance.transforms.graph import PCACellFeatureGraph, CellFeatureGraph
from dance.typing import LogLevel, Optional

from data_pre import data_pre
from WordSage import WordSAGE

In [46]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    x = torch.ones(1, device=device)
    print (x)
elif torch.cuda.is_available():
    device = torch.device("cuda")
    x = torch.ones(1, device=device)
    print(x)
else:
    print ("GPU not found.")
    device = torch.device('cpu')

tensor([1.], device='cuda:0')


In [47]:
class BetaVAE(nn.Module):
    def __init__(self, z_dim=8):
        super(BetaVAE, self).__init__()
        self.z_dim = z_dim
        self.encoder = nn.Sequential(
            nn.Linear(100, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, z_dim*2),
        )
        self.decoder = nn.Sequential(
            nn.Linear(z_dim, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 21)
        )

    def forward(self, x):
        distributions = self.encoder(x)
        x_map = F.softmax(distributions[:, :self.z_dim])
        #std = F.softplus(distributions[:, self.z_dim:])
        y_map = F.softmax(distributions[:, self.z_dim:])
        #z = self.reparametrize(mu, std)
        z = z = (x_map + y_map) / 2
        logit = self.decoder(z)
        return logit, x_map, y_map

In [48]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0.01)

In [49]:
num_clients = 3
model1 = BetaVAE().to(device)
model2 = BetaVAE().to(device)
model3 = BetaVAE().to(device)
models = BetaVAE().to(device)

In [50]:
model1.apply(init_weights)
model2.apply(init_weights)
model3.apply(init_weights)
models.apply(init_weights)
model_e = models.encoder
model_g = models.decoder

In [51]:
data = data_pre()
in_channels = 100
hidden_channels = 100
out_channels = 100
num_classes = 21
WordSage = WordSAGE(in_channels, hidden_channels, out_channels, num_classes)
inputs, targets, genes, normalized_raw_data, test, y_test = data.read_w2v()
seed=42
set_seed(42)
encoding = np.hstack([targets, y_test])
label_encoder = LabelEncoder().fit(encoding)
targets_encoded = label_encoder.transform(targets)
targets_encoded = torch.tensor(targets_encoded, dtype=torch.long).to(device)
num_classes = max(targets_encoded)+1
targets_encoded = F.one_hot(targets_encoded, num_classes=num_classes)
test_encoded = label_encoder.transform(y_test)
test_encoded = torch.tensor(test_encoded, dtype=torch.long).to(device)
test_encoded = F.one_hot(test_encoded, num_classes=num_classes)
train_inputs, train_targets = WordSAGE.mix_data(self='', seed=seed, inputs=inputs, targets=targets_encoded)
test_inputs, test_targets = WordSAGE.mix_data(self='', seed=seed, inputs=test, targets=test_encoded)
train_inputs = torch.tensor(train_inputs, dtype=torch.float32).to(device)
test_inputs = torch.tensor(test_inputs, dtype=torch.float32).to(device)

[INFO][2023-10-05 14:04:49,072][dance][set_seed] Setting global random seed to 42
  np.random.shuffle(targets)


In [52]:
X1 ,X2, X3 = np.array_split(train_inputs, num_clients)
y1, y2, y3 = np.array_split(train_targets, num_clients)

In [53]:
d1 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(X1), torch.tensor(y1).long()), batch_size=32, shuffle=True)
d2 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(X2), torch.tensor(y2).long()), batch_size=32, shuffle=True)
d3 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(X3), torch.tensor(y3).long()), batch_size=32, shuffle=True)

  d1 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(X1), torch.tensor(y1).long()), batch_size=32, shuffle=True)
  d2 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(X2), torch.tensor(y2).long()), batch_size=32, shuffle=True)
  d3 = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(X3), torch.tensor(y3).long()), batch_size=32, shuffle=True)


In [54]:
o1 = torch.optim.SGD(model1.parameters(), lr=1e-3, momentum=0.9)
o2 = torch.optim.SGD(model2.parameters(), lr=1e-3, momentum=0.9)
o3 = torch.optim.SGD(model3.parameters(), lr=1e-3, momentum=0.9)

In [55]:
def train(epoch, model, optim, data, verbose, alpha=0.1, beta=1e-3, lambda_=0.5):
    cl_loss, acc = 0, 0
    for inputs, targets in data:
        logits, mu, std = model(inputs)

        target_labels = torch.argmax(targets, dim=1)
        class_loss = F.cross_entropy(logits, target_labels).div(math.log(2))

        probs = F.softmax(logits, dim=1)

        prediction = probs.max(1)[1]
        accuracy = torch.eq(prediction, target_labels).float().mean()

        cl_loss += class_loss.item()
        acc += accuracy.item()
    else:
        if verbose:
            cl_loss /= len(data)
            acc /= len(data)
            print(f'Epoch [{str(epoch).zfill(3)}], Class loss:{cl_loss:.4f}, Acc. {acc * 100:.2f}%')

In [56]:
for e in range(300):
    train(e+1, model1, o1, d1, (e+1) % 100 == 0)

  x_map = F.softmax(distributions[:, :self.z_dim])
  y_map = F.softmax(distributions[:, self.z_dim:])


Epoch [100], Class loss:4.8190, Acc. 0.06%
Epoch [200], Class loss:4.8220, Acc. 0.06%
Epoch [300], Class loss:4.8200, Acc. 0.06%


In [57]:
for e in range(300):
    train(e+1, model2, o2, d2, (e+1) % 100 == 0)

  x_map = F.softmax(distributions[:, :self.z_dim])
  y_map = F.softmax(distributions[:, self.z_dim:])


Epoch [100], Class loss:4.4486, Acc. 1.87%
Epoch [200], Class loss:4.4479, Acc. 1.89%
Epoch [300], Class loss:4.4480, Acc. 1.85%


In [58]:
for e in range(300):
    train(e+1, model3, o3, d3, (e+1) % 100 == 0)

  x_map = F.softmax(distributions[:, :self.z_dim])
  y_map = F.softmax(distributions[:, self.z_dim:])


Epoch [100], Class loss:4.5734, Acc. 0.45%
Epoch [200], Class loss:4.5737, Acc. 0.45%
Epoch [300], Class loss:4.5732, Acc. 0.45%


In [59]:
tl = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(test_inputs), torch.tensor(test_targets).long()), batch_size=128)

  tl = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.tensor(test_inputs), torch.tensor(test_targets).long()), batch_size=128)


In [60]:
@torch.no_grad()
def test(model, data):
    acc = 0
    for inputs, targets in data:
        logits = model(inputs.float())[0]

        probs = F.softmax(logits, dim=1)
        prediction = probs.max(1)[1]
        target_labels = torch.argmax(targets, dim=1)
        accuracy = torch.eq(prediction, target_labels).float().mean()
        acc += accuracy.item()
    else:
        acc /= len(data)
        print(f'Acc. {acc * 100:.2f}%')

In [61]:
test(model1, tl)
test(model2, tl)
test(model3, tl)

Acc. 86.94%
Acc. 0.00%
Acc. 0.00%


  x_map = F.softmax(distributions[:, :self.z_dim])
  y_map = F.softmax(distributions[:, self.z_dim:])
