In [6]:
from deep_dating.networks import DatingCNN
from deep_dating.datasets import DatingDataLoader, DatasetName, SetType
from deep_dating.metrics import DatingMetrics
from deep_dating.preprocessing import PreprocessRunner
from deep_dating.util import SEED, save_figure
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
from sklearn.preprocessing import StandardScaler
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
import pickle

In [7]:
def merge_patches(all_labels, all_outputs, all_paths, p=None):

    preds = {}

    for i, img_name in enumerate(all_paths):
        img_name = PreprocessRunner.get_base_img_name(img_name)
        if p and img_name not in p:
            continue
        if not img_name in preds:
            preds[img_name] = [all_labels[i], [all_outputs[i]]]
        else:
            preds[img_name][1].append(all_outputs[i])

    features = []
    labels = []

    if p is None:
        p = preds

    for key, val in p.items():
        preds[key][1] = np.mean(preds[key][1], axis=0)
        features.append(preds[key][1])
        labels.append(preds[key][0])
        # print(preds[key][1].shape)
        # exit()
    features = np.array(features)

    return np.array(labels), features, preds

In [8]:
with open("runs/Feb15-19-21-40/model_epoch_0_feats_train.pkl", "rb") as f:
    labels_train_low, features_train_low, all_paths_train_low = pickle.load(f)
    labels_train_low = labels_train_low.flatten()
    labels_train_low, features_train_low, train_p = merge_patches(labels_train_low, features_train_low, all_paths_train_low)

with open("runs/Feb15-19-21-40/model_epoch_0_feats_val.pkl", "rb") as f:
    labels_val_low, features_val_low, all_paths_val_low = pickle.load(f)
    labels_val_low = labels_val_low.flatten()
    labels_val_low, features_val_low, val_p = merge_patches(labels_val_low, features_val_low, all_paths_val_low)

with open("runs/Feb18-10-52-17/model_epoch_1_feats_train.pkl", "rb") as f:
    labels_train_high, features_train_high, all_paths_train_high = pickle.load(f)
    labels_train_high = labels_train_high.flatten()
    labels_train_high, features_train_high, _ = merge_patches(labels_train_high, features_train_high, all_paths_train_high, train_p)

with open("runs/Feb18-10-52-17/model_epoch_1_feats_val.pkl", "rb") as f:
    labels_val_high, features_val_high, all_paths_val_high = pickle.load(f)
    labels_val_high = labels_val_high.flatten()
    labels_val_high, features_val_high, _ = merge_patches(labels_val_high, features_val_high, all_paths_val_high, val_p)

In [None]:
features_train_high.shape

In [None]:
features_train_low.shape

In [9]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler, Normalizer
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
import keras
from keras import layers
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder


scaler = MinMaxScaler(feature_range=(0,1))

2024-02-27 15:40:17.809187: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-27 15:40:17.809225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-27 15:40:17.827550: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-27 15:40:17.857629: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
new_train_x = np.concatenate([features_train_low, features_val_low])
new_train_x = scaler.fit_transform(new_train_x)
new_val_x = np.hstack([features_val_low])
new_val_x = scaler.transform(new_val_x)

In [None]:
new_train_x.shape

In [13]:
svm = SVC() #RandomForestClassifier(random_state=43) #MLPClassifier(batch_size=32, solver="adam", hidden_layer_sizes=(1536, 1000, 512), verbose=True, early_stopping=True, n_iter_no_change=5)

svm.fit(features_train_low, labels_train_low)
labels_val_low_predict = svm.predict(features_val_low)

In [None]:
label_encoder = LabelEncoder()
y_train = label_encoder.fit_transform(labels_train_low)
y_train = to_categorical(y_train)
y_val = label_encoder.transform(labels_val_low)
y_val = to_categorical(y_val)
print(y_train.shape)

In [None]:
model = keras.Sequential(
    [
        keras.Input(shape=(1536)),
        # layers.Dense(2048, activation="relu"),
        # #layers.Dropout(0.3),
        # #layers.BatchNormalization(),


        layers.Dense(1024, activation="relu"),
        #layers.Dropout(0.3),
        #layers.BatchNormalization(),


        layers.Dense(1024, activation="relu"),
        #layers.Dropout(0.3),
        #layers.BatchNormalization(),


        layers.Dense(512, activation="relu"),
        # layers.Dropout(0.3),
        # layers.BatchNormalization(),


        layers.Dense(15, activation="softmax"),
    ]
)
batch_size = 32
epochs = 100

model.compile(loss="categorical_crossentropy", optimizer=keras.optimizers.Adam(learning_rate=0.00001), metrics=["accuracy"])
model.fit(new_train_x, y_train, batch_size=batch_size, epochs=epochs, validation_data=(new_val_x, y_val))


In [14]:
alphas = [0, 25, 50]
metrics = DatingMetrics(alphas=alphas)
metrics.names

vals = metrics.calc(labels_val_low, labels_val_low_predict)

mae, mse = tuple(vals[:2])
cs_ = vals[2:]

print(mae, mse)
print([0, 25])
print(cs_)

# 19.00423728813559 1980.247175141243
# [0, 25]
# [64.97175141242938, 80.64971751412429, 91.66666666666666]

# 19.326271186440678 2154.9138418079096
# [0, 25]
# [65.11299435028248, 80.9322033898305, 91.38418079096046]

# 32.978813559322035 5692.001412429378
# [0, 25]
# [58.47457627118644, 72.03389830508475, 83.75706214689266]

19.00423728813559 1980.247175141243
[0, 25]
[64.97175141242938, 80.64971751412429, 91.66666666666666]
