# Virtually Straining a Biological Tissue with a Conditional GAN

The virtual_staining.ipynb notebook provides you with a complete code example to virtually stain a biological tissue with a conditional GAN

## Downloading the Human Motor Neurons Dataset

You’ll use a dataset that was originally published in the article: E. M. Christiansen et al., Cell 173:792-803, 2018. Specifically, you’ll use the human motor neurons dataset that is designated as “Condition A” in the article. This dataset comprises 22 pairs of brightfield and corresponding fluorescent images, with each pair including spatially registered images showcasing two fluorescent channels: Hoechst stain, revealing nuclei with a blue stain, and anti-TuJ1 stain, highlighting neurons in green. Notably, the brightfield images encompass a z-stack of 13 images across different focal planes, offering a comprehensive view of the cellular structures in question.

Download the dataset from Google Storage. The dataset is about 10 GB and will take at least a few minutes to download, or longer depending on your internet connection.

In [1]:
import os, subprocess

url, path = "gs://in-silico-labeling/paper_data/", "virtual_staining_dataset"
if not os.path.exists(path):
    for dataset in ["train", "test"]:
        dataset_url = url + dataset + "/Rubin/scott_1_0"
        dataset_dir = os.path.join(path, dataset)
        os.makedirs(dataset_dir, exist_ok=True)
        
        command = ["gsutil", "-m", "cp", "-r", dataset_url, dataset_dir]
        subprocess.run(command, check=True)

## Creating a Dataset

Implement a class containing the dataset ...

In [2]:
import glob
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm

class VirtualStainingDataset(Dataset):
    """Dataset containing the brighfield and fluorescence images."""

    _cache = {}  # Class variable to cache loaded images.

    def __init__(self, directory, transform=None, preload=False):
        """Initialize dataset."""
        self.transform, self.preload = transform, preload
        self.images = []

        self.image_dir = os.path.join(directory, "scott_1_0")
        pattern = ("lab-Rubin,condition-scott_1_0,acquisition_date,"
                   "year-2016,month-2,day-6,"
                   "well-r0*c0*,"
                   "depth_computation,"
                   "value-MAXPROJECT,is_mask-false,kind,value-ORIGINAL.png")   
        self.image_list = glob.glob(os.path.join(self.image_dir, pattern))

        self.cache_key = self.image_dir
        if self.preload:
            if self.cache_key in VirtualStainingDataset._cache:
                self.images = VirtualStainingDataset._cache[self.cache_key]
            else:
                for image_path in tqdm(self.image_list,
                                       total=len(self.image_list),
                                       desc="Preloading images"):
                    self.images.append(self.load_image(image_path))
                VirtualStainingDataset._cache[self.cache_key] = self.images

    def load_image(self, image_path):
        """Load input-target image couple."""
        input_image = []
        for i in range(0, 13):
            img_path = (image_path
                        .replace("depth_computation", f"z_depth-{i},channel")
                        .replace("value-MAXPROJECT", "value-BRIGHTFIELD"))
            input_image.append(np.array(Image.open(img_path).convert("L")))
        input_image = np.stack(input_image, axis=-1)

        target_image = np.array(Image.open(image_path))

        return input_image, target_image

    def __len__(self):
        """Return number of images."""
        return len(self.image_list)

    def __getitem__(self, idx):
        """Get next input-target image couple."""
        if self.preload:
            input_image, target_image = self.images[idx]
        else:
            input_image, target_image = self.load_image(self.image_list[idx])

        if self.transform:
            seed = np.random.randint(1_000_000_000)
            
            torch.manual_seed(seed)
            input_image = self.transform[0](input_image)

            torch.manual_seed(seed)
            target_image = self.transform[1](target_image)

        return input_image, target_image

... define the image transformations and normalizations ...

In [3]:
from torchvision import transforms as trans

trans_bright = trans.Compose([
    trans.ToTensor(), trans.RandomCrop((256, 256)), 
    trans.RandomHorizontalFlip(p=0.5), trans.RandomVerticalFlip(p=0.5),
    trans.Normalize(mean=[0.5] * 13, std=[0.5] * 13),
])
trans_fluorescent = trans.Compose([
    trans.ToTensor(), trans.RandomCrop((256, 256)), 
    trans.RandomHorizontalFlip(p=0.5), trans.RandomVerticalFlip(p=0.5),
    trans.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
])

... create the training and testing datasets ...

In [None]:
train_dataset = VirtualStainingDataset(
    directory=os.path.join(path, "train"),
    transform=[trans_bright, trans_fluorescent], preload=True,
)
test_dataset = VirtualStainingDataset(
    directory=os.path.join(path, "test"),
    transform=[trans_bright, trans_fluorescent], preload=True,
)

