# Explainable Adversarial Auto-Encoder Network (xAAEnet)

/!\ Travailler sur le GPU /!\
Runtime -> Change runtime type -> T4 GPU -> Save

Cloner le repository GitHub et se déplacer à l'intérieur du dossier correspondant

In [1]:
! git clone https://github.com/LucaLaFisca/Human-Centered-xAI.git
%cd Human-Centered-xAI

Cloning into 'Human-Centered-xAI'...
remote: Enumerating objects: 76, done.[K
remote: Counting objects: 100% (76/76), done.[K
remote: Compressing objects: 100% (62/62), done.[K
remote: Total 76 (delta 37), reused 40 (delta 12), pack-reused 0[K
Receiving objects: 100% (76/76), 3.06 MiB | 15.72 MiB/s, done.
Resolving deltas: 100% (37/37), done.
/content/Human-Centered-xAI


Importer les librairies nécessaires

In [2]:
import torch
from fastai.vision.all import *
from fastai.data.all import *

from model import AAE
from utils import label_func, FreezeDiscriminator, GetLatentSpace, LossAttrMetric, distrib_regul_regression, compute_main_direction

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns

Définir les données d'entrée

In [3]:
### Define the Dataloader
data_path = untar_data(URLs.PETS) #checker les autres databases dispo
print(data_path.ls())

catblock = MultiCategoryBlock(encoded=True, vocab=['cat', 'dog'])
dblock = DataBlock(
    blocks=(ImageBlock(), catblock),
    get_items=get_image_files,
    splitter=RandomSplitter(valid_pct=0.2, seed=42),
    get_y=label_func,
    item_tfms=Resize(128),
    batch_tfms=[Normalize.from_stats(*imagenet_stats)],
)

# Créez un DataLoader
dls = dblock.dataloaders(data_path/"images", bs=16, drop_last=True)
print('dls created')

# extract the list of labels from the dls
labels = [dls.train_ds[i][1].argmax().item() for i in range(len(dls.train_ds))]
labels = labels + [dls.valid_ds[i][1].argmax().item() for i in range(len(dls.valid_ds))]
labels = torch.tensor(labels)
print(f'final labels: {labels}')

[Path('/root/.fastai/data/oxford-iiit-pet/annotations'), Path('/root/.fastai/data/oxford-iiit-pet/images')]
dls created
final labels: tensor([0, 1, 1,  ..., 1, 0, 0])


Créer le modèle

In [5]:
model = AAE(
        input_size=128,
        input_channels=3,
        encoding_dims=128,
        classes=2,
)

Entraînement de l'auto-encodeur

In [None]:
### Train Autoencoder ###
model_file = 'cat_dog_ae_test'
loss_func = model.ae_loss_func

metrics = [LossAttrMetric("recons_loss"), accuracy_multi]
learn = Learner(dls, model, loss_func=loss_func, metrics=metrics)

learning_rate = learn.lr_find()
learn.fit(100, lr=learning_rate.valley,
            cbs=[TrackerCallback(),
                 SaveModelCallback(fname=model_file),
                 EarlyStoppingCallback(min_delta=1e-4,patience=10)])

state_dict = torch.load(f'models/{model_file}.pth')
model.load_state_dict(state_dict, strict=False)

Entraînement adversarial

In [None]:
### Train Adversarial ###
model_file = 'cat_dog_aae_test'
loss_func = model.aae_loss_func

metrics = [LossAttrMetric("adv_loss"), LossAttrMetric("recons_loss"), LossAttrMetric("crit_loss"),
           accuracy_multi]
learn = Learner(dls, model, loss_func=loss_func, metrics=metrics)

learn.fit(100, lr=5e-3,
            cbs=[GradientAccumulation(n_acc=16*4),
                 TrackerCallback(),
                 SaveModelCallback(fname=model_file),
                 EarlyStoppingCallback(min_delta=1e-4,patience=10),
                 FreezeDiscriminator()])

state_dict = torch.load(f'models/{model_file}.pth')
model.load_state_dict(state_dict, strict=False)

Entraînement du classifieur

In [None]:
### Train Classifier ###
model_file = 'cat_dog_classif_test'
loss_func = model.classif_loss_func

metrics = [LossAttrMetric("adv_loss"), LossAttrMetric("recons_loss"),
           LossAttrMetric("classif_loss"), LossAttrMetric("crit_loss"),
           accuracy_multi]
monitor_loss = 'valid_loss'
learn = Learner(dls, model, loss_func=loss_func, metrics=metrics)

learn.fit(100, lr=1e-2,
            cbs=[GradientAccumulation(n_acc=16*4),
                 TrackerCallback(monitor=monitor_loss),
                 SaveModelCallback(fname=model_file,monitor=monitor_loss),
                 EarlyStoppingCallback(min_delta=1e-4,patience=10,monitor=monitor_loss),
                 FreezeDiscriminator()])

Calcul de l'espace latent final

In [None]:
# load the updated model
learn.load(model_file, strict=False)
# compute the latent space
dev = f'cuda:{torch.cuda.current_device()}'
learn.zi_valid = torch.tensor([]).to(dev)
learn.get_preds(ds_idx=0,cbs=[GetLatentSpace()])
z = learn.zi_valid
learn.zi_valid = torch.tensor([]).to(dev)
learn.get_preds(ds_idx=1,cbs=[GetLatentSpace()])
z = torch.vstack((z,learn.zi_valid))
torch.save(z,'z_aae.pt')
print(z.shape)

Affichage de l'espace latent

In [None]:
# Convert to 2D space for visualization
tsne = TSNE(random_state=42)
z = z.view(-1, 128)
predictions_embedded = tsne.fit_transform(z.cpu().detach().numpy())

fig, ax = plt.subplots()
sns.scatterplot(x=predictions_embedded[:,0], y=predictions_embedded[:,1], hue=labels, s=55)
# Plot the line along the first principal component
start, end = compute_main_direction(predictions_embedded, labels)
ax.arrow(start[0], start[1], end[0]-start[0], end[1]-start[1], linewidth=3,
          head_width=10, head_length=10, fc='#8B0000', ec='#8B0000', length_includes_head=True)

# Define x,y limits
maxabs = np.max(np.abs(predictions_embedded)) + 5
plt.xlim([-maxabs, maxabs])
plt.ylim([-maxabs, maxabs])

# Remove xticks and yticks
ax.set_xticks([])
ax.set_yticks([])
# Remove the legend
ax.get_legend().remove()

Calcul du score final

In [10]:
# final scoring
#Compute linear regression from latent space
y_pred = distrib_regul_regression(z.cpu().detach().numpy(), labels)