In [10]:
import pandas as pd
import numpy as np
from scipy.linalg import sqrtm
from dataset_preprocessing import Paths, Dataset
import plotly.express as px
import torch
import logging
from gan import Gen_dcgan_gp_1d, Gen_ac_wgan_gp_1d

In [11]:
CHANNELS_IMG = 1
FEATURES_GEN = 120
Z_DIM = 100
BATCH_SIZE = 16
FID_ITERATIONS = 30
IMG_SIZE = 120
GEN_EMBEDDING = 100
NUM_CLASSES = 20

LOGGING_FILE = "logs/prob_vectors_analysis.log"
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [12]:
inc_v3_path =  Paths.pandora_18k + 'Conv_models/Inception-V3/'

ds = Dataset(Paths.pandora_18k)

train_path = inc_v3_path + 'train_full_emb.csv'
valid_path = inc_v3_path + 'valid_full_emb.csv'
test_path = inc_v3_path + 'test_full_emb.csv'

df_train = pd.read_csv(train_path)
df_valid = pd.read_csv(valid_path)
df_test = pd.read_csv(test_path)

In [13]:
data = {
    "num_features" : range(1, 120+1),
}

for i in range(ds.number_of_classes):
    style_mean = df_train.query(f"label == {i+1}").drop(["label"], axis=1).mean()
    data[ds.classes[i]] = style_mean

styles = list(data.keys())
styles.remove("num_features")

In [14]:
fig = px.line(data, x="num_features", y=styles, labels={'value':'Mean probability','variable':'Style name'}, markers=True)
fig.update_layout(
    width=1100,
    height=600,
)
fig.show()

## Evaluating of synthetic probability vectors 

### Frechet inception distance

In [6]:
def fid(df_real, df_fake):

    x_real = np.asarray(df_real.drop(["label"], axis=1))

    x_fake = np.asarray(df_fake.drop(["label"], axis=1))

    real_mean, real_cov = x_real.mean(axis=0), np.cov(x_real, rowvar=False)

    fake_mean, fake_cov = x_fake.mean(axis=0), np.cov(x_fake, rowvar=False)

    ssdiff = np.sum((real_mean-fake_mean)**2.0)

    covmean = sqrtm(real_cov.dot(fake_cov))

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = ssdiff + np.trace(real_cov + fake_cov - 2*covmean)

    return fid

In [8]:
logging.basicConfig(level=logging.INFO, filename=LOGGING_FILE,filemode="a",
                    format="%(asctime)s %(levelname)s %(message)s")

df = pd.concat([df_train, df_valid], axis=0)

classes = ds.classes

#### WGAN-GP

In [9]:
logging.info("Generating WGAN_GP synthetic samples started")
logging.info(f"FID iterations : {FID_ITERATIONS}")

for ind, cl in enumerate(classes):

    gen_path = Paths.pandora_18k + 'Generation/model/gen_' + cl + '.pkl'
    gen_path_filtered = Paths.pandora_18k + 'Generation/model/gen_' + cl + '_filtered.pkl'

    gen = Gen_dcgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN)
    gen_filtered = Gen_dcgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN, filtered=True)

    gen.load_state_dict(torch.load(gen_path))
    gen_filtered.load_state_dict(torch.load(gen_path_filtered))
    
    gen.to(DEVICE)
    gen_filtered.to(DEVICE)

    gen.eval()
    gen_filtered.eval()

    df_cl = df.query(f"label == {ind+1}")

    fid_mean = 0

    for _ in range(FID_ITERATIONS):

        fake_vectors = pd.DataFrame()
        fake_vectors_filtered = pd.DataFrame()

        for _ in range(len(df_cl) // BATCH_SIZE):

            noise = torch.randn((BATCH_SIZE, Z_DIM, 1)).to(DEVICE)

            y = gen(noise).squeeze()
            y_filtered = gen_filtered(noise).squeeze()

            fake_vectors = pd.concat([fake_vectors, pd.DataFrame(data=y.detach().cpu())])
            fake_vectors_filtered = pd.concat([fake_vectors_filtered, pd.DataFrame(data=y_filtered.detach().cpu())])

        fake_vectors["label"] = pd.Series([ind+1 for _ in range(len(fake_vectors))])
        fake_vectors.columns = df.columns

        fake_vectors_filtered["label"] = pd.Series([ind+1 for _ in range(len(fake_vectors_filtered))])
        fake_vectors_filtered.columns = df.columns

        fid_nf = fid(df_cl, fake_vectors)
        fid_f = fid(df_cl, fake_vectors_filtered)

        fid_mean += fid_nf

    fid_mean /= FID_ITERATIONS

    print(f"FID for {cl} unfiltered case : {round(fid_mean, 3)}")
    logging.info(f"FID for {cl} unfiltered case : {round(fid_mean, 3)}")
    print(f"FID for {cl} filtered case : {round(fid_f, 3)}")
    logging.info(f"FID for {cl} filtered case : {round(fid_f, 3)}")

FID for 01_Byzantin_Iconography unfiltered case : 0.196
FID for 01_Byzantin_Iconography filtered case : 0.856
FID for 02_Early_Renaissance unfiltered case : 0.538
FID for 02_Early_Renaissance filtered case : 0.304
FID for 03_Northern_Renaissance unfiltered case : 0.334
FID for 03_Northern_Renaissance filtered case : 0.278
FID for 04_High_Renaissance unfiltered case : 0.178
FID for 04_High_Renaissance filtered case : 0.24
FID for 05_Baroque unfiltered case : 0.511
FID for 05_Baroque filtered case : 0.292
FID for 06_Rococo unfiltered case : 0.182
FID for 06_Rococo filtered case : 0.28
FID for 07_Romanticism unfiltered case : 0.348
FID for 07_Romanticism filtered case : 0.147
FID for 08_Realism unfiltered case : 0.256
FID for 08_Realism filtered case : 0.184
FID for 09_Impressionism unfiltered case : 0.254
FID for 09_Impressionism filtered case : 0.31
FID for 10_Post_Impressionism unfiltered case : 0.29
FID for 10_Post_Impressionism filtered case : 0.218
FID for 11_Expressionism unfiltere

#### CWGAN-GP

In [10]:
logging.info("Generating CWGAN_GP synthetic samples started")
logging.info(f"FID iterations : {FID_ITERATIONS}")

gen_cond_path = Paths.pandora_18k + 'Generation/model/gen_cond.pkl'
gen_cond_f_path = Paths.pandora_18k + 'Generation/model/gen_cond_filtered.pkl'

gen_cond = Gen_ac_wgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMG_SIZE, GEN_EMBEDDING).to(DEVICE)
gen_cond_filtered = Gen_ac_wgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMG_SIZE, GEN_EMBEDDING, filtered=True).to(DEVICE)

