In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd

In [10]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


# Define the linear regression model
class LinearRegression(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)
        # y_pred = torch.sigmoid(self.linear(x))
        # return y_pred
    
# Define the graph Laplacian regularization loss
class GraphRegularizationLoss(nn.Module):
    def __init__(self, adjacency_matrix, lambda_reg):
        super(GraphRegularizationLoss, self).__init__()
        self.adjacency_matrix = adjacency_matrix
        # self.laplacian_matrix = laplacian_matrix
        self.lambda_reg = lambda_reg

    def forward(self, weights_matrix):
        # L = torch.sum(self.adjacency_matrix * torch.norm(node_embeddings.unsqueeze(1) - node_embeddings.unsqueeze(0), dim=2)**2)
        # L = torch.sum(self.adjacency_matrix * torch.abs(weights_matrix.t()))
        # return self.lambda_reg * L

        # Multiple Classes
        loss_reg = 0
        for weight_vector in weights_matrix:
            L = torch.sum(self.adjacency_matrix * torch.abs(weight_vector))
            loss_reg += L

        return self.lambda_reg * loss_reg
        # loss_reg = 0
        # for weight_vector in weights_matrix:
        #     # L = torch.trace(torch.mm(torch.mm(weight_matrix.t(), self.laplacian_matrix), weight_matrix))
        #     # L = torch.sum(self.laplacian_matrix * torch.abs(weight_matrix))
        #     L = torch.sum(torch.abs(self.laplacian_matrix * weight_vector.view(-1, 1)))
        #     # print("L %s" % L)
        #     loss_reg += L
        # return self.lambda_reg * loss_reg

# COPDGene SOMASCAN 1.3 Datset
COPDGene_SOMASCAN13 = pd.read_csv('/home/shussein/NetCO/data/SOMASCAN13/COPDGene_SOMASCAN13_subjects.csv')
X = COPDGene_SOMASCAN13.loc[:,COPDGene_SOMASCAN13.columns != 'finalgold_visit']
y = COPDGene_SOMASCAN13['finalgold_visit']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Convert data to PyTorch tensors
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.to_numpy())
X_test_scaled = scaler.fit_transform(X_test.to_numpy())

X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train.to_numpy(), dtype=torch.long)
y_test_tensor = torch.tensor(y_test.to_numpy(), dtype=torch.long)

adjacency_matrix = torch.tensor(pd.read_csv('../data/PPI_Yong/ppi_graph_1183_mRNA_updated_root_2656sub.csv', delimiter='\t').to_numpy(),
                                dtype=torch.float32)
degree_matrix = torch.sum(adjacency_matrix, dim=1)
degree_matrix_sqrt_inv = torch.diag(torch.pow(degree_matrix, -0.5))
laplacian_matrix = torch.eye(adjacency_matrix.size(0)) - degree_matrix_sqrt_inv.matmul(adjacency_matrix).matmul(degree_matrix_sqrt_inv)
print(laplacian_matrix)


# Hyperparameters
input_size = X_train_tensor.shape[1]
output_size = 3
lambda_reg = 0.001
learning_rate = 0.001
num_epochs = 300

# Instantiate the model and loss function
model = LinearRegression(input_size, output_size)
# criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()
graph_reg_loss = GraphRegularizationLoss(adjacency_matrix, lambda_reg)

# Define optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_train_tensor)#.flatten()
    loss_task = criterion(outputs, y_train_tensor)

    # Compute graph regularization loss
    weights = model.linear.weight
    loss_reg = graph_reg_loss(model.linear.weight)

    # Total loss
    total_loss = loss_task # + loss_reg

    # Backward and optimize
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Task Loss: {loss_task.item():.4f}, Graph Regularization Loss: {loss_reg.item():.4f}')

# Print the Indicies of the Top 10 Contributing Features

top_indices = np.argsort(-np.array(weights.detach().numpy()))[:10]
print("Top 10 Task Weights: %s" % top_indices)

# Evaluation
with torch.no_grad():
    outputs = model(X_test_tensor)#.detach().numpy()
    # print(outputs)
    # print("Predicted %s" % outputs)
   
    # predicted_probs = torch.softmax(outputs, dim=1)
    _, predicted_labels = torch.max(outputs, 1)
    # print(predicted_labels)
    accuracy = torch.sum(predicted_labels == y_test_tensor).item() / len(y_test)
    mse = nn.MSELoss()(predicted_labels, y_test_tensor.float())
    print("accuracy: {:.4f}".format(accuracy))
    print(f'Mean Squared Error: {mse:.4f}')

tensor([[ 1.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  1.0000, -0.0232,  ..., -0.0158,  0.0000,  0.0000],
        [ 0.0000, -0.0232,  1.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000, -0.0158,  0.0000,  ...,  1.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  1.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  1.0000]])
Epoch [10/300], Task Loss: 1.1789, Graph Regularization Loss: 1.3737
Epoch [20/300], Task Loss: 1.1520, Graph Regularization Loss: 1.3721
Epoch [30/300], Task Loss: 1.1282, Graph Regularization Loss: 1.3707
Epoch [40/300], Task Loss: 1.1068, Graph Regularization Loss: 1.3696
Epoch [50/300], Task Loss: 1.0873, Graph Regularization Loss: 1.3687
Epoch [60/300], Task Loss: 1.0694, Graph Regularization Loss: 1.3678
Epoch [70/300], Task Loss: 1.0528, Graph Regularization Loss: 1.3670
Epoch [80/300], Task Loss: 1.0372, Graph Regularization Loss: 1.3664
Epoch [90/300]

In [31]:
tensor([[-3.3943e-02, -3.7290e-02, -3.3233e-02],
        [-7.1744e-02, -1.4546e-02, -5.1411e-02],
        [-2.4098e-02, -5.7717e-03,  3.4128e-02],
        [-3.4024e-02,  1.8222e-02,  1.3344e-03],
        [-1.2702e-02, -1.1512e-01,  4.0245e-02],
        [-5.3551e-02,  1.2136e-01,  7.5814e-02],
        [-2.6230e-02, -6.1743e-02, -2.1658e-02],
        [-3.8514e-02,  1.5583e-02, -7.0085e-02],
        [ 4.0977e-02,  9.9375e-02,  1.2206e-02],
        [ 1.9775e-02, -1.9514e-01, -1.1424e-01],
        [ 4.1035e-02, -1.2258e-02,  2.3415e-02],