In [1]:
import os, cv2, time, copy
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils import data
from torchvision import models
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim import lr_scheduler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## define global variables
VERSION = 1.0

INPUT_SIZE = 512
NUM_CLASSES = 10
BATCH_SIZE = 7
MODEL_SAVE_NAME = 'deeplab_resnet50_model_50'

print("Using GPU: ", torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Using GPU:  True


In [3]:
class CustomDataset(data.Dataset):
    def __init__(self, inputs: list, targets: list, transform=None):
        self.inputs = inputs
        self.targets = targets
        self.transform = transform
        self.inputs_dtype = torch.float32
        self.targets_dtype = torch.long
    
    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index: int):
        # Select the sample
        input_ID = self.inputs[index]
        target_ID = self.targets[index]

        # Load input and target
        x = cv2.imread(input_ID) / 255
        y = cv2.imread(target_ID, cv2.IMREAD_UNCHANGED) # to read tif image, y is NHW
        
        x = np.moveaxis(x, -1, 0) # modify x from NHWC to NCHW

        # Preprocessing - should already be done in previous step, make sure to load in pre-processed data
        if self.transform is not None:
            x, y = self.transform(x, y)

        # Typecasting
        x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type(self.targets_dtype)

        return x, y

In [4]:
train_im_dir = r'\\babyserverdw5\Digital pathology image lib\JHU\Ie-Ming Shih\lymphocytes\230110 dataset\{}x{}_v{}\train\images'.format(INPUT_SIZE, INPUT_SIZE,VERSION)
train_ann_dir = r'\\babyserverdw5\Digital pathology image lib\JHU\Ie-Ming Shih\lymphocytes\230110 dataset\{}x{}_v{}\train\labels'.format(INPUT_SIZE, INPUT_SIZE, VERSION)

train_im_list = [os.path.join(train_im_dir, f) for f in os.listdir(train_im_dir) if f[-3:] != '.db']
train_ann_list = [os.path.join(train_ann_dir, f) for f in os.listdir(train_ann_dir)]

val_im_dir = r'\\babyserverdw5\Digital pathology image lib\JHU\Ie-Ming Shih\lymphocytes\230110 dataset\{}x{}_v{}\val\images'.format(INPUT_SIZE, INPUT_SIZE,VERSION)
val_ann_dir = r'\\babyserverdw5\Digital pathology image lib\JHU\Ie-Ming Shih\lymphocytes\230110 dataset\{}x{}_v{}\val\labels'.format(INPUT_SIZE, INPUT_SIZE, VERSION)

val_im_list = [os.path.join(val_im_dir, f) for f in os.listdir(val_im_dir) if f[-3:] != '.db']
val_ann_list = [os.path.join(val_ann_dir, f) for f in os.listdir(val_ann_dir)]

# train_im_list = [f for f in im_list if f[-5:] != '0.tif']
# train_ann_list = [f for f in ann_list if f[-5:] != '0.tif']

# val_im_list = [f for f in im_list if f[-5:] == '0.tif']
# val_ann_list = [f for f in ann_list if f[-5:] == '0.tif']

train_dataset = CustomDataset(inputs=train_im_list, targets=train_ann_list, transform=None)
train_dataloader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

val_dataset = CustomDataset(inputs=val_im_list, targets=val_ann_list, transform=None)
val_dataloader = data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [5]:
x, y = next(iter(train_dataloader))
print('## train samples ##')
print(f'x = shape: {x.shape}; type: {x.dtype}')
print(f'x = min: {x.min()}; max: {x.max()}')
print(f'y = shape: {y.shape}; class: {y.unique()}; type: {y.dtype}')
print(f'num samples: {len(train_dataset)}')

x, y = next(iter(val_dataloader))
print('## val samples ##')
print(f'x = shape: {x.shape}; type: {x.dtype}')
print(f'x = min: {x.min()}; max: {x.max()}')
print(f'y = shape: {y.shape}; class: {y.unique()}; type: {y.dtype}')
print(f'num samples: {len(val_dataset)}')

## train samples ##
x = shape: torch.Size([7, 3, 512, 512]); type: torch.float32
x = min: 0.027450980618596077; max: 1.0
y = shape: torch.Size([7, 512, 512]); class: tensor([0, 4, 5, 9]); type: torch.int64
num samples: 3600
## val samples ##
x = shape: torch.Size([7, 3, 512, 512]); type: torch.float32
x = min: 0.0117647061124444; max: 1.0
y = shape: torch.Size([7, 512, 512]); class: tensor([0, 1, 2, 4, 5, 8, 9]); type: torch.int64
num samples: 1200


In [6]:
model = torchvision.models.segmentation.deeplabv3_resnet50('COCO_WITH_VOC_LABELS_V1')
model.classifier[4] = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=(1, 1), stride=(1, 1))
model.aux_classifier[4] = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=(1, 1), stride=(1, 1))
# for p in model.backbone.parameters():
#     p.requires_grad = False
model = model.to(device)

In [None]:
# get class weighting of entire dataset
def get_total_class_weights():
    class_counts = np.zeros(NUM_CLASSES)
    for ii, (inputs, labels) in enumerate(train_dataloader):
        this_class_types, this_class_counts = np.unique(labels, return_counts=True) # returns tuple of (class_types, class_counts)
        for cls, cnt in zip(this_class_types, this_class_counts):
            class_counts[cls] += cnt
        print('{}/{}'.format(ii, len(train_dataloader)), end='\r')
    return class_counts

