# TB detection using Tensor Networks

# #

In [1]:
%matplotlib inline

import os
import shutil
import random
import torch
import torchvision
import numpy as np
import time
import torch
from models.lotenet import loTeNet
from torchvision import transforms, datasets
import pdb
from PIL import Image
from matplotlib import pyplot as plt
from models.Densenet import *
from utils.tools import *
import argparse
torch.manual_seed(0)

print('Using PyTorch version', torch.__version__)

Using PyTorch version 1.3.1


In [2]:
class_names = ['normal', 'pneumonia']
root_dir = '/home/mashjunior/loTeNet_pytorch/TBChestXRays/chest_xray/chest_xray'
source_dirs = ['normal' ,'pneumonia']

## create dataset 

In [3]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, image_dirs,transform):
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('jpeg')]
            print(f'Found {len(images)}{class_name}')
            return images
        self.images={}
        self.class_names=['normal','pneumonia']
        for c in self.class_names:
            self.images[c]=get_images(c)
        self.image_dirs=image_dirs
        self.transform=transform
    def __len__(self):
        return sum([len(self.images[c]) for c in self.class_names])
    def __getitem__(self, index):
        class_name=random.choice(self.class_names)
        index=index%len(self.images[class_name])
        image_name=self.images[class_name][index]
        image_path =os.path.join(self.image_dirs[class_name], image_name)
        image=Image.open(image_path).convert('L')
        return self.transform(image), self.class_names.index(class_name)

## Tensor model

In [4]:
# Globally load device identifier
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
def evaluate(loader):
     ### Evaluation funcntion for validation/testing

    with torch.no_grad():
        vl_acc = 0.
        vl_loss = 0.
        labelsNp = np.zeros(1)
        predsNp = np.zeros(1)
        model.eval()

        for i, (inputs, labels) in enumerate(loader):

            inputs = inputs.to(device).type(dtype=torch.float)
            labels = labels.to(device).type(dtype=torch.float)
            labelsNp = np.concatenate((labelsNp, labels.cpu().numpy()))

            # Inference
            scores = torch.sigmoid(model(inputs)).type(dtype=torch.float)

            preds = scores
            loss = loss_fun(scores, labels)
            predsNp = np.concatenate((predsNp, preds.cpu().numpy()))
            vl_loss += loss.item()

        # Compute AUC over the full (valid/test) set
        vl_acc = computeAuc(labelsNp[1:],predsNp[1:])
        vl_loss = vl_loss/len(loader)

    return vl_acc, vl_loss

In [6]:
# Miscellaneous initialization
torch.manual_seed(1)
start_time = time.time()

In [7]:
parser = argparse.ArgumentParser()

In [8]:
parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
parser.add_argument('--l2', type=float, default=0, help='L2 regularisation')
parser.add_argument('--aug', action='store_true', default=False, help='Use data augmentation')
parser.add_argument('--data_path', type=str, default=root_dir,help='Path to data.')
parser.add_argument('--bond_dim', type=int, default=5, help='MPS Bond dimension')
parser.add_argument('--nChannel', type=int, default=1, help='Number of input channels')
parser.add_argument('--dense_net', action='store_true', default=False, help='Using Dense Net model')

_StoreTrueAction(option_strings=['--dense_net'], dest='dense_net', nargs=0, const=True, default=False, type=None, choices=None, help='Using Dense Net model', metavar=None)

In [9]:
args = parser.parse_args([])

In [10]:
batch_size = args.batch_size

# LoTeNet parameters
adaptive_mode = False 
periodic_bc   = False

kernel = 2 # Stride along spatial dimensions
output_dim = 1 # output dimension
 
feature_dim = 2

#logFile = time.strftime("%Y%m%d_%H_%M")+'.txt'
#makeLogFile(logFile)

normTensor = 0.5*torch.ones(args.nChannel)
### Data processing and loading....

In [11]:
### Data processing and loading....
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.Resize(size=(128,128)),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomRotation(20),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=normTensor,std=normTensor)])

valid_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(128,128)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(20),
    torchvision.transforms.ToTensor(),
    transforms.Normalize(mean=normTensor,std=normTensor)
    #torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225])
])

## data loaders

In [12]:
train_dirs = {
    'normal': root_dir + '/train/normal',
    'pneumonia': root_dir + '/train/pneumonia'
}
train_dataset=ChestXRayDataset(train_dirs, train_transform)

Found 1341normal
Found 3875pneumonia


In [13]:
valid_dirs = {
    'normal': root_dir + '/val/normal',
    'pneumonia': root_dir + '/val/pneumonia'
}

valid_dataset = ChestXRayDataset(valid_dirs, valid_transform)

Found 234normal
Found 390pneumonia


In [14]:
test_dirs = {
    'normal': root_dir + '/test/normal',
    'pneumonia': root_dir + '/test/pneumonia'
}

test_dataset = ChestXRayDataset(test_dirs, valid_transform)

Found 8normal
Found 8pneumonia


In [15]:
dl_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dl_valid = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
dl_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
print('Num of training batches', len(dl_train))
print('Num of validation batches', len(dl_valid))
print('Num of test batches', len(dl_test))

Num of training batches 163
Num of validation batches 20
Num of test batches 1


In [16]:
# Initiliaze input dimensions
dim = torch.ShortTensor(list(train_dataset[0][0].shape[1:]))
nCh = int(train_dataset[0][0].shape[0])

