In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
from torchvision.models import efficientnet_v2_s
from torchvision import datasets, transforms 

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
model = torch.load('/kaggle/input/efficientnetv2_shots_by_plan/pytorch/73.14/1/model 73.14.pth')
model.eval()
pass

In [None]:
data_dir = "/kaggle/input/movie-images-by-types-of-shooting-plans/data/validation"

val_transforms = transforms.Compose([
                                    transforms.Resize((380, 380)),
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

val_dataset = datasets.ImageFolder(data_dir, val_transforms)

val_loader = torch.utils.data.DataLoader(
                                   val_dataset,
                                   batch_size=32,
                                   shuffle=True,
                                   num_workers=2
                                  ) 

In [None]:
class_names = val_dataset.classes
class_names = [i.replace('%20', ' ') for i in class_names]

In [None]:
def unnormalize(img):
    img = img * torch.tensor([0.229, 0.224, 0.225])[:, None, None] + torch.tensor([0.485, 0.456, 0.406])[:, None, None]
    return img

def visualize_model_errors(model, dataloader, device, class_names,  num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    image_size = 9
    fig = plt.figure(figsize=(image_size * 2, image_size * num_images / 2))
    font_alplha = 0.8

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            back_color = 'darkslategray'
            plt.subplots_adjust(wspace=0.032, hspace=0.0025)
            fig.patch.set_facecolor('black')
            for j in range(inputs.size()[0]):
                if preds[j] != labels[j]:
                    images_so_far += 1
                    
                    ax = plt.subplot(num_images//2, 2, images_so_far)
                    ax.axis('off')
    
                    
                    title = ' {}\n'.format(class_names[labels[j]])
                    ax.set_title(title,
                                 pad=0,
                                 y=0.97,
                                 x=0,
                                 color='goldenrod',
                                 fontsize=25,
#                                  fontweight='semibold',
                                 backgroundcolor=back_color,
                                 verticalalignment='top',
                                 horizontalalignment='left',
                                 fontstyle='oblique',
                                 alpha=font_alplha)
                    
                    image = inputs.cpu().data[j]
                    image = unnormalize(image)
                    image = image.numpy().transpose((1, 2, 0))
                    plt.imshow(image)

                    predicted_probs, predicted_labels = torch.topk(torch.softmax(outputs, dim=1), 3, dim=1)
                    predicted_probs = predicted_probs[j].cpu().numpy()
                    predicted_labels = predicted_labels[j].cpu().numpy()
                    text = ''
                    for k in range(3):
                        text += '{} ({:.2f}%)\n'.format(class_names[predicted_labels[k]], predicted_probs[k]*100)
                    
                    ax.text(1,
                            0.97,
                            text,
                            color='seashell',
                            backgroundcolor=back_color,
                            fontsize=16, verticalalignment='top',
                            horizontalalignment='right',
                            transform=ax.transAxes,
                            fontstyle='italic',
                            alpha=font_alplha)

                    if images_so_far == num_images:
                        model.train(mode=was_training)
                        return
        model.train(mode=was_training)

In [None]:
visualize_model_errors(model, val_loader, device, class_names, num_images=64)