In [None]:
import torch
print("PyTorch version:", torch.__version__)
# check GPU
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"GPU: {num_gpus} device(s) available")
    for i in range(num_gpus):
        print(f"  GPU{i}: {torch.cuda.get_device_name(i)}")

    # set default device
    torch.cuda.set_device(0)
    # set device to GPU
    device = torch.device("cuda:0")  # use the first GPU
    # set parallel strategy
    if num_gpus > 1:
        print("Multi-GPU Training with DataParallel Strategy")
        parallel_strategy = "DataParallel"
    else:
        parallel_strategy = "Single GPU"

else:
    print("can't find GPU，and use CPU")
    device = torch.device("cpu")
    parallel_strategy = "CPU"

print("current device:", device)
print("strategy:", parallel_strategy)

In [None]:
class XRayDataset(Dataset):
    def __init__(self, csv_file, img_dirs, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            img_dirs (list): List of directories containing the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.labels_frame = pd.read_csv(csv_file)# Read the CSV file with image paths and labels
        self.img_dirs = img_dirs# Directory containing image files
        self.transform = transform# Optional image transforms (resize, normalize)

    def __len__(self):
        return len(self.labels_frame) # Total number of samples

    def __getitem__(self, idx):
        # Get the image name from the dataframe
        img_name = self.labels_frame.iloc[idx, 0]
        # Construct the full path to the image
        for img_dir in self.img_dirs:
            img_path = os.getcwd() + os.path.join(img_dir, img_name)
            if os.path.exists(img_path):
                # Image exists, break the loop
                image = Image.open(img_path).convert('RGB')
                # Load labels as a float tensor (multi-label: each class is 0 or 1)
                labels = torch.tensor(self.labels_frame.iloc[idx, 1:].values.astype('float32'))
                # Apply transforms if any
                if self.transform:
                    image = self.transform(image)

                return image, labels

        # if the image can't be found in all files, the error will be raised
        raise FileNotFoundError(f"{img_name} not found in any of the provided directories.")

In [None]:
def filter_invalid_samples(dataset):
    """filter the invalid samples(None) and return new Subset"""
    valid_indices = []
    for idx in range(len(dataset)):
        sample = dataset[idx]
        if sample is not None:
            valid_indices.append(idx)

    print(f"original samples: {len(dataset)}, valid samples: {len(valid_indices)}")
    return Subset(dataset, valid_indices)

transform = transforms.Compose([
    transforms.Resize((224, 224)),               # Resize to 224x224 (required by CNN)
    transforms.ToTensor(),                       # Convert to tensor (value range [0,1])
    transforms.Normalize([0.485, 0.456, 0.406],   # Normalize using ImageNet means
                         [0.229, 0.224, 0.225])   # and standard deviations
])
# Define the path to the CSV file and image directory
csv_file = './labels.csv'
img_dir_1 = '/resized_images_add1'
img_dir_2 = '/resized_images_add2'
img_dir_3 = '/resized_images/resized_images'
img_dirs = [img_dir_3]
# Create the dataset
dataset = XRayDataset(csv_file=csv_file, img_dirs=img_dirs, transform=transform)
filtered_dataset = filter_invalid_samples(dataset)
# read the disease names from the CSV file
df = pd.read_csv(csv_file)
disease_names = df.columns[1:].tolist()  # Skip the first column (image names)
# Split the dataset into training and validation sets
train_size = int(0.8 * len(filtered_dataset))
test_size = len(filtered_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(filtered_dataset, [train_size, test_size])
# Create DataLoader for training and validation sets
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
class CNN(nn.Module):
    def __init__(self, num_classes=15):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)   # [B, 3, 224, 224] -> [B, 16, 224, 224]
        self.pool1 = nn.MaxPool2d(2, 2)                            # -> [B, 16, 112, 112]

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)  # -> [B, 32, 112, 112]
        self.pool2 = nn.MaxPool2d(2, 2)                            # -> [B, 32, 56, 56]

        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # -> [B, 64, 56, 56]
        self.pool3 = nn.MaxPool2d(2, 2)                            # -> [B, 64, 28, 28]

        self.fc1 = nn.Linear(64 * 28 * 28, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))

        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)

        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
