In [None]:
from google.colab import drive
drive.mount('/content/drive')

# **Dependencies**

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import os
import cv2
from glob import glob
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(0)

# **Dataset Preparation**




In [None]:
COVID_IMAGE_DIR="/content/drive/MyDrive/chest-xray-dataset/images/covid"
COVID_MASK_DIR="/content/drive/MyDrive/chest-xray-dataset/masks/covid"
covid_images=[]
covid_masks=glob(os.path.join(COVID_MASK_DIR,"*"))
for mask in covid_masks:
  img_path=COVID_IMAGE_DIR+"/"+os.path.basename(mask)[:-9]+".png"
  if img_path not in glob(os.path.join(COVID_IMAGE_DIR,"*")):
    print(img_path)
  else:
    covid_images.append(img_path)

PNEUMONIA_IMAGE_DIR="/content/drive/MyDrive/chest-xray-dataset/images/pneumonia"
PNEUMONIA_MASK_DIR="/content/drive/MyDrive/chest-xray-dataset/masks/pneumonia"
pneumonia_images=[]
pneumonia_masks=glob(os.path.join(PNEUMONIA_MASK_DIR,"*"))
for mask in pneumonia_masks:
  img_path=PNEUMONIA_IMAGE_DIR+"/"+os.path.basename(mask)[:-9]+".png"
  if img_path not in glob(os.path.join(PNEUMONIA_IMAGE_DIR,"*")):
    print(img_path)
  else:
    pneumonia_images.append(img_path)

NORMAL_IMAGE_DIR="/content/drive/MyDrive/chest-xray-dataset/images/normal"
NORMAL_MASK_DIR="/content/drive/MyDrive/chest-xray-dataset/masks/normal"
normal_images=[]
normal_masks=glob(os.path.join(NORMAL_MASK_DIR,"*"))
for mask in normal_masks:
  img_path=NORMAL_IMAGE_DIR+"/"+os.path.basename(mask)[:-9]+".png"
  if img_path not in glob(os.path.join(NORMAL_IMAGE_DIR,"*")):
    print(img_path)
  else:
    normal_images.append(img_path)


images=covid_images[:]+pneumonia_images[:]+normal_images[:]

masks=covid_masks[:]+pneumonia_masks[:]+normal_masks[:]


print("no. of covid images :",len(covid_images))
print("no. of covid masks :",len(covid_masks))
print("no. of normal images :",len(normal_images))
print("no. of normal masks :",len(normal_masks))
print("no. of pneumonia images :",len(pneumonia_images))
print("no. of pneumonia masks :",len(pneumonia_masks))

In [None]:
class ChestXRAYDataset(Dataset):
  def __init__(self,images,masks,transform=None):
    
    self.transform=transform
    self.images=images
    self.masks=masks
  def __getitem__(self,index):
    img=cv2.imread(self.images[index])
    mask=cv2.imread(self.masks[index],0)
    img=cv2.resize(img,(512,512))
    mask=cv2.resize(mask,(512,512))
    
    if self.transform is not None:
      img=self.transform(img)
      mask=self.transform(mask)
    return img,mask
  def __len__(self):
    return len(self.images)
      

# **Load and Transform**

In [None]:

BATCH_SIZE=1

In [None]:
transform=torchvision.transforms.Compose([
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize(mean=[0.],std=[1.]),
      
])

In [None]:
dataset = ChestXRAYDataset(images, masks,transform=transform)

In [None]:
total = len(images)
p_10 = int(total*0.1)
train_size = total - p_10*2

In [None]:
train_ds, val_ds, test_ds = torch.utils.data.random_split(dataset, [train_size, p_10, p_10])

In [None]:
len(train_ds), len(val_ds), len(test_ds)

(0, 0, 0)

In [None]:
train_dl=DataLoader(train_ds,batch_size=BATCH_SIZE,shuffle=True,num_workers=4,pin_memory=True)
test_dl=DataLoader(test_ds,batch_size=BATCH_SIZE,num_workers=4,shuffle=True,pin_memory=True)
val_dl=DataLoader(val_ds,batch_size=BATCH_SIZE,num_workers=4,shuffle=True,pin_memory=True)

# **Preview**

In [None]:
inputs, targets = next(iter(train_dl))
print("Inputs: ", inputs.size())
print("Targets: ", targets.size())

In [None]:
inputs, targets = next(iter(train_dl))
input_img = inputs[0].permute(1, 2, 0)
plt.imshow(input_img.squeeze())
plt.show()

target_img = targets[0].permute(1, 2, 0)
plt.imshow(target_img.squeeze(), cmap='gray')
plt.show()

