In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

import torch
import torch.nn.functional as F
import ipywidgets as widgets
import matplotlib.pyplot as plt

from torchvision.transforms import v2
from PIL import Image

from weather_classification import MODELS_DIR, PROCESSED_DATA_DIR

#### Paths

In [None]:
data_fpath = (
    PROCESSED_DATA_DIR 
    / "WeatherDataset"
    / "test"
)

In [None]:
model_fname = "EfficientNet.torchscript"
model_fpath = MODELS_DIR / model_fname

#### Load model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load(model_fpath, map_location=device)

#### **Classification**. One example

In [None]:
class_labels = model.class_labels

inference_transforms = v2.Compose([
    v2.Resize((256, 256), interpolation=v2.InterpolationMode.BICUBIC),
    v2.CenterCrop((224, 224)),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
lst_data = list(data_fpath.glob("*"))

# From 0 to 10
idx_dir = 5
lst_images = list(lst_data[idx_dir].glob("*"))

model.eval()

@widgets.interact
def show_image(img_idx=widgets.IntSlider(value=0, min=0, max=len(lst_images) - 1)):
    img_fpath = lst_images[img_idx]
    print(f"True label: {lst_data[idx_dir].name}")
    
    img = Image.open(img_fpath)
    transformed_img = inference_transforms(img).unsqueeze(0)
    transformed_img = transformed_img.to(device)

    with torch.no_grad():
        logits = model(transformed_img)
        softmax_probs = F.softmax(logits, dim=1).cpu()
        pred_idx = torch.argmax(softmax_probs, dim=1)

        prob = softmax_probs[0][pred_idx][0]
        label = class_labels[pred_idx]

    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Pred cls: {label}; Prob: {prob:.2f}")