In [2]:
# import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from torchmetrics.classification import BinaryAUROC, BinaryAveragePrecision
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset

from torchvision.models import resnet50
from tqdm.notebook import tqdm

#Import my functions
import nbimporter
from Functions import MyDataset, train, validate, plot_loss_accuracy


%matplotlib inline

In [3]:
dtypes = {str(i): 'float64' for i in range(5000)}
dtypes.update({
    'Segment_Time': 'str',
    'lab_flag': 'int',
    'Gender': 'str',
    'Age': 'float64',
})


In [4]:
repo_location = 'D:/simedy' # Change this to the location of the repository on your machine

tests = [
    'Platelet Count',
    'Hematocrit',
    'White Blood Cells',
    'Hemoglobin'
]
panel = 'CBC'
time = '15 min'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
num_epochs = 50
batch_size = 128

In [7]:
class ResNet50(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNet50, self).__init__()
        self.resnet = resnet50(weights=None)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet.fc = nn.Linear(2048, num_classes)
        self.dropout = nn.Dropout(p=0.2, inplace=False)

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.resnet.fc(x)
        x = F.softmax(x, dim=1)

        return x
    

class ResNet50_age_gender(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNet50_age_gender, self).__init__()
        self.resnet = resnet50(weights=None)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet.fc = nn.Identity()  # Remove the original fully connected layer
        self.dropout = nn.Dropout(p=0.2, inplace=False)
        self.fc = nn.Linear(2048 + 2, num_classes)  # 2048 for ResNet50 features + 2 for age and gender

    def forward(self, x_img, age, gender):
        # Process image data through ResNet50
        x = self.resnet.conv1(x_img)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.dropout(x)

        # Combine image features with age and gender
        age_gender = torch.cat((age.unsqueeze(1), gender.unsqueeze(1)), dim=1)  # Shape: (batch_size, 2)
        combined = torch.cat((x, age_gender), dim=1)  # Shape: (batch_size, 2048 + 2)

        # Pass combined features through the final classification layer
        output = self.fc(combined)
        output = F.softmax(output, dim=1)

        return output

In [12]:
#if csv results file does not exist, create one
if not os.path.exists(os.path.join(repo_location, "Notebooks", panel, f"results/vital_{panel}_test_results.csv")):
    test_results = pd.DataFrame()
else:
    test_results = pd.read_csv(os.path.join(repo_location, "Notebooks", panel, f"results/vital_{panel}_test_results.csv"))

for test in tqdm(tests):
    for GA in ['GA_used', 'No_GA']:
        for _ in range(1):
            file_path = os.path.join(repo_location, f'CSVs/Vital_{test}_{time}.csv')
            df = pd.read_csv(file_path, dtype=dtypes)
            df = df.drop(columns=['SubjectID'])
            if GA == 'No_GA':
                df = df.drop(columns=['Gender', 'Age'])
            dataset = MyDataset(df)
            test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
            model = ResNet50(num_classes = 2)
            if GA == 'GA_used':
                model = ResNet50_age_gender(num_classes = 2)
            model.load_state_dict(torch.load(os.path.join(repo_location, f"FT models/bestmodel_{test}_{GA}.pth")))
            model.to(device)

            model.eval()
            #Create a confusion matrix
            confusion_matrix = np.zeros((2, 2))
            outputs = torch.tensor([]).to(device)
            targets = torch.tensor([]).to(device)

            #Iterate over the test data
            with torch.no_grad():
                for i, data in enumerate(test_loader, 0):
                    if GA == "GA_used":
                        inputs, age, gender, labels = data
                        inputs, age, gender, labels = inputs.to(device), age.to(device), gender.to(device), labels.to(device)
                        labels = labels.long()  # Convert labels to Long type
                        output = model(inputs, age, gender)
                    else:
                        inputs, labels = data
                        inputs, labels = inputs.to(device), labels.to(device)
                        labels = labels.long()
                        output = model(inputs.unsqueeze(1))
                    

                    _, predicted = torch.max(output.data, 1)
                    outputs = torch.cat((outputs, predicted), 0)
                    targets = torch.cat((targets, labels), 0)

                    for j in range(len(predicted)):
                        confusion_matrix[labels[j], predicted[j]] += 1

            #Calculate the accuracy
            accuracy = np.trace(confusion_matrix) / np.sum(confusion_matrix)

            # Calculate precision and recall for class 1
            if np.sum(confusion_matrix[:, 1]) == 0:
                precision = 0
            else:
                precision = confusion_matrix[1, 1] / np.sum(confusion_matrix[:, 1])

            if np.sum(confusion_matrix[1, :]) == 0:
                recall = 0
            else:
                recall = confusion_matrix[1, 1] / np.sum(confusion_matrix[1, :])

            # Calculate F1 score for class 1
            if precision + recall == 0:
                f1 = 0
            else:
                f1 = 2 * (precision * recall) / (precision + recall)

            # Calculate specificity and NPV for class 0
            if np.sum(confusion_matrix[0, :]) == 0:
                specificity = 0
            else:
                specificity = confusion_matrix[0, 0] / np.sum(confusion_matrix[0, :])
            if np.sum(confusion_matrix[:, 0]) == 0:
                npv = 0
            else:     
                npv = confusion_matrix[0, 0] / np.sum(confusion_matrix[:, 0])

            #Calculate auroc and auprc
            auroc = BinaryAUROC()
            auprc = BinaryAveragePrecision()


            #Create a dictionary of the results
            results_dict = {
                'test': test,
                'class_0_sample':df[df.lab_flag == 0].shape[0],
                'class_1_sample':df[df.lab_flag == 1].shape[0], 
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'specificity': specificity,
                'npv': npv,
                'auroc': auroc(outputs, targets.type(torch.int64)).item(),
                'auprc': auprc(outputs, targets.type(torch.int64)).item(),
                'interval': f"{int(time.split(' ')[0]) * 2} min",
                'Gender_age_used': GA,
                'notes': f'''''' 
            }

            #Append the results to the test_results dataframe
            test_results = pd.concat([test_results, pd.DataFrame([results_dict])], ignore_index=True)

            #Save the results to a csv file
            test_results.to_csv(os.path.join(repo_location, "Notebooks", panel, f"results/vital_{panel}_test_results.csv"), index=False)

  0%|          | 0/4 [00:00<?, ?it/s]

In [13]:
results_dict

{'test': 'Hemoglobin',
 'class_0_sample': 13175,
 'class_1_sample': 35601,
 'accuracy': 0.711476955879941,
 'precision': 0.7289725590299936,
 'recall': 0.9625853206370607,
 'f1': 0.8296473835203544,
 'specificity': 0.03294117647058824,
 'npv': 0.24575311438278596,
 'auroc': 0.4977632462978363,
 'auprc': 0.729006826877594,
 'interval': '30 min',
 'Gender_age_used': 'No_GA',
 'notes': ''}