# Setup packages and GPUs

In [2]:
# !pip install -U torch -q
# !pip install -U torchvision -q
# !pip install -U Pillow -q

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim
from torch.optim import lr_scheduler
import segmentation_models_pytorch as smp
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import shutil
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as TF
import os, warnings
import random
from pathlib import Path
from PIL import Image
import glob
import pandas as pd
import time
import copy
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns


warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Import Data

Split the original database into the training, testing and validation sets.
- Build the required directory structure first.

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
root_dir = '/content/drive/MyDrive/Projects/datasets/COVID-19_Radiography_Dataset'
splitted_dir = os.path.join(root_dir, "Splitted_database")
os.makedirs(splitted_dir, exist_ok=True)
source_dir = '/content/drive/MyDrive/Projects/datasets/COVID-19_Radiography_Dataset'
target_dir = '/content/drive/MyDrive/Projects/datasets/COVID-19_Radiography_Dataset/Splitted_database'
os.makedirs(target_dir, exist_ok=True)
train_dir = os.path.join(target_dir, 'train')
val_dir = os.path.join(target_dir, 'val')
test_dir = os.path.join(target_dir, 'test')

Mounted at /content/drive


Function for splitting the database. Here is the folder structure after splitting:
* **Splitted_database/**
  * **train/** - [70% of data]
    * **COVID/**
      * **images/** - [70% of COVID images]
      * **masks/** - [70% of COVID masks]
    * **Viral_Pneumonia/**
      * **images/** - [70% of Viral_Pneumonia images]
      * **masks/** - [70% of Viral_Pneumonia masks]
    * **Normal/**
      * **images/** - [70% of Normal images]
      * **masks/** - [70% of Normal masks]
  * **val/** - [15% of data]
    * **COVID/**
      * **images/** - [15% of COVID images]
      * **masks/** - [15% of COVID masks]
    * **Viral_Pneumonia/**
      * **images/** - [15% of Viral_Pneumonia images]
      * **masks/** - [15% of Viral_Pneumonia masks]
    * **Normal/**
      * **images/** - [15% of Normal images]
      * **masks/** - [15% of Normal masks]
  * **test/** - [15% of data]
    * **COVID/**
      * **images/** - [15% of COVID images]
      * **masks/** - [15% of COVID masks]
    * **Viral_Pneumonia/**
      * **images/** - [15% of Viral_Pneumonia images]
      * **masks/** - [15% of Viral_Pneumonia masks]
    * **Normal/**
      * **images/** - [15% of Normal images]
      * **masks/** - [15% of Normal masks]

In [3]:
def split_dataset(test_ratio=0.15, val_ratio=0.15, seed=42):
    """
    Split the dataset into training, validation and testing sets while maintaining the folder structure.

    Args:
        source_dir (str): Path to the source dataset directory
        target_dir (str): Path to the target directory where train/val/test will be created
        test_ratio (float): Ratio of test data (default: 0.15)
        val_ratio (float): Ratio of validation data (default: 0.15)
        seed (int): Random seed for reproducibility
    """
    random.seed(seed)
    np.random.seed(seed)
    # Create the folder structure
    for class_name in ['COVID', 'Viral_Pneumonia', 'Normal']:
        os.makedirs(os.path.join(train_dir, class_name, 'images'), exist_ok=True)
        os.makedirs(os.path.join(train_dir, class_name, 'masks'), exist_ok=True)
        os.makedirs(os.path.join(val_dir, class_name, 'images'), exist_ok=True)
        os.makedirs(os.path.join(val_dir, class_name, 'masks'), exist_ok=True)
        os.makedirs(os.path.join(test_dir, class_name, 'images'), exist_ok=True)
        os.makedirs(os.path.join(test_dir, class_name, 'masks'), exist_ok=True)
    for class_name in ['COVID', 'Viral_Pneumonia', 'Normal']:
        print(f"Processing {class_name} class...")
        img_dir = os.path.join(source_dir, class_name, 'images')
        mask_dir = os.path.join(source_dir, class_name, 'masks')

        image_files = sorted([f for f in os.listdir(img_dir) if not f.startswith('.') and "(" not in f])
        mask_files = sorted([f for f in os.listdir(mask_dir) if not f.startswith('.') and "(" not in f])
        # Check if the numbers of image and mask are equal
        if len(image_files) != len(mask_files):
            print(f"Warning: Number of images ({len(image_files)}) doesn't match number of masks ({len(mask_files)}) for {class_name}")
        # Paired image with its corresponding masks
        paired_files = []
        for img_file in image_files:
            mask_basename = os.path.splitext(img_file)[0]
            matching_masks = [m for m in mask_files if os.path.splitext(m)[0] == mask_basename]

            if matching_masks:
                paired_files.append((img_file, matching_masks[0]))
            else:
                print(f"Warning: No matching mask found for image {img_file}")
        random.shuffle(paired_files)
        total_count = len(paired_files)
        test_size = int(total_count * test_ratio)
        val_size = int(total_count * val_ratio)
        train_size = total_count - test_size - val_size
        test_pairs = paired_files[:test_size]
        val_pairs = paired_files[test_size:test_size+val_size]
        train_pairs = paired_files[test_size+val_size:]
        print(f"  {class_name}: Total: {total_count}, Train: {len(train_pairs)}, Val: {len(val_pairs)}, Test: {len(test_pairs)}")
        # Load images into new directory in splitted directory
        for img_file, mask_file in train_pairs:
            src_img = os.path.join(source_dir, class_name, 'images', img_file)
            dst_img = os.path.join(train_dir, class_name, 'images', img_file)
            shutil.copy2(src_img, dst_img)
            src_mask = os.path.join(source_dir, class_name, 'masks', mask_file)
            dst_mask = os.path.join(train_dir, class_name, 'masks', mask_file)
            shutil.copy2(src_mask, dst_mask)
        for img_file, mask_file in val_pairs:
            src_img = os.path.join(source_dir, class_name, 'images', img_file)
            dst_img = os.path.join(val_dir, class_name, 'images', img_file)
            shutil.copy2(src_img, dst_img)
            src_mask = os.path.join(source_dir, class_name, 'masks', mask_file)
            dst_mask = os.path.join(val_dir, class_name, 'masks', mask_file)
            shutil.copy2(src_mask, dst_mask)
        for img_file, mask_file in test_pairs:
            src_img = os.path.join(source_dir, class_name, 'images', img_file)
            dst_img = os.path.join(test_dir, class_name, 'images', img_file)
            shutil.copy2(src_img, dst_img)
            src_mask = os.path.join(source_dir, class_name, 'masks', mask_file)
            dst_mask = os.path.join(test_dir, class_name, 'masks', mask_file)
            shutil.copy2(src_mask, dst_mask)
    print("Dataset splitting complete!")
    # Print dataset structure
    print("\nDataset Structure:")
    print("-" * 60)
    print("Class       | Train Images | Valid Images  | Test Images | Total")
    print("-" * 60)
    total_train = 0
    total_val = 0
    total_test = 0
    for class_name in ['COVID', 'Viral_Pneumonia', 'Normal']:
        train_count = len(os.listdir(os.path.join(train_dir, class_name, 'images')))
        val_count = len(os.listdir(os.path.join(val_dir, class_name, 'images')))
        test_count = len(os.listdir(os.path.join(test_dir, class_name, 'images')))
        total = train_count + val_count + test_count
        print(f"{class_name:<12}| {train_count:<13}| {val_count:<12}| {test_count:<12}| {total}")
        total_train += train_count
        total_val += val_count
        total_test += test_count
    print("-" * 60)
    print(f"Total       | {total_train:<13}| {total_val:<12}| {total_test:<12}| {total_train + total_val + total_test}")
    print("-" * 60)

In [4]:
# Except the first time or the dataset not splitted as expect, don't run this line:
# split_dataset(test_ratio=0.15, val_ratio=0.15)

Check if the database is splitted as expected data proportion(Here is test/valdation/train = 15%/15%/70%), pairred masks and structure.   
This step is to check the last step is finished or not (cuz take too long time last step).

In [5]:
def count_files(directory):
    """Count the number of files in a directory."""
    if not os.path.exists(directory):
        return 0
    return len([f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))])


def validate_split(base_path):
    # if new split method (as 10/10/80) should change here:
    expected_counts = {
        'COVID': {'total': 3616, 'train': 2532, 'val': 542, 'test': 542},
        'Viral_Pneumonia': {'total': 1345, 'train': 943, 'val': 201, 'test': 201},
        'Normal': {'total': 10192, 'train': 7136, 'val': 1528, 'test': 1528}
    }
    actual_counts = {
        'COVID': {'train_images': 0, 'train_masks': 0,
                  'val_images': 0, 'val_masks': 0,
                  'test_images': 0, 'test_masks': 0},
        'Viral_Pneumonia': {'train_images': 0, 'train_masks': 0,
                           'val_images': 0, 'val_masks': 0,
                           'test_images': 0, 'test_masks': 0},
        'Normal': {'train_images': 0, 'train_masks': 0,
                  'val_images': 0, 'val_masks': 0,
                  'test_images': 0, 'test_masks': 0}
    }
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(base_path, split)
        if not os.path.exists(split_path):
            print(f"Warning: Split directory '{split_path}' does not exist.")
            continue

        for class_name in ['COVID', 'Viral_Pneumonia', 'Normal']:
            class_path = os.path.join(split_path, class_name)
            if not os.path.exists(class_path):
                print(f"Warning: Class directory '{class_path}' does not exist.")
                continue

            images_path = os.path.join(class_path, 'images')
            masks_path = os.path.join(class_path, 'masks')

            if os.path.exists(images_path):
                actual_counts[class_name][f'{split}_images'] = count_files(images_path)
            else:
                print(f"Warning: Images directory '{images_path}' does not exist.")

            if os.path.exists(masks_path):
                actual_counts[class_name][f'{split}_masks'] = count_files(masks_path)
            else:
                print(f"Warning: Masks directory '{masks_path}' does not exist.")
    results = []

    for class_name, counts in actual_counts.items():
        train_images = counts['train_images']
        val_images = counts['val_images']
        test_images = counts['test_images']
        total_images = train_images + val_images + test_images

        train_masks = counts['train_masks']
        val_masks = counts['val_masks']
        test_masks = counts['test_masks']
        total_masks = train_masks + val_masks + test_masks

        expected = expected_counts[class_name]
        images_match = (
            train_images == expected['train'] and
            val_images == expected['val'] and
            test_images == expected['test'] and
            total_images == expected['total']
        )
        masks_match_images = (
            train_masks == train_images and
            val_masks == val_images and
            test_masks == test_images and
            total_masks == total_images
        )
        train_pct = round(train_images / total_images * 100, 1) if total_images > 0 else 0
        val_pct = round(val_images / total_images * 100, 1) if total_images > 0 else 0
        test_pct = round(test_images / total_images * 100, 1) if total_images > 0 else 0
        results.append({
            'Class': class_name,
            'Train Images': train_images,
            'Val Images': val_images,
            'Test Images': test_images,
            'Total Images': total_images,
            'Train %': train_pct,
            'Val %': val_pct,
            'Test %': test_pct,
            'Train Masks': train_masks,
            'Val Masks': val_masks,
            'Test Masks': test_masks,
            'Total Masks': total_masks,
            'Expected Train': expected['train'],
            'Expected Val': expected['val'],
            'Expected Test': expected['test'],
            'Expected Total': expected['total'],
            'Images Match Expected': images_match,
            'Masks Match Images': masks_match_images
        })
    df = pd.DataFrame(results)
    print("\n=== Dataset Split Validation Results ===\n")
    print(df[['Class', 'Train Images', 'Train %', 'Val Images', 'Val %', 'Test Images', 'Test %', 'Total Images',
              'Images Match Expected', 'Masks Match Images']])
    mismatches = df[~df['Images Match Expected'] | ~df['Masks Match Images']]
    if not mismatches.empty:
        print("\n=== Detailed Analysis for Mismatches ===\n")
        for _, row in mismatches.iterrows():
            class_name = row['Class']
            print(f"Class: {class_name}")

            if not row['Images Match Expected']:
                print("  Image count mismatch:")
                print(f"    Train: {row['Train Images']} (Expected: {row['Expected Train']})")
                print(f"    Val: {row['Val Images']} (Expected: {row['Expected Val']})")
                print(f"    Test: {row['Test Images']} (Expected: {row['Expected Test']})")
                print(f"    Total: {row['Total Images']} (Expected: {row['Expected Total']})")

            if not row['Masks Match Images']:
                print("  Mask-Image count mismatch:")
                print(f"    Train: {row['Train Masks']} masks vs {row['Train Images']} images")
                print(f"    Val: {row['Val Masks']} masks vs {row['Val Images']} images")
                print(f"    Test: {row['Test Masks']} masks vs {row['Test Images']} images")
                print(f"    Total: {row['Total Masks']} masks vs {row['Total Images']} images")
            print()
    else:
        print("\nAll classes have correct splits and matching masks/images counts! ✓")
    overall_correct_split = True
    for _, row in df.iterrows():
        # if new split method (as 10/10/80) should change here also:
        if abs(row['Train %'] - 70) > 1 or abs(row['Val %'] - 15) > 1 or abs(row['Test %'] - 15) > 1:
            overall_correct_split = False
    if overall_correct_split:
        print("\nSplit Correctly!")
    else:
        print("\nWarning: The distribution doesn't match the expected distribution!!!")
    return df

In [9]:
if __name__ == "__main__":
    results = validate_split(splitted_dir)


=== Dataset Split Validation Results ===

             Class  Train Images  Train %  Val Images  Val %  Test Images  \
0            COVID          2532     70.0         542   15.0          542   
1  Viral_Pneumonia           943     70.1         201   14.9          201   
2           Normal          7136     70.0        1528   15.0         1528   

   Test %  Total Images  Images Match Expected  Masks Match Images  
0    15.0          3616                   True                True  
1    14.9          1345                   True                True  
2    15.0         10192                   True                True  

All classes have correct splits and matching masks/images counts! ✓

Split Correctly!


load data and apply data pre-processing steps on it.

In [6]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
class_types = ['train', 'val', 'test']
def data_pre_processing(img_path, mask_path):
    try:
        # Normalization Config from ImageNet based models
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        img = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        img_tensor = transform(img)
        mask = mask.resize((224, 224), Image.NEAREST)
        mask_temsor = TF.to_tensor(mask)
        mask_temsor = (mask_temsor > 0.5).float()
    except Exception as e:
        print(f"Error processing {img_path} and {mask_path}: {e}")
        return None, None
    return img_tensor, mask_temsor

In [8]:
class_map = {'COVID': 0, 'Viral_Pneumonia': 1, 'Normal': 2}
dataloaders = {}
# only shuffle the training set
for class_l in class_types:
    images = []
    masks = []
    labels = []
    cur_class_path = os.path.join(target_dir, class_l)
    if not os.path.exists(cur_class_path):
        print(f"Warning! {cur_class_path} does not exists!")
        break
    else:
        print(cur_class_path)
        for class_name in class_map.keys():
            class_dir = os.path.join(cur_class_path, class_name)
        if not os.path.exists(class_dir):
            print(f"Warning! {class_dir} does not exists!")
            break
        else:
            image_dir = os.path.join(class_dir, 'images')
            mask_dir = os.path.join(class_dir, 'masks')
            if not os.path.exists(image_dir) or not os.path.exists(mask_dir):
                print(f"Warning! {image_dir} or {mask_dir} does not exists!")
                break
            else:
                image_files = glob.glob(os.path.join(image_dir, '*'))
                for img_path in image_files:
                    # get the correspounding mask file
                    # the mask and the correspounding original image has the same name
                    img_filename = os.path.basename(img_path)
                    mask_path = os.path.join(mask_dir, img_filename)
                    if os.path.exists(mask_path):
                        img_tensor, mask_tensor = data_pre_processing(img_path, mask_path)
                        if img_tensor is not None or mask_tensor is not None:
                            images.append(img_tensor)
                            masks.append(mask_tensor)
                            labels.append(class_map[class_name])
                if images:
                    images = torch.stack(images)
                    masks = torch.stack(masks)
                    labels = torch.tensor(labels)
                    print(f"{class_l} loaded {len(images)} samples")
                    dataset = TensorDataset(images, masks, labels)
                    shuffle = (class_l == 'train')
                    dataloaders[class_l] = DataLoader(dataset, batch_size=16, shuffle=shuffle)
                    print(dataloaders)
                else:
                    print(f"No images found in {class_l}, please check code + data folder again")
                    break


/content/drive/MyDrive/Projects/datasets/COVID-19_Radiography_Dataset/Splitted_database/train


KeyboardInterrupt: 

Show a batch of the sample.

In [15]:
print(dataloaders)
images, masks, labels = next(iter(dataloaders['train']))
mean = torch.tensor(mean).view(1, 3, 1, 1)
std = torch.tensor(std).view(1, 3, 1, 1)
images_display = (images * std) + mean
images_display = images_display.clamp(0, 1)
fig, axes = plt.subplots(4, 3, figsize=(12, 10))
# only display first four
for i in range(4):
    img = images_display[i].permute(1, 2, 0).cpu().numpy()
    axes[i, 0].imshow(img)
    axes[i, 0].set_title(f'Image {i+1}, Class: {class_types[labels[i]]}')
    axes[i, 0].axis('off')
    axes[i, 1].imshow(masks[i].cpu().numpy(), cmap='gray')
    axes[i, 1].set_title(f'Mask {i+1}')
    axes[i, 1].axis('off')
    overlay = img.copy()
    mask_np = masks[i][0].numpy()
    # show mask as a red overlay layer
    for c in range(3):
        channel_overlay = overlay[:, :, c].copy()
        if c == 0:
            channel_overlay[mask_np > 0.5] = 1.0
        else:
            channel_overlay[mask_np > 0.5] = 0.5
        overlay[:, :, c] = channel_overlay
    alpha = 0.3
    blended = (1-alpha) * img + alpha * overlay
    axes[i, 2].imshow(blended)
    axes[i, 2].set_title(f'Blended {i+1}')
    axes[i, 2].axis('off')
plt.tight_layout()
plt.show()

{}


KeyError: 'train'

# Create The Model
## Create the Unet

In [None]:
unet_model = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",
    in_channels=3,
    classes=3,
    activation="sigmoid"
)

## Create the ResNet-18

In [None]:
# this pretrain is on imagenet by pytorch
resnet18_model = models.resnet18(pretrained=True)
num_ftrs = resnet18_model.fc.in_features
resnet18_model = nn.Linear(num_ftrs, 3)
return resnet_18

## Combined two models together

Loss function Combination.

In [None]:
# loss weight should be [seg_weight:cls_weight] to define which event is more important
def loss_function(seg_pred, cls_pred, seg_true, cls_true, loss_weight):
    seg_criterion = nn.BCELoss()
    cls_criterion = nn.CrossEntropyLoss()
    seg_loss = seg_criterion(seg_pred, seg_true)
    cls_loss = cls_criterion(cls_pred, cls_true)
    total_loss = seg_loss * loss_weight[0] + cls_loss * loss_weight[1]
    return total_loss, seg_loss, cls_loss

In [None]:
def combine_model(x):
    original_conv = resnet18_model.conv1
    resnet18_model.conv1 = nn.Conv2d(
        4,
        original_conv.out_channels,
        kernel_size=original_conv.kernel_size,
        stride=original_conv.stride
        padding=original_conv.padding,
        bias=original_conv.bias
    )
    with torch.no_grad():
        resnet18_model.conv1.weight[:, :3] = original_conv.weight
    # foward propagation function
    def foward_pf(x):
        mask = unet_model(x)
        x_combine = torch.cat((images, mask), dim=1)
        classification = resnet18_model(x_combine)
        return mask, classification
    return foward_pf, (unet_model, resnet18_model)

# Training the Model

In [None]:
# scheduler: to adjust the study rate
# patience: for early stopping
def train_model(foward_pf, models, dataloaders, optimizer, loss_weight, scheduler=None, num_epochs=25, patience=5):
    # preparation stage
    seg_model, classfy_model = models
    seg_model = seg_model.to(device)
    classify_model = classify_model.to(device)
    best_seg_state = copy.deepcopy(seg_model.state_dict())
    best_cls_state = copy.deepcopy(classify_model.state_dict())
    best_acc = 0.0
    best_epoch = 0
    history = {
        'train_loss': [], 'val_loss': [],
        'train_seg_loss': [], 'val_seg_loss': [],
        'train_cls_loss': [], 'val_cls_loss': [],
        'train_acc': [], 'val_acc': []
    }
    counter = 0
    # training stage
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        for phase in ['train', 'val']:
            if phase == 'train':
                seg_model.train()
                classify_model.train()
            else:
                seg_model.eval()
                classify_model.eval()
            running_loss = 0.0
            running_seg_loss = 0.0
            running_cls_loss = 0.0
            running_corrects = 0

            for inputs, masks, labels in dataloaders[phase]:
