In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.io import read_image
import numpy as np
from torch.utils.data import DataLoader
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
import load_dataset
import common_functions
from time import time
from tqdm import tqdm
import unet_model
import matplotlib.pyplot as plt
from pytorch_msssim import MS_SSIM, ms_ssim, SSIM, ssim
import unet_model_v2
import torch.nn.functional as F

In [None]:
class MS_SSIM_Loss(MS_SSIM):
    def forward(self, img1, img2):
        return 75*( 1 - super(MS_SSIM_Loss, self).forward(img1, img2) )

    
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.pooling = torch.nn.AvgPool2d(28)
        self.layer1 = torch.nn.Linear(256, 128)
        self.layer2 = torch.nn.Linear(128, 2)

    def forward(self, x):
        x = torch.nn.Flatten()(self.pooling(x))
        output = self.layer2(F.relu(self.layer1(x)))

        return output


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 4
BATCHSIZE = 32

GazeCapture = load_dataset.GazeCapture()
size = len(GazeCapture)
train_size = int(size * 0.8)
test_size = size - train_size
mpii_data = load_dataset.MPIIDataset()

train_dataset, test_dataset = torch.utils.data.random_split(GazeCapture, [train_size, test_size],
                                                            generator=torch.Generator().manual_seed(0))

train_dataloader = DataLoader(train_dataset, batch_size=BATCHSIZE, shuffle=True, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(mpii_data, batch_size=BATCHSIZE, shuffle=True, num_workers=NUM_WORKERS)

Encoder_raw = unet_model_v2.UNetEncoder(n_channels=3, num_classes=3, base_filter_num=32, num_blocks=4).to(device)
Encoder_anchor = unet_model_v2.UNetEncoder(n_channels=3, num_classes=3, base_filter_num=32, num_blocks=4).to(device)
Decoder = unet_model_v2.UNetDecoder(n_channels=3, num_classes=3, base_filter_num=32, num_blocks=4).to(device)
gaze_mlp = MLP().to(device)

gaze_estimator = resnet18().to(device)
gaze_estimator.fc = nn.Linear(512, 2).to(device)
gaze_estimator.load_state_dict(torch.load("/data/volume_2/GazePrivacyModelsV2/PretrainRes18GazeCapture/model.pt"))


identity_net = resnet18().to(device)
identity_net.fc = nn.Linear(512, 15).to(device)

loss_reconstruction = MS_SSIM_Loss(data_range=1.0, size_average=True, channel=3)
loss_gaze = nn.L1Loss()
loss_idenetity = nn.CrossEntropyLoss()

optimizer_ae = torch.optim.Adam(list(Encoder_raw.parameters())+list(Encoder_anchor.parameters())
                                + list(Decoder.parameters()) + list(gaze_mlp.parameters()), lr=1e-3)
optimizer_id = torch.optim.Adam(identity_net.parameters(), lr=5e-3)
optimizer_de = torch.optim.Adam(Decoder.parameters(), lr=1e-3)

anchor_image = torch.unsqueeze(read_image("/data/volume_2/Gaze_privacy_v2/Gaze_privacy/gazecapture_average_face_efficientnet.png")/255., 0).to(device)
anchor_image = F.interpolate(anchor_image, size=(224, 224))

gaze_model_vgg = torch.hub.load('pytorch/vision:v0.10.0', 'vgg11', pretrained=False).to(device)
gaze_model_vgg.classifier._modules['6'] = nn.Linear(4096, 2).to(device)
gaze_model_vgg.load_state_dict(torch.load("/data/volume_2/GazePrivacyModels/PretrainVGG11XGaze/model.pt"))

gaze_model_res18 = resnet18().to(device)
gaze_model_res18.fc = nn.Linear(512, 2).to(device)
gaze_model_res18.load_state_dict(torch.load("/data/volume_2/GazePrivacyModels/PretrainRes18XGaze/model.pt"))

gaze_model_mobilenet = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True).to(device)
gaze_model_mobilenet.classifier[1] = nn.Linear(1280, 2).to(device)
gaze_model_mobilenet.load_state_dict(torch.load("/data/volume_2/GazePrivacyModels/PretrainMobileNetV2XGaze/model.pt"))

