In [1]:
import os
import time
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES=True
from torchvision import models
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from tempfile import TemporaryDirectory

In [2]:
class CustomDataset(Dataset):
    def __init__(self, imgs_path, csv_path, transform=None, start_idx=None, end_idx=None):
        self.imgs = imgs_path
        self.csv = pd.read_csv(csv_path, delimiter=';')
        self.transform = transform
        self.start_idx = start_idx
        self.end_idx = end_idx

        if start_idx is not None and end_idx is not None:
            self.csv = self.csv.iloc[start_idx:end_idx].copy().reset_index(drop=True)

    def __len__(self):
        return len(self.csv)

    def __getitem__(self, idx):
        img_name = self.csv.iloc[idx, self.csv.columns.get_loc('filename_hr')]
        target = self.csv.iloc[idx, self.csv.columns.get_loc('normal')]

        img_path = os.path.join(self.imgs, img_name + '.jpg')
        if os.path.exists(img_path):
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, target
        return self.__getitem__(idx + 1)

In [3]:
imgs_path = './ECGs/'
csv_path = './scp_codes.csv'
batch_size = 32
val_split = 0.15
test_split = 0.15

transform = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

dataset = CustomDataset(imgs_path, csv_path, transform)

num_samples = len(dataset)
num_val = int(val_split * num_samples)
num_test = int(test_split * num_samples)
num_train = num_samples - num_val - num_test

train_dataset, val_dataset, test_dataset = random_split(dataset, [num_train, num_val, num_test])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

class_names = ['Abnormal', 'Normal']

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def train_model(model, criterion, optimizer, scheduler, train_dataloader, val_dataloader, num_epochs):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')

        torch.save(model.state_dict(), best_model_params_path)
        best_val_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                    dataloader = train_dataloader
                else:
                    model.eval()   # Set model to evaluate mode
                    dataloader = val_dataloader

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                for inputs, labels in dataloader:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    # statistics
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                if phase == 'train':
                    scheduler.step()

                epoch_loss = running_loss / len(dataloader.dataset)
                epoch_acc = running_corrects.double() / len(dataloader.dataset)

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                # Save the best model on the validation set
                if phase == 'val' and epoch_acc > best_val_acc:
                    best_val_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_val_acc:.4f}')

        # Load the best model weights
        model.load_state_dict(torch.load(best_model_params_path))
    return model

In [5]:
def evaluate_model(model, criterion, test_dataloader):
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    running_corrects = 0

    # Iterate over data
    for inputs, labels in test_dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    # Calculate overall metrics
    test_loss = running_loss / len(test_dataloader.dataset)
    test_acc = running_corrects.double() / len(test_dataloader.dataset)

    print(f'Test Loss: {test_loss:.4f} Test Acc: {test_acc:.4f}')

    return test_loss, test_acc

In [6]:
def visualize_model(model, dataloader, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure(figsize=(10, 12))

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title(f'predicted: {class_names[preds[j]]}')
                image = inputs.cpu().data[j].permute(1, 2, 0)
                ax.imshow(Image.fromarray((image * 255).numpy().astype(np.uint8)))

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

In [7]:
model = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
fc = torch.nn.Sequential(
    torch.nn.Linear(num_ftrs, 128),
    torch.nn.ReLU(inplace=True),
    torch.nn.Dropout(0.6),
    torch.nn.Linear(128, 32),
    torch.nn.ReLU(inplace=True),
    torch.nn.Dropout(0.6),
    torch.nn.Linear(32, 2)
)
model.fc = fc

model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.3)

In [8]:
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
model = train_model(model, criterion, optimizer, scheduler, train_dataloader, val_dataloader, num_epochs=10)

Epoch 0/9
----------
train Loss: 0.4482 Acc: 0.7889
val Loss: 0.3981 Acc: 0.8168

Epoch 1/9
----------
train Loss: 0.3778 Acc: 0.8372
val Loss: 0.3393 Acc: 0.8431

Epoch 2/9
----------
train Loss: 0.3590 Acc: 0.8463
val Loss: 0.5339 Acc: 0.7755

Epoch 3/9
----------
train Loss: 0.3535 Acc: 0.8516
val Loss: 0.3123 Acc: 0.8703

Epoch 4/9
----------
train Loss: 0.3444 Acc: 0.8596
val Loss: 0.3082 Acc: 0.8626

Epoch 5/9
----------
train Loss: 0.2988 Acc: 0.8752
val Loss: 0.2860 Acc: 0.8776

Epoch 6/9
----------
train Loss: 0.2822 Acc: 0.8824
val Loss: 0.2915 Acc: 0.8807

Epoch 7/9
----------
train Loss: 0.2682 Acc: 0.8884
val Loss: 0.3036 Acc: 0.8786

Epoch 8/9
----------
train Loss: 0.2580 Acc: 0.8945
val Loss: 0.3089 Acc: 0.8807

Epoch 9/9
----------
train Loss: 0.2449 Acc: 0.9004
val Loss: 0.3165 Acc: 0.8730

Training complete in 450m 30s
Best val Acc: 0.8807


In [10]:
evaluate_model(model, criterion, test_dataloader)

Test Loss: 0.3110 Test Acc: 0.8654


(0.3110252802216631, tensor(0.8654, dtype=torch.float64))

In [11]:
torch.save(model.state_dict(), 'LR0003-D65-E10-FD.pth')

In [13]:
from sklearn.metrics import precision_score, recall_score

def evaluate_model_PR(model, criterion, test_dataloader):
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    running_corrects = 0
    true_labels = []
    predicted_labels = []

    # Iterate over data
    for inputs, labels in test_dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

        # Collect true labels and predicted labels for precision and recall computation
        true_labels.extend(labels.cpu().numpy())
        predicted_labels.extend(preds.cpu().numpy())

    # Calculate overall metrics
    test_loss = running_loss / len(test_dataloader.dataset)
    test_acc = running_corrects.double() / len(test_dataloader.dataset)
    test_precision = precision_score(true_labels, predicted_labels)
    test_recall = recall_score(true_labels, predicted_labels)

    print(f'Test Loss: {test_loss:.4f} Test Acc: {test_acc:.4f}')
    print(f'Test Precision: {test_precision:.4f} Test Recall: {test_recall:.4f}')

    return test_loss, test_acc, test_precision, test_recall


In [14]:
china_dataset = CustomDataset('./china_images/', './scp_codes_china.csv', transform)
china_dataloader = DataLoader(china_dataset, batch_size=batch_size, shuffle=True)

evaluate_model_PR(model, criterion, china_dataloader)

Test Loss: 0.6401 Test Acc: 0.7484
Test Precision: 0.0511 Test Recall: 0.4307


(0.6400832505407801,
 tensor(0.7484, dtype=torch.float64),
 0.05111633372502938,
 0.4306930693069307)

model = models.resnet18()
model.fc = fc
model.load_state_dict(torch.load('LR0001D65E10.pth'))
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
model = train_model(model, criterion, optimizer, scheduler, train_dataloader, val_dataloader, num_epochs=5)

evaluate_model(model, criterion, test_dataloader)