In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torchvision import transforms, datasets, models
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score
from torchmetrics.utilities.data import dim_zero_cat
from torch.utils.tensorboard import SummaryWriter

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Setup Dataloaders

In [3]:

# Define the paths to the train and test data folders
DATA_DIR = r'..\data'

# Define a transform to preprocess the images (resize and normalize without standardization)
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize((256, 256)),  # Resize the images to a fixed size
    transforms.RandomRotation((-10, 10)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize the images (mean and std for grayscale images)
])

val_test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.Resize((256, 256)),  # Resize the images to a fixed size
    transforms.CenterCrop(224),
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize the images (mean and std for grayscale images)
])

# Create datasets for train and test
train_dataset = datasets.ImageFolder(root=DATA_DIR + '/train', transform=train_transform)
val_dataset = datasets.ImageFolder(root=DATA_DIR + '/train', transform=val_test_transform)
test_dataset = datasets.ImageFolder(root=DATA_DIR + '/test', transform=val_test_transform)


train_indices, val_indices = train_test_split(torch.arange(len(train_dataset)), test_size=0.15, random_state=42)
train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)

# Create DataLoaders for train and test datasets
batch_size = 8
# initial_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True, pin_memory=True, num_workers=4)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [4]:
len(train_loader), len(val_loader), len(test_loader), len(train_dataset), len(val_dataset), len(test_dataset)

(555, 98, 78, 4433, 783, 624)

In [6]:
# not used
class CustomTensorDataset(torch.utils.data.Dataset):
    def __init__(self, data, transform):
        """
        Args:
            data (List[torch.Tensor]): The preloaded list of tensors containing your inputs and labels.
        """
        self.inputs = data[0]
        self.labels = data[1]
        self.transform = transform

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        item = [self.inputs[index], self.labels[index]]
        return item

## Load the models and change the architecture such that they take grayscale images and add a new classification head

1) Convnext Model

In [8]:
convnext = models.convnext_tiny(weights='DEFAULT')

In [9]:
# Freeze all the pre-trained layers
for param in convnext.parameters():
   param.requires_grad = False

In [10]:
convnext

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

In [11]:
state_dict = convnext.state_dict()