# **Model**

## Architecture

In [None]:


class DenseLayer(nn.Sequential):
    def __init__(self, in_channels, growth_rate):
        super().__init__()
        self.add_module('norm', nn.BatchNorm2d(in_channels))
        self.add_module('relu', nn.ReLU(True))
        self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3,
                                          stride=1, padding=1, bias=True))
        self.add_module('drop', nn.Dropout2d(0.2))
        
    def forward(self, x):
        return super().forward(x)


class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, n_layers, upsample=False):
        super().__init__()
        self.upsample = upsample
        self.layers = nn.ModuleList([DenseLayer(
            in_channels + i*growth_rate, growth_rate)
            for i in range(n_layers)])

    def forward(self, x):
        if self.upsample:
            new_features = []
            #we pass all previous activations into each dense layer normally
            #But we only store each dense layer's output in the new_features array
            for layer in self.layers:
                out = layer(x)
                x = torch.cat([x, out], 1)
                new_features.append(out)
            return torch.cat(new_features,1)
        else:
            for layer in self.layers:
                out = layer(x)
                x = torch.cat([x, out], 1) # 1 = channel axis
            return x


class TransitionDown(nn.Sequential):
    def __init__(self, in_channels):
        super().__init__()
        self.add_module('norm', nn.BatchNorm2d(num_features=in_channels))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(in_channels, in_channels,
                                          kernel_size=1, stride=1,
                                          padding=0, bias=True))
        self.add_module('drop', nn.Dropout2d(0.2))
        self.add_module('maxpool', nn.MaxPool2d(2))

    def forward(self, x):
        return super().forward(x)


class TransitionUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.convTrans = nn.ConvTranspose2d(
            in_channels=in_channels, out_channels=out_channels,
            kernel_size=3, stride=2, padding=0, bias=True)

    def forward(self, x, skip):
        out = self.convTrans(x)
        out = center_crop(out, skip.size(2), skip.size(3))
        out = torch.cat([out, skip], 1)
        return out


class Bottleneck(nn.Sequential):
    def __init__(self, in_channels, growth_rate, n_layers):
        super().__init__()
        self.add_module('bottleneck', DenseBlock(
            in_channels, growth_rate, n_layers, upsample=True))

    def forward(self, x):
        return super().forward(x)


def center_crop(layer, max_height, max_width):
    _, _, h, w = layer.size()
    xy1 = (w - max_width) // 2
    xy2 = (h - max_height) // 2
    return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)]

class FCDenseNet(nn.Module):
    def __init__(self, in_channels=3, down_blocks=(5,5,5,5,5),
                 up_blocks=(5,5,5,5,5), bottleneck_layers=5,
                 growth_rate=16, out_chans_first_conv=48, n_classes=12):
        super().__init__()
        self.down_blocks = down_blocks
        self.up_blocks = up_blocks
        cur_channels_count = 0
        skip_connection_channel_counts = []

        ## First Convolution ##

        self.add_module('firstconv', nn.Conv2d(in_channels=in_channels,
                  out_channels=out_chans_first_conv, kernel_size=3,
                  stride=1, padding=1, bias=True))
        cur_channels_count = out_chans_first_conv

        #####################
        # Downsampling path #
        #####################

        self.denseBlocksDown = nn.ModuleList([])
        self.transDownBlocks = nn.ModuleList([])
        for i in range(len(down_blocks)):
            self.denseBlocksDown.append(
                DenseBlock(cur_channels_count, growth_rate, down_blocks[i]))
            cur_channels_count += (growth_rate*down_blocks[i])
            skip_connection_channel_counts.insert(0,cur_channels_count)
            self.transDownBlocks.append(TransitionDown(cur_channels_count))

        #####################
        #     Bottleneck    #
        #####################

        self.add_module('bottleneck',Bottleneck(cur_channels_count,
                                     growth_rate, bottleneck_layers))
        prev_block_channels = growth_rate*bottleneck_layers
        cur_channels_count += prev_block_channels

        #######################
        #   Upsampling path   #
        #######################

        self.transUpBlocks = nn.ModuleList([])
        self.denseBlocksUp = nn.ModuleList([])
        for i in range(len(up_blocks)-1):
            self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels))
            cur_channels_count = prev_block_channels + skip_connection_channel_counts[i]

            self.denseBlocksUp.append(DenseBlock(
                cur_channels_count, growth_rate, up_blocks[i],
                    upsample=True))
            prev_block_channels = growth_rate*up_blocks[i]
            cur_channels_count += prev_block_channels

        ## Final DenseBlock ##

        self.transUpBlocks.append(TransitionUp(
            prev_block_channels, prev_block_channels))
        cur_channels_count = prev_block_channels + skip_connection_channel_counts[-1]

        self.denseBlocksUp.append(DenseBlock(
            cur_channels_count, growth_rate, up_blocks[-1],
                upsample=False))
        cur_channels_count += growth_rate*up_blocks[-1]

        ## Softmax ##

        self.finalConv = nn.Conv2d(in_channels=cur_channels_count,
               out_channels=n_classes, kernel_size=1, stride=1,
                   padding=0, bias=True)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        out = self.firstconv(x)

        skip_connections = []
        for i in range(len(self.down_blocks)):
            out = self.denseBlocksDown[i](out)
            skip_connections.append(out)
            out = self.transDownBlocks[i](out)

        out = self.bottleneck(out)
        for i in range(len(self.up_blocks)):
            skip = skip_connections.pop()
            out = self.transUpBlocks[i](out, skip)
            out = self.denseBlocksUp[i](out)

        out = self.finalConv(out)
        out = self.softmax(out)
        return out


