In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
os.chdir("/content/drive/MyDrive/GSET25")
!pwd

/content/drive/MyDrive/GSET25


In [3]:
!pip install medmnist
!pip install monai-generative
!pip install lpips

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.0.tar.gz (87 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/87.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.7.0-py3-none-any.whl size=114249 sha256=331c24485028836607feca15b6048ffce21d798289ece423d0767d30f129a108
  Stored in directory: /root/.cache/pip/wheels/19/39/2f/2d3cadc408a8804103f1c34ddd4b9f6a93497b11fa96fe738e
Successfully built fire
Installing collected packages: fire, medmnist
Successfully installed fire-0.7.0 medmnist-3.0.2
Collecting monai-generat

In [4]:
import numpy as np
from medmnist import PneumoniaMNIST, ChestMNIST, BloodMNIST
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, ToTensor



class MyMedMNIST(BloodMNIST):
    def __getitem__(self, item):
        img, _ = super().__getitem__(item)
        return img

if __name__ == '__main__':
    image_size = 64
    train_data = MyMedMNIST(split="train", download=True, size=image_size,root="/content/drive/MyDrive/GSET25", transform=ToTensor())
    print(train_data)
    # indices = list(range(4000))
    # train_data = Subset(train_data, indices)
    print(len(train_data))

    val_data = MyMedMNIST(split="val", download=True, size=image_size,root="/content/drive/MyDrive/GSET25" ,transform=ToTensor())
    print(val_data)
    # val_data = Subset(val_data, list(range(500)))
    print(len(val_data))

    test_data = MyMedMNIST(split="test", download=True, size=image_size,root="/content/drive/MyDrive/GSET25", transform=ToTensor())
    print(test_data)
    # test_data = Subset(test_data, list(range(500)))
    print(len(test_data))

    train_loader = DataLoader(
        dataset=train_data,
        batch_size=128,
        num_workers=2,
        shuffle=True,
        drop_last=False
    )
    val_loader = DataLoader(
        dataset=val_data,
        batch_size=128,
        num_workers=2,
        shuffle=True,
        drop_last=False
    )
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=128,
        num_workers=2,
        shuffle=True,
        drop_last=False
    )

Using downloaded and verified file: /content/drive/MyDrive/GSET25/bloodmnist_64.npz
Dataset MyMedMNIST of size 64 (bloodmnist_64)
    Number of datapoints: 11959
    Root location: /content/drive/MyDrive/GSET25
    Split: train
    Task: multi-class
    Number of channels: 3
    Meaning of labels: {'0': 'basophil', '1': 'eosinophil', '2': 'erythroblast', '3': 'immature granulocytes(myelocytes, metamyelocytes and promyelocytes)', '4': 'lymphocyte', '5': 'monocyte', '6': 'neutrophil', '7': 'platelet'}
    Number of samples: {'train': 11959, 'val': 1712, 'test': 3421}
    Description: The BloodMNIST is based on a dataset of individual normal cells, captured from individuals without infection, hematologic or oncologic disease and free of any pharmacologic treatment at the moment of blood collection. It contains a total of 17,092 images and is organized into 8 classes. We split the source dataset with a ratio of 7:1:2 into training, validation and test set. The source images with resolution

In [5]:
import os
import shutil
import tempfile
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

import torch
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image
from scipy import linalg
import pathlib
from INCEPTION import InceptionV3

from monai import transforms
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import DataLoader, Dataset
from monai.utils import first, set_determinism
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from generative.inferers import LatentDiffusionInferer
from generative.losses.adversarial_loss import PatchAdversarialLoss
from generative.losses.perceptual import PerceptualLoss
from generative.networks.nets import DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler
from generative_custom.networks.nets import AutoencoderKL


  @torch.cuda.amp.autocast(enabled=False)
  @torch.cuda.amp.autocast(enabled=False)


In [None]:
set_determinism(42)
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmpfy60jtzs


