<a href="https://www.kaggle.com/code/khunanonr/tuberculosis-chest-x-rays-shenzhen?scriptVersionId=131113948" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
#for dirname, _, filenames in os.walk('/kaggle/input'):
#    for filename in filenames:
#        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install -Uqq fastbook

In [None]:
import fastbook
# import the vision modules
from fastai.vision.all import *
from fastbook import *

In [None]:
path = Path("../input/tuberculosis-chest-xrays-shenzhen/images/images")
path.ls()

In [None]:
shenzen = pd.read_csv("../input/tuberculosis-chest-xrays-shenzhen/shenzhen_metadata.csv")

In [None]:
shenzen['path'] = [(path/x) for x in shenzen['study_id']]

In [None]:
#unique = shenzen['findings'].unique()
#map_lebel = {unique[i]:i for i in range(len(unique))}
#shenzen['label'] = shenzen['findings'].map(map_lebel)
shenzen['label'] = shenzen['study_id'].apply(lambda x: 0 if x[-5]=="0" else 1)

In [None]:
shenzen.describe()

In [None]:
shenzen.head()

# Data analysis

In [None]:
sex = pd.DataFrame()
sex['normal'] = shenzen[shenzen['label'] == 0 ]['sex'].value_counts()
sex['sick'] = shenzen[shenzen['label'] == 1 ]['sex'].value_counts()
sex.plot.pie(subplots=True,figsize=(5, 5))

In [None]:
shenzen.plot.box(y='age')

# Split Data

In [None]:
def pandas_split(f, data=None) :
    index_train, index_val = f(data)
    return data.iloc[index_train], data.iloc[index_val]

In [None]:
test_size = 0.1
f= TrainTestSplitter(test_size=test_size,random_state = 60, stratify=shenzen['label'])
train, test = pandas_split(f,shenzen)
train['label'].value_counts() test['label'].value_counts()

In [None]:
train.head()

# Train

In [None]:
def get_x(df) : return df['path']
def get_y(df) : return df['label']

tfms = [Rotate(max_deg=5, p=0.5),Zoom(draw=1.1, draw_x=0.5, draw_y=1, p=0.5), Flip(p=0.5)]
valid_size = 1/9
x_ray = DataBlock(blocks = (ImageBlock, CategoryBlock),
               get_x = get_x,
               get_y = get_y,
               splitter = RandomSplitter(valid_pct=valid_size, seed=42),
               item_tfms = Resize(224),
               batch_tfms = setup_aug_tfms(tfms)
                 )
dls = x_ray.dataloaders(train, bs=32)

In [None]:
dls.valid.show_batch(max_n=10, nrows=4)

In [None]:
dls.train.show_batch(max_n=20, nrows=4)

In [None]:
# list_pre_train = [densenet121,densenet169, densenet201, densenet161]
# result = []
# for model in list_pre_train:
#  torch.cuda.empty_cache()
#  learn = cnn_learner(
#      dls,
#      model,
#      metrics=accuracy
#  )
#  learn.fine_tune(10, freeze_epochs=3)
#  result.append(learn.validate())

In [None]:
learn = cnn_learner(
    dls,
    resnet50,
    metrics=accuracy
)

In [None]:
# learn.fit_one_cycle(3,lr_steep)
# learn.unfreeze()
# lr_steep= learn.lr_find(suggest_funcs=(steep))
# lr_steep

In [None]:
# base_lr = lr_steep[0]/2
# learn.fit_one_cycle(30, lr_max=slice(base_lr/100,base_lr), cbs=[ShowGraphCallback(), SaveModelCallback()])

In [None]:
# learn.recorder.plot_loss()

In [None]:
learn.fine_tune(30 , freeze_epochs=4, cbs=[ShowGraphCallback(), SaveModelCallback()])

In [None]:
learn.validate()

In [None]:
learn.show_results() 

## Nomal accuracy


In [None]:
def nomal_accuracy(model, test):
    predn =  tensor([list(model.predict(x)[2]) for x in test['path']]);
    targn =  tensor(test['label'])
    return accuracy(predn, targn)

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
confusion_m = interp.confusion_matrix()
interp.plot_confusion_matrix(figsize=(7,7))

In [None]:
resnet101_model =  cnn_learner(
    dls,
    resnet101,
    metrics=accuracy
)
resnet101_model.load("resnet101")

In [None]:
resnet101_model.validate()

In [None]:
nomal_accuracy(learn, test)

In [None]:
roc_accuracy(learn, test)

In [None]:
dens =  cnn_learner(
    dls,
    densenet201,
    metrics=accuracy
)
dens.load("densnet201")

## Accuracy 
Accuracy = sensitivity x prevalence + specificity * (1 - prevalence)

In [None]:
def confusion_matrix(pred, targ) :
    tp = 0
    fp = 0
    tn = 0
    fn = 0
    for i in range(len(pred)):
        if not pred[i]:
            if targ[i] : fn+=1
            else: tn+=1
        else:
            if targ[i] :
                tp+=1
            else: fp+=1 
    return tp,fp,tn,fn

In [None]:
def roc_accuracy(model, test) :
    pred =  tensor([list(learn.predict(x)[2]) for x in test['path']]);
    targ =  tensor(test['label'])
    pred, targ = flatten_check(pred.argmax(dim=-1), targ)
    tp, fp, tn, fn = confusion_matrix(pred, targ)
    sensitivity = tp/(tp+fn)
    specificity = tn/(fp+tn)
    prevalence = (pred == 1).sum()/len(pred)
    return sensitivity*prevalence + specificity * (1 - prevalence)