odict_keys(['features.0.0.weight', 'features.0.0.bias', 'features.0.1.weight', 'features.0.1.bias', 'features.1.0.layer_scale', 'features.1.0.block.0.weight', 'features.1.0.block.0.bias', 'features.1.0.block.2.weight', 'features.1.0.block.2.bias', 'features.1.0.block.3.weight', 'features.1.0.block.3.bias', 'features.1.0.block.5.weight', 'features.1.0.block.5.bias', 'features.1.1.layer_scale', 'features.1.1.block.0.weight', 'features.1.1.block.0.bias', 'features.1.1.block.2.weight', 'features.1.1.block.2.bias', 'features.1.1.block.3.weight', 'features.1.1.block.3.bias', 'features.1.1.block.5.weight', 'features.1.1.block.5.bias', 'features.1.2.layer_scale', 'features.1.2.block.0.weight', 'features.1.2.block.0.bias', 'features.1.2.block.2.weight', 'features.1.2.block.2.bias', 'features.1.2.block.3.weight', 'features.1.2.block.3.bias', 'features.1.2.block.5.weight', 'features.1.2.block.5.bias', 'features.2.0.weight', 'features.2.0.bias', 'features.2.1.weight', 'features.2.1.bias', 'feature

In [12]:
state_dict['features.0.0.weight'].shape

torch.Size([96, 3, 4, 4])

In [13]:
first_cov_weights = state_dict['features.0.0.weight']
state_dict['features.0.0.weight'] = first_cov_weights.sum(dim=1, keepdim=True)

In [14]:
state_dict['features.0.0.weight'].shape

torch.Size([96, 1, 4, 4])

In [15]:
convnext.state_dict()['features.0.0.weight'].shape

torch.Size([96, 3, 4, 4])

In [None]:
convnext.features[0][0] = nn.Conv2d(1, 96, kernel_size=(4,4), stride=(4,4))
convnext.load_state_dict(state_dict)

In [18]:
convnext.classifier

Sequential(
  (0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=768, out_features=1000, bias=True)
)

In [19]:
convnext.classifier[2] = nn.Sequential(
    nn.Linear(in_features=768, out_features=512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(in_features=512, out_features=128),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(in_features=128, out_features=1),
    nn.Sigmoid(),
)

In [20]:
convnext.classifier[2][0].weight

Parameter containing:
tensor([[-0.0140,  0.0028,  0.0242,  ..., -0.0270, -0.0256,  0.0099],
        [-0.0295, -0.0324, -0.0196,  ..., -0.0115,  0.0142, -0.0328],
        [ 0.0230,  0.0194, -0.0321,  ...,  0.0086,  0.0261,  0.0320],
        ...,
        [ 0.0220, -0.0158,  0.0016,  ..., -0.0260, -0.0055, -0.0259],
        [ 0.0170,  0.0310, -0.0176,  ..., -0.0282, -0.0292,  0.0321],
        [-0.0072, -0.0256,  0.0234,  ...,  0.0328, -0.0232, -0.0278]],
       requires_grad=True)

In [21]:
convnext.classifier[0].weight.requires_grad

False

2) Resnet18

In [24]:
resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

Using cache found in C:\Users\hornh/.cache\torch\hub\pytorch_vision_v0.10.0


In [25]:
resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [26]:
# Freeze all the pre-trained layers
for param in resnet.parameters():
   param.requires_grad = False

# change first layer to grayscale
state_dict2 = resnet.state_dict()
state_dict2['conv1.weight'] = state_dict2['conv1.weight'].sum(dim=1, keepdim=True)
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
resnet.load_state_dict(state_dict2)

<All keys matched successfully>

In [27]:
# add classification head
resnet.fc = nn.Sequential(
    nn.Linear(in_features=512, out_features=256),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(in_features=256, out_features=128),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(in_features=128, out_features=1),
    nn.Sigmoid(),
)

In [28]:
resnet

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

3) Efficientnet

In [29]:
efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b0', pretrained=True)
efficientnet

Using cache found in C:\Users\hornh/.cache\torch\hub\NVIDIA_DeepLearningExamples_torchhub


EfficientNet(
  (stem): Sequential(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (activation): SiLU(inplace=True)
  )
  (layers): Sequential(
    (0): Sequential(
      (block0): MBConvBlock(
        (depsep): Sequential(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
          (act): SiLU(inplace=True)
        )
        (se): SequentialSqueezeAndExcitation(
          (squeeze): Linear(in_features=32, out_features=8, bias=True)
          (expand): Linear(in_features=8, out_features=32, bias=True)
          (activation): SiLU(inplace=True)
          (sigmoid): Sigmoid()
          (mul_a_quantizer): Identity()
          (mul_b_quantizer): Identity()
        )
      

In [30]:
efficientnet.stem[0]

Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

In [31]:
# Freeze all the pre-trained layers
for param in efficientnet.parameters():
   param.requires_grad = False

# change first layer to grayscale
effnet_state_dict = efficientnet.state_dict()
effnet_state_dict['stem.conv.weight'] = effnet_state_dict['stem.conv.weight'].sum(dim=1, keepdim=True)
efficientnet.stem.conv = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
efficientnet.load_state_dict(effnet_state_dict)

In [32]:
# add classification head
efficientnet.classifier.fc = nn.Sequential(
    nn.Linear(in_features=1280, out_features=256),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(in_features=256, out_features=128),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(in_features=128, out_features=1),
    nn.Sigmoid(),
)

In [33]:
efficientnet

EfficientNet(
  (stem): Sequential(
    (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (activation): SiLU(inplace=True)
  )
  (layers): Sequential(
    (0): Sequential(
      (block0): MBConvBlock(
        (depsep): Sequential(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
          (act): SiLU(inplace=True)
        )
        (se): SequentialSqueezeAndExcitation(
          (squeeze): Linear(in_features=32, out_features=8, bias=True)
          (expand): Linear(in_features=8, out_features=32, bias=True)
          (activation): SiLU(inplace=True)
          (sigmoid): Sigmoid()
          (mul_a_quantizer): Identity()
          (mul_b_quantizer): Identity()
        )
      

In [39]:
def create_confusion_matrix_plot(tp, fp, tn, fn):
    df_cm = pd.DataFrame(torch.tensor([[tp, fp], [fn, tn]]), index=['normal', 'pneumonia'],
                         columns=['normal', 'pneumonia'])
    plt.figure(figsize=(12, 7))    
    return sns.heatmap(df_cm, annot=True).get_figure()

In [None]:
criterion = nn.BCELoss()
accuracy_metric = BinaryAccuracy().to(device)
f1_score_metric = BinaryF1Score().to(device)

In [40]:
for model_idx, model in enumerate([convnext, resnet, efficientnet]):
    model.to(device)

    if model_idx == 2:
        optimizer = torch.optim.Adam(model.classifier.fc.parameters(), lr=0.001)
    elif model_idx == 1:
        optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
    elif model_idx == 0:
        continue

    epochs = 4
    writer = SummaryWriter()


    for epoch in range(epochs):  
        model.train()
        running_loss = 0.0
        f1_score_metric.reset() 
        accuracy_metric.reset()

        print("---------------------")
        print('epoch', epoch)

        for i, data in enumerate(tqdm(train_loader)):

            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs).reshape(-1)
            loss = criterion(outputs, labels.type(torch.float32))
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            bin_outputs = torch.round(outputs)

            
            accuracy_metric.update(bin_outputs, labels)
            f1_score_metric.update(bin_outputs, labels)


        train_loss = running_loss/len(train_dataset)
        train_f1 = f1_score_metric.compute()
        accuracy = accuracy_metric.compute()
        
        writer.add_scalar('Train/Loss', train_loss, epoch)
        writer.add_scalar('Train/F1-Score', train_f1, epoch)
        writer.add_scalar('Train/Accuracy', accuracy, epoch)
        
        print(train_loss, train_f1, accuracy)

        model.eval()
        with torch.no_grad():
            val_loss, val_tp, val_fp, val_fn, val_correct = 0.0, 0.0, 0.0, 0.0, 0.0
            epoch_labels, epoch_predictions = [], []
            f1_score_metric.reset()  
            accuracy_metric.reset()

            for data in tqdm(val_loader):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs).reshape(-1)
                loss = criterion(outputs, labels.type(torch.float32))
                
                val_loss += loss.item()
                bin_outputs = torch.round(outputs)

                accuracy_metric.update(bin_outputs, labels)
                f1_score_metric.update(bin_outputs, labels)
                epoch_labels.append(labels)
                epoch_predictions.append(outputs)
                

            val_loss = val_loss / len(val_dataset)
            tp, fp, tn, fn = f1_score_metric.tp, f1_score_metric.fp, f1_score_metric.tn, f1_score_metric.fn
            val_f1 = f1_score_metric.compute()
            val_acc = accuracy_metric.compute()
            # Log validation metrics to TensorBoard
            writer.add_scalar('Validation/Loss', val_loss, epoch)
            writer.add_scalar('Validation/F1-Score', val_f1, epoch)
            writer.add_scalar('Validation/Accuracy', val_acc, epoch)
            writer.add_pr_curve('pr_curve', dim_zero_cat(epoch_labels), dim_zero_cat(epoch_predictions), epoch)
            writer.add_figure("Confusion matrix", create_confusion_matrix_plot(tp, fp, tn, fn), epoch)
        


    torch.save(model, f'model_run_{model_idx}.pth')

    writer.close()

---------------------
epoch 0


  0%|          | 0/555 [00:00<?, ?it/s]

batch num 0


  0%|          | 2/555 [00:19<1:12:56,  7.91s/it]

batch num 1
batch num 2


  1%|          | 4/555 [00:19<24:59,  2.72s/it]  

batch num 3
batch num 4


  1%|          | 5/555 [00:19<16:52,  1.84s/it]

batch num 5


  1%|          | 6/555 [00:20<11:54,  1.30s/it]

batch num 6


  1%|          | 6/555 [00:20<31:46,  3.47s/it]


accuracy tensor(0.7500, device='cuda:0')
0.0007419497305777335 tensor(0.8462, device='cuda:0') tensor(0.7500, device='cuda:0')


  0%|          | 0/98 [00:00<?, ?it/s]