... and visualize brightfield and corresponding fluorescence images.

In [None]:
import matplotlib.pyplot as plt

bf_img, fl_img = train_dataset[np.random.randint(0, len(train_dataset))]

def denormalize(image):
    """Denormalize images for visualization."""
    return (image + 1) / 2

plt.figure()

plt.subplot(1, 3, 1)
plt.imshow(bf_img.mean(axis=0), cmap="gray")
plt.title("Brightfield Image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(denormalize(fl_img[1, :, :].numpy()), cmap="Greens")
plt.title("Stained Neurons")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(denormalize(fl_img[2, :, :].numpy()), cmap="Blues")
plt.title("Stained Nuclei")
plt.axis("off")

plt.tight_layout()
plt.show()

## Instantiating the Generator and Discriminator

Determine the device to be used in the computations ...

In [7]:
def get_device():
    """Select device where to perform computations."""
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

device = get_device()

... define the generator ...

In [None]:
import deeplay as dl
import torch.nn as nn

gen = dl.UNet2d(
    in_channels=13, channels=[32, 64, 128, 256, 512], out_channels=3,
)
gen["encoder", ..., "activation"].configure(nn.LeakyReLU, negative_slope=0.2)
gen["decoder", ..., "activation#:-1"] \
    .configure(nn.LeakyReLU, negative_slope=0.2)
gen["decoder", ..., "activation#-1"].configure(nn.Tanh)
gen["decoder", "blocks", :-1].all.normalized(nn.InstanceNorm2d)
gen[..., "blocks"].configure(order=["layer", "normalization", "activation"])
gen.build()
gen.to(device);

print(gen)


... and define the discriminator.

In [None]:
disc = dl.ConvolutionalNeuralNetwork(
    in_channels=16, hidden_channels=[8, 16, 32, 64], out_channels=1,
)
disc["blocks", ..., "layer"].configure(kernel_size=4, stride=2, padding=1)
disc["blocks", ..., "activation#:-1"] \
    .configure(nn.LeakyReLU, negative_slope=0.2)
disc["blocks", 1:-1].all.normalized(nn.InstanceNorm2d)
disc["blocks", ..., "activation#-1"].configure(nn.Sigmoid)
disc["blocks"].configure(order=["layer", "normalization", "activation"])
disc.build()
disc.to(device);

print(disc)

## Compiling the Conditional GAN

Define the losses ...

In [10]:
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

loss_disc = torch.nn.MSELoss()
loss_gen = torch.nn.L1Loss()
loss_LPIPS = LearnedPerceptualImagePatchSimilarity(net_type="vgg").to(device)

... define the optimizers ...

In [11]:
optim_gen = torch.optim.Adam(gen.parameters(), lr=.0002, betas=(.5, .999))
optim_disc = torch.optim.Adam(disc.parameters(), lr=.00005, betas=(.5, .999))

... and define the data loaders.

In [12]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, 
                                          shuffle=False)

## Training the Conditional GAN

Define a function to train the discriminator ...

In [13]:
def train_disc(inputs, targets, optim, loss_disc=loss_disc):
    """Train the discriminator."""
    optim.zero_grad()

    # Compute real loss.
    output_real = disc(torch.cat([inputs, targets], dim=1))
    label_real = torch.ones_like(output_real)
    loss_real = loss_disc(output_real, label_real)

    # Compute fake loss.
    output_fake = disc(torch.cat([inputs, gen(inputs)], dim=1))
    label_fake = torch.zeros_like(output_fake)
    loss_fake = loss_disc(output_fake, label_fake)

    loss = (loss_real + loss_fake) / 2
    loss.backward()
    optim.step()
    
    return loss

... define a function to train the generator ...

In [17]:
def train_gen(inputs, targets, optim, loss_disc=loss_disc, loss_gen=loss_gen,
              loss_LPIPS=loss_LPIPS, L1_Lambda=100, LPIPS_Lambda=10):
    """Train the generator."""
    optim.zero_grad()

    gen_output = gen(inputs)
    disc_output = disc(torch.cat([inputs, gen_output], dim=1))
    
    label = torch.ones_like(disc_output)
    loss_GAN = loss_disc(disc_output, label)
    loss_L1 = loss_gen(gen_output, targets)
    loss_P = loss_LPIPS(gen_output, targets)
    
    loss = loss_GAN + L1_Lambda * loss_L1 + LPIPS_Lambda * loss_P
    loss.backward()
    optim.step()

    return loss_GAN, loss_L1, loss_P

