Import necessary packages

In [None]:
import argparse
import copy
import os
import random
import time

import numpy as np
from sklearn.metrics import recall_score, precision_score

import torch
from torch.utils.data import Subset
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

from cords.selectionstrategies.supervisedlearning import CRAIGStrategy
from dataloader import ChestXRayImageData
from arch import XRayNet

Set random seeds for reproducibility

In [None]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

Helper functions specific to main process

In [None]:
def get_acc(model, dloader):

    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        ypred = []
        ytruth = []
        for data in dloader:
            images, labels = data['image'], data['label']
            outputs = model(images)
            _, predicted = torch.max(torch.nn.functional.log_softmax(outputs.data, dim=1), 1)
            ytruth.extend(labels.tolist())
            ypred.extend(predicted.tolist())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Recall: {recall_score(ytruth, ypred)}")
    print(f"Precision: {precision_score(ytruth, ypred)}")
    return 100 * correct / total


def train(model, optimizer, objective, dloader, epoch, writer, gammas=None):
    running_loss = 0.0
    if gammas is not None:
        batch_ids = list(dloader.batch_sampler)
    for step, data in enumerate(dloader):
        imgs, labels = data['image'], data['label']

        optimizer.zero_grad()

        pred = model(imgs)
        loss = objective(pred, labels)
        if gammas is not None:
            loss = torch.dot(loss, gammas[batch_ids[step]]) / gammas[batch_ids[step]].sum()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if step % 10 == 9:
            writer.add_scalar("training loss", running_loss / 10, epoch * len(dloader) + step)
            running_loss = 0.0


def validate(model, optimizer, objective, dloader, epoch, writer):
    model.eval()
    val_loss = 0.0
    for step, data in enumerate(dloader):
        imgs, labels = data['image'], data['label']

        optimizer.zero_grad()

        pred = model(imgs)
        val_loss += objective(pred, labels).mean()
    writer.add_scalar("validation loss", val_loss.item() / len(dloader), epoch)
    return val_loss


def cords(strategy, budget, model, optimizer, objective, tdloader, vdloader):
    reflection = copy.deepcopy(model)
    start_time = time.time()
    subset_idxs, gammas = strategy.select(budget, reflection.state_dict())
    print(f"Subset selection took {time.time() - start_time} secs ..")
    idxs = np.array(subset_idxs)
    gammas = np.array(gammas)

    if not np.all(gammas):
        print(f"Found zeros in gammas ..")
        nonZeros = np.where(gammas != 0.0)[0]
        idxs, gammas = idxs[nonZeros], gammas[nonZeros]

    gammas = torch.from_numpy(np.array(gammas)).to("cpu").to(torch.float32)
    return idxs, gammas



Argument parser

In [None]:
def parse_args():
    parser = argparse.ArgumentParser(description="Pneumonia classifier for ChestXRay Images")
    parser.add_argument('dataset_path', type=str)
    parser.add_argument('save_model_path', type=str)
    parser.add_argument('--budget', type=float)
    parser.add_argument('--no_training', action="store_true")
    parser.add_argument('--cords', action="store_true")
    parser.add_argument('--verbosity', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--shuffle', action="store_true")

    args = parser.parse_args()
    if args.verbosity:
        print(f"Args: {args}")
    return args



Main thread entry point

In [None]:
if __name__ == "__main__":
    args = parse_args()
    # Tensorboard writer for train/val loss monitoring
    writer = SummaryWriter(args.save_model_path)


Instantiate full datasets and corresponding dataloaders

In [None]:
    trainDataset = ChestXRayImageData(f"{args.dataset_path}/train/")
    valDataset = ChestXRayImageData(f"{args.dataset_path}/val/")
    testDataset = ChestXRayImageData(f"{args.dataset_path}/test/")
    trainDataloader = torch.utils.data.DataLoader(trainDataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=8)
    valDataloader = torch.utils.data.DataLoader(valDataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=8)
    testDataloader = torch.utils.data.DataLoader(testDataset, batch_size=args.batch_size, shuffle=args.shuffle, num_workers=8)


Instantiate Model for Classification task

In [None]:
    num_classes = 2
    model = XRayNet(3, num_classes)
    print(f"Model summary: {model}")

Instantiate optimizers and schedulers

In [None]:
    if not args.no_training:
        print(f"Training for {args.epochs} epochs ...")

        reduction = "none" if args.cords else "mean"

        objective = torch.nn.CrossEntropyLoss(reduction=reduction, weight=torch.FloatTensor([1.95, 0.67]))

        optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.5)
        #optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
        #lrScheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
        #lrScheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=2)

        lrScheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

        best_loss = float('inf')


Prepare datasets and dataloaders for CORDS

In [None]:
        if args.cords:
            shadowModel = copy.deepcopy(model)
            strategy = CRAIGStrategy(trainDataloader, valDataloader, shadowModel, objective, "cpu", num_classes, False, False, 'PerBatch')
            budget = int(args.budget * len(trainDataloader.dataset))
            ids = np.random.choice(len(trainDataloader.dataset), size=budget, replace=False)
            gammas = torch.ones(len(ids))
            trainSubset = Subset(trainDataset, ids)
            trainDataLoader = torch.utils.data.DataLoader(trainSubset, batch_size=args.batch_size, shuffle=False)
        else:
            gammas = None
            trainDataLoader = trainDataloader



Start training

In [None]:
        for epoch in range(args.epochs):  # loop over the dataset multiple times
            print(f"Epoch: {epoch}")

            # CORDS
            if args.cords and ((epoch + 1) % 5 == 0):
                print(f"Performing CORDS ..")
                ids, gammas = cords(strategy, budget, model, optimizer, objective, trainDataloader, valDataloader)
                trainSubset = Subset(trainDataset, ids)
                trainDataLoader = torch.utils.data.DataLoader(trainSubset, batch_size=args.batch_size, shuffle=False)

            model.train()
            train(model, optimizer, objective, trainDataLoader, epoch, writer, gammas)
            val_loss = validate(model, optimizer, objective, valDataloader, epoch, writer)
            if best_loss > val_loss:
                torch.save(model, args.save_model_path+"/best_model")
            lrScheduler.step()

        print('Finished Training')
    else:
        print(f"No training ...")

Retrieve best model and report the test accuracy

In [None]:
    best_model = torch.load(args.save_model_path+"/best_model")
    print(f"Test acc: {get_acc(best_model, testDataloader)}")