# CNN Chest X-Ray Scans - Multiclassification

### Setup Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math
import torchvision
from torchvision import datasets, models
import torchvision.transforms as T
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import time
import copy
import os

In [None]:
np.random.seed(0)
torch.manual_seed(0)

In [None]:
%matplotlib inline

In [None]:
batch_size = 64
learning_rate = 1e-3
stats = (0.5), (0.5)
image_size = 128
images_count = 64
images_row_count = 8
epochs = 100

In [None]:
def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
device

In [None]:
def denorm(tensors):
    return tensors * stats[1] + stats[0]

def show_images(images, nmax=images_count, interpolation='antialiased'):
    grid = make_grid(
        denorm(
            images.cpu().detach()[:nmax]
        ), nrow=images_row_count)

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(grid.permute(1, 2, 0))

def show_batch(dl, nmax=images_count, interpolation='antialiased'):
    for images, _ in dl:
        show_images(images, nmax, interpolation)
        break

### Setup Dataset

In [None]:
 ! pip install -q kaggle

In [None]:
from google.colab import files
files.upload()

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
! kaggle datasets download -d paultimothymooney/chest-xray-pneumonia

In [None]:
! mkdir dataset
! unzip chest-xray-pneumonia.zip -d dataset

In [None]:
os.makedirs('./dataset/chest_xray/test/VIRUS', exist_ok=True)
os.makedirs('./dataset/chest_xray/test/BACTERIA', exist_ok=True)
os.makedirs('./dataset/chest_xray/train/VIRUS', exist_ok=True)
os.makedirs('./dataset/chest_xray/train/BACTERIA', exist_ok=True)
os.makedirs('./dataset/chest_xray/val/VIRUS', exist_ok=True)
os.makedirs('./dataset/chest_xray/val/BACTERIA', exist_ok=True)

