In [1]:
import torch
from torchvision import transforms
from torchvision import models
import torch.backends.cudnn as cudnn
import torch.nn as nn
import pandas as pd
import time
import copy
import logging
import random
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.optim import Adam
from torchvision import datasets
from torch.optim import lr_scheduler
from sklearn.utils import shuffle
import torch.nn.functional as F
from metrics import confusion_matrix, accuracy_per_class
from sklearn.metrics import accuracy_score

from dataset_preprocessing import Paths, Dataset
from snn import ShallowNN
from utils import FocalLoss, MyDataset

  warn(f"Failed to load image Python extension: {e}")


In [2]:
BATCH_SIZE = 64

EPOCHS = 20

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
inc_v3_path =  Paths.pandora_18k + 'Conv_models/Inception-V3/'

train_path = inc_v3_path + 'train_full_emb.csv'
valid_path = inc_v3_path + 'valid_full_emb.csv'
test_path = inc_v3_path + 'test_full_emb.csv'
mistakes_path = inc_v3_path + 'mistakes.csv'

fake_Baroque_path = inc_v3_path + 'fake_emb_Baroque.csv'
fake_Rococo_path = inc_v3_path + 'fake_emb_Rococo.csv'
fake_Romanticism_path = inc_v3_path + 'fake_emb_Romanticism.csv'
fake_Expressionism_path = inc_v3_path + 'fake_emb_Expressionism.csv'
fake_PostImpressionism_path = inc_v3_path + 'fake_emb_Post-Impressionism.csv'

train_prob =  inc_v3_path + 'train_prob.csv'
valid_prob =  inc_v3_path + 'valid_prob.csv'
test_prob =  inc_v3_path + 'test_prob.csv'

model_save_path = inc_v3_path + 'snn.pth'

logging_file = "logs/snn.log"

In [3]:
ds = Dataset(Paths.pandora_18k)

df_train = pd.read_csv(train_path)

df_valid = pd.read_csv(valid_path)

df_test = pd.read_csv(test_path)

In [4]:
random.seed(42)

dataset_train = MyDataset(shuffle(pd.concat([df_train, df_valid], axis=0)), num_classes=len(ds.classes))

dataset_valid = MyDataset(df_valid, num_classes=len(ds.classes))

dataset_test = MyDataset(df_test, num_classes=len(ds.classes))

dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, 
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

dataset_valid = MyDataset(df_valid, num_classes=len(ds.classes))

