In [2]:
#Capsnet model for human brain dataset
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import datasets, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

USE_CUDA = False

In [3]:
read_counts = pd.read_table('/Users/marong/Dropbox/CG_project/Capsnet/6k/data6k_norm_counts_all.txt')
cell_labels = pd.read_table('/Users/marong/Dropbox/CG_project/Capsnet/6k/data6k_celltypes.txt')
read_counts_2 = pd.read_table('/Users/marong/Dropbox/CG_project/Capsnet/8k/data8k_norm_counts_all.txt')
cell_labels_2 = pd.read_table('/Users/marong/Dropbox/CG_project/Capsnet/8k/data8k_celltypes.txt')

count_6k = read_counts.iloc[:,1:]
count_8k = read_counts_2.iloc[:,1:]

all_data = count_6k.append(count_8k, sort=False)
data_processed = all_data.loc[:,count_6k.columns]
data_processed = data_processed.fillna(0)
all_labels = list(cell_labels['Celltype']) + list (cell_labels_2['Celltype'])

x  = data_processed
y = all_labels 

  """Entry point for launching an IPython kernel.
  
  This is separate from the ipykernel package so we can avoid doing imports until
  after removing the cwd from sys.path.


In [4]:
# extract info and reorganize the data
x = torch.FloatTensor(x.values)
y_labels = set(y)
y_label_dict = dict(zip(y_labels, range(0, len(y_labels))))
y_index = [y_label_dict[i] for i in y]

np.random.seed(888)
rand_index = np.random.permutation(len(y_index))
x_data = x[torch.tensor(rand_index),:]
y_data = [y_index[i] for i in rand_index]

In [5]:
x_data.shape

torch.Size([12957, 13137])

In [7]:
X_train, X_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=42)
X_train = X_train.view(-1, 1,13137)
X_test = X_test.view(-1, 1,13137)
X_valid, X_test, y_valid, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=50)

In [8]:
print(X_train.shape)
print(len(y_train))
print(X_valid.shape)
print(len(y_valid))
print(X_test.shape)
print(len(y_test))

torch.Size([10365, 1, 13137])
10365
torch.Size([1296, 1, 13137])
1296
torch.Size([1296, 1, 13137])
1296


In [9]:
from torch.utils.data import DataLoader, TensorDataset

# create 
train_data = TensorDataset(X_train, torch.LongTensor(y_train))
valid_data = TensorDataset(X_valid, torch.LongTensor(y_valid))
test_data = TensorDataset(X_test, torch.LongTensor(y_test))

# dataloaders 
batch_size = 100
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size,drop_last=True)
valid_loader = DataLoader(valid_data, shuffle=True, batch_size=batch_size,drop_last =True)
test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size,drop_last = True)


# check one of the batches
dataiter = iter(train_loader)
sample_x, sample_y = dataiter.next()
print('Sample input size: ', sample_x.size()) 
print('Sample input: \n', sample_x)
print()
print('Sample label size: ', sample_y.size()) 
print('Sample label: \n', sample_y)

Sample input size:  torch.Size([100, 1, 13137])
Sample input: 
 tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.7181,  ..., 0.0000, 0.0000, 0.0000]],

        ...,

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]])

Sample label size:  torch.Size([100])
Sample label: 
 tensor([3, 5, 2, 0, 0, 0, 4, 4, 4, 5, 2, 4, 2, 0, 0, 5, 2, 5, 4, 4, 4, 0, 4, 0,
        4, 5, 2, 0, 5, 0, 4, 1, 2, 5, 4, 0, 4, 2, 0, 0, 5, 0, 5, 4, 2, 0, 0, 0,
        5, 2, 0, 4, 4, 0, 5, 5, 2, 0, 0, 4, 2, 0, 2, 2, 2, 0, 2, 4, 5, 4, 3, 0,
        4, 0, 0, 5, 2, 5, 2, 4, 5, 4, 5, 2, 2, 2, 4, 2, 2, 2, 5, 0, 2, 5, 5, 4,
        0, 4, 4, 5])


In [10]:
class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_feature = 13137, out_feature=256):
        super(PrimaryCaps, self).__init__()

        self.capsules = nn.ModuleList([
            nn.Linear(in_feature, out_feature)
            for _ in range(num_capsules)])

    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), 256, -1)

        return self.squash(u)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor
    
# output 100 , 256 , 8

In [11]:
class DigitCaps(nn.Module):
    def __init__(self, num_capsules=8, num_routes=256, in_channels=8, out_channels=16):
        super(DigitCaps, self).__init__()

        self.in_channels = in_channels
        self.num_routes = num_routes
        self.num_capsules = num_capsules

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        num_iterations = 3

        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        W = torch.cat([self.W] * batch_size, dim=0)
        u_hat = torch.matmul(W, x)
        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))

        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)

            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)
    
    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm *  input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor
    
    
# output torch.Size([100, 8, 16, 1])

In [12]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 8, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 13137),
            nn.Sigmoid()
        )
        
    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes)
        
        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.sparse.torch.eye(8))
        masked = masked.index_select(dim=0, index=max_length_indices.squeeze(1).data)
        reconstructions = self.reconstraction_layers((x * masked[:, :, None, None]).view(x.size(0), -1))
        reconstructions = reconstructions.view(-1, 1, 13137)
        
        return reconstructions, masked

In [13]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.primary_capsules = PrimaryCaps()
        self.digit_capsules = DigitCaps()
        self.decoder = Decoder()

        self.mse_loss = nn.MSELoss()

    def forward(self, data):
        output = self.digit_capsules(
            self.primary_capsules(data))

        reconstructions, masked = self.decoder(output, data)
        return output, reconstructions, masked

    def loss(self, data, x, target, reconstructions):
        return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)

    def margin_loss(self, x, labels, size_average=True):
        batch_size = x.size(0)
        v_c = torch.sqrt((x**2).sum(dim=2, keepdim=True))

        left = F.relu(0.9 - v_c).view(batch_size, -1)
        right = F.relu(v_c - 0.1).view(batch_size, -1)

        loss = labels * left + 0.5 * (1.0 - labels) * right
        loss = loss.sum(dim=1).mean()

        return loss

    def reconstruction_loss(self, data, reconstructions):
        loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
        return loss * 0.0005

In [14]:
capsule_net = CapsNet()
optimizer = Adam(capsule_net.parameters())

batch_size = 100
n_epochs = 30

In [15]:
# Main function
early_stop_time = 0
max_accuracy = -1
early_stop = False

for epoch in range(n_epochs):
    capsule_net.train()
    train_loss = 0
    for batch_id, (data, target) in enumerate(train_loader):
        
        target = torch.sparse.torch.eye(8).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)
        data = data.float()
        
        optimizer.zero_grad()
        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)
        loss.backward()
        optimizer.step()

        train_loss += loss.data
        if batch_id %10 ==0: 
            print(batch_id)

        if batch_id % 100 == 0:
            print("train accuracy:", sum(np.argmax(masked.data.cpu().numpy(), 1) ==
                                         np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))

    print(train_loss / len(train_loader))

    # Valid Dataset
    capsule_net.eval()
    valid_loss = 0
    for batch_id, (data, target) in enumerate(valid_loader):

        target = torch.sparse.torch.eye(8).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)
        data = data.float()
        if USE_CUDA:
            data, target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsule_net(data)
        loss = capsule_net.loss(data, output, target, reconstructions)

        valid_loss += loss.data

        if batch_id % 100 == 0:
            print("Valid accuracy:", sum(np.argmax(masked.data.cpu().numpy(), 1) ==
                                        np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size))
            
            val_accu = sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size)
            if val_accu >= max_accuracy:
                max_accuracy = val_accu
                #early_stop_time = 0
            else:
                early_stop = True
                #early_stop_time += 1

    if early_stop == True:
        break

    print(valid_loss / len(valid_loader))

  app.launch_new_instance()


0
train accuracy: 0.2
10
20
30
40
50
60
70
80
90
100
train accuracy: 0.64
tensor(0.5791)
Valid accuracy: 0.77
tensor(0.3824)
0
train accuracy: 0.72
10
20
30
40
50
60
70
80
90
100
train accuracy: 0.85
tensor(0.2910)
Valid accuracy: 0.86
tensor(0.2347)
0
train accuracy: 0.93
10
20
30
40
50
60
70
80
90
100
train accuracy: 0.91
tensor(0.1442)
Valid accuracy: 0.89
tensor(0.1720)
0
train accuracy: 0.93
10
20
30
40
50
60
70
80
90
100
train accuracy: 0.96
tensor(0.0725)
Valid accuracy: 0.95
tensor(0.1446)
0
train accuracy: 1.0
10
20
30
40
50
60
70
80
90
100
train accuracy: 0.96
tensor(0.0392)
Valid accuracy: 0.94


In [16]:
torch.save(capsule_net.state_dict(), '/Users/marong/Dropbox/CG_project/Capsnet/capsnet_fc_PBMC.pth')

In [17]:
# Testing the model
capsule_net.eval()
pred_results = []
test_loss = 0
accuracy = []
y_pred, y_true = [], []
for batch_id, (data, target) in enumerate(test_loader):
    target = torch.sparse.torch.eye(8).index_select(dim=0, index=target)
    data, target = Variable(data), Variable(target)
    data = data.float()
    output, reconstructions, masked = capsule_net(data)
    pred_results.append((output,reconstructions,masked))
    loss = capsule_net.loss(data, output, target, reconstructions)

    test_loss += loss.data
    yp= np.argmax(masked.data.cpu().numpy(), 1)
    yt = np.argmax(target.data.cpu().numpy(), 1)
    y_pred = y_pred + list(yp)
    y_true = y_true + list(yt)

    
    acc =sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) / float(batch_size)
    
    print("Test accuracy:", acc)
    accuracy.append(acc)

  app.launch_new_instance()


Test accuracy: 0.91
Test accuracy: 0.9
Test accuracy: 0.85
Test accuracy: 0.92
Test accuracy: 0.9
Test accuracy: 0.91
Test accuracy: 0.92
Test accuracy: 0.95
Test accuracy: 0.89
Test accuracy: 0.92
Test accuracy: 0.88
Test accuracy: 0.93


In [18]:
from sklearn.metrics import precision_recall_fscore_support
print(np.mean(accuracy))
print(precision_recall_fscore_support(y_true, y_pred, average='weighted'))

0.9066666666666667
(0.9107267417696119, 0.9066666666666666, 0.9075402269082887, None)


  _warn_prf(average, modifier, msg_start, len(result))
