# Tải bộ dữ liệu

In [None]:

from google.colab import files
# upload file kaggle.json
files.upload()

In [None]:
%%capture
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d vincenttamml/celebamaskhq512
!unzip celebamaskhq512

In [None]:
import os
import shutil
import random

# Đường dẫn tới thư mục chứa ảnh
source_folder = "/content/image"  # Thay bằng đường dẫn thư mục nguồn
output_folder = "/content/image_divice"  # Thay bằng đường dẫn thư mục đích

# Tạo 10 thư mục đích nếu chưa tồn tại
num_folders = 10
for i in range(1, num_folders + 1):
    os.makedirs(os.path.join(output_folder, f"folder_{i}"), exist_ok=True)

# Lấy danh sách các tệp trong thư mục nguồn
files = [f for f in os.listdir(source_folder) if os.path.isfile(os.path.join(source_folder, f))]

# Phân phối tệp ngẫu nhiên vào 10 thư mục
for file in files:
    target_folder = os.path.join(output_folder, f"folder_{random.randint(1, num_folders)}")
    shutil.move(os.path.join(source_folder, file), os.path.join(target_folder, file))

print("Ảnh đã được chia vào 10 thư mục!")


In [None]:
# Resize tất cả các ảnh trong thư mục về kích thước 512x512
import os
from PIL import Image

input_folder = '/content/image_divice/folder_2'  # Thay đổi nếu cần
output_folder = '/content/image_resize_512_2'

# Tạo thư mục output nếu chưa có
os.makedirs(output_folder, exist_ok=True)

# Duyệt qua tất cả các ảnh trong thư mục và resize
for filename in os.listdir(input_folder):
    if filename.endswith(('.png', '.jpg', '.jpeg')):
        image_path = os.path.join(input_folder, filename)
        image = Image.open(image_path)
        image = image.resize((512, 512), Image.Resampling.LANCZOS)
        image.save(os.path.join(output_folder, filename))

print("Resize hoàn tất!")


# Chuẩn bị môi trường

In [None]:
%%capture
!git clone https://github.com/fenglinglwb/MAT.git

%cd MAT
!pip install -r requirements.txt

# Chuẩn bị dữ liệu

In [None]:
from random import randint, seed
import numpy as np
import cv2

class CustomMaskGenerator1():
    def __init__(self, height=512, width=512, num_lines=20, num_circles=20, num_elips=20, rand_seed=None, hole_range=[0, 1]):
        self.height = height
        self.width = width
        self.channels = 1  # Output should be single-channel (grayscale)
        self.num_lines = num_lines
        self.num_circles = num_circles
        self.num_elips = num_elips
        self.hole_range = hole_range  # Control hole ratio in the mask

        if rand_seed:
            seed(rand_seed)

    def _generate_mask(self):
        """Generates a random irregular mask with lines, circles, and ellipses"""
        img = np.zeros((self.height, self.width), np.uint8)  # Single channel binary image

        # Set size scale for the shapes
        size = int((self.width + self.height) * 0.03)

        # Draw random lines
        for _ in range(randint(1, self.num_lines)):
            x1, x2 = randint(1, self.width), randint(1, self.width)
            y1, y2 = randint(1, self.height), randint(1, self.height)
            thickness = randint(3, size)
            cv2.line(img, (x1, y1), (x2, y2), 255, thickness)

        # Draw random circles
        for _ in range(randint(1, self.num_circles)):
            x1, y1 = randint(1, self.width), randint(1, self.height)
            radius = randint(3, size)
            cv2.circle(img, (x1, y1), radius, 255, -1)

        # Draw random ellipses
        for _ in range(randint(1, self.num_elips)):
            x1, y1 = randint(1, self.width), randint(1, self.height)
            s1, s2 = randint(1, self.width), randint(1, self.height)
            a1, a2, a3 = randint(3, 180), randint(3, 180), randint(3, 180)
            thickness = randint(3, size)
            cv2.ellipse(img, (x1, y1), (s1, s2), a1, a2, a3, 255, thickness)

        # Now ensure that the mask has a proper hole ratio within the specified range
        hole_ratio = 1 - np.mean(img / 255)  # Calculate hole ratio (1 means fully filled, 0 means all holes)

        # If hole ratio is out of bounds, adjust the mask to have a valid hole ratio
        while hole_ratio < self.hole_range[0] or hole_ratio > self.hole_range[1]:
            # Generate a new mask if it does not satisfy hole ratio condition
            img.fill(0)  # Reset the mask
            # Re-draw the shapes
            for _ in range(randint(1, self.num_lines)):
                x1, x2 = randint(1, self.width), randint(1, self.width)
                y1, y2 = randint(1, self.height), randint(1, self.height)
                thickness = randint(3, size)
                cv2.line(img, (x1, y1), (x2, y2), 255, thickness)

            for _ in range(randint(1, self.num_circles)):
                x1, y1 = randint(1, self.width), randint(1, self.height)
                radius = randint(3, size)
                cv2.circle(img, (x1, y1), radius, 255, -1)

            for _ in range(randint(1, self.num_elips)):
                x1, y1 = randint(1, self.width), randint(1, self.height)
                s1, s2 = randint(1, self.width), randint(1, self.height)
                a1, a2, a3 = randint(3, 180), randint(3, 180), randint(3, 180)
                thickness = randint(3, size)
                cv2.ellipse(img, (x1, y1), (s1, s2), a1, a2, a3, 255, thickness)

            hole_ratio = 1 - np.mean(img / 255)  # Recalculate hole ratio

        return img.astype(np.float32)[np.newaxis, ...]


