In [1]:
%load_ext autoreload
%autoreload 1
%aimport rails,aise

In [2]:
import torch
import torch.nn as nn
from collections import deque
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import time
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

import rails
import model as _model
import utils as _tools

from resnet import ResNet18

from sklearn.metrics import confusion_matrix
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import auc
from sklearn.metrics import precision_recall_curve

In [3]:
# set the deive
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
"""
initialize the trainned model
"""

start = time.perf_counter()

resnet = _model.TransferModel(use_cpu=False)

if torch.cuda.is_available():
    resnet.model.load_state_dict(torch.load("models/Cardiomegaly_resnet18.pth"))
else:
    resnet.model.load_state_dict(torch.load("models/Cardiomegaly_resnet18.pth", map_location=torch.device('cpu')))
resnet.model.to(DEVICE)
resnet.model.eval()

end = time.perf_counter()

print(f"Done within {end-start:.3f} secs.")

Done within 28.223 secs.


In [5]:
"""
import resenet architechure
"""

model = ResNet18(num_classes=2)

if torch.cuda.is_available():
    model.load_state_dict(torch.load("./models/chexpert_resnet18.pt"))
else:    
    model.load_state_dict(torch.load("./models/chexpert_resnet18.pt", map_location=torch.device('cpu')))
model.to(DEVICE)
model.eval()

print("done")

done


In [6]:
def reformatter(loader, n_rows, perturb_std):
    """A function to reformat data from ther dataloaders
    
    args:
       : loader (dataloader)
       : n_rows (int): number of rows to sample
       : perturb_std (float or None): the standard deviation of the 
           perturbations
        
    returns:
        : x (torch.FloatTensor): the x data
        : x_pert (torch.FloatTensor): the perturbed x data
        : y (torch.LongTensor): the y data
    """
    row_count = 0
    
    x = []
    y = []
    x_pert = []
    
    for image, labels in loader:
        row_count += 1
        x.append(image)
        
        if not perturb_std is None:
            noise = np.random.normal(loc=0.0, 
                                  scale=perturb_std, 
                                  size=image.shape)
            x_pert.append(noise + image)

        y.append(labels[0]) # the label for cadiomegaly

        if row_count == n_rows:
            break
            
    x = torch.FloatTensor(x)
    x_pert = torch.FloatTensor(x_pert)
    y = torch.LongTensor(y)
    
    return x, x_pert, y

In [7]:
"""
reformat the trainning data
""" 

PERTURBATION_AMT_STD = None
LOADER = resnet.dataloader_train.dataset
TRAIN_ROWS = 100000

x_train, x_train_p, y_train = reformatter(LOADER, TRAIN_ROWS, PERTURBATION_AMT_STD)

print(f"x_train: {x_train.shape}")
print(f"x_train_p: {x_train_p.shape}")
print(f"y_train: {y_train.shape}")


print("done")

x_train: torch.Size([37500, 1, 224, 224])
x_train_p: torch.Size([0])
y_train: torch.Size([37500])
done


  y = torch.LongTensor(y)


In [8]:
"""
validation data reformatting
"""

PERTURBATION_AMT_STD = 0.01
LOADER = resnet.dataloader_valid.dataset
TEST_ROWS = 300 # just needs to be more than 234

x_test, x_test_p, y_test = reformatter(LOADER, TEST_ROWS, PERTURBATION_AMT_STD)

print(f"x_test: {x_test.shape}")
print(f"x_test_p: {x_test_p.shape}")
print(f"y_test: {y_test.shape}")

print("done")

x_test: torch.Size([234, 1, 224, 224])
x_test_p: torch.Size([234, 1, 224, 224])
y_test: torch.Size([234])
done


  y = torch.LongTensor(y)


In [9]:
"""
evaluate ResNet18
"""

# get results on dev set
results = resnet.evaluate_model(resnet.model.state_dict(), 
                                resnet.dataloader_valid, 
                                resnet.valid_map)

res = _tools.get_classification_metrics(results)
res

Unnamed: 0,0
optimal_threshold,0.337538
true negatives,109.0
true positives,51.0
false positives,57.0
false negatives,17.0
sensitivity,0.75
specificity,0.656627
F1-score,0.579545
precision,0.472222
recall,0.75


In [17]:
start = time.perf_counter()