gen_cond.load_state_dict(torch.load(gen_cond_path))
gen_cond_filtered.load_state_dict(torch.load(gen_cond_f_path))

gen_cond.to(DEVICE)
gen_cond_filtered.to(DEVICE)

gen_cond.eval()
gen_cond_filtered.eval()

for ind, cl in enumerate(classes):

    df_cl = df.query(f"label == {ind+1}")

    fid_mean = 0

    for _ in range(FID_ITERATIONS):

        fake_vectors = pd.DataFrame()
        fake_vectors_filtered = pd.DataFrame()

        for _ in range(len(df_cl) // BATCH_SIZE):

            noise = torch.randn((BATCH_SIZE, Z_DIM, 1)).to(DEVICE)
            labels = torch.tensor([ind for _ in range(BATCH_SIZE)])
            labels = labels.type(torch.LongTensor).to(DEVICE)

            y = gen_cond(noise, labels)
            y_filtered = gen_cond_filtered(noise, labels)

            fake_vectors = pd.concat([fake_vectors, pd.DataFrame(data=y.detach().cpu().squeeze())])
            fake_vectors_filtered = pd.concat([fake_vectors_filtered, pd.DataFrame(data=y_filtered.detach().cpu().squeeze())])

        fake_vectors["label"] = pd.Series([ind+1 for _ in range(len(fake_vectors))])
        fake_vectors.columns = df.columns

        fake_vectors_filtered["label"] = pd.Series([ind+1 for _ in range(len(fake_vectors_filtered))])
        fake_vectors_filtered.columns = df.columns

        fid_nf = fid(df_cl, fake_vectors)
        fid_f = fid(df_cl, fake_vectors_filtered)

        fid_mean += fid_nf

    fid_mean /= FID_ITERATIONS

    print(f"FID for {cl} unfiltered case : {round(fid_mean, 3)}")
    logging.info(f"FID for {cl} unfiltered case : {round(fid_mean, 3)}")
    print(f"FID for {cl} filtered case : {round(fid_f, 3)}")
    logging.info(f"FID for {cl} filtered case : {round(fid_f, 3)}")

FID for 01_Byzantin_Iconography unfiltered case : 0.186
FID for 01_Byzantin_Iconography filtered case : 0.168
FID for 02_Early_Renaissance unfiltered case : 0.766
FID for 02_Early_Renaissance filtered case : 0.992
FID for 03_Northern_Renaissance unfiltered case : 2.294
FID for 03_Northern_Renaissance filtered case : 1.1
FID for 04_High_Renaissance unfiltered case : 0.616
FID for 04_High_Renaissance filtered case : 0.861
FID for 05_Baroque unfiltered case : 1.441
FID for 05_Baroque filtered case : 1.408
FID for 06_Rococo unfiltered case : 1.325
FID for 06_Rococo filtered case : 1.196
FID for 07_Romanticism unfiltered case : 1.993
FID for 07_Romanticism filtered case : 1.412
FID for 08_Realism unfiltered case : 1.29
FID for 08_Realism filtered case : 1.18
FID for 09_Impressionism unfiltered case : 4.817
FID for 09_Impressionism filtered case : 1.083
FID for 10_Post_Impressionism unfiltered case : 3.353
FID for 10_Post_Impressionism filtered case : 2.364
FID for 11_Expressionism unfiltere