In [8]:
import pandas as pd
import numpy as np
import torch as torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from pycaret.regression import *
from sklearn.preprocessing import *
from sklearn.metrics import r2_score

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} for computation')

Using cuda for computation


In [9]:
def compute_r2_score(
    y_true: torch.Tensor,  # ground truth values
    y_pred: torch.Tensor  # predicted values
) -> torch.Tensor:  # R^2 score as a scalar tensor
    """
    Compute the R² score for a regression problem.

    Args:
        y_true (torch.Tensor): ground truth values of shape (n_samples,).
        y_pred (torch.Tensor): predicted values of shape (n_samples,).

    Returns:
        torch.Tensor: R^2 score as a scalar tensor.

    """
    # calculate residual sum of squares
    ss_res = torch.sum((y_true - y_pred) ** 2)

    # calculate total sum of squares
    ss_tot = torch.sum((y_true - torch.mean(y_true)) ** 2)

    # calculate R^2 score
    r2_score = 1 - ss_res / ss_tot

    return r2_score

In [10]:
def find_k_nearest_neighbors(lipid_coords, gene_coords, k=1000):
    """
    Find k nearest neighbors in gene_coords for each point in lipid_coords.
    Args:
    - lipid_coords (Tensor): Coordinates of lipids.
    - gene_coords (Tensor): Coordinates of genes.
    - k (int): Number of nearest neighbors to find.

    Returns:
    - Tensor: Indices of the k nearest neighbors in gene_coords for each point in lipid_coords.
    """
    # Calculate pairwise distances
    lipid_coords = lipid_coords.unsqueeze(1)  # Add an extra dimension for broadcasting
    distances = torch.cdist(lipid_coords, gene_coords)  # Compute the distance

    # Find the indices of the k smallest distances
    _, indices = torch.topk(distances, k, largest=False)

    return indices


In [11]:
lipid_path = 'data/section12/lipids_section_12.parquet'
gene_path = 'data/section12/genes_section_12.parquet'

# Loading the dataset
lipids_section_12 = pd.read_parquet(lipid_path, engine='pyarrow')
genes_section_12 = pd.read_parquet(gene_path, engine='pyarrow')

# Extracting the relevant columns
gene_values = genes_section_12.iloc[:, 46:-50].values
lipid_values = lipids_section_12.iloc[:, 13:].values
lipid_values *= 1000

# Standardizing the gene values
scaler = StandardScaler()
gene_values_standardized = scaler.fit_transform(gene_values)

# Coordinates
gene_coords = genes_section_12[['y_ccf', 'z_ccf']].values
lipid_coords = lipids_section_12[['y_ccf', 'z_ccf']].values

# Convert to PyTorch tensors
gene_values_tensor = torch.tensor(gene_values_standardized, dtype=torch.float32)
gene_coords_tensor = torch.tensor(gene_coords, dtype=torch.float32)
lipid_coords_tensor = torch.tensor(lipid_coords, dtype=torch.float32)
lipid_values_tensor = torch.tensor(lipid_values, dtype=torch.float32)

gene_values_tensor = gene_values_tensor.to(device)
gene_coords_tensor = gene_coords_tensor.to(device)
lipid_coords_tensor = lipid_coords_tensor.to(device)
lipid_values_tensor = lipid_values_tensor.to(device)