In [None]:
val_data = MedNISTDataset(root_dir=root_dir, section="validation", download=True, seed=0)
val_datalist = [{"image": item["image"]} for item in val_data.data if item["class_name"] == "Hand"]
val_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image"]),
        transforms.EnsureChannelFirstd(keys=["image"]),
        transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
    ]
)
val_ds = Dataset(data=val_datalist, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=64, shuffle=True, num_workers=4, persistent_workers=True)

MedNIST.tar.gz: 59.0MB [00:01, 55.2MB/s]                            

2024-12-21 08:42:49,591 - INFO - Downloaded: /tmp/tmpfy60jtzs/MedNIST.tar.gz





2024-12-21 08:42:49,702 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2024-12-21 08:42:49,703 - INFO - Writing into directory: /tmp/tmpfy60jtzs.


Loading dataset: 100%|██████████| 5895/5895 [00:04<00:00, 1368.26it/s]


In [None]:
train_data = MedNISTDataset(root_dir=root_dir, section="training", download=True, seed=0)
train_datalist = [{"image": item["image"]} for item in train_data.data if item["class_name"] == "Hand"]
image_size = 64
train_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image"]),
        transforms.EnsureChannelFirstd(keys=["image"]),
        transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),
        transforms.RandAffined(
            keys=["image"],
            rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],
            translate_range=[(-1, 1), (-1, 1)],
            scale_range=[(-0.05, 0.05), (-0.05, 0.05)],
            spatial_size=[image_size, image_size],
            padding_mode="zeros",
            prob=0.5,
        ),
    ]
)
train_ds = Dataset(data=train_datalist, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, persistent_workers=True)

2024-12-21 08:43:20,966 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.
2024-12-21 08:43:20,968 - INFO - File exists: /tmp/tmpfy60jtzs/MedNIST.tar.gz, skipped downloading.
2024-12-21 08:43:20,971 - INFO - Non-empty folder exists in /tmp/tmpfy60jtzs/MedNIST, skipped extracting.


Loading dataset: 100%|██████████| 47164/47164 [00:40<00:00, 1168.55it/s]


In [8]:
check_data = next(iter((train_loader)))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# autoencoderkl = AutoencoderKL(
#     spatial_dims=2,
#     in_channels=3,
#     out_channels=3,
#     num_channels=(128, 128, 256),
#     latent_channels=3,
#     num_res_blocks=2,
#     attention_levels=(False, False, False),
#     with_encoder_nonlocal_attn=False,
#     with_decoder_nonlocal_attn=False,
# )
autoencoderkl = AutoencoderKL(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_channels=(128, 128, 256),
    latent_channels=3,
    num_res_blocks=2,
    attention_levels=(False, True, True),
    with_encoder_nonlocal_attn=False,
    with_decoder_nonlocal_attn=False,
)
autoencoderkl = autoencoderkl.to(device)
# unet = DiffusionModelUNet(
#     spatial_dims=2,
#     in_channels=3,
#     out_channels=3,
#     num_res_blocks=2,
#     num_channels=(128, 256, 512),
#     attention_levels=(False, True, True),
#     num_head_channels=(0, 256, 512),
# )
unet = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_res_blocks=2,
    num_channels=(128, 256, 512),
    attention_levels=(True, True, True),
    num_head_channels=(128, 256, 512),
)

with torch.no_grad():
    with autocast(enabled=True):
        z = autoencoderkl.encode_stage_2_inputs(check_data.to(device))
print(f"Scaling factor set to {1/torch.std(z)}")
scale_factor = 1 / torch.std(z)
scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0015, beta_end=0.0195)
inferer = LatentDiffusionInferer(scheduler, scale_factor=scale_factor)
unet = unet.to(device)


  with autocast(enabled=True):


Scaling factor set to 0.8932662606239319


