# TB detection using Tensor Networks


In [1]:
%matplotlib inline

import os
import shutil
import random
import torch
import torchvision
import numpy as np
import time
import torch
from models.lotenet import loTeNet
from torchvision import transforms, datasets
import pdb
from PIL import Image
from matplotlib import pyplot as plt
from models.Densenet import *
from utils.tools import *
import argparse
torch.manual_seed(0)

print('Using PyTorch version', torch.__version__)

Using PyTorch version 1.3.1


In [2]:
"""
import os
import numpy as np
import shutil
import random

# # Creating Train / Val / Test folders (One time use)
root_dir = '/home/mashjunior/loTeNet_pytorch/TBCXRDatabase/TB_Chest_Radiography_Database/'
classes_dir = ['Normal', 'Tuberculosis']

val_ratio = 0.15
test_ratio = 0.05

for cls in classes_dir:
    os.makedirs(root_dir +'/train' + cls)
    os.makedirs(root_dir +'/val' + cls)
    os.makedirs(root_dir +'/test' + cls)


    # Creating partitions of the data after shuffeling
    src = root_dir + cls # Folder to copy images from

    allFileNames = os.listdir(src)
    np.random.shuffle(allFileNames)
    train_FileNames, val_FileNames, test_FileNames = np.split(np.array(allFileNames),
                                                              [int(len(allFileNames)* (1 - val_ratio + test_ratio)), 
                                                               int(len(allFileNames)* (1 - test_ratio))])


    train_FileNames = [src+'/'+ name for name in train_FileNames.tolist()]
    val_FileNames = [src+'/' + name for name in val_FileNames.tolist()]
    test_FileNames = [src+'/' + name for name in test_FileNames.tolist()]

    print('Total images: ', len(allFileNames))
    print('Training: ', len(train_FileNames))
    print('Validation: ', len(val_FileNames))
    print('Testing: ', len(test_FileNames))

    # Copy-pasting images
    for name in train_FileNames:
        shutil.copy(name, root_dir +'/train' + cls)

    for name in val_FileNames:
        shutil.copy(name, root_dir +'/val' + cls)

    for name in test_FileNames:
        shutil.copy(name, root_dir +'/test' + cls)
"""

"\nimport os\nimport numpy as np\nimport shutil\nimport random\n\n# # Creating Train / Val / Test folders (One time use)\nroot_dir = '/home/mashjunior/loTeNet_pytorch/TBCXRDatabase/TB_Chest_Radiography_Database/'\nclasses_dir = ['Normal', 'Tuberculosis']\n\nval_ratio = 0.15\ntest_ratio = 0.05\n\nfor cls in classes_dir:\n    os.makedirs(root_dir +'/train' + cls)\n    os.makedirs(root_dir +'/val' + cls)\n    os.makedirs(root_dir +'/test' + cls)\n\n\n    # Creating partitions of the data after shuffeling\n    src = root_dir + cls # Folder to copy images from\n\n    allFileNames = os.listdir(src)\n    np.random.shuffle(allFileNames)\n    train_FileNames, val_FileNames, test_FileNames = np.split(np.array(allFileNames),\n                                                              [int(len(allFileNames)* (1 - val_ratio + test_ratio)), \n                                                               int(len(allFileNames)* (1 - test_ratio))])\n\n\n    train_FileNames = [src+'/'+ name for name

In [3]:
class_names = ['normal', 'tuberculosis']
root_dir = '/home/mashjunior/loTeNet_pytorch/TBCXRDatabase/TB_Chest_Radiography_Database/loTenet_data'
source_dirs = ['normal' ,'tuberculosis']

In [4]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, image_dirs,transform):
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('png')]
            print(f'Found {len(images)}{class_name}')
            return images
        self.images={}
        self.class_names=['normal','tuberculosis']
        for c in self.class_names:
            self.images[c]=get_images(c)
        self.image_dirs=image_dirs
        self.transform=transform
    def __len__(self):
        return sum([len(self.images[c]) for c in self.class_names])
    def __getitem__(self, index):
        class_name=random.choice(self.class_names)
        index=index%len(self.images[class_name])
        image_name=self.images[class_name][index]
        image_path =os.path.join(self.image_dirs[class_name], image_name)
        image=Image.open(image_path).convert('L')
        return self.transform(image), self.class_names.index(class_name)