In [None]:


import cv2
import os
import numpy as np
import zipfile
import PIL.Image
import json
import torch
import dnnlib
import random
from random import randint, seed

try:
    import pyspng
except ImportError:
    pyspng = None

from MAT.datasets.dataset_512 import Dataset

gen_mask = CustomMaskGenerator1(height=512, width=512, num_lines=20, num_circles=20, num_elips=20, rand_seed=None, hole_range=[0, 1])


class ImageFolderMaskDataset(Dataset):
    def __init__(self,
        path,                   # Path to directory or zip.
        resolution      = None, # Ensure specific resolution, None = highest available.
        hole_range=[0,1],
        **super_kwargs,         # Additional arguments for the Dataset base class.
    ):
        self._path = path
        self._zipfile = None
        self._hole_range = hole_range

        if os.path.isdir(self._path):
            self._type = 'dir'
            self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
        elif self._file_ext(self._path) == '.zip':
            self._type = 'zip'
            self._all_fnames = set(self._get_zipfile().namelist())
        else:
            raise IOError('Path must point to a directory or zip')

        PIL.Image.init()
        self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
        if len(self._image_fnames) == 0:
            raise IOError('No image files found in the specified path')

        name = os.path.splitext(os.path.basename(self._path))[0]
        raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
        if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
            raise IOError('Image files do not match the specified resolution')
        super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)

    @staticmethod
    def _file_ext(fname):
        return os.path.splitext(fname)[1].lower()

    def _get_zipfile(self):
        assert self._type == 'zip'
        if self._zipfile is None:
            self._zipfile = zipfile.ZipFile(self._path)
        return self._zipfile

    def _open_file(self, fname):
        if self._type == 'dir':
            return open(os.path.join(self._path, fname), 'rb')
        if self._type == 'zip':
            return self._get_zipfile().open(fname, 'r')
        return None

    def close(self):
        try:
            if self._zipfile is not None:
                self._zipfile.close()
        finally:
            self._zipfile = None

    def __getstate__(self):
        return dict(super().__getstate__(), _zipfile=None)

    def _load_raw_image(self, raw_idx):
        fname = self._image_fnames[raw_idx]
        with self._open_file(fname) as f:
            if pyspng is not None and self._file_ext(fname) == '.png':
                image = pyspng.load(f.read())
            else:
                image = np.array(PIL.Image.open(f))
        if image.ndim == 2:
            image = image[:, :, np.newaxis] # HW => HWC

        # for grayscale image
        if image.shape[2] == 1:
            image = np.repeat(image, 3, axis=2)

        # restricted to 512x512
        res = 512
        H, W, C = image.shape
        if H < res or W < res:
            top = 0
            bottom = max(0, res - H)
            left = 0
            right = max(0, res - W)
            image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_REFLECT)
        H, W, C = image.shape
        h = random.randint(0, H - res)
        w = random.randint(0, W - res)
        image = image[h:h+res, w:w+res, :]

        image = np.ascontiguousarray(image.transpose(2, 0, 1)) # HWC => CHW

        return image

    def _load_raw_labels(self):
        fname = 'labels.json'
        if fname not in self._all_fnames:
            return None
        with self._open_file(fname) as f:
            labels = json.load(f)['labels']
        if labels is None:
            return None
        labels = dict(labels)
        labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
        labels = np.array(labels)
        labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
        return labels

    def __getitem__(self, idx):
        image = self._load_raw_image(self._raw_idx[idx])

        assert isinstance(image, np.ndarray)
        assert list(image.shape) == self.image_shape
        assert image.dtype == np.uint8
        if self._xflip[idx]:
            assert image.ndim == 3 # CHW
            image = image[:, :, ::-1]
        mask = gen_mask._generate_mask()  # hole as 0, reserved as 1
        return image.copy(), mask, self.get_label(idx)


