In [2]:
import torch
import os
from tqdm import tqdm
from S1_CNN_Model import CNN_Model
from S2_TimberDataset import TimberDataset, compile_image_df
from S3_intermediateDataset import build_intermediate_dataset_if_not_exists, intermediate_dataset

from torch.utils.data import DataLoader
from torchvision import transforms

In [2]:
test_df, _ = compile_image_df("data/image/test", split_at=1.0)

def listdir_full(path: str) -> list[str]:
    return [f"{path}/{p}" for p in os.listdir(path)]

transform = transforms.Compose([
    transforms.Resize((320,320)),
    transforms.ToTensor(),
])

test_loader = DataLoader(TimberDataset(test_df, is_train=True,transform=transform),
                            shuffle=True,
                            batch_size=12)

build_intermediate_dataset_if_not_exists(lambda x:x, "test", test_loader)
test_loader = intermediate_dataset("test") 

In [3]:
def evaluate_model(model: CNN_Model):
    n_correct = 0
    n_samples = 0

    with torch.no_grad(): 
        with tqdm(test_loader, position=1, leave=False) as pb:
            for images, labels in pb:
                images = images.reshape(images.shape[1:])
                labels = labels.reshape(labels.shape[1:])

                x = model.forward(images)
                _, predictions = torch.max(x,1)
                
                n_samples += labels.shape[0]
                n_correct += (predictions == labels).sum().item()

                pb.set_description(f"{n_correct}/{n_samples} correct predictions ({n_correct/n_samples*100 :.2f}%)")

    return n_correct/n_samples*100

accuracies = [(model_path, evaluate_model(torch.load(model_path))) for model_path in tqdm(listdir_full("ckpt"), position=0)]
accuracies.sort(key = lambda x : x[1], reverse=True)
accuracies

100%|██████████| 17/17 [09:00<00:00, 31.82s/it]

[('ckpt/model_37.pt', 92.36111111111111), ('ckpt/model_36.pt', 91.39957264957265), ('ckpt/model.pt', 90.11752136752136), ('ckpt/model_23.pt', 90.11752136752136), ('ckpt/model_27.pt', 90.03739316239316), ('ckpt/model_19.pt', 89.95726495726495), ('ckpt/model_17.pt', 88.56837606837607), ('ckpt/model_16.pt', 87.79380341880342), ('ckpt/model_15.pt', 85.79059829059828), ('ckpt/model_13.pt', 85.04273504273505), ('ckpt/model_11.pt', 85.01602564102564), ('ckpt/model_14.pt', 84.58867521367522), ('ckpt/model_7.pt', 79.8076923076923), ('ckpt/model_6.pt', 79.27350427350427), ('ckpt/model_5.pt', 75.58760683760684), ('ckpt/model_4.pt', 70.86004273504274), ('ckpt/model_3.pt', 63.67521367521367)]





In [4]:
test_loader = DataLoader(TimberDataset(test_df, is_train=True,transform=transform),
                            batch_size=12)

build_intermediate_dataset_if_not_exists(lambda x:x, "test_full", test_loader)
test_loader = intermediate_dataset("test_full") 

In [6]:
def full_images_evaluation(model: CNN_Model):
    model.image_size = (320,320)
    n_correct = 0
    n_samples = 0

    with torch.no_grad(): 
        with tqdm(test_loader, position=1, leave = False) as pb:
            for images, labels in pb:
                images = images.reshape(images.shape[1:])
                labels = labels.reshape(labels.shape[1:])

                assert torch.all(labels == labels[0]).item()
                label = labels[0]

                x = model.forward(images)
                _, preds = torch.max(x,1)
                pred = torch.mode(preds,0).values
                
                n_samples += 1
                n_correct += (pred == label).item()

                pb.set_description(f"{n_correct}/{n_samples} correct predictions ({n_correct/n_samples*100 :.2f}%)")

    return n_correct/n_samples*100

full_accuracies = [(model_path, full_images_evaluation(torch.load(model_path)))
                   for model_path in tqdm(listdir_full("ckpt"), position=0)]
full_accuracies.sort(key = lambda x : x[1], reverse=True)
print(full_accuracies)

100%|██████████| 17/17 [08:32<00:00, 30.14s/it]

[('ckpt/model_37.pt', 97.11538461538461), ('ckpt/model_23.pt', 96.15384615384616), ('ckpt/model_36.pt', 96.15384615384616), ('ckpt/model_27.pt', 95.51282051282051), ('ckpt/model_19.pt', 95.1923076923077), ('ckpt/model.pt', 94.23076923076923), ('ckpt/model_17.pt', 94.23076923076923), ('ckpt/model_16.pt', 93.26923076923077), ('ckpt/model_15.pt', 92.94871794871796), ('ckpt/model_11.pt', 91.98717948717949), ('ckpt/model_13.pt', 91.34615384615384), ('ckpt/model_14.pt', 91.02564102564102), ('ckpt/model_7.pt', 88.78205128205127), ('ckpt/model_6.pt', 86.85897435897436), ('ckpt/model_5.pt', 82.05128205128204), ('ckpt/model_4.pt', 78.84615384615384), ('ckpt/model_3.pt', 70.83333333333334)]





<zip at 0x286ec807c00>