In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import SwinForImageClassification, AutoFeatureExtractor
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
)
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
from data import load_food

In [None]:
train_ds, val_ds = load_food.load_food(image_size=(224, 224), rand_seed=42, n_class=20)
batch_size = 16
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

# Load Swin Transformer Base Model
path = "microsoft/swin-base-patch4-window7-224"
finetuned_model = SwinForImageClassification.from_pretrained(path)
feature_extractor = AutoFeatureExtractor.from_pretrained(path)

# Modify the classification head to adapt to a 20-class task
num_classes = 20
finetuned_model.classifier = nn.Linear(finetuned_model.config.hidden_size, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
finetuned_model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()  # Cross-entropy loss
optimizer = optim.AdamW(finetuned_model.parameters(), lr=5e-5)  # AdamW optimizer

In [None]:
# Training function
def train(model, train_loader, optimizer, criterion, device, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss, correct, total = 0, 0, 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")

        for batch in loop:
            images = batch["image"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            outputs = model(images).logits

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)

            loop.set_postfix(loss=loss.item(), acc=correct / total)

# Evaluation function
def evaluate(model, val_loader, criterion, device, class_names=None):
    model.eval()
    total_loss, total = 0, 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            images = batch["image"].to(device)
            labels = batch["label"].to(device)

            outputs = model(images).logits
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            total += labels.size(0)

            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(val_loader)

    # Compute evaluation metrics
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

    print(f"Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score (Macro): {f1:.4f}, Loss: {avg_loss:.4f}")

    # Plot confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    np.fill_diagonal(cm, 0)
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        cm,
        annot=True,
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
        vmin=0,
        vmax=21
    )
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.title("Swin Transformer")
    plt.tight_layout()
    plt.show()

In [None]:
import os
from config import OUTPUT_DIR, IMAGE_DIR
class_names = [
    'ramen', 'carrot_cake', 'beef_carpaccio', 'strawberry_shortcake', 'escargots',
    'donuts', 'croque_madame', 'cheese_plate', 'caprese_salad', 'sashimi',
    'oysters', 'caesar_salad', 'pho', 'hot_and_sour_soup', 'beef_tartare',
    'creme_brulee', 'cup_cakes', 'miso_soup', 'pork_chop', 'paella'
]

train(finetuned_model, train_loader, optimizer, criterion, device, epochs=5)
evaluate(finetuned_model, val_loader, criterion, device, class_names=class_names)

# Define the save path
save_path = os.path.join(OUTPUT_DIR, "swin_food101_finetuned.pth")
classifier_path = os.path.join(OUTPUT_DIR, "classifier.pth")

# Save the entire model's state_dict (including the classifier)
torch.save(finetuned_model.state_dict(), save_path)
torch.save(finetuned_model.classifier.state_dict(), classifier_path)

In [None]:
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from transformers import SwinForImageClassification, AutoFeatureExtractor
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget


In [None]:
id2label = {
    0: "ramen",
    1: "carrot_cake",
    2: "beef_carpaccio",
    3: "strawberry_shortcake",
    4: "escargots",
    5: "donuts",
    6: "croque_madame",
    7: "cheese_plate",
    8: "caprese_salad",
    9: "sashimi",
    10: "oysters",
    11: "caesar_salad",
    12: "pho",
    13: "hot_and_sour_soup",
    14: "beef_tartare",
    15: "creme_brulee",
    16: "cup_cakes",
    17: "miso_soup",
    18: "pork_chop",
    19: "paella"
}

class HuggingfaceSwinWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        return self.model(x).logits

## label prediction function
def predict_single_image(image_path, model, feature_extractor, device):
    image = Image.open(image_path).convert("RGB")
    inputs = feature_extractor(image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        pred_index = logits.argmax(dim=1).item()

    return pred_index, logits

def reshape_transform(tensor, height=7, width=7):
    """
    Swin Transformer-specific reshape function.
    Converts tensor from shape [B, (H*W), C] to [B, C, H, W].
    """
    # Get batch size and number of channels
    batch_size, num_patches, num_channels = tensor.shape

    # Compute height and width of the patch grid
    height = width = int(num_patches ** 0.5)

    # Reshape and permute dimensions to [B, C, H, W]
    tensor = tensor.reshape(batch_size, height, width, num_channels)
    tensor = tensor.permute(0, 3, 1, 2)

    return tensor

def visualize_attention(image_path):
    # Step 1: Load and normalize the image
    image = Image.open(image_path).convert("RGB")
    image_resized = image.resize((224, 224))
    rgb_img = np.array(image_resized).astype(np.float32) / 255.0

    # Step 2: Preprocess the image for the model
    inputs = feature_extractor(image, return_tensors="pt").to(device)
    input_tensor = inputs["pixel_values"]

    # Step 3: Predict the class
    pred_index, logits = predict_single_image(image_path, model, feature_extractor, device)
    pred_label = id2label[pred_index]

    # Step 4: Grad-CAM visualization
    # Select the layernorm_after layer of the last block in the last stage
    target_layers = [model.swin.encoder.layers[-1].blocks[-1].layernorm_after]

    wrapped_model = HuggingfaceSwinWrapper(model)
    cam = GradCAM(
        model=wrapped_model,
        target_layers=target_layers,
        reshape_transform=reshape_transform
    )

    # Define the target class for Grad-CAM
    targets = [ClassifierOutputTarget(pred_index)]

    # Ensure the input is a 4D tensor
    if len(input_tensor.shape) == 3:
        input_tensor = input_tensor.unsqueeze(0)

    grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
    visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

    # Step 5: Display the results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    ax1.imshow(image)
    ax1.set_title("Original Image")
    ax1.axis("off")

    ax2.imshow(visualization)
    ax2.set_title(f"Attention Map\nPredicted: {pred_label}")
    ax2.axis("off")

    plt.tight_layout()
    plt.show()


## Load Swin base Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinForImageClassification.from_pretrained(path, ignore_mismatched_sizes=True).to(device)
model.classifier = torch.nn.Linear(model.config.hidden_size, 20).to(device)

## Load backbone parameters
state_dict = torch.load(save_path, map_location="cpu")
state_dict = {k: v for k, v in state_dict.items() if "classifier" not in k}
model.load_state_dict(state_dict, strict=False)
## Load classification head
model.classifier.load_state_dict(torch.load(classifier_path, map_location="cpu"))
model.eval()
##Load feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(path)

image_paths = [os.path.join(IMAGE_DIR, fname) for fname in sorted(os.listdir(IMAGE_DIR)) if fname.endswith(".jpg") or fname.endswith(".jpeg") or fname.endswith(".png")]

for image_path in image_paths:
    visualize_attention(image_path)