In [None]:
import os
import time
import copy
import pickle
import numpy as np
import torch
from torch_utils import misc
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import legacy

from networks.mat import Generator, Discriminator
from losses.loss import TwoStageLoss  # Import lớp TwoStageLoss

start_time = time.time()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

dataset_path = '/content/image_resize_512'
image_resolution=512
batch_size = 2

# Load dataset
print('Loading dataset...')
training_set = ImageFolderMaskDataset(path=dataset_path, resolution=image_resolution, hole_range=[0, 1])
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
print(f'Dataset size: {len(training_set)} images')

c_dim = training_set.label_dim
img_channels = training_set.num_channels

In [None]:
dataset_path = '/content/image_resize_512_2'
image_resolution=512
batch_size = 2

# Load dataset
print('Loading dataset...')
training_set = ImageFolderMaskDataset(path=dataset_path, resolution=image_resolution, hole_range=[0, 1])
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
print(f'Dataset size: {len(training_set)} images')

c_dim = training_set.label_dim
img_channels = training_set.num_channels

# Fine turn


In [None]:
resume_specs = {
    'ffhq256':     'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
    'ffhq512':     'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
    'ffhq1024':    'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
    'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
    'lsundog256':  'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
}

# chọn mô hình muốn fine-turn
model_path = resume_specs['ffhq512']

pr=0.1
pl=False
gamma=5.0

num_epochs = 1

betas=[0, 0.99]
lr=0.0001

In [None]:
from torch_utils.misc import named_params_and_buffers

def copy_params_and_buffers(src_module, dst_module, require_all=False):
    assert isinstance(src_module, torch.nn.Module)
    assert isinstance(dst_module, torch.nn.Module)
    src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
    for name, tensor in named_params_and_buffers(dst_module):
        assert (name in src_tensors) or (not require_all)
        if name in src_tensors:
            tensor.data = src_tensors[name].detach().clone()
            tensor.requires_grad_(tensor.requires_grad)

In [None]:
print(c_dim,image_resolution, img_channels)

In [None]:
print(len(training_loader))

In [None]:
import gc
# Initialize networks
G = Generator(z_dim=512, c_dim=c_dim, w_dim=512, img_resolution=image_resolution, img_channels=img_channels)
D = Discriminator(c_dim=c_dim,img_resolution=image_resolution, img_channels=img_channels).to(device)
G_ema = copy.deepcopy(G).eval()