In [5]:
# Globally load device identifier
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
def evaluate(loader):
     ### Evaluation funcntion for validation/testing

    with torch.no_grad():
        vl_acc = 0.
        vl_loss = 0.
        labelsNp = np.zeros(1)
        predsNp = np.zeros(1)
        model.eval()

        for i, (inputs, labels) in enumerate(loader):

            inputs = inputs.to(device).type(dtype=torch.float)
            labels = labels.to(device).type(dtype=torch.float)
            labelsNp = np.concatenate((labelsNp, labels.cpu().numpy()))

            # Inference
            scores = torch.sigmoid(model(inputs)).type(dtype=torch.float)

            preds = scores
            loss = loss_fun(scores, labels)
            predsNp = np.concatenate((predsNp, preds.cpu().numpy()))
            vl_loss += loss.item()

        # Compute AUC over the full (valid/test) set
        vl_acc = computeAuc(labelsNp[1:],predsNp[1:])
        vl_loss = vl_loss/len(loader)

    return vl_acc, vl_loss

In [7]:
# Miscellaneous initialization
torch.manual_seed(1)
start_time = time.time()

In [8]:
parser = argparse.ArgumentParser()

In [9]:
parser.add_argument('--num_epochs', type=int, default=10, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
parser.add_argument('--l2', type=float, default=0, help='L2 regularisation')
parser.add_argument('--aug', action='store_true', default=False, help='Use data augmentation')
parser.add_argument('--data_path', type=str, default=root_dir,help='Path to data.')
parser.add_argument('--bond_dim', type=int, default=5, help='MPS Bond dimension')
parser.add_argument('--nChannel', type=int, default=1, help='Number of input channels')
parser.add_argument('--dense_net', action='store_true', default=False, help='Using Dense Net model')

_StoreTrueAction(option_strings=['--dense_net'], dest='dense_net', nargs=0, const=True, default=False, type=None, choices=None, help='Using Dense Net model', metavar=None)

In [10]:
args = parser.parse_args([])

In [11]:
batch_size = args.batch_size

# LoTeNet parameters
adaptive_mode = False 
periodic_bc   = False

kernel = 2 # Stride along spatial dimensions
output_dim = 1 # output dimension
 
feature_dim = 2

#logFile = time.strftime("%Y%m%d_%H_%M")+'.txt'
#makeLogFile(logFile)

normTensor = 0.5*torch.ones(args.nChannel)
### Data processing and loading....

In [12]:
### Data processing and loading....
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.Resize(size=(128,128)),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomRotation(20),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=normTensor,std=normTensor)])

valid_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(128,128)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(20),
    torchvision.transforms.ToTensor(),
    transforms.Normalize(mean=normTensor,std=normTensor)
    #torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225])
])

In [13]:
train_dirs = {
    'normal': root_dir + '/train/normal/',
    'tuberculosis': root_dir + '/train/tuberculosis/'
}
train_dataset=ChestXRayDataset(train_dirs, train_transform)

Found 3150normal
Found 3150tuberculosis


In [14]:
valid_dirs = {
    'normal': root_dir + '/val/normal/',
    'tuberculosis': root_dir + '/val/tuberculosis/'
}

valid_dataset = ChestXRayDataset(valid_dirs, valid_transform)

Found 175normal
Found 175tuberculosis


In [15]:
test_dirs = {
    'normal': root_dir + '/test/normal/',
    'tuberculosis': root_dir + '/test/tuberculosis/'
}

test_dataset = ChestXRayDataset(test_dirs, valid_transform)

Found 175normal
Found 175tuberculosis


In [16]:
dl_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dl_valid = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
dl_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
print('Num of training batches', len(dl_train))
print('Num of validation batches', len(dl_valid))
print('Num of test batches', len(dl_test))

