In [1]:
%matplotlib inline

In [2]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms,utils
import matplotlib.pyplot as plt
import time
import os
import copy
import pandas as pd
from skimage import io, transform
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from efficientnet_pytorch import EfficientNet
import PIL
import torch.multiprocessing
plt.ion()   # interactive mode

In [3]:
model_name = 'efficientnet-b0'
#writer = SummaryWriter("/p300/Result_New_Tboard_Tumor_Normal/")
# torch.multiprocessing.set_sharing_strategy('file_system')
#writer_path = '/p300/Tboard_try_paper'
writer_path = '/Users/weizhenliu/Downloads/Tboard_newest'
# data_dir = '/p300/hymenoptera_data'
data_dir = '/Users/weizhenliu/Downloads/AntsAndBees'
BatchSize = 8
NumWorkers = 2
# DownsamplePCT = 0.05
learning_rate = 0.001
# image_size = EfficientNet.get_image_size(model_name)
writer = SummaryWriter(writer_path)
WeightDecay=0
StepSize = 300
Gamma = 0.1
Momentum = 0.9
Epoch = 50

In [4]:
class ImageFolderWithPaths(datasets.ImageFolder):
    """Custom dataset that includes image file paths. Extends
    torchvision.datasets.ImageFolder
    """

    # override the __getitem__ method. this is the method that dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns 
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = (original_tuple + (path,))
        return tuple_with_path

In [5]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# data_dir = '/Users/weizhenliu/Downloads/AntsAndBees'
image_datasets = {x: ImageFolderWithPaths(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
Train_WSI_name_label = {}
Val_WSI_name_label = {}

class_label = image_datasets['train'].class_to_idx
group = [d for d in os.listdir(data_dir) if not d[0] == '.']
classes = [d for d in os.listdir(os.path.join(data_dir,group[0])) if not d[0] == '.']

for c in classes:
    for f in os.listdir(os.path.join(data_dir,"train",c)):
        if str(f)[-3:] == "jpg" and f.split("_")[0] not in Train_WSI_name_label:
            Train_WSI_name_label[f.split("_")[0]] = class_label[c]
            
TrainWSIlabel = [j for i,j in Train_WSI_name_label.items()]

for c in classes:
    for f in os.listdir(os.path.join(data_dir,"val",c)):
        if str(f)[-3:] == "jpg" and f.split("_")[0] not in Val_WSI_name_label:
            Val_WSI_name_label[f.split("_")[0]] = class_label[c]

ValWSIlabel = [j for i,j in Val_WSI_name_label.items()]

# print(Train_WSI_name_label,TrainWSIlabel)
# print("---------------------------------")
# print(Val_WSI_name_label,ValWSIlabel)
WSI_label = {"train":TrainWSIlabel,"val":ValWSIlabel}
WSI_name_label = {"train":Train_WSI_name_label,"val":Val_WSI_name_label}
# print(WSI_name_label["train"])
# print("----------")
# print(WSI_name_label["val"])


In [7]:
# WSI_preds = [[0 for i in range(2)] for i in range(len(WSI_label[phase]))]

In [8]:
print(WSI_label["val"])

[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]


In [9]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            WSI_preds = [[0 for i in range(2)] for i in range(len(WSI_label[phase]))]
            # Iterate over data.
            for inputs, labels, paths in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                WSI_name = [os.path.basename(path).split("_")[0] for path in paths]
#                 print("WSI_name:",WSI_name)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    pred = [p.item() for p in preds]
#                     print('pred:',pred)
                    for i,j in zip(WSI_name,pred):
                        WSI_preds[list(WSI_name_label[phase].keys()).index(i)][j] += 1
#                         print("running WSI_preds:",WSI_preds,"\t","---------------------")
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
#                 print("preds:",preds,"type:",type(preds))
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
#             print(phase,"OneEoch:",WSI_preds)
            Final_WSI_preds = [n.index(max(n)) for n in WSI_preds]
#             print(phase,"Final_WSI_preds:",torch.tensor(Final_WSI_preds))
#             print(phase,"torch.tensor(WSI_label[phase]).double():",torch.tensor(WSI_label[phase]).double())
#             print("torch.tensor(Final_WSI_preds) == torch.tensor(WSI_label[phase]):",torch.tensor(Final_WSI_preds) == torch.tensor(WSI_label[phase]))
#             print("sum:",torch.sum(torch.tensor(Final_WSI_preds) == torch.tensor(WSI_label[phase]).double()))
#             print("len(Final_WSI_preds):",len(Final_WSI_preds))
            WSI_Acc = torch.sum(torch.tensor(Final_WSI_preds) == torch.tensor(WSI_label[phase])).double()/len(Final_WSI_preds)
#             print("epoch_acc:",epoch_acc,type(epoch_acc))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            print('{} WSI Acc: {:.4f}'.format(phase, WSI_Acc))
            writer.add_scalar(phase+'/Loss', epoch_loss, epoch)
            writer.add_scalar(phase+'/Accuracy', epoch_acc, epoch)
            writer.add_scalar(phase+'/WSI Accuracy', WSI_Acc, epoch)
            writer.flush()

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [10]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [11]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

Epoch 0/24
----------
train Loss: 0.5860 Acc: 0.6653
train WSI Acc: 0.8077
val Loss: 0.4316 Acc: 0.8235
val WSI Acc: 0.9091

Epoch 1/24
----------
train Loss: 0.4308 Acc: 0.8367
train WSI Acc: 1.0000
val Loss: 0.2636 Acc: 0.9281
val WSI Acc: 1.0000

Epoch 2/24
----------
train Loss: 0.2722 Acc: 0.9469
train WSI Acc: 1.0000
val Loss: 0.2105 Acc: 0.9412
val WSI Acc: 1.0000

Epoch 3/24
----------
train Loss: 0.2635 Acc: 0.9020
train WSI Acc: 1.0000
val Loss: 0.1904 Acc: 0.9346
val WSI Acc: 1.0000

Epoch 4/24
----------
train Loss: 0.1961 Acc: 0.9429
train WSI Acc: 1.0000
val Loss: 0.1761 Acc: 0.9412
val WSI Acc: 1.0000

Epoch 5/24
----------
train Loss: 0.1656 Acc: 0.9551
train WSI Acc: 1.0000
val Loss: 0.1690 Acc: 0.9412
val WSI Acc: 1.0000

Epoch 6/24
----------
train Loss: 0.1618 Acc: 0.9429
train WSI Acc: 1.0000
val Loss: 0.1637 Acc: 0.9477
val WSI Acc: 1.0000

Epoch 7/24
----------
train Loss: 0.1293 Acc: 0.9633
train WSI Acc: 1.0000
val Loss: 0.1629 Acc: 0.9477
val WSI Acc: 1.0000



Train and evaluate
^^^^^^^^^^^^^^^^^^

It should take around 15-25 min on CPU. On GPU though, it takes less than a
minute.