# Initialize loss
loss_instance = TwoStageLoss(
    device=device,
    G_mapping=G.mapping,
    G_synthesis=G.synthesis,
    D=D,
    r1_gamma=gamma if gamma is not None else 10.0,
    pl_weight=0 if not pl else 1.0,
    pcp_ratio=pr if pr is not None else 0.0
)

# Loading model
print(f'Loading pretrained model from "{model_path}"')
with dnnlib.util.open_url(model_path) as f:
    resume_data = legacy.load_network_pkl(f)
    for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
        copy_params_and_buffers(resume_data[name], module, require_all=False)
        module.to(device)

# Optimizers
G_opt = torch.optim.Adam(G.parameters(), lr=lr, betas=(betas[0], betas[1]))
D_opt = torch.optim.Adam(D.parameters(), lr=lr, betas=(betas[0], betas[1]))

# Define the number of epochs
num_epochs = 1

checkpoint_interval = 1  # Define how often to save a checkpoint

def training_loop(
    training_loader,
    D_opt,
    G_opt,
    loss_instance,
    G,
    G_ema
):
    i = 1

    for  real_imgs, masks, labels in training_loader:
        print(i)
        i = i + 1
        real_imgs = real_imgs.to(device).float() / 127.5 - 1  # Normalize to [-1, 1]
        masks = masks.to(device)
        labels = labels.to(device)

        # Generate latent vectors and conditions
        gen_z = torch.randn([batch_size, G.z_dim], device=device)
        gen_c = labels

        # Train Discriminator
        D_opt.zero_grad()
        loss_instance.accumulate_gradients(
            phase='Dmain',
            real_img=real_imgs,
            mask=masks,
            real_c=labels,
            gen_z=gen_z,
            gen_c=gen_c,
            sync=True,
            gain=1.0
        )
        D_opt.step()
        gc.collect()


        # Train Generator
        G_opt.zero_grad()
        loss_instance.accumulate_gradients(
            phase='Gmain',
            real_img=real_imgs,
            mask=masks,
            real_c=labels,
            gen_z=gen_z,
            gen_c=gen_c,
            sync=True,
            gain=1.0
        )
        G_opt.step()


        # Update EMA for G
        ema_kimg = 10
        ema_beta = 0.5 ** (batch_size / (ema_kimg * 1000))
        with torch.no_grad():
          for p_ema, p in zip(G_ema.parameters(), G.parameters()):
            p_ema.data = p.lerp(p_ema, ema_beta).data

        gc.collect()


for epoch in range(num_epochs):
    print(f'Starting epoch {epoch + 1}...')
    gc.collect()

    training_loop(
        training_loader,
        D_opt,
        G_opt,
        loss_instance,
        G,
        G_ema
    )

    print(f'Epoch {epoch + 1} completed.')

    # Save checkpoints
    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint_path = f'checkpoint_epoch_{epoch + 1}.pt'
        torch.save({
            'epoch': epoch,
            'G_state_dict': G.state_dict(),
            'D_state_dict': D.state_dict(),
            'G_ema_state_dict': G_ema.state_dict(),
            'G_opt_state_dict': G_opt.state_dict(),
            'D_opt_state_dict': D_opt.state_dict()
        }, checkpoint_path)

    torch.cuda.empty_cache()
print('Training completed.')

In [None]:
print(f'checkpoint_epoch_{epoch + 1}.pt')

In [None]:
for epoch in range(num_epochs):
    print(f'Starting epoch {epoch + 1}...')
    gc.collect()

    training_loop(
        training_loader,
        D_opt,
        G_opt,
        loss_instance,
        G,
        G_ema
    )

    print(f'Epoch {epoch + 1} completed.')

    # Save checkpoints
    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint_path = f'/content/checkpoint_epoch_{epoch + 1}.pt'
        torch.save({
            'epoch': epoch,
            'G_state_dict': G.state_dict(),
            'D_state_dict': D.state_dict(),
            'G_ema_state_dict': G_ema.state_dict(),
            'G_opt_state_dict': G_opt.state_dict(),
            'D_opt_state_dict': D_opt.state_dict()
        }, checkpoint_path)

    torch.cuda.empty_cache()
