In [1]:
#!/usr/bin/env python3
import os
import random
import time
import torch
from torchmps import MPS
from torchvision import transforms, datasets
import argparse
import torchvision
from PIL import Image
import numpy as np

In [2]:
class_names = ['normal', 'covid','viral']
root_dir = '/home/mashjunior/loTeNet_pytorch/Covid-19_images/COVID-19_Radiography/COVID-19 Radiography Database/'
source_dirs = ['normal' ,'covid','viral']

In [3]:
parser = argparse.ArgumentParser()

In [4]:
parser.add_argument('--num_epochs', type=int, default=100, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
parser.add_argument('--l2_reg', type=float, default=0, help='L2 regularisation')
parser.add_argument('--aug', action='store_true', default=False, help='Use data augmentation')
parser.add_argument('--data_path', type=str, default=root_dir,help='Path to data.')
parser.add_argument('--bond_dim', type=int, default=5, help='MPS Bond dimension')
parser.add_argument('--nChannel', type=int, default=1, help='Number of input channels')
parser.add_argument('--dense_net', action='store_true', default=False, help='Using Dense Net model')

_StoreTrueAction(option_strings=['--dense_net'], dest='dense_net', nargs=0, const=True, default=False, type=None, choices=None, help='Using Dense Net model', metavar=None)

In [5]:
args = parser.parse_args([])

In [6]:
# Miscellaneous initialization
torch.manual_seed(0)
start_time = time.time()

# MPS parameters
#bond_dim = 20
adaptive_mode = False
periodic_bc = False

# Training parameters
#num_train = 2000
#num_test = 1000
#batch_size = 100
num_epochs = args.num_epochs
learn_rate = args.lr
l2_reg = args.l2_reg


batch_size = args.batch_size
bond_dim = args.bond_dim


# LoTeNet parameters
adaptive_mode = False 
periodic_bc   = False

kernel = 2 # Stride along spatial dimensions
output_dim = 1 # output dimension
 
feature_dim = 2

#logFile = time.strftime("%Y%m%d_%H_%M")+'.txt'
#makeLogFile(logFile)

normTensor = 0.5*torch.ones(args.nChannel)
### Data processing and loading....

In [7]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, image_dirs,transform):
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('png')]
            print(f'Found {len(images)}{class_name}')
            return images
        self.images={}
        self.class_names=['normal','covid', 'viral']
        for c in self.class_names:
            self.images[c]=get_images(c)
        self.image_dirs=image_dirs
        self.transform=transform
    def __len__(self):
        return sum([len(self.images[c]) for c in self.class_names])
    def __getitem__(self, index):
        class_name=random.choice(self.class_names)
        index=index%len(self.images[class_name])
        image_name=self.images[class_name][index]
        image_path =os.path.join(self.image_dirs[class_name], image_name)
        image=Image.open(image_path).convert('L')
        return self.transform(image), self.class_names.index(class_name)

In [8]:
### Data processing and loading....
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.Resize(size=(128,128)),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomRotation(20),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=normTensor,std=normTensor)])

valid_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(128,128)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomRotation(20),
    torchvision.transforms.ToTensor(),
    transforms.Normalize(mean=normTensor,std=normTensor)
    #torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,0.224,0.225])
])

In [9]:
train_dirs = {
    'normal': root_dir + 'train/normal',
    'covid': root_dir + 'train/covid',
    'viral': root_dir + 'train/viral'
}
train_dataset=ChestXRayDataset(train_dirs, train_transform)

Found 600normal
Found 600covid
Found 600viral


In [10]:
valid_dirs = {
    'normal': root_dir + '/valid/normal',
    'covid': root_dir + '/valid/covid',
    'viral': root_dir + '/valid/viral'
}

valid_dataset = ChestXRayDataset(valid_dirs, valid_transform)

Found 400normal
Found 400covid
Found 400viral


In [11]:
test_dirs = {
    'normal': root_dir + '/test/normal',
    'covid': root_dir + '/test/covid',
    'viral': root_dir + '/test/viral'
}

test_dataset = ChestXRayDataset(test_dirs, valid_transform)

Found 100normal
Found 100covid
Found 100viral


In [12]:
# Training parameters
num_train = 600
num_valid = 400
num_test = 100

In [13]:
mps = MPS(
    input_dim=128 ** 2,
    output_dim=3,
    bond_dim=bond_dim,
    adaptive_mode=adaptive_mode,
    periodic_bc=periodic_bc,
)

In [14]:
# Set our loss function and optimizer
loss_fun = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mps.parameters(), lr=learn_rate, weight_decay=l2_reg)

In [15]:
# Put MNIST data into dataloaders
samplers = {
    "train": torch.utils.data.SubsetRandomSampler(range(num_train)),
    "valid": torch.utils.data.SubsetRandomSampler(range(num_valid)),
    "test": torch.utils.data.SubsetRandomSampler(range(num_test)),
}
loaders = {
    name: torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, sampler=samplers[name], drop_last=True
    )
    for (name, dataset) in [("train", train_dataset), ("valid", valid_dataset), ("test", test_dataset)]
}
num_batches = {
    name: total_num // batch_size
    for (name, total_num) in [("train", num_train), ("valid", num_valid), ("test", num_test)]
}

In [16]:
samplers.items

<function dict.items>

In [17]:
num_batches

{'train': 9, 'valid': 6, 'test': 1}