... and a function to evaluate the model on the test dataset ...

In [18]:
def evaluate_model(image, label):
    """Evaluate model on test data."""
    gen.eval()
    with torch.no_grad():
        prediction = gen(image)
    gen.train()

    fig, ax = plt.subplots(2, 3, figsize=(10, 5))

    ax[0, 0].imshow(
        denormalize(image[0].permute(1, 2, 0).cpu().numpy().mean(axis=-1)), 
        cmap="gray"
    )
    ax[0, 0].set_title("Input Image")
    ax[0, 0].axis("off")

    ax[0, 1].imshow(
        denormalize(prediction[0].permute(1, 2, 0).cpu().numpy())[:, :, 1], 
        cmap="Greens",
    )
    ax[0, 1].set_title("Prediction - Neurons")
    ax[0, 1].axis("off")

    ax[1, 1].imshow(
        denormalize(label[0].permute(1, 2, 0).cpu().numpy())[:, :, 1], 
        cmap="Greens",
    )
    ax[1, 1].set_title("Ground truth - Neurons")
    ax[1, 1].axis("off")

    ax[0, 2].imshow(
        denormalize(prediction[0].permute(1, 2, 0).cpu().numpy())[:, :, 2], 
        cmap="Blues",
    )
    ax[0, 2].set_title("Prediction - Nuclei")
    ax[0, 2].axis("off")

    ax[1, 2].imshow(
        denormalize(label[0].permute(1, 2, 0).cpu().numpy())[:, :, 2], 
        cmap="Blues",
    )
    ax[1, 2].set_title("Ground truth - Nuclei")
    ax[1, 2].axis("off")

    ax[1, 0].axis("off")  # Leave the [1, 0] subplot empty.

    plt.tight_layout()
    plt.show()
    plt.close(fig)

... and implement the training cycle.

In [None]:
import time
from datetime import timedelta

epochs, L1_Lambda, LPIPS_Lambda = 500, 100, 10

gen_losses, disc_losses, recon_losses, percep_losses = [], [], [], []
for epoch in range(epochs + 1):
    start_time = time.time()
 
    print("\n" + f"Epoch {epoch + 1}/{epochs}" + "\n" + "-" * 10)

    gen_loss_epochs, disc_loss_epochs, recon_loss_epochs, percep_loss_epochs \
        = [], [], [], []
    for i, (inputs, labels) in enumerate(train_loader, 0):
        inputs, labels = inputs.to(device), labels.to(device)

        # 1. Train the discriminator.
        disc_loss = train_disc(
            inputs=inputs, 
            targets=labels, 
            optim=optim_disc,
        )

        # 2. Train the generator.
        for _ in range(2):
            adv_loss, rec_loss, percep_loss = train_gen(
                inputs=inputs, 
                targets=labels, 
                optim=optim_gen, 
                L1_Lambda=L1_Lambda, 
                LPIPS_Lambda=LPIPS_Lambda,
            )

        gen_loss = adv_loss + rec_loss + percep_loss

        if i % 5 == 0:
            print(f"Batch {i+1}/{len(train_loader)} : "
                  + f"Generator Loss: {gen_loss.item():.4f}, "
                  + f"Discriminator Loss: {disc_loss.item():.4f}")
            
        gen_loss_epochs.append(adv_loss.item())
        disc_loss_epochs.append(disc_loss.item())
        recon_loss_epochs.append(rec_loss.item())
        percep_loss_epochs.append(percep_loss.item())
    
    end_time = time.time()
    
    if epoch % 50 == 0:
        for i, data, in enumerate(test_loader, 0):
            test_inputs, test_labels = data
            test_inputs = test_inputs.to(device)
            test_labels = test_labels.to(device)
            break
        evaluate_model(test_inputs, test_labels)
        
    print("-" * 10 + "\n" + f"Epoch {epoch + 1}/{epochs} : "
          + f"Generator loss: {np.mean(gen_loss_epochs):.4f}, "
          + f"Discriminator Loss: {np.mean(disc_loss_epochs):.4f}, "
          + f"Reconstruction Loss: {np.mean(recon_loss_epochs):.4f}, "
          + f"Perception Loss: {np.mean(percep_loss_epochs):.4f}"
          + "\n" f"Time taken: {timedelta(seconds=end_time - start_time)}")
    
    gen_losses.append(np.mean(gen_loss_epochs))
    disc_losses.append(np.mean(disc_loss_epochs))
    recon_losses.append(np.mean(recon_loss_epochs))
    percep_losses.append(np.mean(percep_loss_epochs))