# Imports

In [1]:
import os
import sys
sys.path.append(os.path.realpath(".."))

from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from evlp_bronch.dataset import ALL_LUNG_IDS, RawEVLPDataset, ProcessedEVLPDataset

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.init as init

import numpy as np
import tqdm as tqdm
from scipy.stats import pearsonr


# Device Setup

In [2]:
device = torch.device("cuda")
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x221b1215610>

# Load Data

In [3]:
train_lung_ids, test_lung_ids = train_test_split(
    ALL_LUNG_IDS, test_size=2, random_state=1
)
train_lung_ids, val_lung_ids = train_test_split(
    train_lung_ids, test_size=2, random_state=1
)

train_dataset = ProcessedEVLPDataset(train_lung_ids)
val_dataset = ProcessedEVLPDataset(val_lung_ids)
test_dataset = ProcessedEVLPDataset(test_lung_ids)

2023-12-03 21:17:11,719 - evlp_bronch.dataset - INFO - Lung_id 29: Interpolation Dy_comp: between 329 and 331: 142.764474, 157.684626
2023-12-03 21:17:11,767 - evlp_bronch.dataset - INFO - Lung_id 47: Interpolation Dy_comp: between 51 and 56: 50.582276, 72.674668
2023-12-03 21:17:11,798 - evlp_bronch.dataset - INFO - Lung_id 53: Interpolation Dy_comp: between 2813 and 2824: 60.586801, 106.632301


In [4]:
len(train_dataset), len(val_dataset), len(test_dataset)

(143, 7, 4)

# Utils Functions

In [5]:
def find_last(lst, value): # find the last occurence of a value in a list
    lst.reverse()
    i = lst.index(value)
    lst.reverse()
    return len(lst) - i - 1

In [6]:
def right_pad_sequence(sequence, target_length):
    current_length = len(sequence)
    total_padding = target_length - current_length
    if total_padding <= 0:
        return sequence
    pad_after = total_padding

    return np.pad(sequence, (0, pad_after), mode='edge')
len(right_pad_sequence(train_dataset[0]['Dy_comp'][:1+1], 1415))


1415

In [7]:
def find_max_length_x():
    m = [0,0,0]
    for i in train_dataset:
        metric_start = find_last(list(i['Is_bronch']), 1) # find the last bronch
        if metric_start > m[0]:
            m[0]=metric_start+1
    for i in val_dataset:
        metric_start = find_last(list(i['Is_bronch']), 1) # find the last bronch
        if metric_start > m[1]:
            m[1]=metric_start+1
    for i in test_dataset:
        metric_start = find_last(list(i['Is_bronch']), 1) # find the last bronch
        if metric_start > m[2]:
            m[2]=metric_start+1
    return max(m)
max_l = find_max_length_x()
max_l

1416

In [8]:
def calculate_slope(start_list, end_list, l):
    y1 = np.mean(start_list, axis=0)
    y2 = np.mean(end_list, axis=0)
    slope = (y2 - y1) / l
    return slope

In [9]:
def set_dataset(dataset):
    X_dc = []
    X_is_normal = []
    X_is_bronch = []

    Y = []

    for i in dataset:
        metric_start = find_last(list(i['Is_bronch']), 1) # find the last bronch
        changes = np.where(np.diff(i['Is_assessment']) == 1)[0]  # Find where each assessment period begins
        # Find the first assessment period that starts after the last bronch occurrence
        first_assessment_after_bronch = None
        for change in changes:
            if change > metric_start:
                first_assessment_after_bronch = change
                break
        if metric_start< (len(i['Is_assessment']) - 1) * 0:
            continue

        if first_assessment_after_bronch is None:
            first_assessment_after_bronch = len(i['Is_assessment']) - 1
        if len(i['Dy_comp'][metric_start:first_assessment_after_bronch]) == 0: # if bronch紧接着assessment
            continue

        # 做padding，保证长度一致，用最长的长度
        X_dc.append(right_pad_sequence(i['Dy_comp'][:metric_start+1], max_l))
        X_is_normal.append(right_pad_sequence(i['Is_normal'][:metric_start+1], max_l))
        X_is_bronch.append(right_pad_sequence(i['Is_bronch'][:metric_start+1], max_l))
        Y.append(calculate_slope(i['Dy_comp'][metric_start + 1: metric_start + 5], i['Dy_comp'][-4:-1], len(i['Dy_comp']) - metric_start - 1))

    print(f"length is {len(X_dc)}")    
    assert len(X_dc) == len(X_is_bronch) == len(X_is_normal) == len(Y), "Inconsistent number of samples"

    X_dc = np.array(X_dc).reshape(-1, max_l)
    X_is_normal = np.array(X_is_normal).reshape(-1, max_l)
    X_is_bronch = np.array(X_is_bronch).reshape(-1, max_l)
    Y = torch.from_numpy(np.array(Y)).float()

    X_combined = np.stack([X_dc, X_is_normal, X_is_bronch], axis=1)  # Shape becomes [N, 3, 1470]
    X_combined = torch.from_numpy(X_combined).float()

    return X_combined, Y



# Set Dataset