In [None]:
# check_point = torch.load('/content/drive/MyDrive/GSET25/AEKL_Hand_Sep.pt')
# autoencoderkl.load_state_dict(check_point["autoencoderkl"])
# check_point = torch.load('/content/drive/MyDrive/GSET25/uNet_Hand_Sep.pt')
# unet.load_state_dict(check_point["unet"],strict=False)

  check_point = torch.load('/content/drive/MyDrive/GSET25/AEKL_Hand_Sep.pt')
  check_point = torch.load('/content/drive/MyDrive/GSET25/uNet_Hand_Sep.pt')


_IncompatibleKeys(missing_keys=['down_blocks.0.attentions.0.to_q.weight', 'down_blocks.0.attentions.0.to_q.bias', 'down_blocks.0.attentions.0.to_k.weight', 'down_blocks.0.attentions.0.to_k.bias', 'down_blocks.0.attentions.0.to_v.weight', 'down_blocks.0.attentions.0.to_v.bias', 'down_blocks.0.attentions.0.proj_attn.weight', 'down_blocks.0.attentions.0.proj_attn.bias', 'down_blocks.0.attentions.1.to_q.weight', 'down_blocks.0.attentions.1.to_q.bias', 'down_blocks.0.attentions.1.to_k.weight', 'down_blocks.0.attentions.1.to_k.bias', 'down_blocks.0.attentions.1.to_v.weight', 'down_blocks.0.attentions.1.to_v.bias', 'down_blocks.0.attentions.1.proj_attn.weight', 'down_blocks.0.attentions.1.proj_attn.bias', 'down_blocks.1.attentions.0.to_q.weight', 'down_blocks.1.attentions.0.to_q.bias', 'down_blocks.1.attentions.0.to_k.weight', 'down_blocks.1.attentions.0.to_k.bias', 'down_blocks.1.attentions.0.to_v.weight', 'down_blocks.1.attentions.0.to_v.bias', 'down_blocks.1.attentions.0.proj_attn.weight',

In [9]:
unet.load_state_dict(torch.load('/content/drive/MyDrive/GSET25/unet_sep_model_blood.pth'))
scheduler.load_state_dict(torch.load('/content/drive/MyDrive/GSET25/scheduler_sep_model_blood.pth'))
autoencoderkl.load_state_dict(torch.load('/content/drive/MyDrive/GSET25/autoencoder_sep_model_blood.pth'))

  unet.load_state_dict(torch.load('/content/drive/MyDrive/GSET25/unet_sep_model_blood.pth'))
  scheduler.load_state_dict(torch.load('/content/drive/MyDrive/GSET25/scheduler_sep_model_blood.pth'))
  autoencoderkl.load_state_dict(torch.load('/content/drive/MyDrive/GSET25/autoencoder_sep_model_blood.pth'))


<All keys matched successfully>

In [10]:
list_inter = []
start_time = time.time()
for i in range(100):
  unet.eval()
  scheduler.set_timesteps(num_inference_steps=1000)
  noise = torch.randn((1, 3, 16, 16))
  noise = noise.to(device)
  with torch.no_grad():
      image, intermediates = inferer.sample(
          input_noise=noise,
          diffusion_model=unet,
          scheduler=scheduler,
          save_intermediates=True,
          intermediate_steps=100,
          autoencoder_model=autoencoderkl,
      )
  list_inter.append(intermediates)
elapsed_time = time.time() - start_time
print(f"Total training time: {elapsed_time:.2f} seconds")

100%|██████████| 1000/1000 [00:21<00:00, 45.46it/s]
100%|██████████| 1000/1000 [00:23<00:00, 42.84it/s]
100%|██████████| 1000/1000 [00:23<00:00, 42.58it/s]
100%|██████████| 1000/1000 [00:24<00:00, 41.02it/s]
100%|██████████| 1000/1000 [00:24<00:00, 40.11it/s]
100%|██████████| 1000/1000 [00:21<00:00, 46.56it/s]
100%|██████████| 1000/1000 [00:22<00:00, 45.14it/s]
100%|██████████| 1000/1000 [00:21<00:00, 46.28it/s]
100%|██████████| 1000/1000 [00:20<00:00, 48.95it/s]
100%|██████████| 1000/1000 [00:21<00:00, 46.66it/s]
100%|██████████| 1000/1000 [00:20<00:00, 48.13it/s]
100%|██████████| 1000/1000 [00:22<00:00, 45.08it/s]
100%|██████████| 1000/1000 [00:19<00:00, 50.39it/s]
100%|██████████| 1000/1000 [00:20<00:00, 48.95it/s]
100%|██████████| 1000/1000 [00:21<00:00, 47.15it/s]
100%|██████████| 1000/1000 [00:20<00:00, 49.58it/s]
100%|██████████| 1000/1000 [00:20<00:00, 47.93it/s]
100%|██████████| 1000/1000 [00:19<00:00, 50.51it/s]
100%|██████████| 1000/1000 [00:21<00:00, 47.43it/s]
100%|███████

Total training time: 2077.71 seconds





In [11]:
for intermediates in list_inter:
  decoded_images = []
  for image in intermediates:
      with torch.no_grad():
          decoded_images.append(image)
  plt.figure(figsize=(10, 12))
  chain = torch.cat(decoded_images, dim=-1)
  plt.style.use("default")
  plt.imshow(chain[0, 0].cpu(), vmin=0, vmax=1)
  plt.tight_layout()
  plt.axis("off")

Output hidden; open in https://colab.research.google.com to view.

In [None]:
from PIL import Image
os.makedirs("src_hand_all", exist_ok=True)  # Đổi tên thư mục để phân biệt
selected_images = []
image_count = 0  # Thêm biến đếm số hình ảnh đã xử lý

for i, batch in enumerate(val_loader):
    images = batch['image']
    for j in range(images.size(0)):
        selected_images.append(images[j])
        image = images[j].squeeze(0)
        image = image.cpu().numpy()
        image = (image * 255).astype(np.uint8)
        pil_image = Image.fromarray(image, mode="L")
        pil_image.save(f"src_hand_all/src_hand_pic_{image_count}.jpg")
        image_count += 1  # Tăng biến đếm sau mỗi lần lưu thành công

# Không cần kiểm tra số lượng hình ảnh được chọn nữa


In [14]:
from PIL import Image
import torchvision.transforms as transforms
os.makedirs("src_blood_all", exist_ok=True)  # Đổi tên thư mục để phân biệt
selected_images = []
image_count = 0  # Thêm biến đếm số hình ảnh đã xử lý

for i, batch in enumerate(val_loader):
    images = batch
    for j in range(images.size(0)):
        image = intermediates[-1].squeeze(0)
        transform = transforms.ToPILImage()
        pil_image = transform(image)
        pil_image.save(f"src_blood_all/src_blood_pic_{image_count}.jpg")
        image_count += 1  # Tăng biến đếm sau mỗi lần lưu thành công

# Không cần kiểm tra số lượng hình ảnh được chọn nữa

In [None]:
os.makedirs("src_chest_100", exist_ok=True)
selected_images = []
for i, images in enumerate(val_loader):
    if len(selected_images) >= 10:
        break
    for j in range(images.size(0)):
        selected_images.append(images[j])
        if len(selected_images) == 10:
            break
for i, image in enumerate(selected_images):
    image = image.squeeze(0)
    image = image.cpu().numpy()
    image = (image * 255).astype(np.uint8)
    pil_image = Image.fromarray(image, mode="L")
    pil_image.save(f"src_chest_100/src_chest_pic_{i}.jpg")

In [13]:
from PIL import Image
import torchvision.transforms as transforms
os.makedirs("gen_blood_sep_100", exist_ok=True)
i= 0
for intermediates in list_inter:
    image = intermediates[-1]
    print("Original shape:", image.shape)
    image = intermediates[-1].squeeze(0)
    transform = transforms.ToPILImage()
    pil_image = transform(image)
    pil_image.save(f"gen_blood_sep_100/gen_sep_pic_{i}.jpg")
    i+=1

Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original shape: torch.Size([1, 3, 64, 64])
Original sh

In [None]:
from PIL import Image
os.makedirs("gen_blood_moon_100", exist_ok=True)
i = 0

for intermediates in list_inter:
    image = intermediates[-1]
    image = image.squeeze(0).squeeze(0)
    image = image.cpu().numpy()
    image = (image * 255).astype(np.uint8)
    pil_image = Image.fromarray(image, mode="L")
    pil_image.save(f"gen_blood_moon_100/gen_moon_pic_{i}.jpg")
    i += 1

ValueError: Too many dimensions: 3 > 2.

In [15]:
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from scipy import linalg
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch
from PIL import Image
from INCEPTION import InceptionV3
import pathlib


In [16]:
IMAGE_EXTENSIONS = {'jpg'}


class ImagePathDataset(Dataset):
    def __init__(self, files, transform=None):
        self.files = files
        # Only include ToTensor if not already in the provided transform
        if transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor()  # Default: Convert PIL image to tensor
            ])
        else:
            self.transform = transform  # Use provided transform as-is

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert('RGB')  # Open the image as PIL
        img = self.transform(img)  # Apply transformations
        return img

