In [5]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
import torch.nn.functional as F


def load_preprocess_data(file_path):
    data = pd.read_csv(file_path, index_col=0)
    data.columns = range(1, len(data.columns) + 1)
    expression_sums = data.iloc[:,:-1].sum(axis=1)
    filtered_data = data[(expression_sums < 2000) & (expression_sums > 0)]
    return filtered_data


def prepare_data(data, window_size, train_len, test_size=0.1, val_size=0.1):
    all_data = data.values[:, 0:train_len+1]
    X, y = sliding_window(all_data, window_size)
    X = X.astype(float).reshape(-1, window_size)
    y = y.astype(float).reshape(-1, 1)

    X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
    adjusted_val_size = val_size / (1 - test_size)

    X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=adjusted_val_size, random_state=42)

    return X_train, X_val, X_test, y_train, y_val, y_test


class GeneExpressionDataset(Dataset):
    def __init__(self, features, targets):
        self.features = features
        self.targets = targets
    def __len__(self):
        return len(self.features)
    def __getitem__(self, idx):
        return self.features[idx], self.targets[idx]

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)  
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

class SelfAttention(nn.Module):
    def __init__(self, channels):
        super(SelfAttention, self).__init__()
        self.query_weight = nn.Linear(channels, channels)

    def forward(self, x):
        batch_size, channels, length = x.size()

        x_reshaped = x.view(batch_size * length, channels)

        Q = self.query_weight(x_reshaped)  # [batch_size * length, channels]
        K = x_reshaped

        # Reshape Q, K for bmm
        Q = Q.view(batch_size, length, channels)  # [batch_size, length, channels]
        K = K.view(batch_size, channels, length)  # [batch_size, channels, length]

        # Attention Scores
        attention_scores = torch.bmm(Q, K)  # [batch_size, length, length]
        attention_scores = F.softmax(attention_scores, dim=-1)

        # Apply Attention Weights
        V = x.view(batch_size, channels, length)  # [batch_size, channels, length]
        output = torch.bmm(V, attention_scores.transpose(-2, -1))  # [batch_size, channels, length]

        return output




class CNN1D(nn.Module):
    def __init__(self):
        super(CNN1D, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=2, stride=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=64, kernel_size=2, stride=1)
        self.relu2 = nn.ReLU()

        self.attention = SelfAttention(64)

        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
        self.fc1 = nn.Linear(64, 16)
        self.fc2 = nn.Linear(16, 1)

    def forward(self, x):
        x = x.unsqueeze(1)

        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))

        x = self.attention(x)

        x = self.adaptive_pool(x)

        x = x.view(x.size(0), -1)

        x = self.fc1(x)
        x = self.fc2(x)

        return x
    
class LinearRegression(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)
    
def sliding_window(data, window_size):
    features = []
    targets = []
    for i in range(data.shape[1] - window_size):
        features.append(data[:, i:i + window_size])
        targets.append(data[:, i + window_size])
    return np.array(features), np.array(targets)

# Create dataloaders
def create_dataloaders(X_train, X_val, X_test, y_train, y_val, y_test):
    X_train_tensor, y_train_tensor = torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32)
    X_val_tensor, y_val_tensor = torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.float32)
    train_dataset = GeneExpressionDataset(X_train_tensor, y_train_tensor)
    val_dataset = GeneExpressionDataset(X_val_tensor, y_val_tensor)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    if X_test is not None and y_test is not None:
        X_test_tensor, y_test_tensor = torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32)
        test_dataset = GeneExpressionDataset(X_test_tensor, y_test_tensor)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
        return train_loader, val_loader, test_loader
    else:
        return train_loader, val_loader

