In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTImageProcessor
import matplotlib.pyplot as plt

# Define transformations for the training, validation, and test sets
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the datasets
train_val_dataset = datasets.ImageFolder(root='../hotdog_nothotdog', transform=transform)
test_dataset = datasets.ImageFolder(root='../hotdog_nothotdog', transform=transform)

# Define the size of the training and validation sets
train_size = int(0.8 * len(train_val_dataset))
val_size = len(train_val_dataset) - train_size

# Split the dataset
train_dataset, val_dataset = random_split(train_val_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Initialize the Vision Transformer model
model_name = "google/vit-base-patch16-224"
image_processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name, ignore_mismatched_sizes=True)

# flaten and output
model.classifier = nn.Sequential(
    nn.Linear(model.config.hidden_size, 4096),
    nn.ReLU(inplace=True),
    nn.Linear(4096, 4096),
    nn.ReLU(inplace=True),
    nn.Linear(4096, 2)
)

# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.99))

# Train
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
train_losses = []
val_losses = []
best_val_accuracy = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    train_losses.append(epoch_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    
    # Val
    model.eval()
    val_loss = 0.0
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels.data)
    
    val_loss = val_loss / len(val_loader.dataset)
    val_losses.append(val_loss)
    val_accuracy = correct.double() / len(val_loader.dataset)
    print(f"Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}")
    
    # safe model if
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')
        print("Model saved!")

# Final evaluation
print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")

# Plotting training and validation losses
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs+1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs+1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

# Test the model
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

test_loss = 0.0
correct = 0
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        test_loss += loss.item() * images.size(0)
        
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels.data)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_loss = test_loss / len(test_loader.dataset)
test_accuracy = correct.double() / len(test_loader.dataset)

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

# Calculate and print confusion matrix
from sklearn.metrics import confusion_matrix
import numpy as np

cm = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix:")
print(cm)

plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(2)
plt.xticks(tick_marks, ['Not Hotdog', 'Hotdog'])
plt.yticks(tick_marks, ['Not Hotdog', 'Hotdog'])
plt.ylabel('True label')
plt.xlabel('Predicted label')

# annotations
thresh = cm.max() / 2.
for i, j in np.ndindex(cm.shape):
    plt.text(j, i, format(cm[i, j], 'd'),
             horizontalalignment="center",
             color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.show()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import glob
from matplotlib.colors import to_rgba
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchsummary import summary

def custom_target(output):
        if len(output.shape) == 1:
            return output[0] if target_label == 1 else 1 - output[0]
        else:
            return output[:, target_label]


def get_grad_cam_visualization(model, image_tensor, target_layer,target_label):
    # Preprocess the image
    input_tensor = image_tensor.unsqueeze(0)  # Add batch dimension
    
    # Create a GradCAM object
    cam = GradCAM(model=model, target_layers=[target_layer])
    targets = [ClassifierOutputTarget(target_label)]
    # Generate the CAM

    def custom_target(output):
        if len(output.shape) == 1:
            return output[0] if target_label == 1 else 1 - output[0]
        else:
            return output[:, target_label]
            
    grayscale_cam = cam(input_tensor=input_tensor, targets=[custom_target])
    
    # Convert the PyTorch tensor to a numpy array and denormalize
    image_numpy = image_tensor.cpu().numpy().transpose(1, 2, 0)
    image_numpy = (image_numpy - image_numpy.min()) / (image_numpy.max() - image_numpy.min())
    
    # Overlay the CAM on the image
    visualization = show_cam_on_image(image_numpy, grayscale_cam[0], use_rgb=True)
    
    return visualization



examples = {
    'TP': None,  # True Positive: True Hotdog predicted as Hotdog
    'FN': None,  # False Negative: True Hotdog predicted as Not Hotdog
    'FP': None,  # False Positive: True Not Hotdog predicted as Hotdog
    'TN': None   # True Negative: True Not Hotdog predicted as Not Hotdog
}
model.load_state_dict(torch.load('best_model.pth'))
model.eval()
target_layer = model.classifier[0] 

# Shuffle the test set to randomize the order
test_indices = list(range(len(test_dataset)))
np.random.shuffle(test_indices)
idxs = [872,636,1046,1284]
"""
for idx in idxs: #substituting test_indices with idxs will always print the same images
    image, label = testset[idx]
    
    # Make prediction
    with torch.no_grad():
        output = model(image.unsqueeze(0).to(device))
    
    pred_prob = output.item()
    pred_class = "Not Hotdog" if pred_prob > 0.5 else "Hotdog"
    true_class = "Not Hotdog" if label == 1 else "Hotdog"
    
    # Determine the category of this prediction
    if true_class == "Hotdog" and pred_class == "Hotdog":
        category = 'TP'
    elif true_class == "Hotdog" and pred_class == "Not Hotdog":
        category = 'FN'
    elif true_class == "Not Hotdog" and pred_class == "Hotdog":
        category = 'FP'
    else:  # true_class == "Not Hotdog" and pred_class == "Not Hotdog"
        category = 'TN'
    
    # If we haven't found an example for this category yet, store it
    if examples[category] is None:
        examples[category] = (image, pred_class, true_class, pred_prob,idx)
    
    # If we've found all four types of examples, break the loop
    if all(examples.values()):
        break
"""
for idx in idxs: #substituting test_indices with idxs will always print the same images
    image, label = test_dataset[idx]
    # Make prediction
    with torch.no_grad():
        output = model(image.unsqueeze(0).to(device))
    
    pred_prob = output.item()
    pred_class = "Not Hotdog" if pred_prob > 0.5 else "Hotdog"
    true_class = "Not Hotdog" if label == 1 else "Hotdog"
    if idx == 872:
        examples["TP"] = (image, pred_class, true_class, pred_prob,idx)
    elif idx == 636:
        examples["FN"] = (image, pred_class, true_class, pred_prob,idx)
    elif idx == 1046:
        examples["FP"] = (image, pred_class, true_class, pred_prob,idx)
    else:
        examples["TN"] = (image, pred_class, true_class, pred_prob,idx)

# Create a 1x4 grid of subplots
fig, axes = plt.subplots(1, 4, figsize=(20, 5))  # 1x4 grid layout
#print(examples)
for ax, (category, (image, pred_class, true_class, pred_prob, idx)) in zip(axes, examples.items()):
    if image is not None:
        # Convert true_class to binary label
        label = 1 if true_class == "Not Hotdog" else 0

        # Generate Grad-CAM visualization
        vis = get_grad_cam_visualization(model, image, target_layer, label)

        # Display the original image and Grad-CAM side by side in the same subplot
        ax.imshow(image.permute(1, 2, 0), alpha=1)  # Original image with some transparency
        ax.imshow(vis, cmap='jet', alpha=0.8)  # Grad-CAM heatmap overlaid on the original image
        
        # Set title
        ax.set_title(f'Predicted: {pred_class}\nTrue: {true_class}')
        ax.axis('off')  # Turn off the axis for cleaner visualization

    else:
        ax.set_title(f'No example for {category}')
        ax.axis('off')  # Turn off the axis for missing images

# Adjust the layout so subplots don't overlap
plt.tight_layout()

# Save the figure
plt.savefig('vggGrad.png', dpi=300, bbox_inches='tight')

# Show the figure
plt.show()