In [None]:
def FCDenseNet103(n_classes):
    return FCDenseNet(
        in_channels=3, down_blocks=(4,5,7,10,12),
        up_blocks=(12,10,7,5,4), bottleneck_layers=15,
        growth_rate=16, out_chans_first_conv=48, n_classes=n_classes)

## Implementation

In [None]:
model=FCDenseNet103(2)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
model.to(device)

FCDenseNet(
  (firstconv): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (denseBlocksDown): ModuleList(
    (0): DenseBlock(
      (layers): ModuleList(
        (0): DenseLayer(
          (norm): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv): Conv2d(48, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (drop): Dropout2d(p=0.2, inplace=False)
        )
        (1): DenseLayer(
          (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv): Conv2d(64, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (drop): Dropout2d(p=0.2, inplace=False)
        )
        (2): DenseLayer(
          (norm): BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv): Conv2d(80, 16, kernel_size=(3, 3), strid

# **Training and Validation**

In [None]:
LEARNING_RATE=0.01
WEIGHT_DECAY = 1e-10

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion=nn.CrossEntropyLoss()

In [None]:
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 30, 40], gamma=0.1)

In [None]:
EPOCHS=15

In [None]:
MODEL_PATH="/content/drive/MyDrive/final_year_project/fcdensenet6.pth"

In [None]:
def get_IOU(pred_m, gt_m):
    pred_m=pred_m.cpu()
    gt_m=gt_m.cpu()
    intersection = np.logical_and(gt_m, pred_m)  

    true_sum= gt_m[:,:].sum()
    pred_sum= pred_m[:,:].sum()
    intersection_sum = intersection[:,:].sum()

    ji = (intersection_sum + 1.) / (true_sum + pred_sum - intersection_sum + 1.)

    return ji           

**Load Model**

In [None]:
current_epochs=0

Run the cell below to load the model 

In [None]:
# model.load_state_dict(torch.load(MODEL_PATH))
checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
current_epochs=checkpoint['epoch']
criterion=checkpoint['loss']
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])

In [None]:
def validate(model, dataloader, optimizer, criterion, val_data):
    print('\nValidating')
    model.eval()
    running_loss = 0.0
    iou = 0.0
    loop=tqdm(enumerate(dataloader),total=len(dataloader))
    for i,data in loop:
      images,masks=data
      images=images.to(device,dtype=torch.float32)
      masks=masks.to(device,dtype=torch.long)
        
      masks=masks.view(1,512,512)
      # print(masks.shape)
        
      output=model(images) 
      loss=criterion(output,masks)
      om=om = torch.argmax(output.squeeze(), dim=0)
      om=om.view(1,512,512)
        
        
      iou+=get_IOU(om,masks)        
      running_loss += loss.item()
      loop.set_postfix(loss=loss.item())

    val_iou=iou/len(dataloader)
    loss = running_loss / len(dataloader)
    del loop
    torch.cuda.empty_cache()
    return loss, val_iou

In [None]:
train_losses = []
val_losses = []
train_ious = []
val_ious = []

run the cell below ,only if you use saved model

In [None]:

current_epochs=checkpoint['epoch']
train_losses=checkpoint['train_losses']
train_ious=checkpoint['train_ious']
val_losses=checkpoint['val_losses']
val_ious=checkpoint['val_ious']

In [None]:
n_epochs_stop = 6
epochs_no_improve = 0
min_val_loss = np.Inf
early_stop = False

In [None]:
for epoch in range(EPOCHS):
  running_loss = 0.0
  iou=0.0
  loop=tqdm(enumerate(train_dl),total=len(train_dl))
  model.train()
  for i,data in loop:
    images,masks=data
    images=images.to(device,dtype=torch.float32)
    masks=masks.to(device,dtype=torch.long)
    
    
    masks=masks.view(1,512,512)
    # print(masks.shape)
    optimizer.zero_grad()
    
    output=model(images)
    
    # print(output.shape)
    loss=criterion(output,masks)

    loss.backward()
    optimizer.step()
    om=om = torch.argmax(output.squeeze(), dim=0)
    om=om.view(1,512,512)
    
    
    iou+=get_IOU(om,masks)
    
    running_loss += loss.item()
    loop.set_postfix(loss=loss.item())
  del loop
  torch.cuda.empty_cache()
  
  train_iou=iou/len(train_dl)
  train_loss = running_loss / len(train_dl)

  #validation
  with torch.no_grad():
    val_loss,val_iou=validate(model,val_dl,optimizer,criterion,test_ds)
  
  # Early stopping criteria
  if val_loss < min_val_loss:
    epochs_no_improve = 0
    min_val_loss = val_loss
  else:
    epochs_no_improve += 1

  if current_epochs+epoch > 3 and epochs_no_improve == n_epochs_stop:
    print('Early stopping at ', current_epochs+epoch, ' epochs.')
    early_stop = True
    break
  
  # Tracking performance metrics
  train_losses.append(train_loss)
  val_losses.append(val_loss)
  train_ious.append(train_iou)
  val_ious.append(val_iou)


  lr_scheduler.step()
  torch.save({
                'epoch': current_epochs+epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                'train_ious':train_ious,
                'val_ious':val_ious,
                'train_losses':train_losses,
                'val_losses':val_losses,
                'lr_scheduler_state_dict':lr_scheduler.state_dict()
                }, MODEL_PATH)
  print('\nEpoch: %d, Training Loss: %.4f , IoU : %.4f' %(current_epochs+epoch + 1, train_loss,train_iou))
  print('Epoch: %d, Validation Loss: %.4f , IoU : %.4f\n' %(current_epochs+epoch + 1, val_loss,val_iou))
  
print("Finished Training..")

# **Performance**

In [None]:
current_epochs=current_epochs+epoch+1


20


In [None]:
print(current_epochs)

17


In [None]:
iters = [i+1 for i in range(current_epochs)]

In [None]:
plt.title("Training Curve (batch_size={}, lr={})".format(BATCH_SIZE, LEARNING_RATE))
plt.plot(iters, train_losses, label="Train")
plt.plot(iters, val_losses, label="Validation")
plt.xticks(iters)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend(loc='best')
plt.show()

plt.title("Training Curve (batch_size={}, lr={})".format(BATCH_SIZE, LEARNING_RATE))
plt.plot(iters, train_ious, label="Train")
plt.plot(iters, val_ious, label="Validation")
plt.xticks(iters)
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend(loc='best')
plt.show()

print('\n Final training accuracy: %.4f' %(train_ious[-1]))
print('\n Final validation accuracy: %.4f' %(val_ious[-1]))

In [None]:
loss, iou = validate(model,test_dl,optimizer,criterion,test_ds)
print('\n Test: loss = %.4f, IoU = %.4f' %(loss, iou))


Validating


  cpuset_checked))
100%|██████████| 130/130 [00:26<00:00,  4.91it/s, loss=0.078]


 Test: loss = 0.1800, IoU = 0.7729





# **Testing**

# Display outputs

In [None]:
data=iter(test_dl)

  cpuset_checked))


In [None]:
with torch.no_grad():
  for i in range(3):
    images,masks=next(data)
    images=images.to(device)
    out=model(images)
    om = torch.argmax(out.squeeze(), dim=0).detach().cpu()
    segmented_img=np.multiply(om.cpu().numpy(),np.array(images[0].cpu()))
    

    fig,ax=plt.subplots(nrows=1,ncols=4,figsize=(12,12))
    fig. tight_layout(pad=1.0)
    ax[0].imshow(np.squeeze(images[0].cpu().permute(1,2,0)),cmap='gray')
    ax[0].set_title("chest x-ray")

    ax[1].imshow(om,cmap='gray')
    ax[1].set_title("predicted mask")

    ax[2].set_title("segmented image")
    ax[2].imshow(segmented_img[1],cmap='gray')

    ax[3].set_title("original mask")
    ax[3].imshow(np.squeeze(masks[0].cpu().permute(1,2,0)),cmap='gray')
    plt.show()