Num of training batches 197
Num of validation batches 11
Num of test batches 11


In [17]:
# Initiliaze input dimensions
dim = torch.ShortTensor(list(train_dataset[0][0].shape[1:]))
nCh = int(train_dataset[0][0].shape[0])

In [18]:
print(dim)

tensor([128, 128], dtype=torch.int16)


In [19]:
print(nCh)

1


In [20]:
# Initialize the models
if not args.dense_net:
	print("Using LoTeNet")
	model = loTeNet(input_dim=dim, output_dim=output_dim, 
				  nCh=nCh, kernel=kernel,
				  bond_dim=args.bond_dim, feature_dim=feature_dim,
				  adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, virtual_dim=1)
else:
	print("Densenet Baseline!")
	model = DenseNet(depth=40, growthRate=12, 
					reduction=0.5,bottleneck=True,nClasses=output_dim)
model = loTeNet(input_dim=dim, output_dim=output_dim, 
				  nCh=nCh, kernel=kernel,
				  bond_dim=args.bond_dim, feature_dim=feature_dim,
				  adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, virtual_dim=1)

Using LoTeNet


In [21]:
# Choose loss function and optimizer
loss_fun = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 
                             weight_decay=args.l2)

In [22]:
nParam = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters:%d"%(nParam))
print(f"Maximum MPS bond dimension = {args.bond_dim}")

print("Bond dim: %d"%(args.bond_dim))
print("Number of parameters:%d"%(nParam),)

Number of parameters:945255
Maximum MPS bond dimension = 5
Bond dim: 5
Number of parameters:945255


In [23]:
print(f"Using Adam w/ learning rate = {args.lr:.1e}")
print("Feature_dim: %d, nCh: %d, B:%d"%(feature_dim,nCh,batch_size))

model = model.to(device)
nValid = len(dl_valid)
nTrain = len(dl_train)
nTest = len(dl_test)

maxAuc = 0
minLoss = 1e3
convCheck = 5
convIter = 0

Using Adam w/ learning rate = 5.0e-04
Feature_dim: 2, nCh: 1, B:32


In [24]:
# Let's start training!
for epoch in range(args.num_epochs):
    running_loss = 0.
    running_acc = 0.
    #t = time.time()
    model.train()
    predsNp = np.zeros(1)
    labelsNp = np.zeros(1)

    for i, (inputs, labels) in enumerate(dl_train):
        
        # convert inputs and labels  and scores to float tensor
        inputs = inputs.to(device).type(dtype=torch.float)
        labels = labels.to(device).type(dtype=torch.float)
        labelsNp = np.concatenate((labelsNp, labels.cpu().numpy()))

        scores = torch.sigmoid(model(inputs)).type(dtype=torch.float)

        preds = scores
        loss = loss_fun(scores, labels)

        with torch.no_grad():
            predsNp = np.concatenate((predsNp, preds.detach().cpu().numpy()))
            running_loss += loss

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 5 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, args.num_epochs, i+1, nTrain, loss.item()))

    accuracy = computeAuc(labelsNp,predsNp)

    # Evaluate on Validation set 
    with torch.no_grad():

        vl_acc, vl_loss = evaluate(dl_valid)
        if vl_acc > maxAuc or vl_loss < minLoss:
            if vl_loss < minLoss:
                minLoss = vl_loss
            if vl_acc > maxAuc:
                ### Predict on test set
                ts_acc, ts_loss = evaluate(dl_test)
                maxAuc = vl_acc
                print('New Max: %.4f'%maxAuc)
                print('Test Set Loss:%.4f	Auc:%.4f'%(ts_loss, ts_acc))
                print('Test Set Loss:%.4f	Auc:%.4f'%(ts_loss, ts_acc))
            convEpoch = epoch
            convIter = 0
        else:
            convIter += 1
        if convIter == convCheck:
            if not args.dense_net:
                print("MPS")
            else:
                print("DenseNet")
            print("Converged at epoch:%d with AUC:%.4f"%(convEpoch+1,maxAuc))

            break

