In [1]:
! pip install medmnist
from tqdm import tqdm
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

import medmnist
from medmnist import INFO, Evaluator

Looking in indexes: https://repo.huaweicloud.com/repository/pypi/simple


In [2]:
data_flag = 'tissuemnist'
download = True

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

In [3]:
# preprocessing
train_data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),  # 随机将图片水平翻转
    transforms.RandomRotation(15),  # 随机旋转图片 (-15,15)
    transforms.Normalize(mean=[.5], std=[.5])
])
test_data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])
# load the data
train_dataset = DataClass(split='train', transform=train_data_transform, download=download)
test_dataset = DataClass(split='test', transform=test_data_transform, download=download)
pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

Downloading https://zenodo.org/record/6496656/files/tissuemnist.npz?download=1 to /root/.medmnist/tissuemnist.npz


  0%|          | 0/124962739 [00:00<?, ?it/s]

Using downloaded and verified file: /root/.medmnist/tissuemnist.npz
Using downloaded and verified file: /root/.medmnist/tissuemnist.npz


In [4]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
            
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels=1, num_classes=2):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(in_channels, num_classes):
    return ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, num_classes=num_classes)


def ResNet50(in_channels, num_classes):
    return ResNet(Bottleneck, [3, 4, 6, 3], in_channels=in_channels, num_classes=num_classes)

model =  ResNet18(in_channels=n_channels, num_classes=n_classes)


# define loss function and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

In [5]:
def test(split):
    model.eval()
    y_true = torch.tensor([]).to(device)
    y_score = torch.tensor([]).to(device)
    
    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs = inputs.to(device) 
            targets = targets.to(device)
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            
            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.cpu().numpy() 
        y_score = y_score.detach().cpu().numpy()
        
        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)
    
        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))

In [6]:
# train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device) 
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
NUM_EPOCHS=30
for epoch in range(NUM_EPOCHS):
    train_correct = 0
    train_total = 0
    test_correct = 0
    test_total = 0
    
    model.train()
    for inputs, targets in tqdm(train_loader):
        model.to(device) 
        inputs = inputs.to(device) 
        targets = targets.to(device)
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = model(inputs)
        
        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
        else:
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    test('test')

100%|██████████| 1293/1293 [01:34<00:00, 13.73it/s]


test  auc: 0.885  acc:0.581


100%|██████████| 1293/1293 [01:33<00:00, 13.85it/s]


test  auc: 0.895  acc:0.612


100%|██████████| 1293/1293 [01:35<00:00, 13.47it/s]


test  auc: 0.909  acc:0.631


100%|██████████| 1293/1293 [01:37<00:00, 13.27it/s]


test  auc: 0.915  acc:0.641


100%|██████████| 1293/1293 [01:28<00:00, 14.69it/s]


test  auc: 0.917  acc:0.650


100%|██████████| 1293/1293 [01:34<00:00, 13.63it/s]


test  auc: 0.917  acc:0.642


100%|██████████| 1293/1293 [01:35<00:00, 13.51it/s]


test  auc: 0.924  acc:0.658


100%|██████████| 1293/1293 [01:34<00:00, 13.67it/s]


test  auc: 0.926  acc:0.668


100%|██████████| 1293/1293 [01:35<00:00, 13.50it/s]


test  auc: 0.927  acc:0.668


100%|██████████| 1293/1293 [01:34<00:00, 13.63it/s]


test  auc: 0.927  acc:0.670


100%|██████████| 1293/1293 [01:37<00:00, 13.29it/s]


test  auc: 0.926  acc:0.665


100%|██████████| 1293/1293 [01:34<00:00, 13.72it/s]


test  auc: 0.928  acc:0.676


100%|██████████| 1293/1293 [01:36<00:00, 13.42it/s]


test  auc: 0.929  acc:0.667


100%|██████████| 1293/1293 [01:35<00:00, 13.52it/s]


test  auc: 0.932  acc:0.684


100%|██████████| 1293/1293 [01:35<00:00, 13.54it/s]


test  auc: 0.929  acc:0.674


100%|██████████| 1293/1293 [01:35<00:00, 13.54it/s]


test  auc: 0.930  acc:0.675


100%|██████████| 1293/1293 [01:33<00:00, 13.76it/s]


test  auc: 0.930  acc:0.680


100%|██████████| 1293/1293 [01:36<00:00, 13.40it/s]


test  auc: 0.929  acc:0.677


100%|██████████| 1293/1293 [01:36<00:00, 13.42it/s]


test  auc: 0.921  acc:0.646


100%|██████████| 1293/1293 [01:34<00:00, 13.63it/s]


test  auc: 0.931  acc:0.679


100%|██████████| 1293/1293 [01:33<00:00, 13.77it/s]


test  auc: 0.929  acc:0.679


100%|██████████| 1293/1293 [01:35<00:00, 13.49it/s]


test  auc: 0.928  acc:0.669


100%|██████████| 1293/1293 [01:35<00:00, 13.47it/s]


test  auc: 0.930  acc:0.677


100%|██████████| 1293/1293 [01:36<00:00, 13.45it/s]


test  auc: 0.925  acc:0.670


100%|██████████| 1293/1293 [01:35<00:00, 13.60it/s]


test  auc: 0.926  acc:0.674


100%|██████████| 1293/1293 [01:35<00:00, 13.51it/s]


test  auc: 0.927  acc:0.677


100%|██████████| 1293/1293 [01:35<00:00, 13.51it/s]


test  auc: 0.925  acc:0.660


100%|██████████| 1293/1293 [01:36<00:00, 13.43it/s]


test  auc: 0.925  acc:0.665


100%|██████████| 1293/1293 [01:35<00:00, 13.48it/s]


test  auc: 0.920  acc:0.657


100%|██████████| 1293/1293 [01:36<00:00, 13.45it/s]


test  auc: 0.924  acc:0.662


In [7]:
print('==> Evaluating ...')
test('train')
test('test')

==> Evaluating ...
train  auc: 0.973  acc:0.789
test  auc: 0.924  acc:0.662