In [17]:
def get_activations(files, model, batch_size, dims, device='cpu'):

    model.eval()

    if batch_size > len(files):
        batch_size = len(files)

    dataset = ImagePathDataset(files, transform=transforms.ToTensor())
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    pred_arr = np.empty((len(files), dims))
    start_idx = 0

    for batch in tqdm(data_loader):
        batch = batch.to(device)

        with torch.inference_mode():
            pred = model(batch)[0]

        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))

        pred = pred.squeeze(3).squeeze(2).cpu().numpy()
        pred_arr[start_idx:start_idx+pred.shape[0]] = pred
        start_idx = start_idx + pred.shape[0]

    return pred_arr

In [18]:
def calculate_frechet_distance(mu1, mu2, sigma1, sigma2, eps=1e-6):

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions'

    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)

    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces sigular product; adding %s to diagonal cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

In [19]:
def calculate_activation_statistics(files, model, batch_size, dims, device='cpu'):

    act = get_activations(files, model, batch_size, dims, device)
    mu = np.mean(act, axis=0)
    sigma = np.cov(act, rowvar=False)

    return mu, sigma

def compute_statistics_of_path(path, model, batch_size, dims, device='cpu'):

    path = pathlib.Path(path)
    files = sorted([file for ext in IMAGE_EXTENSIONS for file in path.glob('*.{}'.format(ext))])
    mu, sigma = calculate_activation_statistics(files, model, batch_size, dims, device)

    return mu, sigma