# Evaluation
from sklearn.metrics import f1_score
def evaluate_CNN(net, val_loader, disease_names, epoch = None, criterion = nn.BCEWithLogitsLoss()):
    net.eval()
    device = next(net.parameters()).device

    all_targets = []
    all_probs = []
    all_preds = []
    total_correct = torch.zeros(len(disease_names)).to(device)
    total_samples = 0
    total_loss = 0.0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = torch.sigmoid(net(images))

            loss = criterion(outputs, labels)
            total_loss += loss.item()*images.size(0)

            predictions = (outputs > 0.5).float()
            total_correct += (predictions == labels).sum(dim=0)
            total_samples += labels.size(0)

            all_targets.append(labels)
            all_probs.append(outputs)
            all_preds.append(predictions)

    avg_loss = total_loss / len(val_loader.dataset)
    all_targets = torch.cat(all_targets).cpu().numpy()
    all_preds = torch.cat(all_preds).cpu().numpy()

    if epoch is None:
      accuracy = total_correct / total_samples * 100
      print("\n=== Per Disease Accuracy ===")
      for disease, acc in zip(disease_names, accuracy):
          print(f"{disease}: {acc:.2f}%")
      min_idx = accuracy.argmin()
      print(f"=== Worst Performing Disease by accuracy: {disease_names[min_idx]} ({accuracy[min_idx]:.2f}%)")

      print("\n=== Per Disease F1 Score ===")
      f1_scores = f1_score(all_targets, all_preds, average=None, zero_division=0)
      for disease, f1 in zip(disease_names, f1_scores):
          print(f"{disease}: {f1:.4f}")
      min_idx = f1_scores.argmin()
      print(f"=== Worst Performing Disease by F1: {disease_names[min_idx]} (F1: {f1_scores[min_idx]:.4f})")

      f1_micro = f1_score(all_targets, all_preds, average="micro", zero_division=0)
      f1_macro = f1_score(all_targets, all_preds, average="macro", zero_division=0)
      print(f"\nF1 Score (Micro): {f1_micro:.4f}")
      print(f"F1 Score (Macro): {f1_macro:.4f}")

    return avg_loss

In [None]:
def train_CNN(net, train_loader, val_loader, device, num_epochs=20):
    criterion = nn.BCEWithLogitsLoss() # Use Binary Cross Entropy for multi-label classification
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 15], gamma=0.5)

    start_time = time.time()
    train_loss_CNN_plot = []
    val_loss_CNN_plot = []

    for epoch in range(num_epochs):
        net.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

        scheduler.step()
        avg_loss = running_loss / len(train_loader.dataset)
        train_loss_CNN_plot.append(avg_loss)
        elapsed = (time.time() - start_time) / 60
        val_loss = evaluate_CNN(net, val_loader, disease_names, epoch, criterion)
        val_loss_CNN_plot.append(val_loss)
        print(f"Epoch {epoch+1}, Train Loss: {avg_loss:.4f}, Validation Loss: {val_loss:.4f}, Time: {elapsed:.2f} min")

baseline_cnn_net = CNN(num_classes=len(disease_names)).to(device)
train_CNN(baseline_cnn_net, train_loader, val_loader, device)
torch.save(baseline_cnn_net.state_dict(), "baseline_cnn_net.pth")
print("Model saved as baseline_cnn_net.pth")

In [None]:
class ResNet50_MultiLabel(nn.Module):
    def __init__(self, n_classes):
        super(ResNet50_MultiLabel, self).__init__()

        self.backbone = models.resnet50(pretrained=True)

        in_features = self.backbone.fc.in_features

        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.16),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, n_classes)
        )

    def forward(self, x):
        return self.backbone(x)

resnet = ResNet50_MultiLabel(15).to(device)

In [None]:
def evaluate_resnet(resnet, val_loader, criterion = nn.BCEWithLogitsLoss(), epoch=None):
    resnet.eval()
    total_correct = [0] * 15
    total_samples = [0] * 15
    total_loss = 0.0

    all_labels = [[] for _ in range(15)]
    all_probs = [[] for _ in range(15)]
    all_preds = [[] for _ in range(15)]

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = resnet(images)

            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)

            probs = torch.sigmoid(outputs)
            predictions = (probs > 0.5).float()

            for i in range(15):
                total_correct[i] += (predictions[:, i] == labels[:, i]).sum().item()
                total_samples[i] += labels[:, i].numel()

                all_labels[i].extend(labels[:, i].cpu().numpy())
                all_probs[i].extend(probs[:, i].cpu().numpy())
                all_preds[i].extend(predictions[:, i].cpu().numpy())

    avg_loss = total_loss / len(val_loader.dataset)

    if epoch is None:
        print("\n=== Per Disease Accuracy ===")
        acc_list = []
        for i, disease in enumerate(disease_names):
            accuracy = total_correct[i] / total_samples[i] * 100
            acc_list.append(accuracy)
            print(f"{disease}: {accuracy:.2f}%")

        min_idx = acc_list.index(min(acc_list))
        print(f"\nWorst performing disease: {disease_names[min_idx]} ({acc_list[min_idx]:.2f}%)")

        print("\n=== Per Disease AUC-ROC ===")
        for i, disease in enumerate(disease_names):
            try:
                auc = roc_auc_score(all_labels[i], all_probs[i])
                print(f"{disease}: AUC = {auc:.4f}")
            except ValueError:
                print(f"{disease}: AUC = N/A (only one class present in labels)")

        print("\n=== Per Disease F1 Score ===")
        f1_scores = f1_score(all_labels, all_preds, average=None, zero_division=0)
        for i, disease in enumerate(disease_names):
            try:
                f1 = f1_score(all_labels[i], all_preds[i])
                print(f"{disease}: F1 = {f1:.4f}")
            except ValueError:
                print(f"{disease}: F1 = N/A (only one class present in labels or predictions)")
        min_idx = f1_scores.argmin()
        print(f"=== Worst Performing Disease by F1: {disease_names[min_idx]} (F1: {f1_scores[min_idx]:.4f})")

        f1_micro = f1_score(all_labels, all_preds, average="micro", zero_division=0)
        f1_macro = f1_score(all_labels, all_preds, average="macro", zero_division=0)
        print(f"\nF1 Score (Micro): {f1_micro:.4f}")
        print(f"F1 Score (Macro): {f1_macro:.4f}")

    return avg_loss