class_counts = get_total_class_weights()
print('Class counts:', class_counts)

440/515

In [None]:
class_counts_part = class_counts[:-1]

CLASS_WEIGHTS = class_counts_part.sum() / (class_counts_part.shape[0] * class_counts_part)
CLASS_WEIGHTS = np.concatenate((CLASS_WEIGHTS, np.zeros(1)),0)

CLASS_WEIGHTS = np.clip(CLASS_WEIGHTS, 0, 2000)

for ii, (c, w) in enumerate(zip(class_counts, CLASS_WEIGHTS)):
    print(f'cls {ii} | cnt {str(int(c)):12s} | weight {str(w):12s}')

In [None]:
print(CLASS_WEIGHTS)
def get_acc(model, dataloader):
    correct = 0
    total = 0
    for ii, (inputs, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.no_grad():
            outputs = model(inputs)['out']
            pred = torch.argmax(outputs, 1)
        
        gt_ind = torch.logical_and(labels > 0, labels < NUM_CLASSES-1)
        t = gt_ind.sum().item()
        
        if t > 0:
            c = torch.sum(pred[gt_ind] == labels[gt_ind]).item()
            correct += c
            total += t
        
    return correct / total
        

In [None]:
## training loop
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs, num_classes):
    
    model_save_dir = r'\\babyserverdw3\PW Cloud Exp Documents\Lab work documenting\W-22-09-10 AT Build Competent multi task DL model for tissue labeling/saved_models/230110/{}x{}_v{}'.format(INPUT_SIZE, INPUT_SIZE,VERSION)
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)
    path_to_saved_model = os.path.join(model_save_dir,'{}.tar'.format(MODEL_SAVE_NAME))
    print('saving model to: ', path_to_saved_model)
            
    model.train()
    best_acc = 0
    start = time.time()
    
    train_loader, val_loader = dataloaders
    
    for epoch in range(num_epochs):
        
        running_avg_loss = 0
        
        for ii, (inputs, labels) in enumerate(train_loader):
            inputs = inputs.to(device)
            labels_copy = labels
            labels = labels.to(device)

            # zero the parameter gradients

            outputs = model(inputs)['out']  
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_avg_loss = running_avg_loss * (ii/(ii+1)) + loss.item() * (1/(ii+1))
            
            if ii % 25 == 0:
                print(f'E{epoch}B{ii} | loss: {loss.item():.3f} ({running_avg_loss:.3f})')
        
        model.eval()
        val_acc = get_acc(model, val_loader)
        model.train()
        
        done = (epoch + 1) / num_epochs
        left = 1 - done
        eta = (time.time() - start) / done * left / 60
        
        print(f'Epoch {epoch}/{num_epochs} | loss: {running_avg_loss:.3f} | acc: {val_acc:.3f} | ETA: {eta:.2f} min')
        if val_acc > best_acc:
            print("===== Best validation performance - saving best model =====")
            best_acc = val_acc
            torch.save(model.state_dict(), path_to_saved_model)
            
        print()
        scheduler.step()

# just save the last model too
last_path_to_saved_model = os.path.join(model_save_dir,'{}_last.tar'.format(MODEL_SAVE_NAME))
torch.save(model.state_dict(), last_path_to_saved_model)

In [None]:
dataloaders = [train_dataloader, val_dataloader]
criterion = torch.nn.CrossEntropyLoss(weight=torch.from_numpy(CLASS_WEIGHTS).float()).to(device)
optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=30, gamma=0.1)

model = train_model(model, dataloaders, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=50, num_classes=NUM_CLASSES)

In [None]:
state_dict = torch.load(os.path.join(r'\\babyserverdw3\PW Cloud Exp Documents\Lab work documenting\W-22-09-10 AT Build Competent multi task DL model for tissue labeling/saved_models/230110/{}x{}_v{}'.format(INPUT_SIZE, INPUT_SIZE,VERSION), '{}.tar'.format(MODEL_SAVE_NAME)))
model.load_state_dict(state_dict)
model.eval()
print('===== Trained Model Summary =====')
print('Train Accuracy:      {:.3f}%'.format(100*get_acc(model, train_dataloader)))
print('Validation Accuracy: {:.3f}%'.format(100*get_acc(model, val_dataloader)))

In [None]:
# evaluate

import math

model.eval()   # Set model to the evaluation mode

# Get the first batch
inputs, labels = next(iter(train_dataloader))
inputs = inputs.to(device)
labels = labels.to(device)
print('inputs.shape', inputs.shape)
print('labels.shape', labels.shape)

# Predict
pred = model(inputs)['out']
# The loss functions include the sigmoid function.
pred = pred.data.cpu().numpy()
print('pred.shape', pred.shape)
pred = np.argmax(pred, 1)
print('pred.shape', pred.shape)
plt.subplot(1,2,1)
plt.title('true')
true = labels.data.cpu().numpy()
plt.imshow(true[0])
plt.subplot(1,2,2)
plt.title('pred')
plt.imshow(pred[0])
print(np.unique(pred[0]))