In [12]:
class GeneLipidModel(nn.Module):
    def __init__(self):
        super(GeneLipidModel, self).__init__()
        self.num_genes = 500
        self.num_data_points_per_lipid = 2000
        self.position_input_size = 2
        self.gene_value_input_size = self.num_genes * self.num_data_points_per_lipid
        self.gene_position_input_size = 2 * self.num_data_points_per_lipid

        # LSTM input size
        self.lstm_input_size = 2 * self.num_data_points_per_lipid + self.position_input_size + self.gene_position_input_size // self.num_genes
        # Bidirectional LSTM
        self.lstm = nn.LSTM(input_size=self.lstm_input_size, hidden_size=1024, num_layers=1, batch_first=True, dropout=0.3, bidirectional=True)

        # Fully connected layers
        self.fc1 = nn.Linear(1024 * 2, 1024)  # For bidirectional LSTM
        self.fc1_residual = nn.Linear(1024 * 2, 1024)  # Additional layer for residual connection
        self.dropout1 = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(1024)
        
        self.fc2 = nn.Linear(1024, 500)
        self.fc2_residual = nn.Linear(1024, 500)  # Additional layer for residual connection
        self.dropout2 = nn.Dropout(p=0.2)
        self.bn2 = nn.BatchNorm1d(500)

        self.fc3 = nn.Linear(500, 156)

    def forward(self, x):
        # Extract lipid positions, gene values, and gene positions
        lipid_positions = x[:, :self.position_input_size]
        gene_values = x[:, self.position_input_size:self.position_input_size + self.gene_value_input_size]
        gene_positions = x[:, self.position_input_size + self.gene_value_input_size:]

        # Feature augmentation: Squaring gene values
        squared_gene_values = gene_values ** 2
        augmented_gene_values = torch.cat([gene_values, squared_gene_values], dim=1)

        # Reshape and concatenate for LSTM input
        augmented_gene_values = augmented_gene_values.view(augmented_gene_values.size(0), self.num_genes, -1)
        gene_positions = gene_positions.view(gene_positions.size(0), self.num_genes, -1)
        lipid_positions_expanded = lipid_positions.unsqueeze(1).repeat(1, self.num_genes, 1)
        lstm_input = torch.cat((augmented_gene_values, lipid_positions_expanded, gene_positions), dim=2)

        # LSTM
        lstm_out, _ = self.lstm(lstm_input)
        lstm_out = torch.cat((lstm_out[:, -1, :1024], lstm_out[:, 0, 1024:]), dim=1)

        # Dense layers with residual connections
        identity = lstm_out
        x = self.dropout1(F.gelu(self.bn1(self.fc1(lstm_out))))
        x = x + self.fc1_residual(identity)  # First residual connection

        identity = x
        x = self.dropout2(F.gelu(self.bn2(self.fc2(x))))
        x = x + self.fc2_residual(identity)  # Second residual connection

        output = self.fc3(x)
        return output


In [13]:
# Assuming gene_values_tensor, gene_coords_tensor, and lipid_coords_tensor are on the GPU
num_epochs = 100
batch_size = 64
patience = 10  # Set the patience for early stopping
best_avg_r2_score = float('-inf')  # Initialize the best average R² score
counter = 0  # Initialize counter for early stopping

# Model, Loss, and Optimizer
model = GeneLipidModel().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.2, threshold_mode='rel', verbose=True)

