In [1]:
import pandas as pd
import numpy as np
from scipy.linalg import sqrtm
from dataset_preprocessing import Paths, Dataset
import plotly.express as px
import torch.nn as nn
import torch.nn.functional as F
import torch
import logging
import random
from metrics import confusion_matrix, accuracy_per_class
from sklearn.metrics import accuracy_score
from torch.optim import lr_scheduler, Adam
from sklearn.utils import shuffle
from tqdm import tqdm
import time
import copy
from utils import MyDataset, FocalLoss
from gan import Gen_ac_wgan_gp_1d, Gen_dcgan_gp_1d
from snn import ShallowNN

  warn(f"Failed to load image Python extension: {e}")


In [2]:
CHANNELS_IMG = 1
FEATURES_GEN = 120
Z_DIM = 100
BATCH_SIZE = 16
IMG_SIZE = 120
GEN_EMBEDDING = 100
NUM_CLASSES = 20
BATCH_SIZE_SNN = 64
EPOCHS = 20

LOGGING_FILE = "logs/cases.log"
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
inc_v3_path =  Paths.pandora_18k + 'Conv_models/Inception-V3/'

ds = Dataset(Paths.pandora_18k)

train_path = inc_v3_path + 'train_full_emb.csv'
valid_path = inc_v3_path + 'valid_full_emb.csv'
test_path = inc_v3_path + 'test_full_emb.csv'

df_train = pd.read_csv(train_path)
df_valid = pd.read_csv(valid_path)
df_test = pd.read_csv(test_path)

df = shuffle(pd.concat([df_train, df_valid], axis=0))

logging.basicConfig(level=logging.INFO, filename=LOGGING_FILE,filemode="a",
                    format="%(asctime)s %(levelname)s %(message)s")

classes = ds.classes

snn_path = inc_v3_path + 'snn.pth'

dataset_valid = MyDataset(df_valid, num_classes=len(ds.classes))

dataset_test = MyDataset(df_test, num_classes=len(ds.classes))


dataset_valid = MyDataset(df_valid, num_classes=len(ds.classes))

