In [None]:
import warnings
warnings.filterwarnings("ignore")

import torch
from transformers import AutoModelForImageClassification

# download pretrained models
model = AutoModelForImageClassification.from_pretrained(
    'facebook/deit-base-patch16-224',
    num_labels=20,  
    ignore_mismatched_sizes=True  
)

# freeze except classifier
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# random seed
torch.manual_seed(42)

In [None]:
#image processing
import os
import json
import random
from PIL import UnidentifiedImageError
from datasets import load_dataset
from torchvision import transforms

from config import CACHE_DIR

IMAGE_SIZE = (224, 224)
NUM_CLASSES = 20
SEED = 42

# Image normalization
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# image processing
def process_image(image, image_size=IMAGE_SIZE):
    try:
        if not isinstance(image, Image.Image):
            image = Image.fromarray(np.array(image))
        image = image.resize(image_size)
        tensor = transform_to_tensor(image)
        return tensor.numpy()  
    except (UnidentifiedImageError, OSError, ValueError) as e:
        return None

# data loading
def load_food(image_size=IMAGE_SIZE, rand_seed=SEED, n_class=NUM_CLASSES):
    random.seed(rand_seed)

    train_ds = load_dataset("ethz/food101", split="train", cache_dir=CACHE_DIR)
    val_ds = load_dataset("ethz/food101", split="validation", cache_dir=CACHE_DIR)

    class_names = train_ds.features["label"].names

    if os.path.exists("selected_classes.json"):
        with open("selected_classes.json", "r") as f:
            selected_classes = json.load(f)
    else:
        selected_classes = random.sample(class_names, n_class)
        with open("selected_classes.json", "w") as f:
            json.dump(selected_classes, f)

    selected_indices = [class_names.index(c) for c in selected_classes]
    label_map = {old: new for new, old in enumerate(selected_indices)}

    train_ds = train_ds.filter(lambda x: x["label"] in selected_indices)
    val_ds = val_ds.filter(lambda x: x["label"] in selected_indices)

    def process_and_relabel(example):
        img = process_image(example["image"], image_size)
        if img is None:
            return {"image": None, "label": -1}
        return {
            "image": img,
            "label": label_map[example["label"]]
        }

    train_ds = train_ds.map(process_and_relabel)
    val_ds = val_ds.map(process_and_relabel)
    train_ds = train_ds.filter(lambda x: x["image"] is not None and x["label"] != -1)
    val_ds = val_ds.filter(lambda x: x["image"] is not None and x["label"] != -1)

    return train_ds, val_ds, selected_classes

In [None]:
# create Dataset and DataLoader
import torch
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 16

# Dataset 
class Food101Dataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = torch.tensor(item["image"], dtype=torch.float32)
        label = torch.tensor(item["label"], dtype=torch.long)
        return {
            "pixel_values": image,
            "label": label
        }

    def __len__(self):
        return len(self.dataset)

# create DataLoader
def get_dataloaders(batch_size=BATCH_SIZE):
    train_ds, val_ds, class_names = load_food()
    train_loader = DataLoader(Food101Dataset(train_ds), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(Food101Dataset(val_ds), batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, class_names

# get DataLoader for train
train_loader, val_loader, selected_classes = get_dataloaders()


In [None]:
import torch.nn as nn
from config import OUTPUT_DIR

# train settings
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=5e-4)
num_epochs = 5

# train the model
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    print(f" epoch {epoch + 1} begin")

    for batch_idx, batch in enumerate(train_loader):
        images = batch["pixel_values"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=images)
        loss = loss_fn(outputs.logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}")

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} finished. Average Loss: {avg_loss:.4f}")

model_weight_path = os.path.join(OUTPUT_DIR, "deit_base_finetuned_food20.pt")
torch.save(model.state_dict(), model_weight_path)

In [None]:
from transformers import AutoModelForImageClassification
from config import OUTPUT_DIR
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loaded_model = AutoModelForImageClassification.from_pretrained("facebook/deit-base-patch16-224")
loaded_model.classifier = torch.nn.Linear(loaded_model.config.hidden_size, 20)

# fine-tuned model
state_dict = torch.load(model_weight_path, map_location=device)
loaded_model.load_state_dict(state_dict, strict=False)  # 用 strict=False 是因为我们只加载了部分参数

loaded_model.to(device)
loaded_model.eval()

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns


model = loaded_model  
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in val_loader:
        images = batch["pixel_values"].to(device)
        labels = batch["label"].to(device)

        outputs = model(pixel_values=images)
        preds = torch.argmax(outputs.logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# evaluate
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)
f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

# print results
print(f"Accuracy: {accuracy*100:.2f}%")
print(f"Precision (Macro): {precision:.4f}")
print(f"Recall (Macro): {recall:.4f}")
print(f"F1 Score (Macro): {f1:.4f}")


cm = confusion_matrix(all_labels, all_preds)
np.fill_diagonal(cm, 0)   
plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    cmap="Blues",
    xticklabels=selected_classes,
    yticklabels=selected_classes,
    vmin=0,
    vmax=21
 )
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("DeiT")
plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from torchvision import transforms
import os

from config import IMAGE_DIR, GRAD_CAM_DIR

# model
model = loaded_model
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

if hasattr(model, "vit"):
    backbone = model.vit
elif hasattr(model, "deit"):
    backbone = model.deit
elif hasattr(model, "base_model"):
    backbone = model.base_model
else:
    raise ValueError("can't recognize")

# Hook storage
activations, gradients = [], []