dataloader_valid = torch.utils.data.DataLoader(dataset=dataset_valid, 
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

dataloader_test = torch.utils.data.DataLoader(dataset=dataset_test, 
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

dataloaders = {"train" : dataloader_train, "validation" : dataloader_valid, "test" : dataloader_test}

In [5]:
SEED = 121

torch.manual_seed(SEED)

logging.basicConfig(level=logging.INFO, filename=logging_file,filemode="a",
                    format="%(asctime)s %(levelname)s %(message)s")

logging.info(f"Seed {SEED}")

Net = ShallowNN().to(DEVICE)

optimizer_name = "Adam"

lr = 0.003

criterion_name = "FocalLoss"

#criterion_name = "CrossEntropy"

optimizer = Adam(Net.parameters(), lr=lr, capturable=True)

criterion = FocalLoss(reduction="mean", gamma=2)

#criterion = nn.CrossEntropyLoss()

scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min")

logging.info(f"Net parameters {Net.parameters}")
logging.info(f"Optimizer :{optimizer_name}, lr : {lr}, criterion : {criterion_name}")

In [6]:
from torchinfo import summary

summary(Net, input_size=(BATCH_SIZE, 1, 120))

Layer (type:depth-idx)                   Output Shape              Param #
ShallowNN                                [64, 1, 20]               --
├─Sequential: 1-1                        [64, 1, 20]               --
│    └─Conv1d: 2-1                       [64, 1, 20]               6
│    └─Linear: 2-2                       [64, 1, 20]               420
│    └─Softmax: 2-3                      [64, 1, 20]               --
Total params: 426
Trainable params: 426
Non-trainable params: 0
Total mult-adds (M): 0.03
Input size (MB): 0.03
Forward/backward pass size (MB): 0.02
Params size (MB): 0.00
Estimated Total Size (MB): 0.05

In [7]:
torch.manual_seed(SEED)

statistics_data = {
            'number of epochs' : range(1,EPOCHS+1),
            'training loss' : [],
            'validation loss' : [],
            'training accuracy' : [],
            'validation accuracy' : []
        }

start_time = time.time()

best_acc = 0.0
best_model_wts = copy.deepcopy(Net.state_dict())

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'validation']:
        if phase == 'train':
            Net.train()  # Set model to training mode
        else:
            Net.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        with tqdm(dataloaders[phase], unit='batch') as tepoch:
            for inputs, labels in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                inputs = inputs.to(DEVICE)
                labels = labels.type(torch.LongTensor).to(DEVICE)
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = torch.squeeze(Net(inputs))
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step(0.005)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            if phase == 'train':
                statistics_data['training loss'].append(epoch_loss)
                statistics_data['training accuracy'].append(epoch_acc.cpu().numpy())
            else:
                statistics_data['validation loss'].append(epoch_loss)
                statistics_data['validation accuracy'].append(epoch_acc.cpu().numpy())

            if phase == 'validation' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(Net.state_dict())
            
            print(f'{phase} loss : {epoch_loss:.4f} {phase} accuracy: {epoch_acc*100:.2f}%')
            logging.info(f'{phase} loss : {epoch_loss:.4f} {phase} accuracy: {epoch_acc*100:.2f}%')

        print()

time_elapsed = time.time() - start_time
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
logging.info(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best validation accuracy: {best_acc*100:.2f}%')
logging.info(f'Best validation accuracy: {best_acc*100:.2f}%')

Net.load_state_dict(best_model_wts)
torch.save(Net.state_dict(), model_save_path)

Epoch 1/20
----------


Epoch 1: 100%|██████████| 247/247 [00:02<00:00, 86.41batch/s] 


train loss : 2.2427 train accuracy: 64.73%



Epoch 1: 100%|██████████| 51/51 [00:00<00:00, 97.95batch/s] 


validation loss : 2.0050 validation accuracy: 67.97%

Epoch 2/20
----------


Epoch 2: 100%|██████████| 247/247 [00:02<00:00, 90.55batch/s] 


train loss : 1.7401 train accuracy: 90.18%



Epoch 2: 100%|██████████| 51/51 [00:00<00:00, 92.65batch/s] 


validation loss : 1.9529 validation accuracy: 69.00%

Epoch 3/20
----------


Epoch 3: 100%|██████████| 247/247 [00:02<00:00, 93.88batch/s] 


train loss : 1.7187 train accuracy: 90.32%



Epoch 3: 100%|██████████| 51/51 [00:00<00:00, 68.93batch/s] 


validation loss : 1.9413 validation accuracy: 68.88%

Epoch 4/20
----------


Epoch 4: 100%|██████████| 247/247 [00:02<00:00, 87.10batch/s] 


train loss : 1.7113 train accuracy: 90.48%



Epoch 4: 100%|██████████| 51/51 [00:00<00:00, 83.42batch/s] 


validation loss : 1.9322 validation accuracy: 69.45%

Epoch 5/20
----------


Epoch 5: 100%|██████████| 247/247 [00:02<00:00, 84.22batch/s] 


train loss : 1.7074 train accuracy: 90.54%



Epoch 5: 100%|██████████| 51/51 [00:00<00:00, 79.66batch/s] 


validation loss : 1.9278 validation accuracy: 69.48%

Epoch 6/20
----------


Epoch 6: 100%|██████████| 247/247 [00:02<00:00, 84.83batch/s] 


train loss : 1.7048 train accuracy: 90.63%



Epoch 6: 100%|██████████| 51/51 [00:00<00:00, 68.59batch/s]


validation loss : 1.9246 validation accuracy: 69.72%

Epoch 7/20
----------


Epoch 7: 100%|██████████| 247/247 [00:02<00:00, 84.50batch/s] 


train loss : 1.7025 train accuracy: 90.75%



Epoch 7: 100%|██████████| 51/51 [00:00<00:00, 95.69batch/s] 


validation loss : 1.9206 validation accuracy: 69.93%

Epoch 8/20
----------


Epoch 8: 100%|██████████| 247/247 [00:02<00:00, 91.46batch/s] 


train loss : 1.7011 train accuracy: 90.73%



Epoch 8: 100%|██████████| 51/51 [00:00<00:00, 69.30batch/s]


validation loss : 1.9201 validation accuracy: 69.87%

Epoch 9/20
----------


Epoch 9: 100%|██████████| 247/247 [00:02<00:00, 84.21batch/s] 


train loss : 1.6998 train accuracy: 90.79%



Epoch 9: 100%|██████████| 51/51 [00:00<00:00, 88.17batch/s] 


validation loss : 1.9182 validation accuracy: 70.02%

Epoch 10/20
----------


Epoch 10: 100%|██████████| 247/247 [00:02<00:00, 87.09batch/s] 


train loss : 1.6987 train accuracy: 90.80%



Epoch 10: 100%|██████████| 51/51 [00:00<00:00, 73.63batch/s]


validation loss : 1.9160 validation accuracy: 70.11%

Epoch 11/20
----------


Epoch 11: 100%|██████████| 247/247 [00:02<00:00, 84.23batch/s] 


train loss : 1.6979 train accuracy: 90.86%



Epoch 11: 100%|██████████| 51/51 [00:00<00:00, 84.24batch/s] 


validation loss : 1.9127 validation accuracy: 70.29%

Epoch 12/20
----------


Epoch 12: 100%|██████████| 247/247 [00:02<00:00, 94.32batch/s] 


train loss : 1.6973 train accuracy: 90.93%



Epoch 12: 100%|██████████| 51/51 [00:00<00:00, 74.83batch/s] 


validation loss : 1.9104 validation accuracy: 70.50%

Epoch 13/20
----------


Epoch 13: 100%|██████████| 247/247 [00:02<00:00, 91.27batch/s] 


train loss : 1.6958 train accuracy: 90.99%



Epoch 13: 100%|██████████| 51/51 [00:00<00:00, 76.28batch/s] 


validation loss : 1.9112 validation accuracy: 70.38%

Epoch 14/20
----------


Epoch 14: 100%|██████████| 247/247 [00:02<00:00, 85.15batch/s] 


train loss : 1.6960 train accuracy: 90.97%



Epoch 14: 100%|██████████| 51/51 [00:00<00:00, 79.68batch/s] 


validation loss : 1.9111 validation accuracy: 70.38%

Epoch 15/20
----------


Epoch 15: 100%|██████████| 247/247 [00:02<00:00, 91.43batch/s] 


train loss : 1.6959 train accuracy: 90.99%



Epoch 15: 100%|██████████| 51/51 [00:00<00:00, 72.63batch/s] 


validation loss : 1.9096 validation accuracy: 70.53%

Epoch 16/20
----------


Epoch 16: 100%|██████████| 247/247 [00:02<00:00, 92.96batch/s] 


train loss : 1.6957 train accuracy: 90.99%



Epoch 16: 100%|██████████| 51/51 [00:00<00:00, 72.35batch/s]


validation loss : 1.9107 validation accuracy: 70.41%

Epoch 17/20
----------


Epoch 17: 100%|██████████| 247/247 [00:02<00:00, 96.45batch/s] 


train loss : 1.6954 train accuracy: 91.03%



Epoch 17: 100%|██████████| 51/51 [00:00<00:00, 87.52batch/s] 


validation loss : 1.9101 validation accuracy: 70.38%

Epoch 18/20
----------


Epoch 18: 100%|██████████| 247/247 [00:02<00:00, 86.67batch/s] 


train loss : 1.6955 train accuracy: 91.01%



Epoch 18: 100%|██████████| 51/51 [00:00<00:00, 100.36batch/s]


validation loss : 1.9098 validation accuracy: 70.44%

Epoch 19/20
----------


Epoch 19: 100%|██████████| 247/247 [00:02<00:00, 84.08batch/s] 


train loss : 1.6955 train accuracy: 91.00%



Epoch 19: 100%|██████████| 51/51 [00:00<00:00, 74.81batch/s]


validation loss : 1.9091 validation accuracy: 70.50%

Epoch 20/20
----------


Epoch 20: 100%|██████████| 247/247 [00:02<00:00, 90.21batch/s] 


train loss : 1.6953 train accuracy: 91.02%



Epoch 20: 100%|██████████| 51/51 [00:00<00:00, 73.12batch/s]


validation loss : 1.9090 validation accuracy: 70.53%

Training complete in 1m 9s
Best validation accuracy: 70.53%


In [8]:
torch.manual_seed(121)

Net.eval()

target = torch.tensor([], dtype=torch.float32).to(DEVICE)
pred = torch.tensor([], dtype=torch.float32).to(DEVICE)

for images, labels in dataloaders["test"]:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    target = torch.cat((target, labels))

    outputs = Net(images)
    _, predictions = torch.max(outputs, 2)
    predictions = torch.squeeze(predictions, 1)

    pred = torch.cat((pred, predictions))

target, pred = target.to(torch.int32).cpu(), pred.to(torch.int32).cpu()

print(f"Accuracy : {round(accuracy_score(target, pred) * 100, 3)} %")
logging.info(f"Accuracy : {round(accuracy_score(target, pred) * 100, 3)} %")

Accuracy : 70.657 %


In [9]:
acc_per_class = accuracy_per_class(pred, target, ds.classes)

for style, acc in acc_per_class.items():
    print(f'Accuracy for {style}: {acc:.1f} %')

Accuracy for 01_Byzantin_Iconography: 94.7 %
Accuracy for 02_Early_Renaissance: 79.7 %
Accuracy for 03_Northern_Renaissance: 83.6 %
Accuracy for 04_High_Renaissance: 72.9 %
Accuracy for 05_Baroque: 58.0 %
Accuracy for 06_Rococo: 38.6 %
Accuracy for 07_Romanticism: 60.6 %
Accuracy for 08_Realism: 75.1 %
Accuracy for 09_Impressionism: 74.9 %
Accuracy for 10_Post_Impressionism: 58.2 %
Accuracy for 11_Expressionism: 43.8 %
Accuracy for 12_Symbolism: 66.5 %
Accuracy for 13_Fauvism: 61.0 %
Accuracy for 14_Cubism: 79.5 %
Accuracy for 15_Surrealism: 63.5 %
Accuracy for 16_AbstractArt: 68.5 %
Accuracy for 17_NaiveArt: 69.8 %
Accuracy for 18_PopArt: 74.3 %
Accuracy for 19_ChineseArt: 77.9 %
Accuracy for 20_JapaneseArt: 98.6 %


In [10]:
confusion_matrix(pred, target, ds.classes)

In [11]:
x = sorted(list(acc_per_class.items()), key=lambda x : x[1])[:3]

print(f"Mean accuracy for min 3 styles : {sum([el[1] for el in x]) / 3:.1f} %")

for style, acc in x:
    print(f'Accuracy for {style}: {acc:.1f} %')

Mean accuracy for min 3 styles : 46.8 %
Accuracy for 06_Rococo: 38.6 %
Accuracy for 11_Expressionism: 43.8 %
Accuracy for 05_Baroque: 58.0 %


: 