dataloader_valid = torch.utils.data.DataLoader(dataset=dataset_valid, 
                                        batch_size=BATCH_SIZE_SNN, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

dataloader_test = torch.utils.data.DataLoader(dataset=dataset_test, 
                                        batch_size=BATCH_SIZE_SNN, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

#### WGAN generation

In [5]:
torch.manual_seed(42)

filter_need = False

fake_vectors = pd.DataFrame()
fake_vectors_filtered = pd.DataFrame()

gen_classes = range(20)

for ind in gen_classes:

    cl = classes[ind]

    gen_path = Paths.pandora_18k + 'Generation/model/gen_' + cl + '.pkl'
    gen_path_filtered = Paths.pandora_18k + 'Generation/model/gen_' + cl + '_filtered.pkl'

    gen = Gen_dcgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN)
    gen_filtered = Gen_dcgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN, filtered=True)

    gen.load_state_dict(torch.load(gen_path))
    gen_filtered.load_state_dict(torch.load(gen_path_filtered))
    
    gen.to(DEVICE)
    gen_filtered.to(DEVICE)

    df_cl = df.query(f"label == {ind+1}")


    valid_cl = df_valid.query(f"label == {ind+1}")

    if filter_need:
        filtering = torch.tensor(valid_cl.mean())[range(ind*6, (ind+1)*6)]
    else:
        filtering = torch.zeros(6)

    cl_fake_vectors = pd.DataFrame()
    cl_fake_vectors_filtered = pd.DataFrame()

    while(len(cl_fake_vectors) < len(df_cl) // 2):

        noise = torch.randn((BATCH_SIZE, Z_DIM, 1)).to(DEVICE)

        fake = gen(noise).squeeze()

        fake_ind = []

        for i in range(len(fake)):

            if all(torch.abs(fake[i].detach().cpu()[range(ind*6, (ind+1)*6)] - filtering) >  0.0):
                fake_ind.append(i)

        cl_fake_vectors = pd.concat([cl_fake_vectors, pd.DataFrame(data=fake[fake_ind].detach().cpu())])

    cl_fake_vectors["label"] = pd.Series([ind+1 for _ in range(len(cl_fake_vectors))])
    cl_fake_vectors.columns = df.columns
    fake_vectors = pd.concat([fake_vectors, cl_fake_vectors])
    
    while(len(cl_fake_vectors_filtered) < len(df_cl) // 2):

        noise = torch.randn((BATCH_SIZE, Z_DIM, 1)).to(DEVICE)

        fake_filtered = gen_filtered(noise).squeeze()

        fake_filtered_ind = []

        for i in range(len(fake_filtered)):

            if all(torch.abs(fake_filtered[i].detach().cpu()[range(ind*6, (ind+1)*6)] - filtering) > 0.0):
                fake_filtered_ind.append(i)

        cl_fake_vectors_filtered = pd.concat([cl_fake_vectors_filtered, pd.DataFrame(data=fake_filtered[fake_filtered_ind].detach().cpu())])

    cl_fake_vectors_filtered["label"] = pd.Series([ind+1 for _ in range(len(cl_fake_vectors_filtered))])
    cl_fake_vectors_filtered.columns = df.columns
    fake_vectors_filtered = pd.concat([fake_vectors_filtered, cl_fake_vectors_filtered])

#### Conditional WGAN Generation

In [6]:
cond_fake_vectors = pd.DataFrame()
cond_fake_vectors_filtered = pd.DataFrame()

gen_cond_path = Paths.pandora_18k + 'Generation/model/gen_cond.pkl'
gen_cond_f_path = Paths.pandora_18k + 'Generation/model/gen_cond_filtered.pkl'

gen_cond = Gen_ac_wgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMG_SIZE, GEN_EMBEDDING).to(DEVICE)
gen_cond_filtered = Gen_ac_wgan_gp_1d(Z_DIM, CHANNELS_IMG, FEATURES_GEN, NUM_CLASSES, IMG_SIZE, GEN_EMBEDDING, filtered=True).to(DEVICE)

gen_cond.load_state_dict(torch.load(gen_cond_path))
gen_cond_filtered.load_state_dict(torch.load(gen_cond_f_path))

gen_cond.to(DEVICE)
gen_cond_filtered.to(DEVICE)

gen_cond.eval()
gen_cond_filtered.eval()

for ind, _ in enumerate(classes):

    cl_fake_vectors = pd.DataFrame()
    cl_fake_vectors_filtered = pd.DataFrame()

    df_cl = df.query(f"label == {ind+1}")

    for _ in range(len(df_cl) // BATCH_SIZE):

        noise = torch.randn((BATCH_SIZE, Z_DIM, 1)).to(DEVICE)
        labels = torch.tensor([ind for _ in range(BATCH_SIZE)])
        labels = labels.type(torch.LongTensor).to(DEVICE)

        fake = gen_cond(noise, labels)
        fake_filtered = gen_cond_filtered(noise, labels)

        cl_fake_vectors = pd.concat([cl_fake_vectors, pd.DataFrame(data=fake.detach().cpu().squeeze())])
        cl_fake_vectors_filtered = pd.concat([cl_fake_vectors_filtered, pd.DataFrame(data=fake_filtered.detach().cpu().squeeze())])

    cl_fake_vectors["label"] = pd.Series([ind+1 for _ in range(len(cl_fake_vectors))])
    cl_fake_vectors_filtered["label"] = pd.Series([ind+1 for _ in range(len(cl_fake_vectors_filtered))])

    cl_fake_vectors.columns = cl_fake_vectors_filtered.columns = df.columns

    cond_fake_vectors = pd.concat([cond_fake_vectors, cl_fake_vectors.iloc[:len(df_cl) // 2]])
    cond_fake_vectors_filtered = pd.concat([cond_fake_vectors_filtered, cl_fake_vectors_filtered.iloc[:len(df_cl) // 2]])

### Experiments with not filtered vectors with WGAN with augmentation of 3 classes with the lowest accuracy

In [19]:
dataset_train = MyDataset(shuffle(pd.concat([df, 
                                             fake_vectors.query("label == 6"), 
                                             fake_vectors.query("label == 10"), 
                                             fake_vectors.query("label == 11"), 
                                             ], 
                                             axis=0)), num_classes=len(ds.classes))

dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, 
                                        batch_size=BATCH_SIZE_SNN, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

dataloaders = {"train" : dataloader_train, "validation" : dataloader_valid, "test" : dataloader_test}


In [20]:
SEED = 121

torch.manual_seed(SEED)

logging.info(f"Seed {SEED}")

Net = ShallowNN().to(DEVICE)

optimizer_name = "Adam"

lr = 0.003

criterion_name = "FocalLoss"

optimizer = Adam(Net.parameters(), lr=lr, capturable=True)

criterion = FocalLoss(reduction="mean", gamma=2)

scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min")

logging.info(f"Net parameters {Net.parameters}")
logging.info(f"Optimizer :{optimizer_name}, lr : {lr}, criterion : {criterion_name}")

In [21]:
torch.manual_seed(SEED)

statistics_data = {
            'number of epochs' : range(1,EPOCHS+1),
            'training loss' : [],
            'validation loss' : [],
            'training accuracy' : [],
            'validation accuracy' : []
        }

start_time = time.time()

best_acc = 0.0
best_model_wts = copy.deepcopy(Net.state_dict())

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'validation']:
        if phase == 'train':
            Net.train()  # Set model to training mode
        else:
            Net.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        with tqdm(dataloaders[phase], unit='batch') as tepoch:
            for inputs, labels in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                inputs = inputs.to(DEVICE)
                labels = labels.type(torch.LongTensor).to(DEVICE)
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = torch.squeeze(Net(inputs))
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step(0.005)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            if phase == 'train':
                statistics_data['training loss'].append(epoch_loss)
                statistics_data['training accuracy'].append(epoch_acc.cpu().numpy())
            else:
                statistics_data['validation loss'].append(epoch_loss)
                statistics_data['validation accuracy'].append(epoch_acc.cpu().numpy())

            if phase == 'validation' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(Net.state_dict())
            
            print(f'{phase} loss : {epoch_loss:.4f} {phase} accuracy: {epoch_acc*100:.2f}%')
            logging.info(f'{phase} loss : {epoch_loss:.4f} {phase} accuracy: {epoch_acc*100:.2f}%')

        print()

time_elapsed = time.time() - start_time
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
logging.info(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best validation accuracy: {best_acc*100:.2f}%')
logging.info(f'Best validation accuracy: {best_acc*100:.2f}%')

Net.load_state_dict(best_model_wts)

Epoch 1/20
----------


Epoch 1: 100%|██████████| 266/266 [00:02<00:00, 118.83batch/s]


train loss : 2.2166 train accuracy: 65.32%



Epoch 1: 100%|██████████| 51/51 [00:00<00:00, 72.05batch/s] 


validation loss : 1.9959 validation accuracy: 68.21%

Epoch 2/20
----------


Epoch 2: 100%|██████████| 266/266 [00:02<00:00, 88.84batch/s] 


train loss : 1.7412 train accuracy: 89.77%



Epoch 2: 100%|██████████| 51/51 [00:00<00:00, 87.66batch/s] 


validation loss : 1.9505 validation accuracy: 69.03%

Epoch 3/20
----------


Epoch 3: 100%|██████████| 266/266 [00:02<00:00, 88.81batch/s] 


train loss : 1.7202 train accuracy: 90.04%



Epoch 3: 100%|██████████| 51/51 [00:00<00:00, 75.26batch/s] 


validation loss : 1.9406 validation accuracy: 69.18%

Epoch 4/20
----------


Epoch 4: 100%|██████████| 266/266 [00:03<00:00, 87.95batch/s] 


train loss : 1.7135 train accuracy: 90.14%



Epoch 4: 100%|██████████| 51/51 [00:00<00:00, 76.12batch/s] 


validation loss : 1.9318 validation accuracy: 69.57%

Epoch 5/20
----------


Epoch 5: 100%|██████████| 266/266 [00:02<00:00, 93.08batch/s] 


train loss : 1.7093 train accuracy: 90.22%



Epoch 5: 100%|██████████| 51/51 [00:00<00:00, 121.93batch/s]


validation loss : 1.9277 validation accuracy: 69.45%

Epoch 6/20
----------


Epoch 6: 100%|██████████| 266/266 [00:02<00:00, 89.09batch/s] 


train loss : 1.7065 train accuracy: 90.39%



Epoch 6: 100%|██████████| 51/51 [00:00<00:00, 116.49batch/s]


validation loss : 1.9244 validation accuracy: 70.05%

Epoch 7/20
----------


Epoch 7: 100%|██████████| 266/266 [00:02<00:00, 89.29batch/s] 


train loss : 1.7044 train accuracy: 90.50%



Epoch 7: 100%|██████████| 51/51 [00:00<00:00, 73.39batch/s]


validation loss : 1.9212 validation accuracy: 70.17%

Epoch 8/20
----------


Epoch 8: 100%|██████████| 266/266 [00:02<00:00, 91.70batch/s] 


train loss : 1.7027 train accuracy: 90.54%



Epoch 8: 100%|██████████| 51/51 [00:00<00:00, 71.02batch/s]


validation loss : 1.9204 validation accuracy: 70.08%

Epoch 9/20
----------


Epoch 9: 100%|██████████| 266/266 [00:02<00:00, 95.89batch/s] 


train loss : 1.7014 train accuracy: 90.58%



Epoch 9: 100%|██████████| 51/51 [00:00<00:00, 67.92batch/s] 


validation loss : 1.9190 validation accuracy: 70.23%

Epoch 10/20
----------


Epoch 10: 100%|██████████| 266/266 [00:02<00:00, 91.84batch/s] 


train loss : 1.7002 train accuracy: 90.65%



Epoch 10: 100%|██████████| 51/51 [00:00<00:00, 122.95batch/s]


validation loss : 1.9170 validation accuracy: 70.17%

Epoch 11/20
----------


Epoch 11: 100%|██████████| 266/266 [00:02<00:00, 90.12batch/s] 


train loss : 1.6992 train accuracy: 90.70%



Epoch 11: 100%|██████████| 51/51 [00:00<00:00, 82.94batch/s] 


validation loss : 1.9130 validation accuracy: 70.35%

Epoch 12/20
----------


Epoch 12: 100%|██████████| 266/266 [00:02<00:00, 93.21batch/s] 


train loss : 1.6985 train accuracy: 90.66%



Epoch 12: 100%|██████████| 51/51 [00:00<00:00, 85.85batch/s] 


validation loss : 1.9107 validation accuracy: 70.59%

Epoch 13/20
----------


Epoch 13: 100%|██████████| 266/266 [00:02<00:00, 100.45batch/s]


train loss : 1.6972 train accuracy: 90.78%



Epoch 13: 100%|██████████| 51/51 [00:00<00:00, 66.89batch/s]


validation loss : 1.9111 validation accuracy: 70.65%

Epoch 14/20
----------


Epoch 14: 100%|██████████| 266/266 [00:02<00:00, 98.79batch/s] 


train loss : 1.6972 train accuracy: 90.81%



Epoch 14: 100%|██████████| 51/51 [00:00<00:00, 103.91batch/s]


validation loss : 1.9111 validation accuracy: 70.56%

Epoch 15/20
----------


Epoch 15: 100%|██████████| 266/266 [00:02<00:00, 91.06batch/s] 


train loss : 1.6971 train accuracy: 90.82%



Epoch 15: 100%|██████████| 51/51 [00:00<00:00, 72.74batch/s] 


validation loss : 1.9097 validation accuracy: 70.80%

Epoch 16/20
----------


Epoch 16: 100%|██████████| 266/266 [00:02<00:00, 89.29batch/s] 


train loss : 1.6969 train accuracy: 90.83%



Epoch 16: 100%|██████████| 51/51 [00:00<00:00, 101.63batch/s]


validation loss : 1.9109 validation accuracy: 70.68%

Epoch 17/20
----------


Epoch 17: 100%|██████████| 266/266 [00:02<00:00, 89.20batch/s] 


train loss : 1.6969 train accuracy: 90.82%



Epoch 17: 100%|██████████| 51/51 [00:00<00:00, 65.26batch/s]


validation loss : 1.9106 validation accuracy: 70.62%

Epoch 18/20
----------


Epoch 18: 100%|██████████| 266/266 [00:02<00:00, 95.24batch/s] 


train loss : 1.6968 train accuracy: 90.83%



Epoch 18: 100%|██████████| 51/51 [00:00<00:00, 72.12batch/s]


validation loss : 1.9102 validation accuracy: 70.62%

Epoch 19/20
----------


Epoch 19: 100%|██████████| 266/266 [00:02<00:00, 101.03batch/s]


train loss : 1.6967 train accuracy: 90.81%



Epoch 19: 100%|██████████| 51/51 [00:00<00:00, 67.78batch/s]


validation loss : 1.9093 validation accuracy: 70.71%

Epoch 20/20
----------


Epoch 20: 100%|██████████| 266/266 [00:02<00:00, 89.58batch/s] 


train loss : 1.6964 train accuracy: 90.82%



Epoch 20: 100%|██████████| 51/51 [00:00<00:00, 65.03batch/s]


validation loss : 1.9094 validation accuracy: 70.68%

Training complete in 1m 10s
Best validation accuracy: 70.80%


<All keys matched successfully>

In [22]:
target = torch.tensor([], dtype=torch.int32).to(DEVICE)
pred = torch.tensor([], dtype=torch.int32).to(DEVICE)

for images, labels in dataloaders["test"]:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    target = torch.cat((target, labels))

    outputs = Net(images)
    _, predictions = torch.max(outputs, 2)
    predictions = torch.squeeze(predictions, 1)

    pred = torch.cat((pred, predictions))

target, pred = target.to(torch.int32).cpu(), pred.to(torch.int32).cpu()

print(f"Accuracy : {round(accuracy_score(target, pred) * 100, 3)} %")

Accuracy : 70.335 %


In [23]:
acc_per_class = accuracy_per_class(pred, target, ds.classes)

for style, acc in acc_per_class.items():
    print(f'Accuracy for {style}: {acc:.1f} %')

Accuracy for 01_Byzantin_Iconography: 94.7 %
Accuracy for 02_Early_Renaissance: 79.0 %
Accuracy for 03_Northern_Renaissance: 85.1 %
Accuracy for 04_High_Renaissance: 73.9 %
Accuracy for 05_Baroque: 59.0 %
Accuracy for 06_Rococo: 47.6 %
Accuracy for 07_Romanticism: 54.2 %
Accuracy for 08_Realism: 71.4 %
Accuracy for 09_Impressionism: 74.1 %
Accuracy for 10_Post_Impressionism: 61.7 %
Accuracy for 11_Expressionism: 47.0 %
Accuracy for 12_Symbolism: 63.4 %
Accuracy for 13_Fauvism: 55.8 %
Accuracy for 14_Cubism: 76.4 %
Accuracy for 15_Surrealism: 64.1 %
Accuracy for 16_AbstractArt: 70.1 %
Accuracy for 17_NaiveArt: 67.1 %
Accuracy for 18_PopArt: 73.2 %
Accuracy for 19_ChineseArt: 77.0 %
Accuracy for 20_JapaneseArt: 98.6 %


In [24]:
x = sorted(list(acc_per_class.items()), key=lambda x : x[1])[:3]

print(f"Mean accuracy for min 3 styles : {sum([el[1] for el in x]) / 3:.1f} %")

for style, acc in x:
    print(f'Accuracy for {style}: {acc:.1f} %')

Mean accuracy for min 3 styles : 49.6 %
Accuracy for 11_Expressionism: 47.0 %
Accuracy for 06_Rococo: 47.6 %
Accuracy for 07_Romanticism: 54.2 %
