# Imports

In [1]:
import os
import sys
sys.path.append(os.path.realpath(".."))

from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from evlp_bronch.dataset import ALL_LUNG_IDS, RawEVLPDataset, ProcessedEVLPDataset

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.init as init

import numpy as np
import tqdm as tqdm
from scipy.stats import pearsonr


# Device Setup

In [2]:
device = torch.device("mps")
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x13aabacb0>

# Load Data

In [3]:
def load_data(seed):   
    train_lung_ids, test_lung_ids = train_test_split(
    ALL_LUNG_IDS, test_size=2, random_state=seed
    )
    train_lung_ids, val_lung_ids = train_test_split(
    train_lung_ids, test_size=2, random_state=seed
    )

    train_dataset = ProcessedEVLPDataset(train_lung_ids)
    val_dataset = ProcessedEVLPDataset(val_lung_ids)
    test_dataset = ProcessedEVLPDataset(test_lung_ids)
    print(len(train_dataset), len(val_dataset), len(test_dataset))

    return train_dataset, val_dataset, test_dataset

In [4]:
train_dataset, val_dataset, test_dataset = load_data(1)

2023-12-11 21:29:14,702 - evlp_bronch.dataset - INFO - Lung_id 29: Interpolation Dy_comp: between 329 and 331: 142.764474, 157.684626
2023-12-11 21:29:14,727 - evlp_bronch.dataset - INFO - Lung_id 47: Interpolation Dy_comp: between 51 and 56: 50.582276, 72.674668
2023-12-11 21:29:14,738 - evlp_bronch.dataset - INFO - Lung_id 53: Interpolation Dy_comp: between 2813 and 2824: 60.586801, 106.632301


143 7 4


# Utils Functions

In [5]:
def find_last(lst, value): # find the last occurence of a value in a list
    lst.reverse()
    i = lst.index(value)
    lst.reverse()
    return len(lst) - i - 1

In [6]:
def right_pad_sequence(sequence, target_length):
    current_length = len(sequence)
    total_padding = target_length - current_length
    if total_padding <= 0:
        return sequence
    pad_after = total_padding

    return np.pad(sequence, (0, pad_after), mode='edge')


In [7]:
def find_max_length_x():
    m = [0,0,0]
    for i in train_dataset:
        metric_start = find_last(list(i['Is_bronch']), 1) # find the last bronch
        if metric_start > m[0]:
            m[0]=metric_start+1
    for i in val_dataset:
        metric_start = find_last(list(i['Is_bronch']), 1) # find the last bronch
        if metric_start > m[1]:
            m[1]=metric_start+1
    for i in test_dataset:
        metric_start = find_last(list(i['Is_bronch']), 1) # find the last bronch
        if metric_start > m[2]:
            m[2]=metric_start+1
    return max(m)
max_l = find_max_length_x()
max_l

1416

In [8]:
def find_max_length_y(dataset):
    m = 0
    for i in dataset:
        metric_start = find_last(list(i['Is_bronch']), 1) # find the last bronch
        changes = np.where(np.diff(i['Is_assessment']) == 1)[0]  # Find where each assessment period begins
        # Find the first assessment period that starts after the last bronch occurrence
        first_assessment_after_bronch = None
        for change in changes:
            if change > metric_start:
                first_assessment_after_bronch = change
                break
        if metric_start< (len(i['Is_assessment']) - 1) * 0:
            continue
        if first_assessment_after_bronch is None:
            first_assessment_after_bronch = len(i['Is_assessment']) - 1
        if len(i['Dy_comp'][metric_start:first_assessment_after_bronch]) == 0: # if bronch紧接着assessment
            continue
        if len(i['Dy_comp'][metric_start:first_assessment_after_bronch]) > m:
            m = first_assessment_after_bronch - metric_start
    return m + 1
max_y_train = find_max_length_y(train_dataset)
max_y_val = find_max_length_y(val_dataset)
max_y_test = find_max_length_y(test_dataset)
max_y = max(max_y_train, max_y_val, max_y_test)
max_y

368