In [10]:
class EVLPDataset(Dataset):
    def __init__(self, X_combined, Y):
        self.X_combined = X_combined
        self.Y = Y

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, idx):
        return self.X_combined[idx], self.Y[idx]

In [11]:
x_combine_train, y_train = set_dataset(train_dataset)
x_combine_val, y_val = set_dataset(val_dataset)
x_combine_test, y_test = set_dataset(test_dataset)

train_loader = DataLoader(EVLPDataset(x_combine_train, y_train), batch_size=32, shuffle=True)
val_loader = DataLoader(EVLPDataset(x_combine_val, y_val), batch_size=1, shuffle=False)
test_loader = DataLoader(EVLPDataset(x_combine_test, y_test), batch_size=1, shuffle=False)


length is 143
length is 7
length is 4


# PyTorch

In [12]:
def train(model, train_loader, val_loader, test_loader, criterion, optimizer, epochs):
    model.train()
    epoch_losses = []  # List to store average training loss per epoch
    val_losses = []    # List to store average validation loss per epoch

    val_target = []
    val_predict = []

    for epoch in range(epochs):
        running_loss = 0.0

        for inputs, y in train_loader:
            inputs, y = inputs.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            
        average_loss = running_loss / len(train_loader)
        epoch_losses.append(average_loss)

        model.eval()
        with torch.no_grad():
            running_val_loss = 0.0
            for inputs, y in val_loader:
                inputs, y = inputs.to(device), y.to(device)
                outputs = model(inputs)

                val_loss = criterion(outputs, y)
            
                if epoch == epochs - 1:
                    val_target.append(y)
                    val_predict.append(outputs)
    
                running_val_loss += val_loss.item()

            average_val_loss = running_val_loss / len(val_loader)
            val_losses.append(average_val_loss)

        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch + 1}, Training Loss: {average_loss}, Validation Loss: {average_val_loss}")

        model.train()

    model.eval()

    output_ls = []
    target_ls = []

    test_loss = 0
    with torch.no_grad():
        for data, y in test_loader:
            data, y = data.to(device), y.to(device)
            outputs = model(data)

            output_ls.append(outputs)
            target_ls.append(y)

            case_loss = criterion(outputs, y).item()
            test_loss += case_loss

    test_loss /= len(test_loader.dataset)

    print(f'Test set: Average loss: {test_loss:.4f}')

In [13]:
class CNN(nn.Module):
    def __init__(self, kernel=3, num_filters=64, num_in_channels=3, padding=0):
        super().__init__()

        self.conv1 = torch.nn.Conv1d(num_in_channels, num_filters, kernel_size=kernel, padding=padding)
        self.conv1.weight.data.uniform_(0, 0.01)
        self.conv2 = nn.Conv1d(num_filters, 128, kernel_size=kernel, padding=padding)
        self.conv2.weight.data.uniform_(0, 0.01)

        self.conv_seq = nn.Sequential(
            self.conv1,
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
            self.conv2,
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
        )
        conv_output_size = self._calculate_conv_output_size(max_l, kernel, padding)
        max_y = 1
        self.fc1 = torch.nn.Linear(conv_output_size, 128)
        init.uniform_(self.fc1.weight, -0.01, 0.01)

        self.fc_seq = torch.nn.Sequential( 
            self.fc1,
            torch.nn.ReLU()
        )

        self.final_layer = nn.Linear(in_features=128, out_features=max_y)
        init.uniform_(self.final_layer.weight, -0.01, 0.01)

    def _calculate_conv_output_size(self, input_length, kernel, padding):
        size = (input_length - kernel + 2 * padding) + 1
        size = size // 2
        size = (size - kernel + 2 * padding) + 1
        size = size // 2
        return size * 128  
    
    def forward(self, x):
        x = self.conv_seq(x)
        x = x.view(x.size(0), -1)
        x = self.fc_seq(x)
        x = self.final_layer(x)
        return x

## Performance Comparison

In [14]:
epochs = 200
model = CNN(kernel=3, num_filters=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.000001, betas=(0.9, 0.999), eps=1e-08, amsgrad=True)
criterion = nn.MSELoss()

train(model, train_loader, val_loader, test_loader, criterion, optimizer, epochs)

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 20, Training Loss: 0.011284404690377415, Validation Loss: 0.0020649744879587422
Epoch 40, Training Loss: 0.01194277296308428, Validation Loss: 0.0018668715607158706
Epoch 60, Training Loss: 0.011357122007757426, Validation Loss: 0.0017704091782694117
Epoch 80, Training Loss: 0.011946255271323026, Validation Loss: 0.0018928102768508584
Epoch 100, Training Loss: 0.011590041918680072, Validation Loss: 0.00205516440577672
Epoch 120, Training Loss: 0.01133207492530346, Validation Loss: 0.0017770495941087055
Epoch 140, Training Loss: 0.011733517423272133, Validation Loss: 0.0019614668978777316
Epoch 160, Training Loss: 0.013275116588920355, Validation Loss: 0.0019082970746759592
Epoch 180, Training Loss: 0.011272911075502634, Validation Loss: 0.0018145603098673746
Epoch 200, Training Loss: 0.011474159918725491, Validation Loss: 0.0017742292244033056
Test set: Average loss: 0.0027