In [18]:
#loaders['train'][0][0]

In [19]:
print(f"Maximum MPS bond dimension = {bond_dim}")
print(f" * {'Adaptive' if adaptive_mode else 'Fixed'} bond dimensions")
print(f" * {'Periodic' if periodic_bc else 'Open'} boundary conditions")
print(f"Using Adam w/ learning rate = {learn_rate:.1e}")
if l2_reg > 0:
    print(f" * L2 regularization = {l2_reg:.2e}")
print()

Maximum MPS bond dimension = 5
 * Fixed bond dimensions
 * Open boundary conditions
Using Adam w/ learning rate = 5.0e-04



In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [21]:
def evaluate(loader):
     ### Evaluation funcntion for validation/testing

    with torch.no_grad():
        vl_acc = 0.
        vl_loss = 0.
        labelsNp = np.zeros(1)
        predsNp = np.zeros(1)
        model.eval()

        for i, (inputs, labels) in enumerate(loader):

            inputs = inputs.to(device).type(dtype=torch.float)
            labels = labels.to(device).type(dtype=torch.float)
            labelsNp = np.concatenate((labelsNp, labels.cpu().numpy()))

            # Inference
            scores = torch.sigmoid(model(inputs)).type(dtype=torch.float)

            preds = scores
            loss = loss_fun(scores, labels)
            predsNp = np.concatenate((predsNp, preds.cpu().numpy()))
            vl_loss += loss.item()

        # Compute AUC over the full (valid/test) set
        vl_acc = computeAuc(labelsNp[1:],predsNp[1:])
        vl_loss = vl_loss/len(loader)

    return vl_acc, vl_loss

In [22]:
maxAuc = 0
minLoss = 1e3
convCheck = 5
convIter = 0

In [23]:
# Let's start training!
for epoch_num in range(1, num_epochs + 1):
    running_loss = 0.0
    running_acc = 0.0

    for inputs, labels in loaders["train"]:
        inputs, labels = inputs.view([batch_size, 128 ** 2]), labels.data

        # Call our MPS to get logit scores and predictions
        scores = mps(inputs)
        _, preds = torch.max(scores, 1)

        # Compute the loss and accuracy, add them to the running totals
        loss = loss_fun(scores, labels)
        with torch.no_grad():
            accuracy = torch.sum(preds == labels).item() / batch_size
            running_loss += loss
            running_acc += accuracy

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"### Epoch {epoch_num} ###")
    print(f"Average loss:           {running_loss / num_batches['train']:.4f}")
    print(f"Average train accuracy: {running_acc / num_batches['train']:.4f}")

    # Evaluate accuracy of MPS classifier on the validation set
    with torch.no_grad():
        running_acc = 0.0

        for inputs, labels in loaders["valid"]:
            inputs, labels = inputs.view([batch_size, 128 ** 2]), labels.data

            # Call our MPS to get logit scores and predictions
            scores = mps(inputs)
            _, preds = torch.max(scores, 1)
            running_acc += torch.sum(preds == labels).item() / batch_size

    print(f"Validation accuracy:          {running_acc / num_batches['valid']:.4f}")
    print(f"Runtime so far:         {int(time.time()-start_time)} sec\n")

### Epoch 1 ###
Average loss:           2989.7849
Average train accuracy: 0.3194
Validation accuracy:          0.4115
Runtime so far:         30 sec

### Epoch 2 ###
Average loss:           2.2395
Average train accuracy: 0.3194
Validation accuracy:          0.2682
Runtime so far:         57 sec

### Epoch 3 ###
Average loss:           1.1549
Average train accuracy: 0.3455
Validation accuracy:          0.3229
Runtime so far:         86 sec

### Epoch 4 ###
Average loss:           1.0957
Average train accuracy: 0.3559
Validation accuracy:          0.3984
Runtime so far:         115 sec

### Epoch 5 ###
Average loss:           1.1072
Average train accuracy: 0.3455
Validation accuracy:          0.3880
Runtime so far:         145 sec

### Epoch 6 ###
Average loss:           1.1102
Average train accuracy: 0.3733
Validation accuracy:          0.3828
Runtime so far:         172 sec

### Epoch 7 ###
Average loss:           1.0977
Average train accuracy: 0.3299
Validation accuracy:          0.39

### Epoch 56 ###
Average loss:           1.0938
Average train accuracy: 0.4062
Validation accuracy:          0.3620
Runtime so far:         1657 sec

### Epoch 57 ###
Average loss:           1.0904
Average train accuracy: 0.4097
Validation accuracy:          0.4271
Runtime so far:         1685 sec

### Epoch 58 ###
Average loss:           1.0811
Average train accuracy: 0.3646
Validation accuracy:          0.3802
Runtime so far:         1715 sec

### Epoch 59 ###
Average loss:           1.0675
Average train accuracy: 0.4340
Validation accuracy:          0.3932
Runtime so far:         1745 sec

### Epoch 60 ###
Average loss:           1.0415
Average train accuracy: 0.4497
Validation accuracy:          0.5312
Runtime so far:         1776 sec

### Epoch 61 ###
Average loss:           0.9277
Average train accuracy: 0.6007
Validation accuracy:          0.5781
Runtime so far:         1806 sec

### Epoch 62 ###
Average loss:           0.9036
Average train accuracy: 0.5868
Validation accuracy: 