In [9]:
def set_dataset(dataset, max_l, max_y):
    X_dc = []
    X_is_normal = []
    X_is_bronch = []

    Y = []
    Y_len = []

    for i in dataset:
        metric_start = find_last(list(i['Is_bronch']), 1) # find the last bronch
        changes = np.where(np.diff(i['Is_assessment']) == 1)[0]  # Find where each assessment period begins
        # Find the first assessment period that starts after the last bronch occurrence
        first_assessment_after_bronch = None
        for change in changes:
            if change > metric_start:
                first_assessment_after_bronch = change
                break
        if metric_start< (len(i['Is_assessment']) - 1) * 0:
            continue

        if first_assessment_after_bronch is None:
            first_assessment_after_bronch = len(i['Is_assessment']) - 1
        if len(i['Dy_comp'][metric_start + 5:first_assessment_after_bronch]) == 0: # if bronch紧接着assessment
            continue

        # 做padding，保证长度一致，用最长的长度
        X_dc.append(right_pad_sequence(i['Dy_comp'][:metric_start+1], max_l))
        X_is_normal.append(right_pad_sequence(i['Is_normal'][:metric_start+1], max_l))
        X_is_bronch.append(right_pad_sequence(i['Is_bronch'][:metric_start+1], max_l))

        Y_len.append(len(i['Dy_comp'][metric_start + 5:first_assessment_after_bronch])) # 记录长度，用于计算loss)
        Y.append(right_pad_sequence(i['Dy_comp'][metric_start + 5:first_assessment_after_bronch], max_y))

    print(f"length is {len(X_dc)}")    
    assert len(X_dc) == len(X_is_bronch) == len(X_is_normal) == len(Y), "Inconsistent number of samples"

    X_dc = np.array(X_dc).reshape(-1, max_l)
    X_is_normal = np.array(X_is_normal).reshape(-1, max_l)
    X_is_bronch = np.array(X_is_bronch).reshape(-1, max_l)
    Y = torch.from_numpy(np.array(Y)).float()
    Y_len = torch.from_numpy(np.array(Y_len)).int()

    X_combined = np.stack([X_dc, X_is_normal, X_is_bronch], axis=1)  # Shape becomes [N, 3, 1470]
    X_combined = torch.from_numpy(X_combined).float()

    return X_combined, Y, Y_len



# Set Dataset

In [10]:
class EVLPDataset(Dataset):
    def __init__(self, X_combined, Y, Y_len):
        self.X_combined = X_combined
        self.Y = Y
        self.Y_len = Y_len

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        return self.X_combined[idx], self.Y[idx], self.Y_len[idx]

In [11]:
x_combine_train, y_train, y_len_train = set_dataset(train_dataset, max_l, max_y)
x_combine_val, y_val, y_len_val = set_dataset(val_dataset, max_l, max_y)
x_combine_test, y_test, y_len_test = set_dataset(test_dataset, max_l, max_y)

train_loader = DataLoader(EVLPDataset(x_combine_train, y_train, y_len_train), batch_size=32, shuffle=True)
val_loader = DataLoader(EVLPDataset(x_combine_val, y_val, y_len_val), batch_size=1, shuffle=False)
test_loader = DataLoader(EVLPDataset(x_combine_test, y_test, y_len_test), batch_size=1, shuffle=False)


length is 142
length is 7
length is 4