Epoch [1/10], Step [5/197], Loss: 0.8039
Epoch [1/10], Step [10/197], Loss: 0.8388
Epoch [1/10], Step [15/197], Loss: 0.8399
Epoch [1/10], Step [20/197], Loss: 0.6653
Epoch [1/10], Step [25/197], Loss: 0.6585
Epoch [1/10], Step [30/197], Loss: 0.5364
Epoch [1/10], Step [35/197], Loss: 0.5199
Epoch [1/10], Step [40/197], Loss: 0.5150
Epoch [1/10], Step [45/197], Loss: 0.3875
Epoch [1/10], Step [50/197], Loss: 0.3956
Epoch [1/10], Step [55/197], Loss: 0.5289
Epoch [1/10], Step [60/197], Loss: 0.3673
Epoch [1/10], Step [65/197], Loss: 0.3732
Epoch [1/10], Step [70/197], Loss: 0.2648
Epoch [1/10], Step [75/197], Loss: 0.4687
Epoch [1/10], Step [80/197], Loss: 0.2769
Epoch [1/10], Step [85/197], Loss: 0.4923
Epoch [1/10], Step [90/197], Loss: 0.4187
Epoch [1/10], Step [95/197], Loss: 0.4244
Epoch [1/10], Step [100/197], Loss: 0.3784
Epoch [1/10], Step [105/197], Loss: 0.3060
Epoch [1/10], Step [110/197], Loss: 0.3663
Epoch [1/10], Step [115/197], Loss: 0.3219
Epoch [1/10], Step [120/197], L

Epoch [5/10], Step [165/197], Loss: 0.2074
Epoch [5/10], Step [170/197], Loss: 0.2783
Epoch [5/10], Step [175/197], Loss: 0.2849
Epoch [5/10], Step [180/197], Loss: 0.3362
Epoch [5/10], Step [185/197], Loss: 0.4638
Epoch [5/10], Step [190/197], Loss: 0.2619
Epoch [5/10], Step [195/197], Loss: 0.3470
New Max: 0.9779
Test Set Loss:0.2294	Auc:0.9671
Test Set Loss:0.2294	Auc:0.9671
Epoch [6/10], Step [5/197], Loss: 0.4477
Epoch [6/10], Step [10/197], Loss: 0.2106
Epoch [6/10], Step [15/197], Loss: 0.3326
Epoch [6/10], Step [20/197], Loss: 0.3489
Epoch [6/10], Step [25/197], Loss: 0.5271
Epoch [6/10], Step [30/197], Loss: 0.1158
Epoch [6/10], Step [35/197], Loss: 0.1148
Epoch [6/10], Step [40/197], Loss: 0.2300
Epoch [6/10], Step [45/197], Loss: 0.3432
Epoch [6/10], Step [50/197], Loss: 0.2011
Epoch [6/10], Step [55/197], Loss: 0.3551
Epoch [6/10], Step [60/197], Loss: 0.4555
Epoch [6/10], Step [65/197], Loss: 0.1336
Epoch [6/10], Step [70/197], Loss: 0.2589
Epoch [6/10], Step [75/197], Los

Epoch [10/10], Step [135/197], Loss: 0.2192
Epoch [10/10], Step [140/197], Loss: 0.1493
Epoch [10/10], Step [145/197], Loss: 0.0851
Epoch [10/10], Step [150/197], Loss: 0.2100
Epoch [10/10], Step [155/197], Loss: 0.2177
Epoch [10/10], Step [160/197], Loss: 0.0948
Epoch [10/10], Step [165/197], Loss: 0.1352
Epoch [10/10], Step [170/197], Loss: 0.0500
Epoch [10/10], Step [175/197], Loss: 0.2344
Epoch [10/10], Step [180/197], Loss: 0.1593
Epoch [10/10], Step [185/197], Loss: 0.2021
Epoch [10/10], Step [190/197], Loss: 0.2100
Epoch [10/10], Step [195/197], Loss: 0.1376
New Max: 0.9864
Test Set Loss:0.1821	Auc:0.9802
Test Set Loss:0.1821	Auc:0.9802