# Train model function
def train_model(model, train_loader, val_loader, num_epochs, criterion, optimizer, device):
    for epoch in range(num_epochs):
        model.train()
        for batch_features, batch_targets in train_loader:
            batch_features, batch_targets = batch_features.to(device), batch_targets.to(device)
            outputs = model(batch_features)
            loss = criterion(outputs, batch_targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch_features, batch_targets in val_loader:
                batch_features, batch_targets = batch_features.to(device), batch_targets.to(device)
                outputs = model(batch_features)
                loss = criterion(outputs, batch_targets)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {loss.item():.4f}, Validation Loss: {val_loss:.4f}')

# Test model function
def test_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch_features, batch_targets in test_loader:
            batch_features, batch_targets = batch_features.to(device), batch_targets.to(device)
            outputs = model(batch_features)
            loss = criterion(outputs, batch_targets)
            total_loss += loss.item()
    average_test_loss = total_loss / len(test_loader)
    print(f"Average Test Loss: {average_test_loss:.4f}")

# Get prediction function
def get_prediction(input_data, model, device):
    input_tensor = torch.tensor(input_data, dtype=torch.float32).to(device)
    with torch.no_grad():
        prediction = model(input_tensor)
    return prediction.cpu().numpy().reshape(-1, 1)

def generate_cluster_specific_prediction(predict_nums, train_len, data, window_size, models, device):

    gene_symbols = data.index


    all_data = data.values[:, :train_len]
    predictions = np.zeros((all_data.shape[0], predict_nums))

    for cluster, model in models.items():
        cluster_indices = np.where(data.iloc[:, -1] == cluster)[0]
        cluster_data = all_data[cluster_indices, :]

        for i in range(predict_nums):
            input_data = cluster_data[:, -window_size:]

            pred = get_prediction(input_data, model, device)
            predictions[cluster_indices, i] = pred.squeeze()
            cluster_data = np.hstack((cluster_data, pred))



    # Combine predictions with original data
    extended_data = np.concatenate((all_data, predictions), axis=1)
    predicted_cols = ["predicted_" + str(train_len + i) for i in range(1, predict_nums + 1)]
    columns = data.columns.tolist()[:train_len] + predicted_cols
    extended_df = pd.DataFrame(extended_data, columns=columns)

    # Reattach gene symbols
    extended_df.insert(0, 'gene symbol', gene_symbols)

    return extended_df

def train_models_for_each_cluster(data, window_size, train_len, num_epochs, device):
    unique_clusters = np.unique(data.iloc[:, -1])
    models = {}
    
    for cluster in unique_clusters:
        print(f"Training model for cluster {cluster}")
        cluster_data = data[data.iloc[:, -1] == cluster]
        X_train, X_val, _, y_train, y_val, _ = prepare_data(cluster_data, window_size, train_len)
        loaders = create_dataloaders(X_train, X_val, None, y_train, y_val, None)
        train_loader = loaders[0]
        val_loader = loaders[1]

        model = LinearRegression(3,1).to(device)
        criterion = nn.L1Loss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        train_model(model, train_loader, val_loader, num_epochs, criterion, optimizer, device)
        
        models[cluster] = model
    
    return models



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
data_path = 'ForeBrain_TPM_c6.csv'
data = load_preprocess_data(data_path)
train_len = 4
window_size = 3
num_epochs = 30

models = train_models_for_each_cluster(data, window_size, train_len, num_epochs, device)
results = generate_cluster_specific_prediction(3, train_len, data, window_size, models, device)

results.to_csv(f'{data_path[:-4]}_results.csv', index=False)
print("Prediction generation complete.")

Training model for cluster 1
Epoch [1/30], Training Loss: 25.1625, Validation Loss: 31.0282
Epoch [2/30], Training Loss: 23.3245, Validation Loss: 29.2856
Epoch [3/30], Training Loss: 21.5190, Validation Loss: 27.5848
Epoch [4/30], Training Loss: 19.6944, Validation Loss: 25.8975
Epoch [5/30], Training Loss: 17.8872, Validation Loss: 24.2405
Epoch [6/30], Training Loss: 16.1007, Validation Loss: 22.6058
Epoch [7/30], Training Loss: 14.3034, Validation Loss: 20.9643
Epoch [8/30], Training Loss: 12.5520, Validation Loss: 19.3826
Epoch [9/30], Training Loss: 10.7848, Validation Loss: 17.8065
Epoch [10/30], Training Loss: 9.0975, Validation Loss: 16.2902
Epoch [11/30], Training Loss: 8.3219, Validation Loss: 14.8547
Epoch [12/30], Training Loss: 7.7273, Validation Loss: 13.4885
Epoch [13/30], Training Loss: 7.1925, Validation Loss: 12.2191
Epoch [14/30], Training Loss: 6.6830, Validation Loss: 11.1065
Epoch [15/30], Training Loss: 6.2071, Validation Loss: 10.1365
Epoch [16/30], Training Lo

Epoch [10/30], Training Loss: 1.5946, Validation Loss: 17.1880
Epoch [11/30], Training Loss: 0.9648, Validation Loss: 15.6872
Epoch [12/30], Training Loss: 0.3485, Validation Loss: 14.2645
Epoch [13/30], Training Loss: 0.7895, Validation Loss: 12.9632
Epoch [14/30], Training Loss: 1.3196, Validation Loss: 11.7651
Epoch [15/30], Training Loss: 1.8322, Validation Loss: 10.6591
Epoch [16/30], Training Loss: 2.3123, Validation Loss: 9.6827
Epoch [17/30], Training Loss: 2.7788, Validation Loss: 8.8315
Epoch [18/30], Training Loss: 3.2007, Validation Loss: 8.1275
Epoch [19/30], Training Loss: 3.5920, Validation Loss: 7.5319
Epoch [20/30], Training Loss: 3.9402, Validation Loss: 7.0757
Epoch [21/30], Training Loss: 4.2508, Validation Loss: 6.7259
Epoch [22/30], Training Loss: 4.5292, Validation Loss: 6.4815
Epoch [23/30], Training Loss: 4.7647, Validation Loss: 6.3135
Epoch [24/30], Training Loss: 4.9432, Validation Loss: 6.2243
Epoch [25/30], Training Loss: 5.0787, Validation Loss: 6.1688
Ep