In [12]:
def train(model, train_loader, val_loader, test_loader, criterion, optimizer, epochs, plot=False):

    model.train()
    epoch_losses = []
    val_losses = []
    val_pearson_rs = []

    val_target = []
    val_predict = []

    for epoch in range(epochs):
        running_loss = 0.0

        for inputs, y, lengths in train_loader:
            inputs, y = inputs.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)

            mask = torch.arange(outputs.size(1)).expand(len(lengths), outputs.size(1)) < lengths.unsqueeze(1)
            mask = mask.to(device)
            outputs_masked = torch.masked_select(outputs, mask).to(device)
            y_masked = torch.masked_select(y, mask).to(device)

            loss = criterion(outputs_masked, y_masked)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            
        average_loss = running_loss / len(train_loader)
        epoch_losses.append(average_loss)

        model.eval()
        with torch.no_grad():
            running_val_loss = []
            running_val_pearson_r = []
            for inputs, y, lengths in val_loader:
                inputs, y = inputs.to(device), y.to(device)
                outputs = model(inputs)

                mask = torch.arange(outputs.size(1)).expand(len(lengths), outputs.size(1)) < lengths.unsqueeze(1)
                mask = mask.to(device)
                outputs_masked = torch.masked_select(outputs, mask).to(device)
                y_masked = torch.masked_select(y, mask).to(device)
                val_loss = criterion(outputs_masked, y_masked)
                running_val_loss.append(val_loss.cpu().item())
                for i in range(outputs.size()[0]):
                    val_pearson_r, _ = pearsonr(outputs[i, :lengths[i]].cpu().numpy(), y[i, :lengths[i]].cpu().numpy())
                    running_val_pearson_r.append(val_pearson_r)
                if epoch == epochs - 1:
                    val_target.append(y[:, :lengths])
                    val_predict.append(outputs[:, :lengths])

            average_val_loss = np.mean(running_val_loss)
            val_losses.append(average_val_loss)
            average_val_peason_r = np.mean(running_val_pearson_r)
            val_pearson_rs.append(average_val_peason_r)

        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch + 1}, Training Loss: {average_loss}, Validation Loss: {average_val_loss}, Validation Pearson's R: {average_val_peason_r}")

        if epoch == epochs - 1:
            pearson_r_val = average_val_peason_r

        model.train()

    # Plotting
    if plot:
        num_samples = len(val_predict)
        cols = 2
        rows = num_samples // cols + (num_samples % cols > 0)
        plt.figure(figsize=(12, 4 * rows))
        for i in range(num_samples):
            plt.subplot(rows, cols, i + 1)
            plt.plot(val_predict[i][0].cpu().numpy(), label='Predicted')
            plt.plot(val_target[i][0].cpu().numpy(), label='Target', alpha=0.7)
            plt.title(f"Sample {i+1}")
            plt.xlabel("Time Steps")
            plt.ylabel("Values")
            plt.legend()


        plt.show()

        plt.figure(figsize=(7, 4))
        plt.plot(range(1, epochs+1), epoch_losses, marker='o', color='blue', label='Training Loss')
        plt.plot(range(1, epochs+1), val_losses, marker='o', color='red', label='Validation Loss')
        plt.title('Training and Validation Loss per Epoch')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.legend()
        plt.show()

    model.eval()

    output_ls = []
    target_ls = []

    test_loss = 0
    running_test_pearson_r = []
    with torch.no_grad():
        for data, y, lengths in test_loader:
            data, y = data.to(device), y.to(device)
            outputs = model(data)
            
            mask = torch.arange(outputs.size(1)).expand(len(lengths), outputs.size(1)) < lengths.unsqueeze(1)
            mask = mask.to(device)
            outputs_masked = torch.masked_select(outputs, mask).to(device)
            y_masked = torch.masked_select(y, mask).to(device)

            for i in range(outputs.size()[0]):
                pearson_r, _ = pearsonr(outputs[i, :lengths[i]].cpu().numpy(), y[i, :lengths[i]].cpu().numpy())
                running_test_pearson_r.append(pearson_r)

            output_ls.append(outputs[:, :lengths])
            target_ls.append(y[:, :lengths])

            case_loss = criterion(outputs_masked, y_masked).item()
            test_loss += case_loss

    test_loss /= len(test_loader.dataset)
    pearson_r_test = np.mean(running_test_pearson_r)

    # Plotting
    if plot:
        num_samples = len(output_ls)
        cols = 2
        rows = num_samples // cols + (num_samples % cols > 0)
        plt.figure(figsize=(12, 4 * rows))
        for i in range(num_samples):
            plt.subplot(rows, cols, i + 1)
            plt.plot(output_ls[i][0].cpu().numpy(), label='Predicted')
            plt.plot(target_ls[i][0].cpu().numpy(), label='Target', alpha=0.7)
            plt.title(f"Sample {i+1}")
            plt.xlabel("Time Steps")
            plt.ylabel("Values")
            plt.legend()

        plt.tight_layout()
        plt.show()

    print(f'Test set: Average loss: {test_loss:.4f}, Pearson\'s R: {pearson_r}')
    return average_loss, average_val_loss, test_loss, pearson_r_val, pearson_r_test

# CNN

In [13]:
class CNN(nn.Module):
    def __init__(self, kernel=3, num_filters=64, num_in_channels=3, padding=0):
        super().__init__()

        self.conv1 = torch.nn.Conv1d(num_in_channels, num_filters, kernel_size=kernel, padding=padding)
        self.conv1.weight.data.uniform_(0, 0.01)
        self.conv2 = nn.Conv1d(num_filters, 128, kernel_size=kernel, padding=padding)
        self.conv2.weight.data.uniform_(0, 0.01)

        self.conv_seq = nn.Sequential(
            self.conv1,
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
            self.conv2,
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
        )
        conv_output_size = self._calculate_conv_output_size(max_l, kernel, padding)
        self.fc1 = torch.nn.Linear(conv_output_size, 128)
        init.uniform_(self.fc1.weight, -0.01, 0.01)

        self.fc_seq = torch.nn.Sequential( 
            self.fc1,
            torch.nn.ReLU()
        )

        self.final_layer = nn.Linear(in_features=128, out_features=max_y)
        init.uniform_(self.final_layer.weight, -0.01, 0.01)

    def _calculate_conv_output_size(self, input_length, kernel, padding):
        size = (input_length - kernel + 2 * padding) + 1
        size = size // 2
        size = (size - kernel + 2 * padding) + 1
        size = size // 2
        return size * 128  
    
    def forward(self, x):
        x = self.conv_seq(x)
        x = x.view(x.size(0), -1)
        x = self.fc_seq(x)
        x = self.final_layer(x)
        return x

In [14]:
# epochs = 200
# model = CNN(kernel=3, num_filters=64).to(device)
# total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f"Total trainable parameters in the model: {total_params}")
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.000001, betas=(0.9, 0.999), eps=1e-08, amsgrad=True)
# criterion = nn.MSELoss()

# train(model, train_loader, val_loader, test_loader, criterion, optimizer, epochs)

