In [1]:
from tqdm import tqdm
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

import medmnist
from medmnist import INFO, Evaluator

In [53]:
data_flag = 'pathmnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [54]:
# preprocessing
train_data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),  # 随机将图片水平翻转
    transforms.RandomRotation(15),  # 随机旋转图片 (-15,15)
    transforms.Normalize(mean=[.5], std=[.5])
])
test_data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])
# load the data
train_dataset = DataClass(split='train', transform=train_data_transform, download=download)
test_dataset = DataClass(split='test', transform=test_data_transform, download=download)
pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

Using downloaded and verified file: C:\Users\lwl89\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\lwl89\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\lwl89\.medmnist\pathmnist.npz


In [55]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
            
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=1, num_classes=2):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(in_channels, num_classes):
    return ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, num_classes=num_classes)


def ResNet50(in_channels, num_classes):
    return ResNet(Bottleneck, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes)

model =  ResNet18(in_channels=n_channels, num_classes=n_classes)


# define loss function and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

In [56]:
def test(split):
    model.eval()
    y_true = torch.tensor([]).to(device)
    y_score = torch.tensor([]).to(device)
    
    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs = inputs.to(device) 
            targets = targets.to(device)
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            
            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.cpu().numpy() 
        y_score = y_score.detach().cpu().numpy()
        
        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)
    
        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))

In [57]:
# train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device) 
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
NUM_EPOCHS=30
for epoch in range(NUM_EPOCHS):
    train_correct = 0
    train_total = 0
    test_correct = 0
    test_total = 0
    
    model.train()
    for inputs, targets in tqdm(train_loader):
        model.to(device) 
        inputs = inputs.to(device) 
        targets = targets.to(device)
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = model(inputs)
        
        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
        else:
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    test('test')

100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:28<00:00,  7.96it/s]


test  auc: 0.946  acc:0.709


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:28<00:00,  7.92it/s]


test  auc: 0.958  acc:0.720


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:28<00:00,  7.92it/s]


test  auc: 0.934  acc:0.618


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.90it/s]


test  auc: 0.955  acc:0.792


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:28<00:00,  7.91it/s]


test  auc: 0.951  acc:0.719


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.91it/s]


test  auc: 0.963  acc:0.772


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.86it/s]


test  auc: 0.966  acc:0.732


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:30<00:00,  7.82it/s]


test  auc: 0.963  acc:0.716


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.87it/s]


test  auc: 0.971  acc:0.767


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.83it/s]


test  auc: 0.973  acc:0.795


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.84it/s]


test  auc: 0.978  acc:0.814


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.84it/s]


test  auc: 0.963  acc:0.739


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.83it/s]


test  auc: 0.970  acc:0.814


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:30<00:00,  7.82it/s]


test  auc: 0.972  acc:0.858


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.89it/s]


test  auc: 0.973  acc:0.857


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.86it/s]


test  auc: 0.975  acc:0.842


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.86it/s]


test  auc: 0.984  acc:0.830


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.89it/s]


test  auc: 0.973  acc:0.733


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.87it/s]


test  auc: 0.982  acc:0.869


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.88it/s]


test  auc: 0.985  acc:0.866


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.89it/s]


test  auc: 0.984  acc:0.865


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.86it/s]


test  auc: 0.975  acc:0.807


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.88it/s]


test  auc: 0.983  acc:0.853


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:31<00:00,  7.70it/s]


test  auc: 0.980  acc:0.822


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.86it/s]


test  auc: 0.985  acc:0.876


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.86it/s]


test  auc: 0.964  acc:0.761


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.84it/s]


test  auc: 0.981  acc:0.852


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.85it/s]


test  auc: 0.977  acc:0.797


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.88it/s]


test  auc: 0.975  acc:0.813


100%|████████████████████████████████████████████████████████████████████████████████| 704/704 [01:29<00:00,  7.89it/s]


test  auc: 0.980  acc:0.846


In [58]:
print('==> Evaluating ...')
test('train')
test('test')

==> Evaluating ...
train  auc: 1.000  acc:0.991
test  auc: 0.980  acc:0.846