def calculate_fid_given_paths(path1, path2, batch_size, dims, device='cpu'):

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    print(block_idx)
    model = InceptionV3([block_idx]).to(device)

    mu1, sigma1 = compute_statistics_of_path(path1, model, batch_size, dims, device)
    mu2, sigma2 = compute_statistics_of_path(path2, model, batch_size, dims, device)

    fid_value = calculate_frechet_distance(mu1, mu2, sigma1, sigma2)
    return print('FID distance:', round(fid_value, 3))

In [21]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 50
#BLOCK_INDEX_BY_DIM = {64: 0, 192: 1, 768: 2, 2048: 3}
dims = 768

src_path = os.getcwd() + '/src_blood_all'
gen_path = os.getcwd() + '/gen_blood_moon_100'

print('Total images in src_blood_all:', len(next(os.walk('src_blood_all'))[2]))
print('Total images in gen_blood_moon_100:', len(next(os.walk('gen_blood_moon_100'))[2]))

calculate_fid_given_paths(path1=src_path, path2=gen_path, batch_size=batch_size, dims=dims, device=device)

Total images in src_blood_all: 1712
Total images in gen_blood_moon_100: 100
2


100%|██████████| 35/35 [00:11<00:00,  3.11it/s]
100%|██████████| 2/2 [00:02<00:00,  1.13s/it]


FID distance: 0.951
