In [11]:
# Example plotting multiple values
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import nrrd
import pandas as pd
import os
import sys
from torchmetrics.image.fid import FrechetInceptionDistance

# sys.path.append("/mnt/raid/C1_ML_Analysis/source/autoencoder/src")
sys.path.append("/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/dl/")
sys.path.append("/mnt/raid/C1_ML_Analysis/source/famli-ultra-sim/dl/nets")

from nets import diffusion, spade, lotus, cut, layers, cut_G
from loaders import ultrasound_dataset as usd
import monai

import plotly.express as px
import plotly.graph_objects as go


In [2]:
def toint(x):
    return (x*255).clamp(0, 255).to(torch.uint8)

In [27]:
def compute_fid_all_types(dl, model, types=["Voluson", "Butterfly", "Clarius"]):

    values_real = [ ]
    values_real_type = [ ]
    values_fake = [ ]
    values_fake_type = [ ]

    metric = FrechetInceptionDistance(feature=64).cuda()

    for i in range(10):

        all_data = []
        all_target = []

        for x, y in dl:
            all_data.append(x)
            all_target.append(y)

        all_data = torch.cat(all_data)
        all_target = torch.cat(all_target)

        idx = torch.arange(0, model.hparams.num_samples_test)*3

        source0 = all_data[idx]
        source0_target = all_target[idx]

        source1 = all_data[idx+1]
        source1_target = all_target[idx+1]

        source2 = all_data[idx+2]
        source2_target = all_target[idx+2]

        with torch.no_grad():
            source0_fake = model(source0.cuda())
            source1_fake = model(source1.cuda())
            source2_fake = model(source2.cuda())

        #       ___           ___           ___           ___ 
        #      /\  \         /\  \         /\  \         /\__\
        #     /::\  \       /::\  \       /::\  \       /:/  /
        #    /:/\:\  \     /:/\:\  \     /:/\:\  \     /:/  / 
        #   /::\~\:\  \   /::\~\:\  \   /::\~\:\  \   /:/  /  
        #  /:/\:\ \:\__\ /:/\:\ \:\__\ /:/\:\ \:\__\ /:/__/   
        #  \/_|::\/:/  / \:\~\:\ \/__/ \/__\:\/:/  / \:\  \   
        #     |:|::/  /   \:\ \:\__\        \::/  /   \:\  \  
        #     |:|\/__/     \:\ \/__/        /:/  /     \:\  \ 
        #     |:|  |        \:\__\         /:/  /       \:\__\
        #      \|__|         \/__/         \/__/         \/__/

        # Voluson
        
        metric.update(toint(source0_target).repeat(1, 3, 1, 1).cuda(), real=True)
        metric.update(toint(source0).repeat(1, 3, 1, 1).cuda(), real=False)
        values_real.append(metric.compute())
        values_real_type.append(types[0])

        metric.reset()
        # Butterfly
        metric.update(toint(source1_target).repeat(1, 3, 1, 1).cuda(), real=True)
        metric.update(toint(source1).repeat(1, 3, 1, 1).cuda(), real=False)

        values_real.append(metric.compute())
        values_real_type.append(types[1])

        metric.reset()

        # Clarius
        metric.update(toint(source2_target).repeat(1, 3, 1, 1).cuda(), real=True)
        metric.update(toint(source2).repeat(1, 3, 1, 1).cuda(), real=False)

        values_real.append(metric.compute())
        values_real_type.append(types[2])

        metric.reset()




        #       ___           ___           ___           ___     
        #      /\  \         /\  \         /\__\         /\  \    
        #     /::\  \       /::\  \       /:/  /        /::\  \   
        #    /:/\:\  \     /:/\:\  \     /:/__/        /:/\:\  \  
        #   /::\~\:\  \   /::\~\:\  \   /::\__\____   /::\~\:\  \ 
        #  /:/\:\ \:\__\ /:/\:\ \:\__\ /:/\:::::\__\ /:/\:\ \:\__\
        #  \/__\:\ \/__/ \/__\:\/:/  / \/_|:|~~|~    \:\~\:\ \/__/
        #       \:\__\        \::/  /     |:|  |      \:\ \:\__\  
        #        \/__/        /:/  /      |:|  |       \:\ \/__/  
        #                    /:/  /       |:|  |        \:\__\    
        #                    \/__/         \|__|         \/__/    



        metric.reset()

        # Voluson
        metric.update(toint(source0_target).repeat(1, 3, 1, 1).cuda(), real=True)
        metric.update(toint(source0_fake).repeat(1, 3, 1, 1).cuda(), real=False)

        values_fake.append(metric.compute())
        values_fake_type.append(types[0])

        metric.reset()


        # Butterfly

        metric.update(toint(source1_target).repeat(1, 3, 1, 1).cuda(), real=True)
        metric.update(toint(source1_fake).repeat(1, 3, 1, 1).cuda(), real=False)

        values_fake.append(metric.compute())
        values_fake_type.append(types[1])

        metric.reset()

        # Clarius
        metric.update(toint(source2_target).repeat(1, 3, 1, 1).cuda(), real=True)
        metric.update(toint(source2_fake).repeat(1, 3, 1, 1).cuda(), real=False)

        values_fake.append(metric.compute())
        values_fake_type.append(types[2])

        metric.reset()

    values_real = torch.stack(values_real)
    values_fake = torch.stack(values_fake)
    return values_real.cpu().numpy(), values_real_type, values_fake.cpu().numpy(), values_fake_type

