# Image discriminator

This notebook contains code to run an attribute classifier for the 40 CelebA attributes on a dataset of images, in preparation of the SVM of InterFaceGAN.

Given a proportion `p` (percentage), the notebook runs the classifier on the dataset, and for each of the 40 attributes, it saves the scores of all the images in a JSON file, then labels the images with the top `p`% scores with +1 and the images with the lowest `p`% scores with -1, and finally saves the resulting labels in a JSON file.

The classifier was provided by Gwilherm Lesné (pre-trained, based on EfficientNet).

The list of attributes is as follows:

<details>
  <summary>Click to expand</summary>
  
  | Attribute # | Name |
  |-------------|------|
  |0|`5_o_Clock_Shadow`|
  |1|`Arched_Eyebrows`|
  |2|`Attractive`|
  |3|`Bags_Under_Eyes`|
  |4|`Bald`|
  |5|`Bangs`|
  |6|`Big_Lips`|
  |7|`Big_Nose`|
  |8|`Black_Hair`|
  |9|`Blond_Hair`|
  |10|`Blurry`|
  |11|`Brown_Hair`|
  |12|`Bushy_Eyebrows`|
  |13|`Chubby`|
  |14|`Double_Chin`|
  |15|`Eyeglasses`|
  |16|`Goatee`|
  |17|`Gray_Hair`|
  |18|`Heavy_Makeup`|
  |19|`High_Cheekbones`|
  |20|`Male`|
  |21|`Mouth_Slightly_Open`|
  |22|`Mustache`|
  |23|`Narrow_Eyes`|
  |24|`No_Beard`|
  |25|`Oval_Face`|
  |26|`Pale_Skin`|
  |27|`Pointy_Nose`|
  |28|`Receding_Hairline`|
  |29|`Rosy_Cheeks`|
  |30|`Sideburns`|
  |31|`Smiling`|
  |32|`Straight_Hair`|
  |33|`Wavy_Hair`|
  |34|`Wearing_Earrings`|
  |35|`Wearing_Hat`|
  |36|`Wearing_Lipstick`|
  |37|`Wearing_Necklace`|
  |38|`Wearing_Necktie`|
  |39|`Young`|
</details>