def forward_hook(module, inputs, outputs):
    if isinstance(outputs, tuple):
        hidden_state = outputs[0]
    elif hasattr(outputs, "last_hidden_state"):
        hidden_state = outputs.last_hidden_state
    else:
        hidden_state = outputs
    activations.append(hidden_state)

def backward_hook(module, grad_in, grad_out):
    gradients.append(grad_out[0])

# Grad-cam
def compute_gradcam(input_tensor, model, target_layer):
    activations.clear()
    gradients.clear()

    out = model(input_tensor)
    logits = out.logits if hasattr(out, "logits") else out[0]
    pred = int(logits.argmax(dim=-1).item())

    score = logits[0, pred]
    model.zero_grad()
    score.backward()

    A = activations[0]
    G = gradients[0]
    weights = G.mean(dim=1, keepdim=True)
    cam_1d = (weights * A).sum(-1).squeeze(0)
    cam_1d = cam_1d[1:]

    side = int(cam_1d.size(0)**0.5)
    cam_2d = cam_1d.reshape(side, side)
    cam_2d = F.interpolate(
        cam_2d.unsqueeze(0).unsqueeze(0), size=(224,224),
        mode='bilinear', align_corners=False
    ).squeeze().detach().cpu().numpy()

    cam_2d = (cam_2d - cam_2d.min()) / (cam_2d.max() - cam_2d.min() + 1e-8)
    return cam_2d, pred

# Image Un-normalization and label mapping
def denormalize(tensor_img):
    mean = torch.tensor([0.485,0.456,0.406], device=tensor_img.device).view(3,1,1)
    std  = torch.tensor([0.229,0.224,0.225], device=tensor_img.device).view(3,1,1)
    return torch.clamp(tensor_img * std + mean, 0, 1)

def get_label_name(cls_id):
    cls_id = int(cls_id)
    if hasattr(model, "config") and hasattr(model.config, "id2label"):
        label_map = model.config.id2label
        if isinstance(label_map, dict):
            return label_map.get(str(cls_id), str(cls_id))
        elif isinstance(label_map, list):
            return label_map[cls_id] if 0 <= cls_id < len(label_map) else str(cls_id)
    return str(cls_id)

image_paths = [os.path.join(IMAGE_DIR, fname) for fname in sorted(os.listdir(IMAGE_DIR)) if fname.endswith(".jpg") or fname.endswith(".jpeg") or fname.endswith(".png")]

# image processing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# choose the layer
target_layers = [
    backbone.encoder.layer[-1].intermediate,
    backbone.encoder.layer[-2].intermediate,
    backbone.encoder.layer[-3].intermediate
]

#process and save images
def adjust_brightness(img_np, factor=0.7):
    img = Image.fromarray((img_np * 255).astype(np.uint8))
    enhancer = ImageEnhance.Brightness(img)
    img_dark = enhancer.enhance(factor)
    return np.array(img_dark) / 255.0

for img_path in image_paths:
    image = Image.open(img_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)

    img_dn = denormalize(input_tensor.squeeze(0))
    img_np = img_dn.permute(1,2,0).cpu().numpy()
    img_np_dark = adjust_brightness(img_np, factor=0.7)

    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(img_np_dark)
    axs[0].set_title("Original (Darkened)")
    axs[0].axis('off')

    for i, layer in enumerate(target_layers):
        fwd_handle = layer.register_forward_hook(forward_hook)
        bwd_handle = layer.register_backward_hook(backward_hook)

        cam, pred_class = compute_gradcam(input_tensor, model, layer)

        axs[i+1].imshow(img_np_dark)
        axs[i+1].imshow(cam, cmap='jet', alpha=0.35)  # 更低透明度，突出红色区域
        axs[i+1].set_title(f"Layer -{i+1}\nPred: {get_label_name(pred_class)}")
        axs[i+1].axis('off')

        fwd_handle.remove()
        bwd_handle.remove()

    plt.tight_layout()

    base_name = os.path.splitext(os.path.basename(img_path))[0]
    save_path = os.path.join(GRAD_CAM_DIR, f"{base_name}_styled_gradcam.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight', pad_inches=0.05)
    plt.close(fig)

In [None]:
#adjust colors
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import ImageEnhance

def adjust_brightness(img_np, factor=0.85):
    img = Image.fromarray((img_np * 255).astype(np.uint8))
    enhancer = ImageEnhance.Brightness(img)
    img_dark = enhancer.enhance(factor)
    return np.array(img_dark) / 255.0

for img_path in image_paths:
    image = Image.open(img_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)

    img_dn = denormalize(input_tensor.squeeze(0))
    img_np = img_dn.permute(1, 2, 0).cpu().numpy()
    img_np_bright = adjust_brightness(img_np, factor=0.85)

    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(img_np_bright)
    axs[0].set_title("Original (Lightened)")
    axs[0].axis('off')

    for i, layer in enumerate(target_layers):
        fwd_handle = layer.register_forward_hook(forward_hook)
        bwd_handle = layer.register_backward_hook(backward_hook)

        cam, pred_class = compute_gradcam(input_tensor, model, layer)

        cam = cam ** 1.5
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)

        cam_rgb = show_cam_on_image(img_np_bright, cam, use_rgb=True)

        axs[i+1].imshow(cam_rgb)
        axs[i+1].set_title(f"Layer -{i+1}\nPred: {get_label_name(pred_class)}")
        axs[i+1].axis('off')

        fwd_handle.remove()
        bwd_handle.remove()

    plt.tight_layout()
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    save_path = os.path.join(GRAD_CAM_DIR, f"{base_name}_styled_gradcam_v2.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight', pad_inches=0.05)
    plt.close(fig)