# Fréchet Inception Distance (FID)

The **Fréchet Inception Distance (FID)** is a metric used to evaluate the quality of images generated by generative models such as GANs. It compares the **distribution of generated images** to the **distribution of real images** using features extracted from the **Inception v3** neural network.

## Key Concepts

- FID computes the **Fréchet distance** (also known as the Wasserstein-2 distance) between two multivariate Gaussians:
  - One fitted to the **real image features**
  - One fitted to the **generated image features**

- The formula for FID:
  
  $$
  \text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{1/2})
  $$

  where:
  - $(\mu_r, \Sigma_r)$: mean and covariance of real images
  - $(\mu_g, \Sigma_g)$: mean and covariance of generated images

## Interpretation

- **Lower FID scores** indicate that the generated images are more similar to the real images.
- FID is preferred over earlier metrics (like Inception Score) because it **captures both the quality and diversity** of generated images.

## Applications

- Commonly used in research to benchmark generative models like **GANs**.
- Used in various image synthesis tasks, such as super-resolution, style transfer, and image generation.


In [37]:
model = cut.CutG.load_from_checkpoint("/mnt/raid/C1_ML_Analysis/train_output/Cut/allvslast/allvssonosite/v0.8/epoch=19-val_loss=6.33.ckpt").eval().cuda()

DATAMODULE = getattr(usd, model.hparams.data_module)
data = DATAMODULE(**model.hparams)

data.setup()
dl = data.test_dataloader()

values_real, values_real_type, values_fake, values_fake_type = compute_fid_all_types(dl, model, types=["Voluson", "Butterfly", "Sonosite"])

df = pd.DataFrame(values_real, columns=["fid"])
df["is"] = values_real_type
df["type"] = "real"

df_fake = pd.DataFrame(values_fake, columns=["fid"])
df_fake["is"] = values_fake_type
df_fake["type"] = "fake"

df = pd.concat([df, df_fake])

fig = px.scatter(df, x="is", y="fid", color="type")
fig.show()


In [38]:
model = cut.CutG.load_from_checkpoint("/mnt/raid/C1_ML_Analysis/train_output/Cut/allvslast/allvsclarius/v0.3/epoch=14-val_loss=6.75.ckpt").eval().cuda()
DATAMODULE = getattr(usd, model.hparams.data_module)
data = DATAMODULE(**model.hparams)
data.setup()
dl = data.test_dataloader()

values_real, values_real_type, values_fake, values_fake_type = compute_fid_all_types(dl, model, types=["Voluson", "Butterfly", "Sonosite"])

df = pd.DataFrame(values_real, columns=["fid"])
df["is"] = values_real_type
df["type"] = "real"

df_fake = pd.DataFrame(values_fake, columns=["fid"])
df_fake["is"] = values_fake_type
df_fake["type"] = "fake"

df = pd.concat([df, df_fake])

fig = px.scatter(df, x="is", y="fid", color="type", title="ALL v.s. Clarius")
fig.show()