print('Training completed.')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Sau khi kết thúc huấn luyện, bạn lưu mô hình vào Google Drive

checkpoint_path = '/content/drive/MyDrive/checkpoint_epoch_final.pt'

torch.save({
    'epoch': num_epochs - 1,  # Lưu epoch cuối cùng
    'G_state_dict': G.state_dict(),
    'D_state_dict': D.state_dict(),
    'G_ema_state_dict': G_ema.state_dict(),
    'G_opt_state_dict': G_opt.state_dict(),
    'D_opt_state_dict': D_opt.state_dict()
}, checkpoint_path)

print(f'Checkpoint đã được lưu tại {checkpoint_path}')


# Đánh giá FID

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import shutil
import random

# Đường dẫn tới thư mục chứa ảnh
source_folder = "/content/image_divice/folder_5"  # Thay bằng đường dẫn thư mục đích
origin_folder = "/content/drive/dataset/MAT/origin_image"  # Thay bằng đường dẫn thư mục đích
mask_folder = "/content/drive/dataset/MAT/mask_image"  # Thay bằng đường dẫn thư mục đích


gen_mask = CustomMaskGenerator1(height=512, width=512, num_lines=20, num_circles=20, num_elips=20, rand_seed=None, hole_range=[0, 1])


os.makedirs(os.path.join(origin_folder), exist_ok=True)
os.makedirs(os.path.join(mask_folder), exist_ok=True)
# Lấy danh sách các tệp trong thư mục nguồn
files = [f for f in os.listdir(source_folder) if os.path.isfile(os.path.join(source_folder, f))]

# Phân phối tệp ngẫu nhiên vào 10 thư mục
for file in files:
    shutil.move(os.path.join(source_folder, file), os.path.join(origin_folder, file))
    mask = gen_mask._generate_mask()





print("Ảnh đã được chia vào 10 thư mục!")









In [None]:
import cv2
import os
import sys
sys.path.insert(0, '../')
import numpy as np
import math
import glob
import pyspng
import PIL.Image
import torch
import dnnlib
import scipy.linalg
import sklearn.svm


_feature_detector_cache = dict()

def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
    assert 0 <= rank < num_gpus
    key = (url, device)
    if key not in _feature_detector_cache:
        is_leader = (rank == 0)
        if not is_leader and num_gpus > 1:
            torch.distributed.barrier() # leader goes first
        with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
            _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
        if is_leader and num_gpus > 1:
            torch.distributed.barrier() # others follow
    return _feature_detector_cache[key]


def read_image(image_path):
    with open(image_path, 'rb') as f:
        if pyspng is not None and image_path.endswith('.png'):
            image = pyspng.load(f.read())
        else:
            image = np.array(PIL.Image.open(f))
    if image.ndim == 2:
        image = image[:, :, np.newaxis] # HW => HWC
    if image.shape[2] == 1:
        image = np.repeat(image, 3, axis=2)
    image = image.transpose(2, 0, 1) # HWC => CHW
    image = torch.from_numpy(image).unsqueeze(0).to(torch.uint8)

    return image


