In [None]:
import torch
from torchvision import transforms
from torchvision import models
import torch.backends.cudnn as cudnn
import torch.nn as nn
import pandas as pd
import time
import copy
import logging
import random
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.optim import Adam
from torchvision import datasets
from torch.optim import lr_scheduler
from sklearn.utils import shuffle
import torch.nn.functional as F
from metrics import confusion_matrix, accuracy_per_class
from sklearn.metrics import accuracy_score

from dataset_preprocessing import Paths, Dataset
from snn import ShallowNN
from utils import FocalLoss, MyDataset

In [None]:
BATCH_SIZE = 64

EPOCHS = 20

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
inc_v3_path =  Paths.pandora_18k + 'Conv_models/Inception-V3/'

train_path = inc_v3_path + 'train_full_emb.csv'
valid_path = inc_v3_path + 'valid_full_emb.csv'
test_path = inc_v3_path + 'test_full_emb.csv'
mistakes_path = inc_v3_path + 'mistakes.csv'

fake_Baroque_path = inc_v3_path + 'fake_emb_Baroque.csv'
fake_Rococo_path = inc_v3_path + 'fake_emb_Rococo.csv'
fake_Romanticism_path = inc_v3_path + 'fake_emb_Romanticism.csv'
fake_Expressionism_path = inc_v3_path + 'fake_emb_Expressionism.csv'
fake_PostImpressionism_path = inc_v3_path + 'fake_emb_Post-Impressionism.csv'

train_prob =  inc_v3_path + 'train_prob.csv'
valid_prob =  inc_v3_path + 'valid_prob.csv'
test_prob =  inc_v3_path + 'test_prob.csv'

model_save_path = inc_v3_path + 'snn.pth'

logging_file = "logs/snn.log"

In [None]:
ds = Dataset(Paths.pandora_18k)

df_train = pd.read_csv(train_path)

df_valid = pd.read_csv(valid_path)

df_test = pd.read_csv(test_path)

In [None]:
random.seed(42)

dataset_train = MyDataset(shuffle(pd.concat([df_train, df_valid], axis=0)), num_classes=len(ds.classes))

dataset_valid = MyDataset(df_valid, num_classes=len(ds.classes))

dataset_test = MyDataset(df_test, num_classes=len(ds.classes))

dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, 
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

dataset_valid = MyDataset(df_valid, num_classes=len(ds.classes))

dataloader_valid = torch.utils.data.DataLoader(dataset=dataset_valid, 
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

dataloader_test = torch.utils.data.DataLoader(dataset=dataset_test, 
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True, 
                                        num_workers=4,
                                        drop_last=True)

dataloaders = {"train" : dataloader_train, "validation" : dataloader_valid, "test" : dataloader_test}

In [None]:
SEED = 121

torch.manual_seed(SEED)

logging.basicConfig(level=logging.INFO, filename=logging_file,filemode="a",
                    format="%(asctime)s %(levelname)s %(message)s")

logging.info(f"Seed {SEED}")

Net = ShallowNN().to(DEVICE)

optimizer_name = "Adam"

lr = 0.003

criterion_name = "FocalLoss"

#criterion_name = "CrossEntropy"

optimizer = Adam(Net.parameters(), lr=lr, capturable=True)

criterion = FocalLoss(reduction="mean", gamma=2)

#criterion = nn.CrossEntropyLoss()

scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min")

logging.info(f"Net parameters {Net.parameters}")
logging.info(f"Optimizer :{optimizer_name}, lr : {lr}, criterion : {criterion_name}")

In [None]:
from torchinfo import summary

summary(Net, input_size=(BATCH_SIZE, 1, 120))

In [None]:
torch.manual_seed(SEED)

statistics_data = {
            'number of epochs' : range(1,EPOCHS+1),
            'training loss' : [],
            'validation loss' : [],
            'training accuracy' : [],
            'validation accuracy' : []
        }

start_time = time.time()

best_acc = 0.0
best_model_wts = copy.deepcopy(Net.state_dict())

for epoch in range(1, EPOCHS+1):
    print(f'Epoch {epoch}/{EPOCHS}')
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'validation']:
        if phase == 'train':
            Net.train()  # Set model to training mode
        else:
            Net.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        with tqdm(dataloaders[phase], unit='batch') as tepoch:
            for inputs, labels in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                inputs = inputs.to(DEVICE)
                labels = labels.type(torch.LongTensor).to(DEVICE)
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = torch.squeeze(Net(inputs))
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step(0.005)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            if phase == 'train':
                statistics_data['training loss'].append(epoch_loss)
                statistics_data['training accuracy'].append(epoch_acc.cpu().numpy())
            else:
                statistics_data['validation loss'].append(epoch_loss)
                statistics_data['validation accuracy'].append(epoch_acc.cpu().numpy())

            if phase == 'validation' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(Net.state_dict())
            
            print(f'{phase} loss : {epoch_loss:.4f} {phase} accuracy: {epoch_acc*100:.2f}%')
            logging.info(f'{phase} loss : {epoch_loss:.4f} {phase} accuracy: {epoch_acc*100:.2f}%')

        print()

time_elapsed = time.time() - start_time
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
logging.info(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best validation accuracy: {best_acc*100:.2f}%')
logging.info(f'Best validation accuracy: {best_acc*100:.2f}%')

Net.load_state_dict(best_model_wts)
torch.save(Net.state_dict(), model_save_path)

In [None]:
Mistakes_train = pd.DataFrame()
Mistakes_valid = pd.DataFrame()

In [None]:
torch.manual_seed(121)

Net.eval()

target = torch.tensor([], dtype=torch.float32).to(DEVICE)
pred = torch.tensor([], dtype=torch.float32).to(DEVICE)

for images, labels in dataloaders["test"]:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    target = torch.cat((target, labels))

    outputs = Net(images)
    _, predictions = torch.max(outputs, 2)
    predictions = torch.squeeze(predictions, 1)

    pred = torch.cat((pred, predictions))

target, pred = target.to(torch.int32).cpu(), pred.to(torch.int32).cpu()

print(f"Accuracy : {round(accuracy_score(target, pred) * 100, 3)} %")
logging.info(f"Accuracy : {round(accuracy_score(target, pred) * 100, 3)} %")

In [None]:
acc_per_class = accuracy_per_class(pred, target, ds.classes)

for style, acc in acc_per_class.items():
    print(f'Accuracy for {style}: {acc:.1f} %')

In [None]:
confusion_matrix(pred, target, ds.classes)

In [None]:
x = sorted(list(acc_per_class.items()), key=lambda x : x[1])[:3]

print(f"Mean accuracy for min 3 styles : {sum([el[1] for el in x]) / 3:.1f} %")

for style, acc in x:
    print(f'Accuracy for {style}: {acc:.1f} %')