In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import numpy as np
import pandas as pd
import scanpy as sc
import hexagdly
import math

In [None]:
def data_processing(data_path, file_nums, k):
    labels = []
    nums = file_nums
    features = np.zeros((len(nums), k, 64, 78))
    for i in range(len(nums)):
        print(f'processing no. {nums[i]}')
        num = nums[i]
        cells=pd.read_csv(f'{data_path}/{num}/ck17_{num}_cell_barcodes.txt')
        gene=pd.read_csv(f'{data_path}/{num}/ck17_{num}_gene_names.txt', dtype={
            'no': 'int64',
            'gene name': 'string'
        })
        meta=pd.read_csv(f'{data_path}/{num}/ck17_{num}_metadata.txt')
        if 'barcode' not in meta.columns:
            meta=meta.rename(columns={"Unnamed: 0":"barcode"})
        adata=sc.read_mtx(f'{data_path}/{num}/ck17_{num}_gex_data.txt').T
        position = pd.read_csv(f'{data_path}/{num}/ck17_{num}_tissue_positions_list.csv', names=['in_tissue', 'row', 'col', 'pixel_row', 'pixel_col'])
        position['barcode'] = position.index
        adata.obs.index=cells['x']
        adata.var.index=gene['x']
        meta.index=meta.iloc[:,0]
        adata.obs=meta
    
        adata.obs.index.name='idx'
        obs = adata.obs.merge(position[['row', 'col', 'barcode']], on='barcode', how='inner')
        # get the label: 0 for non-responder and 1 for responder
        if obs['ici_response'].unique().item() == 'NR':
            labels.append(0)
        else:
            labels.append(1)

        # get geneswith k-largest normalized gex
        gex_filter = np.asarray(np.sum(adata.X.todense()>0,axis=0)/2261)
        gene_no = np.argpartition(gex_filter, len(gex_filter) - k)
        gex_total = np.asarray(adata.X.todense())
    
        for j in range(k):
            for idx, row in obs.iterrows():
                hex_row = row['row']
                hex_col = row['col']
                mat_col = hex_row
                mat_row = math.floor(hex_col/2)
                features[i][j][mat_row][mat_col] = gex_total[idx][gene_no[0][-k:][j]]

    return features, labels

In [6]:
def to_dataloader(features, labels, batch_size=2):
    data = torch.tensor(features, dtype=torch.float32)
    labels = np.asarray(labels)
    labels = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)
    train_data, val_data = data[:7], data[7:]
    train_labels, val_labels = labels[:7], labels[7:]
    train_dataset = TensorDataset(train_data, train_labels)
    val_dataset = TensorDataset(val_data, val_labels)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

In [7]:
class hex_model(nn.Module):
    def __init__(self, nin):          
        super(hex_model, self).__init__()
        self.name = 'hex_model'
        self.hexconv1 = hexagdly.Conv2d(in_channels = nin, out_channels = 128, \
                                         kernel_size = 2, stride = 1, bias=True)
        self.hexpool1 = hexagdly.MaxPool2d(kernel_size = 2, stride = 2)
        self.hexconv2 = hexagdly.Conv2d(128, 64, 2, 1, bias=True)
        self.hexpool2 = hexagdly.MaxPool2d(kernel_size = 2, stride = 1)
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.dropout = nn.Dropout(0.5)

        self.fc1 = nn.Linear(64 * 32 * 39, 512)
        self.fc2 = nn.Linear(512, 100)
        self.fc3 = nn.Linear(100, 1)

    def forward(self, x):
        x = self.hexconv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.hexpool1(x)

        x = self.hexconv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.hexpool2(x)

        x = self.dropout(x)
        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)

        return x

In [30]:
input_channel = 1000
features, labels = data_processing("Visium", [5,7,12,19,208,209,1294], input_channel)
train_loader, val_loader = to_dataloader(features, labels, 2)

processing no. 5
processing no. 7
processing no. 12
processing no. 19
processing no. 208
processing no. 209
processing no. 1294


In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = hex_model(input_channel).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)
epochs = 20
train_losses, val_losses = [], []
for epoch in range(epochs):
    # training period
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    train_losses.append(running_loss/len(train_loader))

    # evaluation period
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
    val_losses.append(val_loss/len(val_loader))
    
    print(f'Epoch {epoch+1}, Train Loss: {training_losses[-1]}, Validation Loss: {val_losses[-1]}')

RuntimeError: The size of tensor a (10) must match the size of tensor b (9) at non-singleton dimension 3

In [None]:
import matplotlib.pyplot as plt
plt.plot(train_losses, label = 'Training Loss')
plt.plot(val_losses, label = 'Validation Loss')
plt.legend()
plt.show()