class FeatureStats:
    def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
        self.capture_all = capture_all
        self.capture_mean_cov = capture_mean_cov
        self.max_items = max_items
        self.num_items = 0
        self.num_features = None
        self.all_features = None
        self.raw_mean = None
        self.raw_cov = None

    def set_num_features(self, num_features):
        if self.num_features is not None:
            assert num_features == self.num_features
        else:
            self.num_features = num_features
            self.all_features = []
            self.raw_mean = np.zeros([num_features], dtype=np.float64)
            self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)

    def is_full(self):
        return (self.max_items is not None) and (self.num_items >= self.max_items)

    def append(self, x):
        x = np.asarray(x, dtype=np.float32)
        assert x.ndim == 2
        if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
            if self.num_items >= self.max_items:
                return
            x = x[:self.max_items - self.num_items]

        self.set_num_features(x.shape[1])
        self.num_items += x.shape[0]
        if self.capture_all:
            self.all_features.append(x)
        if self.capture_mean_cov:
            x64 = x.astype(np.float64)
            self.raw_mean += x64.sum(axis=0)
            self.raw_cov += x64.T @ x64

    def append_torch(self, x, num_gpus=1, rank=0):
        assert isinstance(x, torch.Tensor) and x.ndim == 2
        assert 0 <= rank < num_gpus
        if num_gpus > 1:
            ys = []
            for src in range(num_gpus):
                y = x.clone()
                torch.distributed.broadcast(y, src=src)
                ys.append(y)
            x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
        self.append(x.cpu().numpy())

    def get_all(self):
        assert self.capture_all
        return np.concatenate(self.all_features, axis=0)

    def get_all_torch(self):
        return torch.from_numpy(self.get_all())

    def get_mean_cov(self):
        assert self.capture_mean_cov
        mean = self.raw_mean / self.num_items
        cov = self.raw_cov / self.num_items
        cov = cov - np.outer(mean, mean)
        return mean, cov

    def save(self, pkl_file):
        with open(pkl_file, 'wb') as f:
            pickle.dump(self.__dict__, f)

    @staticmethod
    def load(pkl_file):
        with open(pkl_file, 'rb') as f:
            s = dnnlib.EasyDict(pickle.load(f))
        obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
        obj.__dict__.update(s)
        return obj


def calculate_metrics(folder1, folder2):
    l1 = sorted(glob.glob(folder1 + '/*.png') + glob.glob(folder1 + '/*.jpg'))
    l2 = sorted(glob.glob(folder2 + '/*.png') + glob.glob(folder2 + '/*.jpg'))
    assert(len(l1) == len(l2))
    print('length:', len(l1))

    # l1 = l1[:3]; l2 = l2[:3];

    # build detector
    detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
    detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
    device = torch.device('cuda:0')
    detector = get_feature_detector(url=detector_url, device=device, num_gpus=1, rank=0, verbose=False)
    detector.eval()

    stat1 = FeatureStats(capture_all=True, capture_mean_cov=True, max_items=len(l1))
    stat2 = FeatureStats(capture_all=True, capture_mean_cov=True, max_items=len(l1))

    with torch.no_grad():
        for i, (fpath1, fpath2) in enumerate(zip(l1, l2)):
            print(i)
            _, name1 = os.path.split(fpath1)
            _, name2 = os.path.split(fpath2)
            name1 = name1.split('.')[0]
            name2 = name2.split('.')[0]
            assert name1 == name2, 'Illegal mapping: %s, %s' % (name1, name2)

            img1 = read_image(fpath1).to(device)
            img2 = read_image(fpath2).to(device)
            assert img1.shape == img2.shape, 'Illegal shape'
            fea1 = detector(img1, **detector_kwargs)
            stat1.append_torch(fea1, num_gpus=1, rank=0)
            fea2 = detector(img2, **detector_kwargs)
            stat2.append_torch(fea2, num_gpus=1, rank=0)

    # calculate fid
    mu1, sigma1 = stat1.get_mean_cov()
    mu2, sigma2 = stat2.get_mean_cov()
    m = np.square(mu1 - mu2).sum()
    s, _ = scipy.linalg.sqrtm(np.dot(sigma1, sigma2), disp=False) # pylint: disable=no-member
    fid = np.real(m + np.trace(sigma1 + sigma2 - s * 2))

    return fid


In [None]:
if __name__ == '__main__':
    folder1 = 'path to the inpainted result'
    folder2 = 'path to the gt'

    fid = calculate_metrics(folder1, folder2)
    print('fid: %.4f' % (fid))
    with open('fid_pids_uids.txt', 'w') as f:
        f.write('fid: %.4f' % (fid))