In [0]:
from google.colab import drive

drive.mount('/content/gdrive')

In [0]:
import numpy as np
import pandas as pd 
import os
from glob import glob
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from tqdm.notebook import tqdm
from sklearn.metrics import roc_auc_score, average_precision_score
import ast

from torch.autograd import Function
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as tfm
from torchvision import models
from torch import optim
from torch.optim import lr_scheduler
from torch.nn import MSELoss
import torch
import random
from torch import nn

plt.style.use('default')

In [0]:
EPOCHS = 20
DEBUG = True
RESUME = False
LR = 1e-4
WD = 1e-5
PAT = 5
FACTOR = 0.1
BS = 16
NUM_CLASSES = 14
CROP_SIZE = 896
SEED = 42
lam = 0.9     #used in loss function

if DEBUG:
    EPOCHS = 1

DIR = '/content/gdrive/My Drive/ChestXray14/'

In [0]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
 
seed_everything(SEED)

In [0]:
invalid = pd.read_csv('../input/chestxray14-invalid/invalid.txt', sep=' ', header=None)
invalid_list = []

invalid_list.append(ast.literal_eval(invalid.values[0][0][1:-1]).split('/')[-1])
for i in range(518):
  invalid_list.append(ast.literal_eval(invalid.values[0][i+1][:-1]).split('/')[-1])  

In [0]:
train_idx = np.concatenate(pd.read_csv('../input/chestxray14-csvs/train_val_list.txt', sep=' ', header=None).values)
test_idx = np.concatenate(pd.read_csv('../input/chestxray14-csvs/test_list.txt', sep=' ', header=None).values)

# taken from: https://www.kaggle.com/kmader/train-simple-xray-cnn/input?select=Data_Entry_2017.csv

all_xray_df = pd.read_csv('../input/csv-for-chest/Data_Entry_2017_v2020.csv')
all_image_paths = {os.path.basename(x): x for x in 
                   glob(os.path.join('..', 'input', 'data','*','images', '*.png'))}
print('Scans found:', len(all_image_paths), ', Total Headers', all_xray_df.shape[0])
all_xray_df['path'] = all_xray_df['Image Index'].map(all_image_paths.get)
all_xray_df.sample(3)

In [0]:
# taken from: https://www.kaggle.com/kmader/train-simple-xray-cnn/input?select=Data_Entry_2017.csv

all_xray_df['Finding Labels'] = all_xray_df['Finding Labels'].map(lambda x: x.replace('No Finding', ''))
from itertools import chain
all_labels = np.unique(list(chain(*all_xray_df['Finding Labels'].map(lambda x: x.split('|')).tolist())))
all_labels = [x for x in all_labels if len(x)>0]
print('All Labels ({}): {}'.format(len(all_labels), all_labels))
for c_label in all_labels:
    if len(c_label)>1: # leave out empty labels
        all_xray_df[c_label] = all_xray_df['Finding Labels'].map(lambda finding: 1.0 if c_label in finding else 0)
        
all_xray_df.sample(3)

In [0]:
class ChestDataset(Dataset):
    def __init__(self, df, transform, mode='train'):
        self.label_df = df[all_labels].values
        self.path_df = df['path'].values
        self.transform = transform
        self.mode = mode
        
    def __len__(self):
        return len(self.path_df)
    
    def __getitem__(self, idx):
        path = self.path_df[idx]
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        
        if self.mode=='val':
            img = cv2.resize(img, (CROP_SIZE, CROP_SIZE))
            
#         img = np.stack([img,img,img], -1)
        
        if self.transform:
            img = self.transform(img)
            

        label = self.label_df[idx]
            

        return img, torch.tensor(label, dtype=torch.float)

In [0]:
train_tfm = tfm.Compose([tfm.ToPILImage(),
                         tfm.RandomCrop(CROP_SIZE),
                         tfm.RandomRotation(5),
                         tfm.RandomHorizontalFlip(),
                         tfm.ToTensor()])

val_tfm = tfm.Compose([tfm.ToPILImage(),
                       tfm.ToTensor()])

In [0]:
all_xray_df = all_xray_df[~all_xray_df['Image Index'].isin(invalid_list)].reset_index(drop=True)
train_valdf = all_xray_df[all_xray_df['Image Index'].isin(train_idx)].reset_index(drop=True)
valdf = train_valdf.sample(n=10000, random_state=SEED)
traindf = train_valdf[~train_valdf['Image Index'].isin(list(valdf['Image Index'].values))]
testdf = all_xray_df[all_xray_df['Image Index'].isin(test_idx)].reset_index(drop=True)

In [0]:
train_dataset = ChestDataset(traindf, train_tfm)
val_dataset = ChestDataset(valdf, val_tfm, mode='val')

trainloader = DataLoader(train_dataset, batch_size=BS, num_workers=4, shuffle=True)
valloader = DataLoader(val_dataset, batch_size=BS, num_workers=4, shuffle=False)

In [0]:
plt.figure(figsize=(20,10))

