In [None]:
!git clone https://github.com/cydonia999/VGGFace2-pytorch.git

In [None]:
try:
    from google.colab import drive
    IN_COLAB = True
    print("Running on Google Colab. ")
except:
    IN_COLAB = False
    print("Not running on Google Colab. ")

## Function to create image list file for NN2 evaluation

In [None]:
import os
def create_image_list_file(root_dir, output_file, ext = '.jpg'):

    image_paths = []

    for class_id in os.listdir(root_dir):
        class_dir = os.path.join(root_dir, class_id)
        
        if os.path.isdir(class_dir):

            for filename in os.listdir(class_dir):
                
                if filename.endswith(ext):  
                    image_path = f"{os.path.basename(root_dir)}/{class_id}/{filename}"  
                    image_paths.append(image_path)

    with open(output_file, 'w') as f:
        for image_path in image_paths:
            f.write(image_path + '\n')

    print(f"File di output creato con successo: {output_file}")

## Evaluation of NN2 as defined in the repository of the net (black-box)

In [None]:
import torch
from VGGFace2_pytorch.models import senet as SENet
from VGGFace2_pytorch.models.resnet import resnet50 as ResNet
from VGGFace2_pytorch import utils
from VGGFace2_pytorch.trainer import Validator
from torch.utils.data import DataLoader
from VGGFace2_pytorch.datasets.vgg_face2 import VGG_Faces2

import os
from torch.nn.modules.loss import CrossEntropyLoss

root_dir = "face_dataset/test_set_MTCNN_NN2" 
output_file = 'image_list_file_NN2.txt'
create_image_list_file(root_dir, output_file)
meta_file = ".\\face_dataset\identity_meta.csv"
id_label_dict = utils.get_id_label_map(meta_file)
model = SENet.senet50(num_classes = 8631, include_top = True)
utils.load_state_dict(model, ".\in_progress\senet50_ft_weight.pkl")
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model.to(device)
print("device: ", device)
#import torchsummary
#torchsummary.summary(model, (3, 224, 224))


val_dataset = VGG_Faces2(".//face_dataset", output_file, id_label_dict, split = 'valid')
val_loader = DataLoader(val_dataset, batch_size = 1)


"""
for batch_idx, (imgs, target, img_files, class_ids) in enumerate(val_loader):
    imgs = imgs.to(device)
    target = target.to(device)
    with torch.no_grad():
        output = model(imgs)
        print("Target: ", target)
        print("Class ids: ", class_ids)
        pred = torch.argmax(output, dim = 1)
        print("Predictions: ", pred)
        break
"""

validator = Validator(
            cmd = "test",
            cuda = True,
            model = model,
            criterion = CrossEntropyLoss(),
            val_loader = val_loader,
            log_file = "./log_file",
            print_freq = 1000,
        )

accuracy = validator.validate()


In [None]:
print("Accuracy: ", accuracy)

In [None]:
from PIL import Image
from torchvision import transforms 
def make_inference_NN2(model, img_path,  id_label_dict, device, temp_file = 'temp.txt'):

    # create file temp per l'immagine .txt
    with open(temp_file, 'w') as f:
        f.write(img_path)
    output_file = temp_file
    val_dataset = VGG_Faces2(".//face_dataset", output_file, id_label_dict, split = 'valid')
    val_loader = DataLoader(val_dataset, batch_size = 1)
    validator = Validator(
            cmd = "test",
            cuda = True if "cuda" in device else False,
            model = model,
            criterion = CrossEntropyLoss(),
            val_loader = val_loader,
            log_file = "./log_file",
            print_freq = 1000,
        )
    os.remove(temp_file)
    return validator.validate(make_inference = True)


In [None]:
# Test inference
if IN_COLAB:
    img_path = "/content/drive/Shareddrives/AI4CYBSEC/test_set_MTCNN_NN2/n000017/0082_01.jpg"
else:
    img_path = "test_set_MTCNN_NN2/n000017/0082_01.jpg"
attack = "PGD"
network = "NN2"
temp_file = "temp_"+attack+"_"+network+".txt"

# come vedere se una stringa contiene un'altra stringa

pred = make_inference_NN2(model, img_path, id_label_dict, device, temp_file=temp_file)
print("Prediction: ", pred)
print("True class: ", id_label_dict[os.path.basename(os.path.dirname(img_path))])