# Split the dataset
X_train, X_test, y_train, y_test = train_test_split(lipid_coords_tensor, lipid_values_tensor, test_size=0.3, random_state=100003)
total_batches = len(y_train) // batch_size
# Training loop
for epoch in range(num_epochs):
    model.train()
    
    for batch in tqdm(range(total_batches), desc=f"Epoch {epoch+1}/{num_epochs}"):
        start_idx = batch * batch_size
        end_idx = min((batch + 1) * batch_size, len(X_train))

        # Nearest neighbors search
        neighbors_indices = find_k_nearest_neighbors(X_train[start_idx:end_idx], gene_coords_tensor, k=2000)

        # Prepare inputs and labels
        neighbors_gene_values = gene_values_tensor[neighbors_indices].view(-1, 500 * 2000)  # Reshape accordingly
        neighbors_coords = gene_coords_tensor[neighbors_indices].view(-1, 2 * 2000)  # Reshape accordingly
        inputs = torch.hstack([X_train[start_idx:end_idx], neighbors_gene_values, neighbors_coords]).to(device)
        labels = y_train[start_idx:end_idx].to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Evaluate the model on the test set for R² score
    model.eval()  # Set the model to evaluation mode
    test_loss = 0.0 
    total_test_batches = len(X_test) // batch_size
    test_predictions = []
    test_actuals = []

    with torch.no_grad():
        for batch in range(total_test_batches):
            start_idx = batch * batch_size
            end_idx = min((batch + 1) * batch_size, len(X_test))

            # Prepare test set inputs
            test_neighbors_indices = find_k_nearest_neighbors(X_test[start_idx:end_idx], gene_coords_tensor, k=2000)
            test_neighbors_gene_values = gene_values_tensor[test_neighbors_indices].view(-1, 500 * 2000)
            test_neighbors_coords = gene_coords_tensor[test_neighbors_indices].view(-1, 2 * 2000)
            test_inputs = torch.hstack([X_test[start_idx:end_idx], test_neighbors_gene_values, test_neighbors_coords]).to(device)

            # Get model predictions
            test_outputs = model(test_inputs)
            test_predictions.append(test_outputs)
            test_actuals.append(y_test[start_idx:end_idx])
            
            test_outputs = model(test_inputs)
            test_loss += criterion(test_outputs, y_test[start_idx:end_idx]).item()
    
    test_loss /= total_test_batches
    scheduler.step(test_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Test Loss: {test_loss:.4e}')
    
    # Concatenate predictions and actuals, compute R² score for each lipid
    test_predictions = torch.cat(test_predictions, dim=0)
    test_actuals = torch.cat(test_actuals, dim=0)
    r2_scores_per_lipid = [compute_r2_score(test_actuals[:, i], test_predictions[:, i]).item() 
                           for i in range(156)]

    # Display R² scores using DataFrame
    lipid_names = lipids_section_12.columns[13:]
    r2_scores_df = pd.DataFrame({'Lipid Name': lipid_names, 'R² Score': r2_scores_per_lipid})
    display(r2_scores_df)

    # Calculate and print the average R² score
    average_r2_score = np.mean(r2_scores_per_lipid)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average R² Score: {average_r2_score:.4f}')

    # Early stopping and saving the best model based on average R² score
    if average_r2_score > best_avg_r2_score:
        best_avg_r2_score = average_r2_score
        counter = 0  # Reset the counter
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Epoch {epoch+1}: Average R² improved to {average_r2_score:.4f}, saving model...')
    else:
        counter += 1
        print(f'Epoch {epoch+1}: Average R² did not improve (counter {counter}/{patience})')

    if counter >= patience:
        print(f'Early stopping triggered after {epoch+1} epochs')
        break


Epoch 1/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [1/100], Test Loss: 4.4769e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,-0.300412
1,LPC 15:1,0.628244
2,LPC O-18:3,-0.103555
3,LPC 20:4,0.217633
4,Cer 36:1,-0.176903
...,...,...
151,LPC O-16:2,0.341994
152,LPC 16:0,0.235558
153,LPC O-18:2,-0.032952
154,LPC 18:0,-0.004256


Epoch [1/100], Average R² Score: 0.2797
Epoch 1: Average R² improved to 0.2797, saving model...


Epoch 2/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [2/100], Test Loss: 4.2807e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.153670
1,LPC 15:1,0.664800
2,LPC O-18:3,0.129709
3,LPC 20:4,0.283392
4,Cer 36:1,-0.004850
...,...,...
151,LPC O-16:2,0.390317
152,LPC 16:0,0.183581
153,LPC O-18:2,-0.081396
154,LPC 18:0,0.170260


Epoch [2/100], Average R² Score: 0.3326
Epoch 2: Average R² improved to 0.3326, saving model...


Epoch 3/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [3/100], Test Loss: 4.5839e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.122348
1,LPC 15:1,0.667959
2,LPC O-18:3,0.128182
3,LPC 20:4,0.305803
4,Cer 36:1,-0.029419
...,...,...
151,LPC O-16:2,0.511914
152,LPC 16:0,0.240656
153,LPC O-18:2,-0.247813
154,LPC 18:0,0.074233


Epoch [3/100], Average R² Score: 0.3313
Epoch 3: Average R² did not improve (counter 1/10)


Epoch 4/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [4/100], Test Loss: 4.5581e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,-0.117533
1,LPC 15:1,0.687749
2,LPC O-18:3,0.073953
3,LPC 20:4,0.180318
4,Cer 36:1,0.031828
...,...,...
151,LPC O-16:2,0.197663
152,LPC 16:0,0.097879
153,LPC O-18:2,0.020702
154,LPC 18:0,-0.027202


Epoch [4/100], Average R² Score: 0.2901
Epoch 4: Average R² did not improve (counter 2/10)


Epoch 5/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [5/100], Test Loss: 4.4750e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.185844
1,LPC 15:1,0.665680
2,LPC O-18:3,-0.054160
3,LPC 20:4,0.261976
4,Cer 36:1,0.003873
...,...,...
151,LPC O-16:2,0.527277
152,LPC 16:0,0.263118
153,LPC O-18:2,0.120014
154,LPC 18:0,0.144625


Epoch [5/100], Average R² Score: 0.3148
Epoch 5: Average R² did not improve (counter 3/10)


Epoch 6/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch 00006: reducing learning rate of group 0 to 6.0000e-05.
Epoch [6/100], Test Loss: 4.4681e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.132746
1,LPC 15:1,0.661523
2,LPC O-18:3,0.128955
3,LPC 20:4,0.359950
4,Cer 36:1,0.025082
...,...,...
151,LPC O-16:2,0.515046
152,LPC 16:0,0.128743
153,LPC O-18:2,0.153242
154,LPC 18:0,-0.002599


Epoch [6/100], Average R² Score: 0.3491
Epoch 6: Average R² improved to 0.3491, saving model...


Epoch 7/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [7/100], Test Loss: 3.8498e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.223204
1,LPC 15:1,0.721030
2,LPC O-18:3,0.177805
3,LPC 20:4,0.383676
4,Cer 36:1,0.048481
...,...,...
151,LPC O-16:2,0.535194
152,LPC 16:0,0.306419
153,LPC O-18:2,0.183215
154,LPC 18:0,0.258446


Epoch [7/100], Average R² Score: 0.4253
Epoch 7: Average R² improved to 0.4253, saving model...


Epoch 8/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [8/100], Test Loss: 3.8683e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.219088
1,LPC 15:1,0.720297
2,LPC O-18:3,0.176041
3,LPC 20:4,0.388009
4,Cer 36:1,0.039928
...,...,...
151,LPC O-16:2,0.534547
152,LPC 16:0,0.309037
153,LPC O-18:2,0.179484
154,LPC 18:0,0.254734


Epoch [8/100], Average R² Score: 0.4243
Epoch 8: Average R² did not improve (counter 1/10)


Epoch 9/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [9/100], Test Loss: 3.8806e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.195622
1,LPC 15:1,0.720161
2,LPC O-18:3,0.175386
3,LPC 20:4,0.392823
4,Cer 36:1,0.028389
...,...,...
151,LPC O-16:2,0.529320
152,LPC 16:0,0.306970
153,LPC O-18:2,0.176303
154,LPC 18:0,0.264361


Epoch [9/100], Average R² Score: 0.4243
Epoch 9: Average R² did not improve (counter 2/10)


Epoch 10/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [10/100], Test Loss: 3.8874e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.204529
1,LPC 15:1,0.721082
2,LPC O-18:3,0.174757
3,LPC 20:4,0.397936
4,Cer 36:1,0.038123
...,...,...
151,LPC O-16:2,0.534814
152,LPC 16:0,0.306909
153,LPC O-18:2,0.170074
154,LPC 18:0,0.278369


Epoch [10/100], Average R² Score: 0.4261
Epoch 10: Average R² improved to 0.4261, saving model...


Epoch 11/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch 00011: reducing learning rate of group 0 to 1.2000e-05.
Epoch [11/100], Test Loss: 3.8865e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.193053
1,LPC 15:1,0.721701
2,LPC O-18:3,0.178878
3,LPC 20:4,0.396566
4,Cer 36:1,0.043084
...,...,...
151,LPC O-16:2,0.534752
152,LPC 16:0,0.308330
153,LPC O-18:2,0.168171
154,LPC 18:0,0.274575


Epoch [11/100], Average R² Score: 0.4277
Epoch 11: Average R² improved to 0.4277, saving model...


Epoch 12/100:   0%|          | 0/977 [00:00<?, ?it/s]

Epoch [12/100], Test Loss: 3.8500e-01


Unnamed: 0,Lipid Name,R² Score
0,LPC O- 18:3,0.234475
1,LPC 15:1,0.723428
2,LPC O-18:3,0.187114
3,LPC 20:4,0.400876
4,Cer 36:1,0.062839
...,...,...
151,LPC O-16:2,0.562366
152,LPC 16:0,0.311056
153,LPC O-18:2,0.198194
154,LPC 18:0,0.285437


Epoch [12/100], Average R² Score: 0.4389
Epoch 12: Average R² improved to 0.4389, saving model...


Epoch 13/100:   0%|          | 0/977 [00:00<?, ?it/s]

KeyboardInterrupt: 