In [None]:
def train_resnet(net, train_loader, val_loader, num_epochs=20):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(resnet.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10, 15], gamma=0.5)

    start_time = time.time()
    train_loss_resnet_plot = []
    val_loss_resnet_plot = []

    for epoch in range(0, 20):
        resnet.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            # Set the gradients to zeros
            optimizer.zero_grad()
            # forward the minibatch through the net
            outputs = resnet(images)
            # Compute the average of the losses of the data points in the minibatch
            loss = criterion(outputs, labels)
            # backward pass to compute dL/dU, dL/dV and dL/dW
            loss.backward()
            # do one step of stochastic gradient descent
            optimizer.step()
            # add the loss of this batch to the running loss
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        train_loss_resnet_plot.append(avg_loss)
        elapsed = (time.time() - start_time) / 60
        val_loss = evaluate_resnet(resnet, val_loader, criterion, epoch)
        val_loss_resnet_plot.append(val_loss)
        print(f"\nEpoch {epoch}, Train Loss: {avg_loss:.4f}, Validation Loss: {val_loss:.4f}, Time: {elapsed:.2f} min")
train_resnet(resnet, train_loader, val_loader, num_epochs=20)
torch.save(resnet.state_dict(), "resnet.pth")
print("✅ Model saved as resnet.pth")

In [None]:
# create a new figure with 1 row and 2 columns
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
# the second subplot：CNN
axs[0].plot(train_loss_CNN_plot, label='Train Loss')
axs[0].plot(val_loss_CNN_plot, label='Validation Loss')
axs[0].set_title('CNN Loss')
axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('Loss')
axs[0].legend(loc='upper right')
#the first subplot：ResNet
axs[1].plot(train_loss_resnet_plot, label='Train Loss')
axs[1].plot(val_loss_resnet_plot, label='Validation Loss')
axs[1].set_title('ResNet Loss')
axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('Loss')
axs[1].legend(loc='upper right')
# automatically adjust subplot parameters to give specified padding
plt.tight_layout()
plt.show()

# Effect of Image Augmentation and Tranformation
Because of the poor performance of the previous model, further tuning was needed to train a better performing model based on the experience of training.

1.   Image catalogue update and Image Augmentation

Image data that has been processed through Data Augmentation and Image Transform has been added. Refer to the Data Preprocessing section for specific operations.

Purpose: To improve the prediction performance on unseen samples by increasing the diversity of samples and mitigating overfitting.
2.   Number of training rounds (Epoch) adjustment

The number of training rounds is set to 5 epochs.
In previous training experiments, the model showed obvious overfitting near the 8th epoch (training loss continued to fall but validation loss rose), so we actively shortened the number of training rounds in the new round of training to avoid the overfitting problem from occurring again.

Purpose: To maintain the generalisation ability during training while learning sufficiently, and to improve the stability of the final model on the validation set.

In [None]:
new_img_dirs = [img_dir_3, img_dir_1, img_dir_2]
csv_file = './new_labels.csv'
# Create the dataset
new_dataset = XRayDataset(csv_file=csv_file, img_dirs=new_img_dirs, transform=transform)

# read the disease names from the CSV file
new_df = pd.read_csv(csv_file)
new_disease_names = new_df.columns[1:].tolist()  # Skip the first column (image names)
# Split the dataset into training and validation sets
new_train_size = int(0.8 * len(new_dataset))
new_test_size = len(new_dataset) - new_train_size
new_train_dataset, new_test_dataset = torch.utils.data.random_split(new_dataset, [new_train_size, new_test_size])
# Create DataLoader for training and validation sets
batch_size = 32
new_train_loader = DataLoader(new_train_dataset, batch_size=batch_size, shuffle=True)
new_val_loader = DataLoader(new_test_dataset, batch_size=batch_size, shuffle=False)
# Train with new data
full_resnet_net = ResNet50_MultiLabel(15).to(device)
train_resnet(full_resnet_net, train_loader, val_loader, num_epochs=5)
torch.save(full_resnet_net.state_dict(), "full_resnet.pth")
print("✅ Model saved as full_resnet.pth")