# CSTM

In [19]:
class CNNLSTM(nn.Module):
    def __init__(self, kernel=3, num_filters=64, num_in_channels=3, padding=0, lstm_hidden_size=128, lstm_layers=1):
        super().__init__()

        conv1 = nn.Conv1d(num_in_channels, num_filters, kernel_size=kernel, padding=padding)
        conv1.weight.data.uniform_(0, 0.01)
        conv2 = nn.Conv1d(num_filters, 128, kernel_size=kernel, padding=padding)
        conv2.weight.data.uniform_(0, 0.01)

        self.conv_seq = nn.Sequential(
            conv1,
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
            conv2,
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
        )

        self.lstm = nn.LSTM(input_size=128, hidden_size=lstm_hidden_size, num_layers=lstm_layers, batch_first=True)

        self.final_layer = nn.Linear(in_features=lstm_hidden_size, out_features=max_y)

    def forward(self, x):
        x = self.conv_seq(x)
        x = x.permute(0, 2, 1)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = nn.functional.leaky_relu(x)
        x = self.final_layer(x)
        return x

In [None]:
epochs = 200
model1 = CNNLSTM(kernel=3, num_filters=64, lstm_hidden_size=128, lstm_layers=1).to(device)
total_params = sum(p.numel() for p in model1.parameters() if p.requires_grad)
print(f"Total trainable parameters in the model: {total_params}")
optimizer = torch.optim.Adam(model1.parameters(), lr=0.01, amsgrad=True)
criterion = nn.MSELoss()

train(model1, train_loader, val_loader, test_loader, criterion, optimizer, epochs)

# Auto Regressive  (Abandoned)
https://arxiv.org/pdf/1703.04122.pdf

In [24]:
# Define the Offset Network (Multilayer Perceptron)
class OffsetNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(OffsetNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)
        self.fc1.weight.data.uniform_(0.001, 0.005)

    def forward(self, x):
        x = nn.functional.leaky_relu(self.fc1(x))
        return x

# Define the Significance Network (Fully Convolutional Network)
class SignificanceNetwork(nn.Module):
    def __init__(self, channels):
        super(SignificanceNetwork, self).__init__()
        self.conv1 = nn.Conv1d(channels, max_y, kernel_size=3, dilation=1)
        self.conv1.weight.data.uniform_(0.001, 0.005)

    def forward(self, x):
        x = nn.functional.leaky_relu(self.conv1(x))
        return x

# Define the SOCNN Model
class SOCNN(nn.Module):
    def __init__(self, input_dim, output_dim, time_steps, channels):
        super(SOCNN, self).__init__()
        self.time_steps = time_steps
        self.output_dim = output_dim
    
        self.significance_network = SignificanceNetwork(channels).to(device)
        self.offset_network = OffsetNetwork(input_dim, output_dim).to(device)

    def forward(self, x):
        batch_size = x.shape[0] # [32, 3, 1416]
        
        self.W = nn.Parameter(torch.empty((batch_size, self.output_dim, self.time_steps)).uniform_(0, 0.005)).to(device) # [32, 368, 1416]
        
        significance = self.significance_network(x) # [32, 1416, 1414]

        offsets = torch.stack([self.offset_network(x[:, :, i]) for i in range(self.time_steps)]) # [1416, 32, 1416]
        offsets = torch.reshape(offsets, (offsets.shape[1], offsets.shape[0], offsets.shape[2])) # [32, 1416, 368]

        temp = torch.bmm(self.W, offsets) # [32, 368, 368]
        y_hat = torch.bmm(temp, significance) # [32, 368, 1414]
        
        y_hat = torch.sum(y_hat, dim=-1)
        
        return y_hat

In [18]:
epochs = 100
model1 = SOCNN(input_dim=3, output_dim=max_y, time_steps=max_l, channels=3).to(device)
total_params = sum(p.numel() for p in model1.parameters() if p.requires_grad)
print(f"Total trainable parameters in the model: {total_params}")
optimizer = torch.optim.Adam(model1.parameters(), lr=0.01, amsgrad=True)
criterion = nn.MSELoss()

train(model1, train_loader, val_loader, test_loader, criterion, optimizer, epochs)

Total trainable parameters in the model: 5152
Epoch 20, Training Loss: 35459.036328125, Validation Loss: 341370.01785714284
Epoch 40, Training Loss: 68819.5732421875, Validation Loss: 353839.02566964284
Epoch 60, Training Loss: 243427.71943359374, Validation Loss: 287327.56870814733
Epoch 80, Training Loss: 52285.768359375, Validation Loss: 261351.24441964287
Epoch 100, Training Loss: 91496.0357421875, Validation Loss: 261265.5767299107
Epoch 100, Pearson's R: 0.7527632184931585
Test set: Average loss: 242172.6172, Pearson's R: 0.9236166957326347