CONFIG = {
    "start_layer": 1,
    "n_class": 2,
    "aise_params": [
        {"hidden_layer": 3, 
         "sampling_temperature": 1, 
         "max_generation": 2, 
         "n_neighbors" : 5,
         "n_population" : 2 * 5,
         "mut_range": (.005, .015)}, 
    ]
}

rails_clf = rails.RAILS(model, 
              CONFIG, 
              x_train,
              y_train)


y_proba = rails_clf.predict(x_test)
y_pred = y_proba.argmax(axis=1)

end = time.perf_counter()
print(f"Done within {end-start:.3f} secs.")

Done within 298.690 secs.


In [11]:
def get_metrics(results):
    """A function to return classification metrics 
    
    args:
        : results (pd.DataFrame): with columns [y_true, y_pred, y_prob]
        
    returns:
        : metrics (pd.DataFrame): results
    """
    metrics = {}
    
    y_prob = results['y_prob'].astype(float)
    y_true = results['y_true'].astype(int)
    y_pred = results['y_pred'].astype(int)
    
    # confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()
    
    metrics['true_positive'] = int(tp)
    metrics['true_negative'] = int(tn)
    metrics['false_positive'] = int(fp)
    metrics['false_negative'] = int(fn)
    
    metrics['accuracy'] = (tp + tn) / (tn + tp + fn + fp)
    metrics['precision'] = tp / (tp + fp)
    metrics['recall'] = tp / (tp + fn)
    metrics['f1_score'] = 2 * tp / (2*tp + fp + fn)
    
    metrics['aucroc'] = roc_auc_score(y_true, y_prob)

    precision, recall, _ = precision_recall_curve(y_true, y_prob)
    metrics['aucpr'] = auc(recall, precision)
    
    metrics = pd.DataFrame.from_dict(metrics, orient='index').round(4)
    return metrics    

In [18]:
"""
get classification metrics
"""

df = pd.DataFrame({
    "y_true": y_test.numpy(),
    "y_pred": y_pred,
    "y_prob": y_proba[:,1]
})

get_metrics(df)

Unnamed: 0,0
true_positive,19.0
true_negative,144.0
false_positive,22.0
false_negative,49.0
accuracy,0.6966
precision,0.4634
recall,0.2794
f1_score,0.3486
aucroc,0.5734
aucpr,0.4761


In [13]:
"""
test on the perturbed data
"""

y_proba = rails_clf.predict(x_test_p)
y_pred = y_proba.argmax(axis=1)

end = time.perf_counter()
print(f"Done within {end-start:.3f} secs.")

"""
get classification metrics
"""

df = pd.DataFrame({
    "y_true": y_test.numpy(),
    "y_pred": y_pred,
    "y_prob": y_proba[:,1]
})

df.head()

get_metrics(df)

Done within 341.668 secs.


Unnamed: 0,0
true_positive,16.0
true_negative,150.0
false_positive,16.0
false_negative,52.0
accuracy,0.7094
precision,0.5
recall,0.2353
f1_score,0.32
aucroc,0.5695
aucpr,0.4788


In [14]:
import torch

class Dataset(torch.utils.data.Dataset):
    """ Characterizes a dataset for PyTorch """
    def __init__(self, data, labels):
        'Initialization'
        self.labels = labels
        self.data = data

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.data)

    def __getitem__(self, index):
        'Generates one sample of data'
        
        image = self.data[index]
        label = self.labels[index]
        
        return (image, label)

In [15]:
"""
evaluate ResNet on perturbed images 
"""

new_rows = []

t = Dataset(x_test_p, y_test)
x_perturbed = DataLoader(t, batch_size=1, shuffle=False)
label_map = {resnet.condition : 0}

for i, (inputs, label) in enumerate(x_perturbed):
    output = resnet.model(inputs.to(resnet.device)).to(resnet.device)

    _, y_pred = torch.max(output, 1)
    y_prob = torch.sigmoid(output)
    top_p, _ = y_prob.topk(1, dim=1)
    
    row = {
        'y_prob': 1 - top_p.cpu().detach().numpy()[0][0],
        'y_pred': y_pred.cpu().detach().numpy()[0],
        'y_true': label.cpu().detach().numpy()[0]
    }
    
    new_rows.append(row)


results2 = pd.DataFrame(new_rows)       

get_metrics(results2)


Unnamed: 0,0
true_positive,7.0
true_negative,162.0
false_positive,4.0
false_negative,61.0
accuracy,0.7222
precision,0.6364
recall,0.1029
f1_score,0.1772
aucroc,0.764
aucpr,0.5762
