In [1]:
import torch
from pathlib import Path
import numpy as np
from tqdm import tqdm
import cv2
from cnn4 import CNN4
from datamodule import Datamodule
from params import LocationConfig, TrainingConfig

In [2]:
cnn4 = CNN4(lr=TrainingConfig.lr)
state_dict = torch.load(LocationConfig.best_model)
cnn4.load_state_dict(state_dict)
cnn4.eval();

In [3]:
train_data_path = Path('data/train')
test_data_path = Path('data/test')
dm = Datamodule(
        batch_size=TrainingConfig.batch_size,
        train_dir=train_data_path,
        val_dir=test_data_path,
        )
dm.setup(val_only=True)

file: data/test/test_clselfie_v7.pickle


In [7]:
for batch in dm.val_dataloader():
    X, Y = batch['normalized'], batch['label']
    print(X.shape)
    break

torch.Size([128, 208, 208, 1])


In [8]:
X

tensor([[[[0.2105],
          [0.2105],
          [0.2105],
          ...,
          [0.2672],
          [0.2632],
          [0.2632]],

         [[0.2146],
          [0.2146],
          [0.2105],
          ...,
          [0.2672],
          [0.2632],
          [0.2632]],

         [[0.2227],
          [0.2146],
          [0.2105],
          ...,
          [0.2713],
          [0.2672],
          [0.2672]],

         ...,

         [[0.0364],
          [0.0405],
          [0.0445],
          ...,
          [0.0769],
          [0.0769],
          [0.0769]],

         [[0.0364],
          [0.0405],
          [0.0405],
          ...,
          [0.0769],
          [0.0810],
          [0.0810]],

         [[0.0324],
          [0.0324],
          [0.0364],
          ...,
          [0.0810],
          [0.0810],
          [0.0769]]],


        [[[0.6584],
          [0.6683],
          [0.6881],
          ...,
          [0.5099],
          [0.4752],
          [0.4802]],

         [[0.5396],
    

In [5]:
train_data_path = Path('data_connected/train')
test_data_path = Path('data_connected/test')
dm_con = Datamodule(
        batch_size=TrainingConfig.batch_size,
        train_dir=train_data_path,
        val_dir=test_data_path,
        )
dm_con.setup(val_only=True)

file: data_connected/test/test.pickle


In [6]:
for batch in dm_con.val_dataloader():
    X, Y = batch['original'], batch['label']
    print(X.shape)
    break

torch.Size([128, 208, 208, 1])


In [None]:
checkpoints_dir = Path(LocationConfig.checkpoints_dir)
list_of_checkpoints = list(checkpoints_dir.glob("*.ckpt"))
latest_checkpoint_path = max(list_of_checkpoints, key=lambda p: p.stat().st_ctime)

lightning_model = CNN4(lr=TrainingConfig.lr)
lightning_model.load_from_checkpoint(latest_checkpoint_path)
lightning_model.eval();

In [None]:
acc_class_global = np.zeros(5)
i=0
for batch in tqdm(dm.val_dataloader()):
    limit = 4/7
    X, Y = batch
    Y_pred = cnn4.predict_step(X, None)
    Y_pred = np.where(Y_pred > limit, 1, 0)
    acc_class = np.sum(Y_pred == Y, axis=0) / len(Y)
    acc_class_global += acc_class
    i+=1
acc_class_global /= i
print(acc_class_global)
print(acc_class_global.mean())