In [129]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split, SubsetRandomSampler
from sklearn.model_selection import train_test_split

import os
import random
import pandas as pd
import numpy as np
from scipy.sparse import coo_matrix

random.seed(42)
torch.manual_seed(42)
torch.mps.manual_seed(42)

In [130]:
def read_mtx(path, index):
    contact_mat = pd.read_csv(path,sep="\t",header=0)
    contact_mat = contact_mat[(contact_mat["chrom1"] != "chrY") & (contact_mat["chrom2"] != "chrY")]
    contact_mat = pd.concat([contact_mat["chrom1"] + "_" + contact_mat["start1"].astype(str),contact_mat["chrom2"] + "_" + contact_mat["start2"].astype(str), contact_mat["count"]],axis=1)
    contact_mat.columns = ["loc1","loc2","count"]
    contact_mat = pd.merge(contact_mat,index,left_on="loc1",right_on="loc",how="left").drop(["loc","loc1"],axis=1)
    contact_mat = contact_mat.rename(columns={contact_mat.columns[2]: "index1"})
    contact_mat = pd.merge(contact_mat,index,left_on="loc2",right_on="loc",how="left").drop(["loc","loc2"],axis=1)
    contact_mat = contact_mat.rename(columns={contact_mat.columns[2]: "index2"})
    contact_mat = coo_matrix((contact_mat['count'], (contact_mat['index1'], contact_mat['index2'])), shape=(index.shape[0], index.shape[0]))
    contact_mat = contact_mat.toarray()
    
    return contact_mat

In [143]:
index = pd.read_csv("index.csv",header=0)
pos_coord = np.where(index["loc"] == "chr7_54000000")[0][0]

w0 = int((5+1)/2)
w1 = int((45+1)/2)

matrices_f1 = []
matrices_f2 = []

directories = sorted([d for d in os.listdir("data") if os.path.isdir(os.path.join("data", d))])
for directory in directories:
    mtx = read_mtx("data/" + directory + "/matrix.mtx",index)

    mtx_f2 = mtx[(pos_coord-w0+1):(pos_coord+w0),:]
    mtx_f1 = mtx_f2[:,(pos_coord-w1+1):(pos_coord+w1)]
    
    stacked_mtx_f1 = torch.from_numpy(mtx_f1)
    stacked_mtx_f2 = torch.from_numpy(mtx_f2)
    
    matrices_f1.append(stacked_mtx_f1)
    matrices_f2.append(stacked_mtx_f2)
    
input_f1 = torch.log10(torch.stack(matrices_f1, dim=0).to("mps")+1)
input_f2 = (torch.stack(matrices_f2, dim=0).to("mps")>0).int()
label = torch.from_numpy(np.ones(input_f1.shape[0]))
#np.zeros(input_f1.shape[0]) cat

In [144]:
class TripleDataset(Dataset):
    def __init__(self, tensor1, tensor2, tensor3, train=True, val_split=0.1):
        self.dataset1 = TensorDataset(tensor1)
        self.dataset2 = TensorDataset(tensor2)
        self.dataset3 = TensorDataset(tensor3)
        
        self.train = train

        dataset_size = len(self.dataset1)
        val_size = int(val_split * dataset_size)
        train_size = dataset_size - val_size

        indices = list(range(dataset_size))
        train_indices, val_indices = random_split(indices, [train_size, val_size])

        self.indices = train_indices if train else val_indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        dataset_index = self.indices[index]
        sample1 = self.dataset1[dataset_index]
        sample2 = self.dataset2[dataset_index]
        sample3 = self.dataset3[dataset_index]
        return sample1, sample2, sample3

train_dual_dataset = TripleDataset(input_f1, input_f2, label, train=True, val_split=0.1)
val_dual_dataset = TripleDataset(input_f1, input_f2, label, train=False, val_split=0.1)

batch_size = 32
train_dual_dataloader = DataLoader(train_dual_dataset, batch_size=batch_size, shuffle=True)
val_dual_dataloader = DataLoader(val_dual_dataset, batch_size=len(val_dual_dataset), shuffle=False)

In [145]:
for samples1, samples2, samples3 in val_dual_dataloader:
    print(samples3[0].shape)

torch.Size([4])


In [5]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv_f1_l1 = nn.Conv2d(1, 8, kernel_size=(5,4), stride=1, padding=0)
        self.conv_f1_l2 = nn.Conv2d(8, 16, kernel_size=(1,4), stride=1, padding=0)
        
        self.conv_f2_l1 = nn.Conv2d(1, 4, kernel_size=(5,45), stride=1, padding=0)
        self.conv_f2_l2 = nn.Conv2d(4, 8, kernel_size=(1,45), stride=1, padding=0)
        
        self.pool_f1_l1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.pool_f2_l1 = nn.MaxPool2d(kernel_size=6, stride=6, padding=0)
        self.pool_f2_l2 = nn.MaxPool2d(kernel_size=6, stride=6, padding=0)
        
        self.fc_f1_l1 = nn.Linear(18 * 16, 64)
        self.fc_f2_l1 = nn.Linear(76 * 8, 128)
        self.fc_l1 = nn.Linear(128, 10)
        
        self.dropout_f1 = nn.Dropout(0.5)
        self.dropout_f2 = nn.Dropout(0.5)

    def forward(self, x1, x2):
        
        x2_s = F.normalize(torch.sum(x2, dim=1))
        
        x1 = F.relu(self.conv_f1_l1(x1))
        x1 = self.pool_f1_l1(x1)
        x1 = F.relu(self.conv_f1_l2(x1))
        
        x2 = F.relu(self.conv_f2_l1(x2))
        x2 = self.pool_f2_l1(x2)
        x2 = F.relu(self.conv_f2_l2(x2))
        x2 = self.pool_f2_l2(x2)

        x1 = x1.view(-1, 18 * 16)
        x2 = x2.view(-1, 76 * 8)

        x1 = F.relu(self.fc_f1_l1(x1))
        x2 = F.relu(self.fc_f2_l1(x2))
        
        x1 = self.dropout_f1(x1)
        x2 = self.dropout_f2(x2)
        
        x = torch.cat((x1, x2, x2_s), dim=1)
        
        x = F.sigmoid(self.fc_l1(x))

        return x

In [None]:
# Split the dataset into training and validation sets
train_indices, val_indices = train_test_split(list(range(len(mnist_train))), test_size=0.1, random_state=42)
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

# Create data loaders
train_loader = DataLoader(mnist_train, batch_size=64, sampler=train_sampler)
val_loader = DataLoader(mnist_train, batch_size=64, sampler=val_sampler)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)

# Instantiate the model, loss function, and optimizer
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
epochs = 5
for epoch in range(epochs):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Validation loop
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
        
            # Use the actual batch size for accuracy calculation
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

average_val_loss = val_loss / len(val_loader)
accuracy = correct / total

print(f'Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}, Validation Loss: {average_val_loss:.4f}, Accuracy: {accuracy * 100:.2f}%')

# Test the model on the test set
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

test_accuracy = test_correct / test_total
print(f'Test Accuracy: {test_accuracy * 100:.2f}%')