In [1]:
import os
from tqdm import trange
import numpy as np
import PIL
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.models import resnet18, resnet50
from tensorboardX import SummaryWriter
from collections import OrderedDict
import medmnist
from medmnist import INFO, Evaluator

In [2]:
data_flag = 'octmnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001
gamma=0.1
milestones = [0.5 * NUM_EPOCHS, 0.75 * NUM_EPOCHS]

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [3]:
# preprocessing
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download)
val_dataset = DataClass(split='val', transform=data_transform, download=download)
test_dataset = DataClass(split='test', transform=data_transform, download=download)

pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)
val_loader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)

Using downloaded and verified file: /home/y2xiong/.medmnist/octmnist.npz
Using downloaded and verified file: /home/y2xiong/.medmnist/octmnist.npz
Using downloaded and verified file: /home/y2xiong/.medmnist/octmnist.npz
Using downloaded and verified file: /home/y2xiong/.medmnist/octmnist.npz


In [4]:
def trainModel(model, train_loader, task, criterion, optimizer, writer):
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    total_loss = []
    global iteration

    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs.to(device))

        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32).to(device)
            loss = criterion(outputs, targets)
        else:
            targets = torch.squeeze(targets, 1).long().to(device)
            loss = criterion(outputs, targets)

        total_loss.append(loss.item())
        writer.add_scalar('train_loss_logs', loss.item(), iteration)
        iteration += 1

        loss.backward()
        optimizer.step()
    
    loss = sum(total_loss)/len(total_loss)
    return loss

In [5]:
def testModel(model, evaluator, data_loader, task, criterion, save_folder=None):
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
        
    model.eval()
    
    total_loss = []
    y_score = torch.tensor([]).to(device)

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            outputs = model(inputs.to(device))
            
            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32).to(device)
                loss = criterion(outputs, targets)
                m = nn.Sigmoid()
                outputs = m(outputs).to(device)
            else:
                targets = torch.squeeze(targets, 1).long().to(device)
                loss = criterion(outputs, targets)
                m = nn.Softmax(dim=1)
                outputs = m(outputs).to(device)
                targets = targets.float().resize_(len(targets), 1)

            total_loss.append(loss.item())
            y_score = torch.cat((y_score, outputs), 0)

        y_score = y_score.detach().cpu().numpy()
        auc, acc = evaluator.evaluate(y_score, save_folder, 'AlexNet_modified')
        
        test_loss = sum(total_loss) / len(total_loss)

        return [test_loss, auc, acc]

In [6]:
class CNNImproved(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        # raise NotImplementedError
        self.layer1 = nn.Conv2d(n_channels, 48, 2)
        self.layer2 = nn.Conv2d(48, 24, 3)
        self.layer3 = nn.Conv2d(24, 16, 2)
        self.layer4 = nn.Linear(64, 32)
        self.layer5 = nn.Linear(32,n_classes)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.act1 = nn.ReLU()
        self.act2 = nn.ReLU()
        self.act3 = nn.ReLU()
        self.act4 = nn.ReLU()
        self.Batch1 = nn.BatchNorm2d(48)
        self.Batch2 = nn.BatchNorm2d(24)
        self.Batch3 = nn.BatchNorm2d(16)

    def forward(self, x):
        # raise NotImplementedError
        x = self.layer1(x)
        x = self.act1(x)
        x = self.pool1(x)
        x = self.Batch1(x)
        x = self.layer2(x)
        x = self.act2(x)
        x = self.pool2(x)
        x = self.Batch2(x)
        x = self.layer3(x)
        x = self.act3(x)
        x = self.pool3(x)
        x = self.Batch3(x)
        x = x.view(x.size(0), -1)
        x = self.layer4(x)
        x = self.act4(x)
        x = self.layer5(x)
        return x

In [7]:
model =  CNNImproved(n_channels, n_classes)
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
model = model.to(device)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)

output_path = './output/AlexNet_modified'
if not os.path.exists(output_path):
    os.makedirs(output_path)

logs = ['loss', 'auc', 'acc']
train_logs = ['train_'+log for log in logs]
val_logs = ['val_'+log for log in logs]
test_logs = ['test_'+log for log in logs]
log_dict = OrderedDict.fromkeys(train_logs+val_logs+test_logs, 0)
    
writer = SummaryWriter(log_dir=os.path.join(output_path, 'Tensorboard_Results'))

best_auc = 0
best_epoch = 0
best_model = model

global iteration
iteration = 0

if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()
        
train_evaluator = medmnist.Evaluator(data_flag, 'train')
val_evaluator = medmnist.Evaluator(data_flag, 'val')
test_evaluator = medmnist.Evaluator(data_flag, 'test')
    
for epoch in trange(NUM_EPOCHS): 
    train_loss = trainModel(model, train_loader, task, criterion, optimizer, writer)

    train_metrics = testModel(model, train_evaluator, train_loader_at_eval, task, criterion)
    val_metrics = testModel(model, val_evaluator, val_loader, task, criterion)
    test_metrics = testModel(model, test_evaluator, test_loader, task, criterion)

    scheduler.step()

    for i, key in enumerate(train_logs):
        log_dict[key] = train_metrics[i]
    for i, key in enumerate(val_logs):
        log_dict[key] = val_metrics[i]
    for i, key in enumerate(test_logs):
        log_dict[key] = test_metrics[i]

    for key, value in log_dict.items():
        writer.add_scalar(key, value, epoch)

    cur_auc = val_metrics[1]
    if cur_auc > best_auc:
        best_epoch = epoch
        best_auc = cur_auc
        best_model = model

state = {
    'net': best_model.state_dict(),
}

path = os.path.join(output_path, 'best_model.pth')
torch.save(state, path)

train_metrics = testModel(best_model, train_evaluator, train_loader_at_eval, task, criterion, output_path)
val_metrics = testModel(best_model, val_evaluator, val_loader, task, criterion, output_path)
test_metrics = testModel(best_model, test_evaluator, test_loader, task, criterion, output_path)

train_log = 'train  auc: %.5f  acc: %.5f\n' % (train_metrics[1], train_metrics[2])
val_log = 'val  auc: %.5f  acc: %.5f\n' % (val_metrics[1], val_metrics[2])
test_log = 'test  auc: %.5f  acc: %.5f\n' % (test_metrics[1], test_metrics[2])

log = '%s\n' % (data_flag) + train_log + val_log + test_log
print(log)

with open(os.path.join(output_path, '%s_log.txt' % (data_flag)), 'a') as f:
    f.write(log)  

writer.close()

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|██████████| 3/3 [01:49<00:00, 36.37s/it]


octmnist
train  auc: 0.95651  acc: 0.88023
val  auc: 0.95356  acc: 0.87768
test  auc: 0.93534  acc: 0.67400

