In [1]:
%cd /home/jdafflon/GenerativeModels

/mnt_homes/home4T7/jdafflon/GenerativeModels


In [2]:
from generative.metrics import FID
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def subtract_mean(x: torch.Tensor) -> torch.Tensor:
    mean = [0.406, 0.456, 0.485]
    x[:, 0, :, :] -= mean[0]
    x[:, 1, :, :] -= mean[1]
    x[:, 2, :, :] -= mean[2]
    return x

def normalize_tensor(x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor:
    norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
    return x / (norm_factor + eps)

def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
    return x.mean([2, 3], keepdim=keepdim)

def get_features(image):

    # If input has just 1 channel, repeat channel to have 3 channels
    if image.shape[1]:
        image = image.repeat(1, 3, 1, 1)

    # Change order from 'RGB' to 'BGR'
    image = image[:, [2, 1, 0], ...]

    # Subtract mean used during training
    image = subtract_mean(image)

    # Get model outputs
    with torch.no_grad():
        feature_image = radnet.forward(image)
        # flattens the image spatially
        feature_image = spatial_average(feature_image, keepdim=False)

    # normalise through channels
    #features_image = normalize_tensor(feature_image)

    return feature_image

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [5]:
radnet = torch.hub.load("Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True)
radnet.to(device)
radnet.eval()

Using cache found in /home/jdafflon/.cache/torch/hub/Warvito_radimagenet-models_main


ResNet50(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (bn1): BatchNorm2d(64, eps=1.001e-05, momentum=0.01, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(64, eps=1.001e-05, momentum=0.99, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1.001e-05, momentum=0.99, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      (bn3): BatchNorm2d(256, eps=1.001e-05, momentum=0.99, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm

In [11]:
image1 = torch.ones([1005, 2, 64, 64]) / 2
image1 = image1.to(device)
image2 = torch.ones([1005, 2, 64, 64]) / 2
image2 = image2.to(device)
features_image_1 = get_features(image1)
features_image_2 = get_features(image2)


fid = FID()
results = fid(features_image_1, features_image_2)
print(results)

tensor(-0.0041, device='cuda:0')


In [9]:
image1 = torch.ones([3, 3, 144, 144]) / 2
image1 = image1.to(device)
image2 = torch.ones([3, 3, 144, 144]) / 3
image2 = image2.to(device)
features_image_1 = get_features(image1)
features_image_2 = get_features(image2)


fid = FID()
results = fid(features_image_1, features_image_2)
print(results)

tensor(7.2366, device='cuda:0')


In [10]:
print(features_image_1.shape, features_image_2.shape)

torch.Size([3, 2048]) torch.Size([3, 2048])
