# Mac training Using Fluorescent and Brightfield Channels

We train a predefined "MacNet" CNN to identify alveolar (tissue resident) macrophages versus bone marrow (proxy for monocyte-derived) macrophages. 

The input will be 1-4 channels of brightfield, lipid stain (BODIPY), nuclear stain (Hoechst), mitochondria stain (MitoTracker Red), or cell autofluorescence in green/red/blue channels.

The output will be a binary classification of whether the cell is a bone marrow macrophage or alveolar macrophage

## Constants/Variables

In [1]:
PATH = r"D:\data\processed\autof_2"
NUM_FOLDS = 5
NUM_BATCHES = 2
NUM_EPOCHS = 10

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
%matplotlib inline

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch import nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torchvision import models
#from torchvision.transforms import functional

from torch.utils.tensorboard import SummaryWriter

from captum.attr import IntegratedGradients
from captum.attr import Saliency
from captum.attr import DeepLift
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz

import random
from macdataset import MacDataset
import macnet
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statistics

## Equal Class Sampler
This function identifies different classes from the dataframe label column and returns a WeightedRandomSampler such that each class is sampled equally in aggregate.

In [3]:
def equal_classes_sampler(df):
    class_count = np.array((df["label"].value_counts()))
    weight = 1. / class_count
    labels = list(df['label'])
    weights = np.array([weight[label] for label in labels])
    samples_weight = torch.from_numpy(weights).double()
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    return sampler

In [4]:
def visualize_samples(dataloader, num_samples):
    dataiter = iter(dataloader)
    data = dataiter.next()
    for _ in range(num_samples):
        X = data["image"][0][0]
        plt.imshow(X)
        plt.show()
        input()

## Transforms
This section defines the transforms used to augment the base data for training and testing.

In [5]:
class standardize_input(object):
    # single channel
    """
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std
    """
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        mean = np.mean(image)
        stdev = np.std(image)
        image = (image - mean)/stdev
        return {'image': torch.from_numpy(image),
                'label': label}
class rotate_90_input(object):
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        num_rot = random.randint(0, 3)
        image = torch.rot90(image,num_rot, [1,2])
        return {'image': image,
                'label': label}    

    def __repr__(self):
        return self.__class__.__name__ + '(size={0})'.format(self.size)
        
class center_crop(object):
    def __init__(self, size_range):
        self.range = size_range
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        orig = image.shape[2]
        crop_size = random.randint(int(self.range[0]/2), int(self.range[1]/2))*2
        p_size = int((orig - crop_size) / 2)
        image = functional.center_crop(image, crop_size)
        image = F.pad(input=image, pad=(p_size, p_size, p_size, p_size), mode='constant', value=0)
        return {'image': image,
                'label': label}

train_transforms = transforms.Compose([
    standardize_input(),
    rotate_90_input()
    ])
test_transforms = transforms.Compose([
    standardize_input()
    ])

In [6]:
csv_path = PATH + '\\' + 'labels.csv' 
raw_data = pd.read_csv(csv_path)

In [7]:
split_data = np.array_split(raw_data.sample(frac=1), NUM_FOLDS)

In [8]:
def train(dataloader, model, loss_fn, optimizer, batches):
    size = len(dataloader.dataset)
    total_done = 0
    correct = 0
    bad_batches = 0
    final_training_acc = 0
    
    for batch, data in enumerate(dataloader):
        try:
            X, y = data["image"].to(device), data["label"].to(device)

            # Compute prediction error
            pred = model(X.float())
            loss = loss_fn(torch.squeeze(pred), y.float())

            # Backpropagation
            model.train()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if batch % 25 == 0:
                loss, current = loss.item(), batch * len(X)
                #print(torch.squeeze(pred).round(), y)
                #input()
                correct += (torch.squeeze(pred).round() == y).type(torch.float).sum().item()
                total_done += batches
                training_acc = correct/total_done
                final_training_acc = training_acc
                print(f"Avg. Loss: {loss:>7f}, Accuracy: {training_acc:>.2%} [{current:>5d}/{size:>5d}]", end="\r")
        except:
            bad_batches += 1
    print()
    print("bad batches:" + str(bad_batches))
    return final_training_acc
    

def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct, bad_batches = 0, 0, 0
    with torch.no_grad():
        for data in dataloader:
            try:
                X, y = data["image"].to(device), data["label"].to(device)
                pred = model(X.float())
                test_loss += loss_fn(pred, torch.unsqueeze(y, 1).float()).item()
                correct += (torch.squeeze(pred).round() == y).type(torch.float).sum().item()
            except:
                bad_batches += 1
    test_loss /= size
    correct /= size
    print(f"\nTest Error: \nAvg. Loss: {test_loss:>7f}, Accuracy: {correct:>0.2%}\n" \
        , " bad_batches: " + str(bad_batches))
    return correct

In [None]:
testing_errors = []
training_errors = []
for i in range(len(split_data)):
    print("\n FOLD " + str(i + 1) + " OF " + str(NUM_FOLDS))
    print("=========================================================\n")

    train_idx = list(range(NUM_FOLDS))
    train_idx.remove(i)
    train_idx_start = train_idx.pop()
    train_df = split_data[train_idx_start].copy()
    for idx in train_idx:
        train_df = train_df.append(split_data[idx])
    
    train_transforms = transforms.Compose([
        standardize_input(),
        rotate_90_input()
        ])
    test_transforms = transforms.Compose([
        standardize_input()
        ])

    train_data = MacDataset(root_dir=PATH, dataframe=train_df,
                                transform=train_transforms)
    test_data = MacDataset(root_dir=PATH, dataframe=split_data[i],
                                transform=test_transforms)

    train_sampler = equal_classes_sampler(train_data.macs_frame)
    test_sampler = equal_classes_sampler(test_data.macs_frame)
    
    dataloader = DataLoader(train_data, batch_size=NUM_BATCHES, sampler=train_sampler,
                            shuffle=False, num_workers=0)

    dataloader_test = DataLoader(test_data, batch_size=NUM_BATCHES, sampler=test_sampler,
                            shuffle=False, num_workers=0)              

    # Get cpu or gpu device for training.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using {} device".format(device))
    
    model = macnet.Net().to(device)
    #print("\nConvolutional Neural Net Model:")
    #print(model)

    loss_fn = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print("\nTraining Start")

    training_error = []
    testing_error = []
    for t in range(NUM_EPOCHS):
        print(f"Epoch {t+1}\n-------------------------------")

        print("\nTraining Error:")
        training_error.append(train(dataloader, model, loss_fn, optimizer, NUM_BATCHES))
        testing_error.append(test(dataloader_test, model))
    training_errors.append(training_error[-1])
    curr_testing_error = testing_error[-1]
    if len(testing_errors) == 0 or curr_testing_error > max(testing_errors):
        torch.save(model, "./model")
    testing_errors.append(statistics.mean(testing_error))

training_errors = [round(error, 4) for error in training_errors]
testing_errors = [round(error, 4) for error in testing_errors]
print("training errors per fold")
print(training_errors)
print("testing errors per fold")
print(testing_errors)