for i in range(15):
    plt.subplot(3,5,i+1)
    idx = np.random.randint(len(trainloader))
    img, _ = train_dataset[idx]
    plt.imshow(img.numpy().transpose(1,2,0).squeeze(-1), cmap='gray')
    plt.axis('off')
    
plt.tight_layout()
plt.show()

In [0]:
class AECNN(nn.Module):

    def __init__(self, classCount):
        super (AECNN0, self).__init__()

        self.classCount = classCount
        # self.y2 = torch.Tensor(bs, 3, h, w).cuda()
        self.normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

        self.encoder = nn.Sequential(
            #1x896x896
            nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 5, stride = 4, padding = 2),
            nn.ELU(),
            #1X224X224
            nn.Conv2d(in_channels=32, out_channels=1, kernel_size=1, stride=1, padding=0),
            #1x224x224
            )

        self.decoder = nn.Sequential(
            #1x224x224
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(4),
            #1x896x896
            )


        #CLASSIFIER
        self.classifier = DenseNet121(classCount = self.classCount, isTrained = True)

    def forward(self, x):
        
        y = self.encoder(x)
        y = Relu1.apply(y)
        
        z1 = self.decoder(y)
        z1 = Relu1.apply(z1)
        
        bs, c, h, w = y.shape
        y2 = torch.Tensor(bs, 3, h, w).cuda()
        
        for img_no in range(bs):
            y2[img_no] = y[img_no]
            y2[img_no] = self.normalize(y2[img_no]) #broadcasting 1 channel to 3 channels

        z2 = self.classifier(y2)

        return z1, z2



class Relu1(Function):

    @staticmethod
    def forward(ctx, input):

        ctx.save_for_backward(input)
        #print("fwd:", input[0])
        return input.clamp(min=0, max=1)

    @staticmethod
    def backward(ctx, grad_output):

        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input<0]*=0.0
        grad_input[input>1]*=0.0

        return grad_input

In [0]:
model = AECNN()

optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay = WD)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, patience=PAT, factor = FACTOR, mode='min') # Reduce lr by 0.1 factor after every 5 epochs
criterion = MSELoss()

def get_lr(optimizer):
  lr_list = []
  for p in optimizer.param_groups:
    lr_list.append(p['lr'])

  return lr_list[0]

In [0]:
model

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_one_epoch(epoch, dataloader, model, criterion, optimizer, scheduler):
  train_loss = 0
  length = len(dataloader)
  model.train()
  optimizer.zero_grad()
  iterator = tqdm(enumerate(dataloader), total=length, leave=False, desc=f'Epoch {epoch+1}/{EPOCHS}')

  for i, (img, label) in iterator:
    img = img.to(device)
    label = label.to(device)
    
    ae_img, pred_label = model(img)
    
    img_loss = criterion1(ae_img,img)
    label_loss = criterion2(pred_label, label)
    loss = lam*label_loss + (1-lam)*img_loss
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    train_loss += loss.item()/length


    if DEBUG:
      if i==100:
        break

    if scheduler:
      scheduler.step()

  return train_loss

def validate_one_epoch(epoch, dataloader, model, criterion):
  val_loss = 0
  length = len(dataloader)
  model.eval()
  iterator = tqdm(enumerate(dataloader), total=length, leave=False, desc=f'Epoch {epoch+1}/{EPOCHS}')

  for i, (img, label) in iterator:
    img = img.to(device)
    label = label.to(device)
    
    ae_img, pred_label = model(img)

    
    img_loss = criterion1(ae_img,img)
    label_loss = criterion2(pred_label, label)
    loss = lam*label_loss + (1-lam)*img_loss
    
    val_loss += loss.item()/length
    
    if DEBUG:
      if i==100:
        break

  return val_loss

In [0]:
best_loss = np.inf
best_auc = 0

model.to(device)

for epoch in range(EPOCHS):
  if RESUME:
    epoch = resume_epoch + 1

  t_loss = train_one_epoch(epoch, trainloader, model, criterion, optimizer, scheduler=None)
  lr = get_lr(optimizer)
  print('Epoch {}/{} (train) || Loss: {:.4f} LR: {:.5f}'.format(epoch+1, EPOCHS, t_loss, lr))
    
  v_loss = validate_one_epoch(epoch, valloader, model, criterion)
  print('Epoch {}/{} (validation) || Loss: {:.4f} '.format(epoch+1, EPOCHS, v_loss))

  scheduler.step(v_loss)  


  content = 'Train Loss: {:.4f} Val Loss: {:.4f}'.format(t_loss, v_loss)

  recorder={}

  recorder['epoch'] = epoch
  recorder['model'] = model.state_dict()
  recorder['optimizer'] = optimizer.state_dict()
  recorder['scheduler'] = scheduler.state_dict()

  torch.save(recorder, DIR+'chestxray14_ae_recorder.pth')

  with open(DIR+'chestxray14_ae.txt', 'a') as logger:
    logger.write(content + '\n')

  if v_loss<best_loss:
    torch.save(model.state_dict(), DIR+'chestxray14_ae_loss.pth')
    best_loss = v_loss