gaze_model_efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b0', pretrained=True).to(device)
gaze_model_efficientnet.classifier.fc = nn.Linear(1280, 2).to(device)
gaze_model_efficientnet.load_state_dict(torch.load("/data/volume_2/GazePrivacyModels/PretrainEffiNetXGaze/model.pt"))

gaze_model_shufflenet = torch.hub.load('pytorch/vision:v0.10.0', 'shufflenet_v2_x1_0', pretrained=True).to(device)
gaze_model_shufflenet.fc = nn.Linear(1024, 2).to(device)
gaze_model_shufflenet.load_state_dict(torch.load("/data/volume_2/GazePrivacyModels/PretrainShuffleNetXGaze/model.pt"))

In [None]:
def ID_train(dataloader, ID):
    size = len(dataloader.dataset)
    Encoder_raw.eval()
    Encoder_anchor.eval()
    Decoder.eval()
    ID.train()

    for batch, (X, y) in tqdm(enumerate(dataloader)):
        X, y, y_id = X.to(device), y[:, 0:2].to(device), y[:, 2]
        y_id = y_id.type(torch.LongTensor).to(device)

        representation, _ = Encoder_raw(X)
        _, anchor_mid = Encoder_anchor(anchor_image)
        anchor_mid = [torch.tile(anchor_mid[0], (y_id.shape[0], 1, 1, 1)), torch.tile(anchor_mid[1], (y_id.shape[0], 1, 1, 1)),
                      torch.tile(anchor_mid[2], (y_id.shape[0], 1, 1, 1)), torch.tile(anchor_mid[3], (y_id.shape[0], 1, 1, 1))]
        reconstructed_img = Decoder(representation, anchor_mid)

        id_pre = ID(reconstructed_img)
        loss = loss_idenetity(id_pre, y_id)

        loss.backward()
        optimizer_id.step()
        optimizer_id.zero_grad()

        if batch % 300 == 0:
            current = (batch + 1) * len(X)
            print(f"[{current:>5d}/{size:>5d}]")

            id_pre = ID(reconstructed_img)
            _, predictions = torch.max(id_pre, 1)
            correct = (predictions == y_id).sum().item()
            print("id acc", correct / BATCHSIZE)
            img = reconstructed_img.cpu().data.numpy()[0, :, :, :]
            img = np.transpose(img, (1, 2, 0))
            plt.imshow(img)
            plt.show()

In [None]:
def test(dataloader, Gaze, ID):
    size = len(dataloader.dataset)
    Gaze.eval()
    Encoder_raw.eval()
    Encoder_anchor.eval()
    Decoder.eval()
    ID.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y, y_id = X.to(device), y[:, 0:2].to(device), y[:, 2]
            y_id = y_id.type(torch.LongTensor).to(device)

            representation, _ = Encoder_raw(X)
            _, anchor_mid = Encoder_anchor(anchor_image)
            anchor_mid = [torch.tile(anchor_mid[0], (y_id.shape[0], 1, 1, 1)),
                          torch.tile(anchor_mid[1], (y_id.shape[0], 1, 1, 1)),
                          torch.tile(anchor_mid[2], (y_id.shape[0], 1, 1, 1)),
                          torch.tile(anchor_mid[3], (y_id.shape[0], 1, 1, 1))]
            reconstructed_img = Decoder(representation, anchor_mid)

            gaze_pre = Gaze(reconstructed_img)
            test_loss += common_functions.avg_angle_error(gaze_pre, y).item() * y.shape[0]

            id_pre = ID(reconstructed_img)
            _, predictions = torch.max(id_pre, 1)
            correct += (predictions == y_id).sum().item()

    test_loss /= size
    print("angular error on testing set:", test_loss, "id acc", correct / size)

