# Analysis Notebook
This notebook is used to analyze the trained model and the dataset. It is used to visualize the dataset and the model's performance.

In [18]:
import json
import timm
from fastai.imports import *
from fastai.vision.all import *
from data import get_dls_from_images


In [19]:
with open("../config.json", "r") as config_file:
    config = json.load(config_file)
    
metrics_dict = {
            "f1_score": F1Score(),
            "precision": Precision(),
            "recall": Recall(),
            "accuracy": accuracy,
        }

In [15]:
def load_saved(timm_model_name: str, model_save: str = "output"):
    """
    Loads a saved model from disk.

    Parameters
    ----------
    model_name : str
        The name of the model to load, excluding `.pth`.

    Returns
    -------
    Timm model
        The loaded model.
    """
    model = timm.create_model(
        model_name=timm_model_name,
        checkpoint_path=model_save,
        num_classes=2,
    )
    model.eval()
    return model
model = load_saved("efficientnet_b0", "../output/efficientnet_b0-0.9324.pth")

In [16]:

# Load the data
config["data"]["image_dir"] = f"../{config['data']['image_dir']}"
train_dl, val_dl, test_dl = get_dls_from_images(config=config)
dls = DataLoaders(train_dl, val_dl)


In [24]:
# Load the saved learner
learn = Learner(dls, model)
learn.load("efficientnet_b0-0.9324")

# Assuming 'learn' is your Learner and 'test_dl' is your DataLoader for the test set
#preds, targets = learn.get_preds(dl=test_dl)

  elif with_opt: warn("Saved file doesn't contain an optimizer state.")


<fastai.learner.Learner at 0x235a12f8940>

In [None]:
# Convert predictions to actual class labels (assuming single-label classification)
predicted_labels = preds.argmax(dim=1)

# Find incorrect classifications
incorrects = predicted_labels != targets
incorrect_images = test_dl.dataset.items[incorrects]  # Assuming `items` holds file paths or PIL images
incorrect_preds = predicted_labels[incorrects]
actual_labels = targets[incorrects]


In [None]:
import matplotlib.pyplot as plt

def show_incorrect_images(incorrect_images, incorrect_preds, actual_labels, n_show=5):
    fig, axs = plt.subplots(1, n_show, figsize=(10, 2))
    for i, ax in enumerate(axs.flatten()):
        img = PILImage.create(incorrect_images[i])
        ax.imshow(img)
        ax.set_title(f'Pred: {incorrect_preds[i]}\nActual: {actual_labels[i]}')
        ax.axis('off')
    plt.show()

show_incorrect_images(incorrect_images, incorrect_preds, actual_labels)
