In [1]:
import torch
import os
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

import torch.utils.data as data

import torchvision.models as models
import torch.nn as nn

from PIL import Image

import glob
from random import shuffle

制作数据集

In [2]:
# 数据分类
files = glob.glob('../data/patches/train/*/*.tiff')
shuffle(files)

train_cancer = 0
train_normal = 0

train_imgs = []
val_imgs = []
test_imgs = []

# 先遍历一遍得到总数
cancer_sum = 0
normal_sum = 0
for file in files:
    if file.split('/')[-2].find('cancer') > -1:
        cancer_sum += 1
    else:
        normal_sum += 1
# 再遍历一遍分类
p80 = len(files) * 4 // 5

for file in files:
    # _type = 0 if file.split('/')[-2].find('cancer') > -1 else 1
    if 'cancer' in file:
    # if file.split('/')[-2].find('cancer') > -1:
        # 是cancer
        if train_cancer < cancer_sum * 4 // 5:
            train_imgs.append( ( file, 0 ) )
            train_cancer += 1
        else:
            val_imgs.append( ( file, 0 ) )
    else:
        # 是 normal
        if train_normal < normal_sum * 4 // 5:
            train_imgs.append( ( file, 1 ) )
            train_normal += 1
        else:
            val_imgs.append( ( file, 1 ) )

print('cancer_sum: %5d, normal_sum: %5d, train_cancer: %5d, train_normal: %5d, ' %
      (cancer_sum, normal_sum, train_cancer, train_normal))

test_files = glob.glob('../data/patches/test/*/*.tiff')
for file in test_files:
    if 'cancer' in file:
    # if file.split('/')[-2].find('cancer') > -1:
        # 是cancer
        test_imgs.append( ( file, 0 ) )
    else:
        # 是 normal
        test_imgs.append( ( file, 1 ) )
print('test_imgs_sum :', len(test_imgs))

cancer_sum: 56123, normal_sum: 57033, train_cancer: 44898, train_normal: 45626, 
test_imgs_sum : 4806


In [3]:
# 数据集的加载
def default_loader(path):
    return Image.open(path).convert('RGB')


class MyDataset(data.Dataset):
    def __init__(self, train=False, val=False, test=False, transform=None, target_transform=None, loader=default_loader):
        if train:
            self.imgs = train_imgs
        elif val:
            self.imgs = val_imgs
        elif test:
            self.imgs = test_imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)

In [4]:
# default_loader(file)["size"]
# default_loader(file).size
# type(default_loader(file))

数据集的预处理

In [4]:

