In [None]:
# Import necessary libraries
import torch
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage
import os
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm import tqdm
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from Omodel import swin_small_patch4_window7_224, QuantityClassifier, QuantityClassifierV2, QuantityClassifierV3, BaggageClassifier
from Odataset import create_GPT_train_test_loader, PersonWithBaggageDataset
from PIL import Image
from kornia.losses import binary_focal_loss_with_logits, focal_loss
# Load the model checkpoint



classes_weight=torch.tensor([ 1.49175036 , 0.42052578 , 1.08703607, 31.50757576])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# BEST_MODEL_PATH = 'Oruns/GPT/focal_loss_no_weight/BiggerLR_07-17:17:51_E32_[2e-05,0.0001-min1e-05]/model_best.pth.tar'
BEST_MODEL_PATH = 'Oruns/GPT/focal_loss_no_weight/BiggerLR_07-17:18:00_E32_[1e-05,0.001-min1e-05]/model_best.pth.tar'
ckpt = torch.load(BEST_MODEL_PATH)

# Define the function to plot images with true and predicted labels
def plot_images_with_labels(data_loader, model, device, num_images=5):
    model.eval()
    to_pil = ToPILImage()
    targets = [0,0,0,0] 
    preds = [0,0,0,0] 
    with torch.no_grad():
        for i, (images, targetTop1s, _, _, _, img_paths) in enumerate(data_loader):
            images = images.to(device)
            targetTop1s = targetTop1s.to(device)

            outputs = model(images)
            _, predicted = outputs.max(1)
            # Plot the images with their true and predicted labels
            for j in range(min(num_images, len(images))):
                # plt.figure(figsize=(10, 5))
                true_label = targetTop1s[j].item()
                pred_label = predicted[j].item()
                targets[true_label] +=1 
                preds[pred_label] +=1 
                img = Image.open(img_paths[j]).convert("RGB")
                
                # plt.imshow(img)
                # plt.title(f'True Label: {true_label}, Predicted Label: {pred_label}')
                # plt.axis('off')
                # plt.show()

            if i >= 20:
                break
    print(f"correct: {preds}")
    print(f"outof  : {targets}")
    

# Prepare the dataset & model
train_loader, test_loader = create_GPT_train_test_loader(BATCH_SIZE=64)
backbone = swin_small_patch4_window7_224()
classifier = QuantityClassifierV3()
model = BaggageClassifier(backbone, classifier).to(device)
model.load_state_dict(ckpt['state_dict'])
model.eval()

# Plot some images with their true and predicted labels
plot_images_with_labels(test_loader, model, device, num_images=10)