In [None]:
! mv ./dataset/chest_xray/test/PNEUMONIA/*virus*.jpeg ./dataset/chest_xray/test/VIRUS/ 2>/dev/null
! mv ./dataset/chest_xray/test/PNEUMONIA/*bacteria*.jpeg ./dataset/chest_xray/test/BACTERIA/ 2>/dev/null
! rm -rf ./dataset/chest_xray/test/PNEUMONIA/

! mv ./dataset/chest_xray/train/PNEUMONIA/*virus*.jpeg ./dataset/chest_xray/train/VIRUS/ 2>/dev/null
! mv ./dataset/chest_xray/train/PNEUMONIA/*bacteria*.jpeg ./dataset/chest_xray/train/BACTERIA/ 2>/dev/null
! rm -rf ./dataset/chest_xray/train/PNEUMONIA/

! mv ./dataset/chest_xray/val/PNEUMONIA/*virus*.jpeg ./dataset/chest_xray/val/VIRUS/ 2>/dev/null
! mv ./dataset/chest_xray/val/PNEUMONIA/*bacteria*.jpeg ./dataset/chest_xray/val/BACTERIA/ 2>/dev/null
! rm -rf ./dataset/chest_xray/val/PNEUMONIA/

In [None]:
transforms = T.Compose([ T.Resize(image_size),                       
                          T.CenterCrop(image_size),
                          T.transforms.Grayscale(num_output_channels=1),   
                          T.ToTensor(),
                          T.Normalize(*stats)])

In [None]:
train_dataset = datasets.ImageFolder(
  root="./dataset/chest_xray/train/",
  transform=transforms
)
test_dataset = datasets.ImageFolder(
  root="./dataset/chest_xray/test/",
  transform=transforms
)

Validate the classes that the dataloader will use:

In [None]:
train_dataset.classes

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
train_dataloader = DeviceDataLoader(train_dataloader, device)
test_dataloader = DeviceDataLoader(test_dataloader, device)

In [None]:
show_batch(train_dataloader)

In [None]:
net = models.resnext50_32x4d(pretrained=True)
net.conv1 = torch.nn.Conv1d(1, 64, (7, 7), (2, 2), (3, 3), bias=False)
net = to_device(net, device)
net

In [None]:
use_cuda = torch.cuda.is_available()

In [None]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=learning_rate, betas=(0.5, 0.999))

def accuracy(out, labels):
    _,pred = torch.max(out, dim=1)
    return torch.sum(pred==labels).item()

features_count = net.fc.in_features
net.fc = nn.Linear(features_count, 128)
if cuda_available:
  net.fc = net.fc.cuda() 

In [None]:
! pip install wandb

In [None]:
!wandb login
import wandb

In [None]:
def fit(epochs, start_idx=1):
  torch.cuda.empty_cache()
  valid_loss_min = np.Inf
  test_losses = []
  test_accuracies = []
  train_losses = []
  train_accuracies = []

  dataset_size = len(train_dataloader)

  config_defaults = {
    'epochs': epochs,
    'batch_size': batch_size,
    'learning_rate': learning_rate,
    'optimizer': 'adam',
    'fc_layer_size': 128,
  }

  wandb.init(project='xray-data-cnn', config=config_defaults)
  try:
    wandb.watch(net)
  except:
    print("Error watching model, ignoring.")

  config = wandb.config

  for epoch in range(start_idx, epochs + start_idx):
      total_batches_loss = 0.0
      correct_train_predictions = 0
      train_size = 0

      for batch_idx, (data_, target_) in enumerate(tqdm(train_dataloader)):
          data_, target_ = data_.to(device), target_.to(device)
          optimizer.zero_grad()
          
          outputs = net(data_)
          loss = criterion(outputs, target_)
          loss.backward()
          optimizer.step()

          total_batches_loss += loss.item()
          _,pred = torch.max(outputs, dim=1)
          correct_train_predictions += torch.sum(pred==target_).item()
          train_size += target_.size(0)

      epoch_train_accuracy = 100 * correct_train_predictions / train_size
      epoch_train_loss = total_batches_loss / dataset_size
      train_accuracies.append(epoch_train_accuracy)
      train_losses.append(epoch_train_loss)
      print("Epoch [{}/{}]".format(epoch, epochs + 1))
      print("Train Loss: {:.4f}, Train Accuracy: {:.4f}\n".format(
          epoch_train_loss, epoch_train_accuracy
      ))
      print("Average Train Loss: {:.4f}, Average Train Accuracy: {:.4f}\n".format(
          np.mean(train_losses),
          np.mean(train_accuracies),
      ))

      epoch_loss = 0
      test_size = 0
      correct_test_predictions = 0
      with torch.no_grad():
          net.eval()
          for data_test, target_test in (test_dataloader):
              data_test, target_test = data_test.to(device), target_test.to(device)
              outputs_test = net(data_test)
              test_size += target_test.size(0)

              loss_test = criterion(outputs_test, target_test)
              epoch_loss += loss_test.item()

              _,pred_t = torch.max(outputs_test, dim=1)
              correct_test_predictions += torch.sum(pred_t==target_test).item()

          epoch_test_accuracy = 100 * correct_test_predictions/test_size
          epoch_test_loss = epoch_loss/dataset_size
          test_accuracies.append(epoch_test_accuracy)
          test_losses.append(epoch_test_loss)
          print("Test Loss: {:.4f}, Test Accuracy: {:.4f}\n".format(
              epoch_test_loss , epoch_test_accuracy
          ))
          print("Avg Loss: {:.4f}, Avg Accuracy: {:.4f}\n".format(
              np.mean(test_losses) , np.mean(test_accuracies)
          ))
          
          
          if epoch_loss < valid_loss_min:
              valid_loss_min = epoch_loss
              torch.save(net.state_dict(), 'resnet.pt')
              print('Saving checkpoint...')
      net.train()
      wandb.log({
        "test loss": epoch_test_loss,
        "test accuracy": epoch_test_accuracy,
        "correct test predications": correct_test_predictions,
        "correct test predications": correct_test_predictions,
        "average test loss": np.mean(test_losses),
        "average test accuracy": np.mean(test_accuracies),
        "train accuracy": epoch_train_accuracy,
        "train loss": epoch_train_loss,
        "avaerage train loss": np.mean(train_losses),
        "avaerage train accuracy": np.mean(train_accuracies),

      }, step=epoch)
  return test_losses, test_accuracies, train_losses, train_accuracies

In [None]:
history = fit(epochs)
test_losses, test_accuracies, train_losses, train_accuracies = history

In [None]:
fig = plt.figure(figsize=(20,10))
plt.title("Train-Validation Accuracy")
plt.plot(train_accuracies, label='train')
plt.plot(test_accuracies, label='validation')
plt.xlabel('num_epochs', fontsize=12)
plt.ylabel('accuracy', fontsize=12)
plt.legend(loc='best')

In [None]:
def show_image(tensors, title=None):
    tensors = tensors.cpu() if device else tensors
    tensors = tensors.numpy().transpose((1, 2, 0))
  
    tensors = tensors * stats[1] + stats[0]
    tensors = np.clip(tensors, 0, 1)

def visualize_model(net, num_images=16):
    images_so_far = 0
    # plt.figure(figsize=(3, 3))
    fig = plt.figure(1, (16., 16.))
    grid = ImageGrid(fig, 111,
      nrows_ncols=(4, 4),
      axes_pad=1,
    )
    
    for i, data in enumerate(test_dataloader):
        inputs, labels = data
        if cuda_available:
            inputs, labels = inputs.cuda(), labels.cuda()
        outputs = net(inputs)
        _, predictions = torch.max(outputs.data, 1)
        labels = labels.cpu().numpy()
        predictions = predictions.cpu().numpy() if cuda_available else predictions.numpy()
        images_num = inputs.size()[0] if inputs.size()[0] < 16 else 16
        for j in range(images_num):
            label = test_dataset.classes[labels[j]]
            prediction = test_dataset.classes[predictions[j]]
            tensors = inputs[j].cpu().numpy().transpose((1, 2, 0))
          
            tensors = tensors * stats[1] + stats[0]
            tensors = np.clip(tensors, 0, 1)
            grid[j].set_title('actual: {} \n predicts: {}'.format(label, prediction))
            grid[j].imshow(1 - tensors[:, :, 0], cmap='Greys', interpolation='nearest')

plt.ion()
visualize_model(net)
plt.ioff()

In [None]:
import gc

gc.collect()

torch.cuda.empty_cache()