In [17]:
print(dim)

tensor([128, 128], dtype=torch.int16)


In [18]:
print(nCh)

1


In [19]:
# Initialize the models
if not args.dense_net:
	print("Using LoTeNet")
	model = loTeNet(input_dim=dim, output_dim=output_dim, 
				  nCh=nCh, kernel=kernel,
				  bond_dim=args.bond_dim, feature_dim=feature_dim,
				  adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, virtual_dim=1)
else:
	print("Densenet Baseline!")
	model = DenseNet(depth=40, growthRate=12, 
					reduction=0.5,bottleneck=True,nClasses=output_dim)
model = loTeNet(input_dim=dim, output_dim=output_dim, 
				  nCh=nCh, kernel=kernel,
				  bond_dim=args.bond_dim, feature_dim=feature_dim,
				  adaptive_mode=adaptive_mode, periodic_bc=periodic_bc, virtual_dim=1)

Using LoTeNet


In [20]:
# Choose loss function and optimizer
loss_fun = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 
                             weight_decay=args.l2)

In [21]:
nParam = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters:%d"%(nParam))
print(f"Maximum MPS bond dimension = {args.bond_dim}")

print("Bond dim: %d"%(args.bond_dim))
print("Number of parameters:%d"%(nParam),)


Number of parameters:945255
Maximum MPS bond dimension = 5
Bond dim: 5
Number of parameters:945255


In [22]:
print(f"Using Adam w/ learning rate = {args.lr:.1e}")
print("Feature_dim: %d, nCh: %d, B:%d"%(feature_dim,nCh,batch_size))

model = model.to(device)
nValid = len(dl_valid)
nTrain = len(dl_train)
nTest = len(dl_test)

maxAuc = 0
minLoss = 1e3
convCheck = 5
convIter = 0

Using Adam w/ learning rate = 5.0e-04
Feature_dim: 2, nCh: 1, B:32


In [23]:
# Let's start training!
for epoch in range(args.num_epochs):
    running_loss = 0.
    running_acc = 0.
    #t = time.time()
    model.train()
    predsNp = np.zeros(1)
    labelsNp = np.zeros(1)

    for i, (inputs, labels) in enumerate(dl_train):
        
        # convert inputs and labels  and scores to float tensor
        inputs = inputs.to(device).type(dtype=torch.float)
        labels = labels.to(device).type(dtype=torch.float)
        labelsNp = np.concatenate((labelsNp, labels.cpu().numpy()))

        scores = torch.sigmoid(model(inputs)).type(dtype=torch.float)

        preds = scores
        loss = loss_fun(scores, labels)

        with torch.no_grad():
            predsNp = np.concatenate((predsNp, preds.detach().cpu().numpy()))
            running_loss += loss

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 5 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, args.num_epochs, i+1, nTrain, loss.item()))

    accuracy = computeAuc(labelsNp,predsNp)

    # Evaluate on Validation set 
    with torch.no_grad():

        vl_acc, vl_loss = evaluate(dl_valid)
        if vl_acc > maxAuc or vl_loss < minLoss:
            if vl_loss < minLoss:
                minLoss = vl_loss
            if vl_acc > maxAuc:
                ### Predict on test set
                ts_acc, ts_loss = evaluate(dl_test)
                maxAuc = vl_acc
                print('New Max: %.4f'%maxAuc)
                print('Test Set Loss:%.4f	Auc:%.4f'%(ts_loss, ts_acc))
                print('Test Set Loss:%.4f	Auc:%.4f'%(ts_loss, ts_acc))
            convEpoch = epoch
            convIter = 0
        else:
            convIter += 1
        if convIter == convCheck:
            if not args.dense_net:
                print("MPS")
            else:
                print("DenseNet")
            print("Converged at epoch:%d with AUC:%.4f"%(convEpoch+1,maxAuc))

            break

Epoch [1/5], Step [5/163], Loss: 0.7428
Epoch [1/5], Step [10/163], Loss: 0.7542
Epoch [1/5], Step [15/163], Loss: 0.6743
Epoch [1/5], Step [20/163], Loss: 0.6873
Epoch [1/5], Step [25/163], Loss: 0.5684
Epoch [1/5], Step [30/163], Loss: 0.5836
Epoch [1/5], Step [35/163], Loss: 0.5208
Epoch [1/5], Step [40/163], Loss: 0.4957
Epoch [1/5], Step [45/163], Loss: 0.4612
Epoch [1/5], Step [50/163], Loss: 0.3088
Epoch [1/5], Step [55/163], Loss: 0.3481
Epoch [1/5], Step [60/163], Loss: 0.3933
Epoch [1/5], Step [65/163], Loss: 0.1834
Epoch [1/5], Step [70/163], Loss: 0.4002
Epoch [1/5], Step [75/163], Loss: 0.2036
Epoch [1/5], Step [80/163], Loss: 0.3790
Epoch [1/5], Step [85/163], Loss: 0.3848
Epoch [1/5], Step [90/163], Loss: 0.5487
Epoch [1/5], Step [95/163], Loss: 0.4169
Epoch [1/5], Step [100/163], Loss: 0.5524
Epoch [1/5], Step [105/163], Loss: 0.2403
Epoch [1/5], Step [110/163], Loss: 0.4750
Epoch [1/5], Step [115/163], Loss: 0.3068
Epoch [1/5], Step [120/163], Loss: 0.3765
Epoch [1/5],