Update torchvision to fix checksum mismatch (see https://github.com/pytorch/vision/issues/7744)

In [None]:
!pip install -U torchvision

Collecting torchvision
  Downloading torchvision-0.16.2-cp310-cp310-manylinux1_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
Collecting torch==2.1.2 (from torchvision)
  Downloading torch-2.1.2-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.1.2->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.1.2->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m9

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


**Settings**

In [None]:
proportion = 2 # Percentage of images to keep for each class

archive_path ='/content/drive/MyDrive/Projet_IIN/InterFaceGAN/test_dataset/test_UnlabeledImages.tar.gz' # Path to archive containing images to classify
dataset_path = '/content/test_UnlabeledImages' # Path to dataset after extraction from the archive
output_path ='/content/drive/MyDrive/Projet_IIN/InterFaceGAN/test_dataset/Labels' # Path where the output JSON files will be placed

In [None]:
!tar zxvf {archive_path}

[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
test_UnlabeledImages/img4000.jpg
test_UnlabeledImages/img4001.jpg
test_UnlabeledImages/img4002.jpg
test_UnlabeledImages/img4003.jpg
test_UnlabeledImages/img4004.jpg
test_UnlabeledImages/img4005.jpg
test_UnlabeledImages/img4006.jpg
test_UnlabeledImages/img4007.jpg
test_UnlabeledImages/img4008.jpg
test_UnlabeledImages/img4009.jpg
test_UnlabeledImages/img4010.jpg
test_UnlabeledImages/img4011.jpg
test_UnlabeledImages/img4012.jpg
test_UnlabeledImages/img4013.jpg
test_UnlabeledImages/img4014.jpg
test_UnlabeledImages/img4015.jpg
test_UnlabeledImages/img4016.jpg
test_UnlabeledImages/img4017.jpg
test_UnlabeledImages/img4018.jpg
test_UnlabeledImages/img4019.jpg
test_UnlabeledImages/img4020.jpg
test_UnlabeledImages/img4021.jpg
test_UnlabeledImages/img4022.jpg
test_UnlabeledImages/img4023.jpg
test_UnlabeledImages/img4024.jpg
test_UnlabeledImages/img4025.jpg
test_UnlabeledImages/img4026.jpg
test_UnlabeledIma

Imports

In [None]:
from torchvision import models, transforms
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
from os.path import isfile
from tqdm import tqdm
import json
import argparse

## Downloading and loading the discriminator

In [None]:
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(device)

cuda


In [None]:
class EfficientNetB0(nn.Module):
    def __init__(self):
        super(EfficientNetB0, self).__init__()
        self.model = models.efficientnet_b0(pretrained=True)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
        self.head = nn.Sequential(
            nn.Linear(1280, 1024),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.2),
            nn.Linear(1024 , 256),
            nn.BatchNorm1d(256),
            nn.Dropout(0.2),
            nn.Linear(256 , 40)
            )

    def forward(self,x):
        x = self.model.features(x)
        #-----#
        x = self.avgpool(x)
        x = x.view(x.shape[0], -1)
        x = self.head(x)
        return x


Downloading the weights of the discriminator:

In [None]:
!wget -nc https://github.com/GurvanR/GANSpace-Reimplementation/raw/main/atclas2.pt

--2024-01-12 14:41:27--  https://github.com/GurvanR/GANSpace-Reimplementation/raw/main/atclas2.pt
Resolving github.com (github.com)... 140.82.112.4
Connecting to github.com (github.com)|140.82.112.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/GurvanR/GANSpace-Reimplementation/main/atclas2.pt [following]
--2024-01-12 14:41:27--  https://raw.githubusercontent.com/GurvanR/GANSpace-Reimplementation/main/atclas2.pt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 27818833 (27M) [application/octet-stream]
Saving to: ‘atclas2.pt’


2024-01-12 14:41:28 (151 MB/s) - ‘atclas2.pt’ saved [27818833/27818833]



Loading the discriminator

In [None]:
# Loading Discriminator

net = EfficientNetB0()
net.load_state_dict(torch.load('atclas2.pt'))
net = net.to(device)
net.eval()

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 109MB/s] 


EfficientNetB0(
  (model): EfficientNet(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (scale_

## Classification

In [None]:
def run_classification(dataset_path, output_path, proportion, batch_size=10):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    if not os.path.isdir(output_path):
        raise ValueError("Output path must be a directory, not overwriting existing file")
    filenames = [f for f in os.listdir(dataset_path) if isfile(os.path.join(dataset_path, f))]
    filenames_and_scores = [[] for _ in range(40)]
    totensor = transforms.ToTensor()

    for i in tqdm(range(0, len(filenames), batch_size)):
        # Load the batch
        batch = torch.zeros((batch_size, 3, 1024, 1024))
        batch_filenames = filenames[i:min(i+batch_size, len(filenames))]
        for j, f in enumerate(batch_filenames):
            path = os.path.join(dataset_path, f)
            if not isfile(path):
                continue
            img = Image.open(path)
            batch[j] = 255*totensor(img)
            img.close()
        batch = batch.to(device)

        # Feed the batch to the network and store the results
        with torch.inference_mode():
            out = net(batch).cpu()
        for j, f in enumerate(batch_filenames):
            for att in range(40):
                filenames_and_scores[att].append((f, out[j, att].item()))

    # Store the results in JSON files
    for att in range(40):
        with open(os.path.join(output_path, f"att{att}_scores.json"), "w") as outfile:
            json.dump(dict(filenames_and_scores[att]), outfile)

        # Create two classes with the samples classified with the highest confidence
        filenames_and_scores[att].sort(key=lambda p: p[1])
        num_top = int(proportion*len(filenames)/100)
        filenames_minus1 = [(p[0], -1) for p in filenames_and_scores[att][:num_top]]
        filenames_plus1 = [(p[0], 1) for p in filenames_and_scores[att][-num_top:]]
        with open(os.path.join(output_path, f"att{att}_labels.json"), "w") as outfile:
            json.dump(dict(filenames_minus1 + filenames_plus1), outfile)

In [None]:
run_classification(dataset_path, output_path, proportion)

100%|██████████| 1000/1000 [09:23<00:00,  1.78it/s]