transform = transforms.Compose(
    [
        # transforms.RandomSizedCrop(224),
        # transforms.RandomCrop(32, padding=2),
        transforms.RandomHorizontalFlip(),
#         transforms.RandomVerticalFlip(),
        # transforms.Scale(244),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

trainset = MyDataset(train=True, transform=transform)
valset = MyDataset(val=True, transform=transform)
testset = MyDataset(test=True, transform=transform)

trainloader = data.DataLoader(trainset, batch_size=180,
                                          shuffle=True, num_workers=4)
valloader = data.DataLoader(valset, batch_size=50,
                                          shuffle=False, num_workers=4)
testloader = data.DataLoader(testset, batch_size=50,
                                         shuffle=False, num_workers=4)

classes = ('cancer', 'normal')

定义net或加载之前的net

In [5]:
use_cuda = torch.cuda.is_available()

best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

# 选网络
net = models.resnet18(num_classes=2)
# print('net', net)

加载之前的

In [9]:
checkpoint = torch.load('./checkpoint/ckpt.t7')
net = checkpoint['net']
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

In [10]:
if use_cuda:
    net.cuda()
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

# 加载之前的参数
# net.load_state_dict( torch.load('../net_state/resnet18_patches_epoch104_params.pkl') )

In [7]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

import torch.optim.lr_scheduler as lr_scheduler
# scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer)

定义训练和验证

In [9]:
import time

def train(epoch):
    print('\nEpoch: %d' % epoch)
    t1 = time.time()
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        # get the inputs
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        # zero the parameter gradients
        optimizer.zero_grad()
        # wrap them in Variable
        inputs, targets = Variable(inputs), Variable(targets)
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

        if batch_idx % 200 == 199:    # print every 200 mini-batches
            print('m-b %4d loss: %.3f | Acc: %.3f%% | lr: %.4f | time: %.2f' %
                ( batch_idx+1, train_loss/batch_idx+1, 100.*correct/total, optimizer.param_groups[0]['lr'], time.time() - t1 ) )
            
def validation(epoch):
    global best_acc
    net.eval()
    t1 = time.time()
    val_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(valloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs, volatile=True), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    print('val loss: %.3f | Acc: %.3f%% | lr: %.4f | time: %.2f' %
        ( val_loss/batch_idx+1, 100.*correct/total, optimizer.param_groups[0]['lr'], time.time() - t1 ) )

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.module if use_cuda else net,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.t7')
        best_acc = acc

    return val_loss / len(valloader)


In [10]:
%%time
# start_epoch = 0
for epoch in range(start_epoch, start_epoch + 80):
    train(epoch)
    val_loss = validation(epoch)
    scheduler.step(val_loss)


Epoch: 25
m-b  200 loss: 1.002 | Acc: 99.908% | lr: 0.0100 | time: 139.03
m-b  400 loss: 1.002 | Acc: 99.914% | lr: 0.0100 | time: 273.62
val loss: 1.001 | Acc: 99.947% | lr: 0.0100 | time: 38.26
Saving..

Epoch: 26
m-b  200 loss: 1.002 | Acc: 99.922% | lr: 0.0100 | time: 135.29
val loss: 1.002 | Acc: 99.943% | lr: 0.0100 | time: 37.73

Epoch: 27
m-b  200 loss: 1.002 | Acc: 99.947% | lr: 0.0100 | time: 135.35
m-b  400 loss: 1.002 | Acc: 99.949% | lr: 0.0100 | time: 270.24
val loss: 1.003 | Acc: 99.916% | lr: 0.0100 | time: 37.71

Epoch: 28
m-b  200 loss: 1.002 | Acc: 99.933% | lr: 0.0100 | time: 135.41
m-b  400 loss: 1.002 | Acc: 99.936% | lr: 0.0100 | time: 270.46
val loss: 1.002 | Acc: 99.943% | lr: 0.0100 | time: 37.71

Epoch: 29
m-b  200 loss: 1.002 | Acc: 99.958% | lr: 0.0100 | time: 135.81
m-b  400 loss: 1.002 | Acc: 99.954% | lr: 0.0100 | time: 271.27
val loss: 1.002 | Acc: 99.947% | lr: 0.0100 | time: 37.81

Epoch: 30
m-b  200 loss: 1.001 | Acc: 99.958% | lr: 0.0100 | time: 13

val loss: 1.002 | Acc: 99.925% | lr: 0.0001 | time: 37.60

Epoch: 68
m-b  200 loss: 1.001 | Acc: 99.992% | lr: 0.0001 | time: 135.94
m-b  400 loss: 1.001 | Acc: 99.982% | lr: 0.0001 | time: 271.45
val loss: 1.002 | Acc: 99.934% | lr: 0.0001 | time: 37.73

Epoch: 69
m-b  400 loss: 1.001 | Acc: 99.972% | lr: 0.0001 | time: 271.32
val loss: 1.001 | Acc: 99.943% | lr: 0.0001 | time: 37.80

Epoch: 70
m-b  200 loss: 1.001 | Acc: 99.964% | lr: 0.0001 | time: 135.42
m-b  400 loss: 1.001 | Acc: 99.969% | lr: 0.0001 | time: 270.41
val loss: 1.001 | Acc: 99.947% | lr: 0.0001 | time: 37.76

Epoch: 71
m-b  200 loss: 1.001 | Acc: 99.969% | lr: 0.0001 | time: 135.52
m-b  400 loss: 1.001 | Acc: 99.967% | lr: 0.0001 | time: 270.53
val loss: 1.002 | Acc: 99.938% | lr: 0.0001 | time: 37.68

Epoch: 72
m-b  200 loss: 1.001 | Acc: 99.972% | lr: 0.0001 | time: 135.84
m-b  400 loss: 1.001 | Acc: 99.974% | lr: 0.0001 | time: 271.25
val loss: 1.001 | Acc: 99.956% | lr: 0.0001 | time: 37.65

Epoch: 73
m-b  200 l

test

In [11]:
%%time
def test(epoch):
    print(epoch)
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(testloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs, volatile=True), Variable(targets)
        outputs = net(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    print('test loss: %.3f | Acc: %.3f%% | correct: %5d | total: %5d' %
        ( test_loss/batch_idx+1, 100.*correct/total, correct, total ) )
test(start_epoch)

31
test loss: 1.302 | Acc: 92.717% | correct:  4456 | total:  4806
CPU times: user 6.14 s, sys: 1.4 s, total: 7.54 s
Wall time: 7.6 s


In [12]:
torch.save(net.state_dict(), '../net_state/resnet18_patches_epoch104_params.pkl')