# Virtually Straining a Biological Tissue with a Conditional GAN

<div style="background-color: #f0f8ff; border: 2px solid #4682b4; padding: 10px;">
<a href="https://colab.research.google.com/github/DeepTrackAI/DeepLearningCrashCourse/blob/main/Ch09_GAN/ec09_B_virtual_staining/virtual_staining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<strong>If using Colab/Kaggle:</strong> You need to uncomment the code in the cell below this one.
</div>

In [None]:
# !pip install deeplay  # Uncomment if using Colab/Kaggle.

The virtual_staining.ipynb notebook provides you with a complete code example to virtually stain a biological tissue with a conditional GAN

<div style="background-color: #f0f8ff; border: 2px solid #4682b4; padding: 10px;">
<strong>Note:</strong> This notebook contains the Code Example 9-B from the book  

**Deep Learning Crash Course**  
Benjamin Midtvedt, Jesús Pineda, Henrik Klein Moberg, Harshith Bachimanchi, Joana B. Pereira, Carlo Manzo, Giovanni Volpe  
No Starch Press, San Francisco (CA), 2025  
ISBN-13: 9781718503922  

[https://nostarch.com/deep-learning-crash-course](https://nostarch.com/deep-learning-crash-course)

You can find the other notebooks on the [Deep Learning Crash Course GitHub page](https://github.com/DeepTrackAI/DeepLearningCrashCourse).
</div>

## Downloading the Human Motor Neurons Dataset

You’ll use a dataset that was originally published in the article: E. M. Christiansen et al., Cell 173:792-803, 2018. Specifically, you’ll use the human motor neurons dataset that is designated as “Condition A” in the article. This dataset comprises 22 pairs of brightfield and corresponding fluorescent images, with each pair including spatially registered images showcasing two fluorescent channels: Hoechst stain, revealing nuclei with a blue stain, and anti-TuJ1 stain, highlighting neurons in green. Notably, the brightfield images encompass a z-stack of 13 images across different focal planes, offering a comprehensive view of the cellular structures in question.

Download the virtual staining dataset.

In [None]:
import os

if not os.path.exists("vs_dataset"):
    os.system("git clone https://github.com/DeepTrackAI/vs_dataset")

## Creating a Dataset

Implement a class containing the dataset ...

In [None]:
import glob
import numpy as np
from PIL import Image
import torch
from tqdm import tqdm

class VirtualStainingDataset(torch.utils.data.Dataset):
    """Dataset containing the brightfield and fluorescence images."""

    _cache = {}  # Class variable to cache loaded images.

    def __init__(self, dir, transforms=None, preload=False):
        """Initialize dataset."""
        self.dir, self.transforms, self.preload = dir, transforms, preload
        self.images = []

        pattern = ("lab-Rubin,condition-scott_1_0,acquisition_date,"
                   "year-2016,month-2,day-6,well-r0*c0*,depth_computation,"
                   "value-MAXPROJECT,is_mask-false,kind,value-ORIGINAL.png")
        self.image_list = glob.glob(os.path.join(self.dir, pattern))

        if self.preload:
            if dir in VirtualStainingDataset._cache:
                self.images = VirtualStainingDataset._cache[dir]
            else:
                for image_path in tqdm(self.image_list,
                                       total=len(self.image_list),
                                       desc="Preloading images ..."):
                    self.images.append(self.load_image(image_path))
                VirtualStainingDataset._cache[dir] = self.images

    def load_image(self, image_path):
        """Load input-target image couple."""
        input_image = []
        for i in range(13):
            image_path_i = (
                image_path.replace("MAXPROJECT", "BRIGHTFIELD")
                .replace("depth_computation", f"z_depth-{i},channel")
            )
            input_image.append(np.array(Image.open(image_path_i).convert("L")))
        input_image = np.stack(input_image, axis=-1)

        target_image = np.array(Image.open(image_path))

        return input_image, target_image

    def __len__(self):
        """Return number of images."""
        return len(self.image_list)

    def __getitem__(self, i):
        """Get input-target image couple."""
        if self.preload:
            input_image, target_image = self.images[i]
        else:
            input_image, target_image = self.load_image(self.image_list[i])

        if self.transforms:
            seed = np.random.randint(1_000_000_000)
            torch.manual_seed(seed)
            input_image = self.transforms[0](input_image)
            torch.manual_seed(seed)
            target_image = self.transforms[1](target_image)

        return input_image, target_image

... define the image transformations and normalizations ...

In [None]:
from torchvision import transforms

trans_bright = transforms.Compose([
    transforms.ToTensor(), transforms.RandomCrop((256, 256)),
    transforms.RandomHorizontalFlip(p=.5), transforms.RandomVerticalFlip(p=.5),
    transforms.Normalize(mean=[0.5] * 13, std=[0.5] * 13),
])
trans_fluorescent = transforms.Compose([
    transforms.ToTensor(), transforms.RandomCrop((256, 256)),
    transforms.RandomHorizontalFlip(p=.5), transforms.RandomVerticalFlip(p=.5),
    transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
])

... create the training and testing datasets ...

In [None]:
train_set = VirtualStainingDataset(
    dir=os.path.join("vs_dataset", "train"),
    transforms=[trans_bright, trans_fluorescent], preload=True,
)
test_set = VirtualStainingDataset(
    dir=os.path.join("vs_dataset", "test"),
    transforms=[trans_bright, trans_fluorescent], preload=True,
)

... and visualize brightfield and corresponding fluorescence images.

In [None]:
import matplotlib.pyplot as plt

input_image, target_image = train_set[np.random.randint(0, len(train_set))]

def denormalize(image):
    """Denormalize image for visualization."""
    return (image + 1) / 2

plt.subplot(1, 3, 1)
plt.imshow(input_image.mean(axis=0), cmap="gray")
plt.title("Brightfield Image")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(denormalize(target_image[1, :, :].numpy()), cmap="Greens")
plt.title("Stained Neurons")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(denormalize(target_image[2, :, :].numpy()), cmap="Blues")
plt.title("Stained Nuclei")
plt.axis("off")

plt.tight_layout()
plt.show()

## Instantiating the Generator and Discriminator

Determine the device to be used in the computations ...

In [None]:
def get_device():
    """Select device where to perform the computations."""
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

device = get_device()

print(device)

... define the generator ...

In [None]:
import deeplay as dl
from torch.nn import InstanceNorm2d, LeakyReLU, Tanh

gen = dl.UNet2d(
    in_channels=13, channels=[32, 64, 128, 256, 512], out_channels=3,
)
gen["encoder", ..., "activation"].configure(LeakyReLU, negative_slope=0.2)
gen["decoder", ..., "activation#:-1"].configure(LeakyReLU, negative_slope=0.2)
gen["decoder", ..., "activation#-1"].configure(Tanh)
gen["decoder", "blocks", :-1].all.normalized(InstanceNorm2d)
gen[..., "blocks"].configure(order=["layer", "normalization", "activation"])
gen.build().to(device);

... print out the generator's architecture ...

In [None]:
print(gen)

... define the discriminator ...

In [None]:
from torch.nn import Sigmoid

disc = dl.ConvolutionalNeuralNetwork(
    in_channels=16, hidden_channels=[8, 16, 32, 64], out_channels=1,
)
disc["blocks", ..., "layer"].configure(kernel_size=4, stride=2, padding=1)
disc["blocks", ..., "activation#-1"].configure(LeakyReLU, negative_slope=0.2)
disc["blocks", 1:-1].all.normalized(InstanceNorm2d)
disc["blocks", ..., "activation#-1"].configure(Sigmoid)
disc["blocks"].configure(order=["layer", "normalization", "activation"])
disc.build().to(device);

... and print out the discriminator's architecture.

In [None]:
print(disc)

## Compiling the Conditional GAN

Define the losses ...

In [None]:
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

loss_disc = torch.nn.MSELoss()
loss_recon = torch.nn.L1Loss()
loss_percep = LearnedPerceptualImagePatchSimilarity(net_type="vgg").to(device)

... define the optimizers ...

In [None]:
optim_gen = torch.optim.Adam(gen.parameters(), lr=.0002, betas=(.5, .999))
optim_disc = torch.optim.Adam(disc.parameters(), lr=.00005, betas=(.5, .999))

... and define the data loaders.

In [None]:
train_loader = dl.DataLoader(train_set, batch_size=2, shuffle=True)
test_loader = dl.DataLoader(test_set, batch_size=2, shuffle=False)

## Training the Conditional GAN

Define a function to train the discriminator ...

In [None]:
def train_disc(inputs, targets, optim_disc, loss_disc):
    """Train the discriminator."""
    optim_disc.zero_grad()

    # Compute real loss.
    disc_outputs_real = disc(torch.cat([inputs, targets], dim=1))
    labels_real = torch.ones_like(disc_outputs_real)
    disc_loss_real = loss_disc(disc_outputs_real, labels_real)

    # Compute fake loss.
    gen_outputs = gen(inputs)
    disc_outputs_fake = disc(torch.cat([inputs, gen_outputs], dim=1))
    labels_fake = torch.zeros_like(disc_outputs_fake)
    disc_loss_fake = loss_disc(disc_outputs_fake, labels_fake)

    disc_loss = (disc_loss_real + disc_loss_fake) / 2
    disc_loss.backward()
    optim_disc.step()

    return disc_loss

... define a function to train the generator ...

In [None]:
def train_gen(inputs, targets, optim_gen, loss_disc, loss_recon, loss_percep,
              recon_coef=100, percep_coef=10):
    """Train the generator."""
    optim_gen.zero_grad()

    gen_outputs = gen(inputs)
    disc_outputs = disc(torch.cat([inputs, gen_outputs], dim=1))

    labels = torch.ones_like(disc_outputs)
    adv_loss = loss_disc(disc_outputs, labels)
    recon_loss = loss_recon(gen_outputs, targets)
    percep_loss = loss_percep(gen_outputs, targets)

    gen_loss = adv_loss + recon_coef * recon_loss + percep_coef * percep_loss
    gen_loss.backward()
    optim_gen.step()

    return gen_loss, adv_loss, recon_loss, percep_loss

... and a function to evaluate the model on the test dataset ...

In [None]:
def evaluate_model(input, target):
    """Evaluate model on test data."""
    gen.eval()
    with torch.no_grad():
        prediction = gen(input.to(device))
    gen.train()

    fig, ax = plt.subplots(2, 3, figsize=(10, 7))
    
    ax[0, 0].imshow(
        denormalize(input[0].permute(1, 2, 0).cpu().numpy().mean(axis=-1)),
        cmap="gray",
    )
    ax[0, 0].set_title("Input Image")
    ax[0, 0].axis("off")
    
    ax[0, 1].imshow(
        denormalize(prediction[0].permute(1, 2, 0).cpu().numpy())[:, :, 1],
        cmap="Greens",
    )
    ax[0, 1].set_title("Prediction - Neurons")
    ax[0, 1].axis("off")

    ax[1, 1].imshow(
        denormalize(target[0].permute(1, 2, 0).cpu().numpy())[:, :, 1],
        cmap="Greens",
    )
    ax[1, 1].set_title("Ground Truth - Neurons")
    ax[1, 1].axis("off")

    ax[0, 2].imshow(
        denormalize(prediction[0].permute(1, 2, 0).cpu().numpy())[:, :, 2],
        cmap="Blues",
    )
    ax[0, 2].set_title("Prediction - Nuclei")
    ax[0, 2].axis("off")

    ax[1, 2].imshow(
        denormalize(target[0].permute(1, 2, 0).cpu().numpy())[:, :, 2],
        cmap="Blues",
    )
    ax[1, 2].set_title("Ground Truth - Nuclei")
    ax[1, 2].axis("off")

    ax[1, 0].axis("off")  # Leave the [1, 0] subplot empty.

    plt.tight_layout()
    plt.show()
    plt.close(fig)

... and implement the training cycle.

In [None]:
import time
from datetime import timedelta

epochs = 500

disc_losses, gen_losses, recon_losses, percep_losses = [], [], [], []
for epoch in range (epochs):
    start_time = time.time()
    print("\n" + f"Epoch {epoch + 1}/{epochs}" + "\n" + "-" * 10)

    disc_loss_epoch, gen_loss_epoch, recon_loss_epoch, percep_loss_epoch = \
        [], [], [], []
    for i, (inputs, targets) in enumerate(train_loader, 0):
        inputs, targets = inputs.to(device), targets.to(device)

        # 1. Train the discriminator.
        disc_loss = train_disc(inputs, targets, optim_disc, loss_disc)

        # 2. Train the generator.
        for _ in range(2):
            gen_loss, adv_loss, recon_loss, percep_loss = train_gen(
                inputs, targets, optim_gen, loss_disc, loss_recon, loss_percep,
            )

        if i % 5 == 0:
            print(f"Batch {i + 1}/{len(train_loader)} : "
                  f"Total Generator Loss: {gen_loss.item():.4f}, "
                  f"Discriminator Loss: {disc_loss.item():.4f}")

        disc_loss_epoch.append(disc_loss.item())
        gen_loss_epoch.append(gen_loss.item())
        recon_loss_epoch.append(recon_loss.item())
        percep_loss_epoch.append(percep_loss.item())
    
    end_time = time.time()

    if epoch % 50 == 0 or epoch + 1 == epochs:
        for i, (test_input, test_target) in enumerate(test_loader, 0):
            evaluate_model(test_input, test_target)
            break

    disc_losses.append(np.mean(disc_loss_epoch))
    gen_losses.append(np.mean(gen_loss_epoch))
    recon_losses.append(np.mean(recon_loss_epoch))
    percep_losses.append(np.mean(percep_loss_epoch))

    print(f"-" * 10 + "\n" + f"Epoch {epoch + 1}/{epochs} : "
          f"Discriminator Loss: {disc_losses[-1]:.4f}, "
          f"Total Generator Loss: {gen_losses[-1]:.4f}, "
          f"Reconstrucntion Loss: {recon_losses[-1]:.4f}, "
          f"Perceptual Loss: {percep_losses[-1]:.4f}" + "\n"
          f"Time taken: {timedelta(seconds=end_time - start_time)}")