In [None]:
def pre_train_ae(dataloader, Gaze, ID):
    size = len(dataloader.dataset)
    Encoder_raw.train()
    Encoder_anchor.train()
    Decoder.train()
    ID.eval()
    Gaze.eval()

    for batch, (X, y) in tqdm(enumerate(dataloader)):
        X, y, y_id = X.to(device), y[:, 0:2].to(device), y[:, 2]
        y_id = y_id.type(torch.LongTensor).to(device)
        # y = y + bias_prediction
        representation, _ = Encoder_raw(X)
        _, anchor_mid = Encoder_anchor(anchor_image)
        anchor_mid = [torch.tile(anchor_mid[0], (y_id.shape[0], 1, 1, 1)), torch.tile(anchor_mid[1], (y_id.shape[0], 1, 1, 1)),
                      torch.tile(anchor_mid[2], (y_id.shape[0], 1, 1, 1)), torch.tile(anchor_mid[3], (y_id.shape[0], 1, 1, 1))]
        reconstructed_img = Decoder(representation, anchor_mid)

        gaze_pre = Gaze(reconstructed_img)
        gaze_pre_mlp = gaze_mlp(representation)
        loss = loss_gaze(gaze_pre, y) + loss_gaze(gaze_pre_mlp, y) + \
               loss_reconstruction(reconstructed_img, torch.tile(anchor_image, (y_id.shape[0], 1, 1, 1)))

        loss.backward()
        optimizer_ae.step()
        optimizer_ae.zero_grad()
        
        
        if batch % 300 == 0:
            current = (batch + 1) * len(X)
            print(f"[{current:>5d}/{size:>5d}]")

            loss_recon = loss_reconstruction(reconstructed_img, X)
            gaze_pre = Gaze(reconstructed_img)
            gaze_pre_mlp = gaze_mlp(representation)
            #converted_gaze_pre_mlp = gaze_mlp(converted_representation)
            angle_error = common_functions.avg_angle_error(gaze_pre, y)
            angle_error_mlp = common_functions.avg_angle_error(gaze_pre_mlp, y)
            #converted_angle_error_mlp = common_functions.avg_angle_error(converted_gaze_pre_mlp, y)
            id_pre = ID(reconstructed_img)
            _, predictions = torch.max(id_pre, 1)
            correct = (predictions == y_id).sum().item()
            print("reconstruction loss", loss_recon.item(), "gaze loss", angle_error.item(),
                  "gaze loss mlp", angle_error_mlp.item(), #"converted gaze loss mlp", converted_angle_error_mlp.item(), 
                  "id acc", correct / BATCHSIZE)
            img = reconstructed_img.cpu().data.numpy()[0, :, :, :]
            img = np.transpose(img, (1, 2, 0))
            plt.imshow(img)
            plt.show()
        if (batch % 3000 == 0 and batch!=0) or batch == 300:
            print("testing performance on whitebox gaze model")
            test(test_dataloader, gaze_estimator, identity_net)

            print("testing performance on efficientnet blackbox gaze model (XGaze)")
            test(test_dataloader, gaze_model_efficientnet, identity_net)
        if batch == 12000:
            break

In [None]:
epochs = 1
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    pre_train_ae(train_dataloader, gaze_estimator, identity_net)
    
    print("testing performance on whitebox gaze model")
    test(test_dataloader, gaze_estimator, identity_net)
    
    print("testing performance on vgg11 blackbox gaze model (XGaze)")
    test(test_dataloader, gaze_model_vgg, identity_net)
    
    print("testing performance on resnet18 blackbox gaze model (XGaze)")
    test(test_dataloader, gaze_model_res18, identity_net)
    
    print("testing performance on mobilenetv2 blackbox gaze model (XGaze)")
    test(test_dataloader, gaze_model_mobilenet, identity_net)
    
    print("testing performance on efficientnet blackbox gaze model (XGaze)")
    test(test_dataloader, gaze_model_efficientnet, identity_net)
    
    print("testing performance on shufflenet blackbox gaze model (XGaze)")
    test(test_dataloader, gaze_model_shufflenet, identity_net)
print("Done!")


In [None]:
torch.save(Encoder_raw.state_dict(), "/data/volume_2/GazePrivacyModelsV2/GazeCaptureRes18EncRawMeanFace/efficientnet_avg_face.pt")
torch.save(Encoder_anchor.state_dict(), "/data/volume_2/GazePrivacyModelsV2/GazeCaptureRes18EncAnchorMeanFace/efficientnet_avg_face.pt")
torch.save(Decoder.state_dict(), "/data/volume_2/GazePrivacyModelsV2/GazeCaptureRes18DecTwoMeanFace/efficientnet_avg_face.pt")
torch.save(gaze_mlp.state_dict(), "/data/volume_2/GazePrivacyModelsV2/GazeCaptureMLP/efficientnet_avg_face.pt")