In [None]:

import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
from datetime import datetime
from Omodel import swin_small_patch4_window7_224, BaggageClassifier #,QuantityClassifier
from Odataset import PersonWithBaggageDataset, TRAIN_CSV_FILE, TEST_CSV_FILE, ROOT_DIR, TRAIN_TRANSFORM, VAL_TRANSFORM
from torchvision.transforms import v2 as T
import torchvision.transforms.functional as F

class UnNormalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be un-normalized.
        Returns:
            Tensor: Un-normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        return tensor

# Example mean and std values
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

class QuantityClassifier(nn.Module):
    '''
    num_classes: 0, 1, 2, >=3
    '''
    def __init__(self,c_in=768,num_classes = 4, pool='avg'):
        super().__init__()
        self.pool = pool
        if pool == 'avg':
            self.pool = nn.AdaptiveAvgPool2d(1)
        elif pool == 'max':
            self.pool = nn.AdaptiveMaxPool2d(1)

        self.logits = nn.Sequential(
            nn.Linear(c_in, c_in//4),
            nn.ReLU(),
            nn.Linear(c_in//4, (c_in//4)//4),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear( (c_in//4)//4, num_classes),
        )

    def forward(self, feature, label=None):
        feat = self.pool(feature).view(feature.size(0), -1)
        x = self.logits(feat)
        return x
    
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BEST_MODEL_PATH = '/home/deepvisionpoc/Desktop/Jeans/SOLIDER_exp/SOLIDER-PersonAttributeRecognition/Oxygen_runs/run_20240712_162058/checkpoint.pth.tar'
unnormalize = UnNormalize(mean, std)

FAKE_VAL_TRANSFORM = T.Compose(
    [
        T.Resize((224, 224)),
        T.ToTensor(),
        # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

# Prepare the dataset
test_ds = PersonWithBaggageDataset(TEST_CSV_FILE, ROOT_DIR, FAKE_VAL_TRANSFORM)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)

# Prepare the model
backbone = swin_small_patch4_window7_224()
classifier = QuantityClassifier()
model = BaggageClassifier(backbone, classifier).to(device)

# Load the best model
ckpt = torch.load(BEST_MODEL_PATH)
model.load_state_dict(ckpt['state_dict'])
print(f"Loaded Model Details:\nEpoch: {ckpt['epoch']} Acc: {ckpt['best_acc']}")

# Function to visualize images and predictions
def visualize_predictions(images, labels, predictions, class_names):
    plt.figure(figsize=(14, 16))
    for i in range(len(images)):
        plt.subplot(4, 4, i + 1)
        img = images[i].cpu().numpy().transpose((1, 2, 0))
        img = (img + 1)/2
        print(img.min())
        
        print(img.max())
        # img = np.clip(img * 255, 0, 255).astype(np.uint8)
        plt.imshow(img)
        plt.title(f"True: {class_names[labels[i].item()]}\nPred: {class_names[predictions[i].item()]}")
        plt.axis('off')
    plt.show()

# Perform inference and visualize results
model.eval()
class_names = ['No Bag', '1 Bag', '2 Bag', 'At Least 3 Bag']  # Update this with actual class names
count = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Inference"):
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        _, predictions = outputs.max(1)
        
        
        visualize_predictions(images, labels, predictions, class_names)
        count+=1  # Remove this to visualize all batches, currently visualizes only the first batch
        if count == 1:
            break



In [1]:
import matplotlib.pyplot as plt
import torch
import os
import numpy as np
from datetime import datetime
from math import exp
from torch.utils.data import DataLoader
from Omodel import swin_small_patch4_window7_224, BaggageClassifier  # ,QuantityClassifier
from Odataset import GPTDataset, TRAIN_CSV_FILE, TEST_CSV_FILE, ROOT_DIR, TRAIN_TRANSFORM, VAL_TRANSFORM
from torchvision.transforms import v2 as T

FAKE_VAL_TRANSFORM = T.Compose(
    [
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

# Create output directory if it does not exist
output_dir = 'output/fig-gpt'
os.makedirs(output_dir, exist_ok=True)

# Create a DataLoader for the GPTDataset
dataset = GPTDataset(transform=FAKE_VAL_TRANSFORM, use_expected_value=False)
dataloader = DataLoader(dataset, batch_size=8, shuffle=False)

# Iterate over the batches
count = 0
for batch_index, (imgs, targetTop1s,logProbTop1s, targetEVs, _, img_paths) in enumerate(dataloader):
    # Convert the batch of images to a NumPy array
    imgs = imgs.numpy().transpose(0, 2, 3, 1)  # (batch_size, height, width, channels)
    imgs = (imgs + 1.0) / 2.0

    # Create a figure with subplots for each image
    fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(20, 10))

    # Iterate over the images and labels
    for i, (img, top, logprob,ev,img_name) in enumerate(zip(imgs, targetTop1s, logProbTop1s,targetEVs,img_paths)):
        count += 1
        row = i // 4
        col = i % 4
        axs[row, col].imshow(img)
        axs[row, col].set_title(f"{"_".join(img_name.split("_")[1:-3])}\n{count} ({exp(logprob):.3f})(Top1, EV): {top}, {ev}")
        axs[row, col].axis('off')

    # Save the plot
    fig_filename = os.path.join(output_dir, f"fig_{batch_index + 1}.png")
    plt.savefig(fig_filename)
    plt.close(fig)  # Close the figure to free up memory

    # # Break after the first 8 batches
    if batch_index >= 100:
        break




Error processing 2024-07-01T10-25-48_FLAG_P_FID_290_OID_148g0_ctw-cf-2c-094_1719804300_1719804348_290.jpg: 'choices'
Error processing 2024-07-01T10-37-15_FLAG_P_FID_816_OID_574g0_ctw-cf-3c-163_1719804900_1719805035_816.jpg: 'choices'
Error processing 2024-07-01T10-48-00_FLAG_P_FID_1083_OID_1185g0_ctw-cf-1c-049_1719805500_1719805680_1083.jpg: 'choices'
Error processing 2024-07-01T11-20-54_FLAG_P_FID_324_OID_530g0_ctw-cf-1e-072_1719807601_1719807654_324.jpg: 'choices'
Error processing 2024-07-01T12-13-34_FLAG_P_FID_1280_OID_286g0_ctw-cf-3c-169_1719810601_1719810814_1280.jpg: 'choices'
Error processing 2024-07-01T13-04-02_FLAG_P_FID_1456_OID_109g0_ctw-cf-2b-146_1719813600_1719813842_1456.jpg: 'choices'
Error processing 2024-07-01T13-53-25_FLAG_P_FID_1230_OID_1099g0_ctw-cf-3c-156_1719816601_1719816805_1230.jpg: 'choices'
Error processing 2024-07-01T14-27-11_FLAG_P_FID_793_OID_1158g0_ctw-cf-2a-138_1719818700_1719818831_793.jpg: 'choices'
Error processing 2024-07-01T14-42-14_FLAG_P_FID_810_O