In [137]:
from datasets import load_dataset
from datasets import DatasetDict

# 加载完整数据集
ds = load_dataset("ethz/food101")
train_num = ds["train"].num_rows
valid_num = ds["validation"].num_rows

# 假设想要取 1/20 的数据
fraction = 1/10

# 计算要选取的样本数
train_sample_size = int(train_num * fraction)
valid_sample_size = int(valid_num * fraction)

# 对 train 和 validation 数据集进行随机打乱并取子集
train_subset = ds["train"].shuffle(seed=42).select(range(train_sample_size))
valid_subset = ds["validation"].shuffle(seed=42).select(range(valid_sample_size))

# 将抽取后的子集组成新的 DatasetDict
ds = DatasetDict({
    "train": train_subset,
    "validation": valid_subset
})

print("Original train size:", train_num)
print("Original validation size:", valid_num)
print("Reduced train size:", ds["train"].num_rows)
print("Reduced validation size:", ds["validation"].num_rows)


Original train size: 75750
Original validation size: 25250
Reduced train size: 7575
Reduced validation size: 2525


In [138]:
# 查看数据集的结构
print(ds)

# 查看训练集的前几条数据
print(ds['train'][1])

# 查看测试集的前几条数据
print(ds['validation'][0])


DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 7575
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 2525
    })
})
{'image': <PIL.Image.Image image mode=RGB size=512x512 at 0x3FF4C4230>, 'label': 71}
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x504 at 0x3FE2C4C50>, 'label': 28}


In [139]:
from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize

# 数据增强
transform = Compose([
    RandomResizedCrop(224),
    RandomHorizontalFlip(),
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


def preprocess(example):
    # Check and handle unexpected data types
    if isinstance(example['image'], list):
        example['pixel_values'] = [transform(img) for img in example['image']]
    else:
        example['pixel_values'] = transform(example['image'])
    
    return example

ds.reset_format()  # Ensure dataset is in the original format
ds = ds.with_transform(preprocess)


In [140]:
# 访问单个样本
sample = ds['train'][0]
# 此时才会调用 preprocess 函数并打印出 Inside preprocess 信息


In [141]:
import torch
from torch.utils.data import DataLoader

def collate_fn(examples):
    pixel_values = torch.stack([example['pixel_values'] for example in examples])
    labels = torch.tensor([example['label'] for example in examples])
    return {'pixel_values': pixel_values, 'label': labels}

train_loader = DataLoader(ds['train'], batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(ds['validation'], batch_size=32, shuffle=False, collate_fn=collate_fn)


In [142]:
from transformers import AutoImageProcessor, ViTForImageClassification

# 下载预训练的ViT模型（ImageNet上预训练）
model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=101,  # Food101有101个类别
)

model = model.to("cuda" if torch.cuda.is_available() else "cpu")


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [143]:
# 优化器和调度器
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = CosineAnnealingLR(optimizer, T_max=10)

In [144]:
import torch
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from transformers import AutoModelForImageClassification, AutoImageProcessor
from torch.utils.data import DataLoader

# 假设 train_loader 和 val_loader 已经初始化
device = "cuda" if torch.cuda.is_available() else "cpu"
scaler = GradScaler()

# 定义模型、优化器、损失函数
model = AutoModelForImageClassification.from_pretrained("nateraw/food").to(device)
optimizer = AdamW(model.parameters(), lr=0.0002)
criterion = torch.nn.CrossEntropyLoss()

# 超参数
epochs = 5
train_batch_size = 128
eval_batch_size = 128
log_every = 10  # 每多少个 batch 打印一次日志

# 存储结果
results = {"epoch": [], "step": [], "train_loss": [], "val_loss": [], "accuracy": []}

for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")
    model.train()
    total_train_loss = 0
    total_steps = 0

    for batch_idx, batch in enumerate(train_loader):
        inputs = batch["pixel_values"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()

        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_train_loss += loss.item()
        total_steps += 1

        # 打印训练进度
        if batch_idx % log_every == 0:
            print(f"Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

    avg_train_loss = total_train_loss / total_steps
    print(f"Epoch {epoch} Completed. Average Training Loss: {avg_train_loss:.4f}")

    # 验证阶段
    model.eval()
    total_val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in val_loader:
            inputs = batch["pixel_values"].to(device)
            labels = batch["label"].to(device)

            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)
            total_val_loss += loss.item()

            _, preds = torch.max(outputs.logits, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_val_loss = total_val_loss / len(val_loader)
    accuracy = correct / total
    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}")

    # 保存结果
    results["epoch"].append(epoch)
    results["step"].append(total_steps)
    results["train_loss"].append(avg_train_loss)
    results["val_loss"].append(avg_val_loss)
    results["accuracy"].append(accuracy)

# 打印总结
print("\nTraining Summary:")
print(f"{'Epoch':<10}{'Training Loss':<15}{'Validation Loss':<20}{'Accuracy':<10}")
for i in range(epochs):
    print(f"{results['epoch'][i]:<10}{results['train_loss'][i]:<15.4f}{results['val_loss'][i]:<20.4f}{results['accuracy'][i]:<10.4f}")


  scaler = GradScaler()


Epoch 1/5


  with autocast():


Batch 0/237, Loss: 1.4607
Batch 10/237, Loss: 1.8996
Batch 20/237, Loss: 2.0092
Batch 30/237, Loss: 1.4960
Batch 40/237, Loss: 1.6364
Batch 50/237, Loss: 2.1042
Batch 60/237, Loss: 0.8125
Batch 70/237, Loss: 0.8102
Batch 80/237, Loss: 1.1118
Batch 90/237, Loss: 1.2422
Batch 100/237, Loss: 1.2459
Batch 110/237, Loss: 1.0719
Batch 120/237, Loss: 1.0350
Batch 130/237, Loss: 0.8642
Batch 140/237, Loss: 1.0745
Batch 150/237, Loss: 1.8133
Batch 160/237, Loss: 1.0905
Batch 170/237, Loss: 1.1453
Batch 180/237, Loss: 0.9390
Batch 190/237, Loss: 1.9917
Batch 200/237, Loss: 1.3051
Batch 210/237, Loss: 1.2332
Batch 220/237, Loss: 1.3495
Batch 230/237, Loss: 1.3919
Epoch 1 Completed. Average Training Loss: 1.3253
Validation Loss: 1.1283, Accuracy: 0.7172
Epoch 2/5
Batch 0/237, Loss: 0.9801
Batch 10/237, Loss: 0.6122
Batch 20/237, Loss: 1.0358
Batch 30/237, Loss: 0.8174
Batch 40/237, Loss: 1.2123
Batch 50/237, Loss: 0.9493
Batch 60/237, Loss: 1.2787
Batch 70/237, Loss: 0.8445
Batch 80/237, Loss: 1.3

In [145]:
print(batch.keys())


dict_keys(['pixel_values', 'label'])


In [146]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in val_loader:
        # 获取输入和标签
        images = batch["pixel_values"].to(device)
        labels = batch["label"].to(device)
        
        # 获取模型输出并提取 logits
        outputs = model(images)
        logits = outputs.logits  # 提取 logits
        
        # 使用 logits 进行预测
        _, predicted = torch.max(logits, 1)
        
        # 计算总数和正确预测数
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# 打印验证集的准确率
print(f"Validation Accuracy: {100 * correct / total:.2f}%")


Validation Accuracy: 72.87%


In [136]:
# import os
# from transformers import AutoModelForImageClassification, AutoImageProcessor
# from PIL import Image
# import torch


# image_directory = "Images"


# model_name = "nateraw/food"
# model = AutoModelForImageClassification.from_pretrained(model_name)
# processor = AutoImageProcessor.from_pretrained(model_name)


# for filename in os.listdir(image_directory):
#     if filename.endswith(".jpeg") or filename.endswith(".png"):
#         image_path = os.path.join(image_directory, filename)


#         image = Image.open(image_path)
#         inputs = processor(images=image, return_tensors="pt")
        

#         with torch.no_grad():
#             outputs = model(**inputs)


#         predicted_class_idx = outputs.logits.argmax(-1).item()
#         predicted_label = model.config.id2label[predicted_class_idx]

#         print(predicted_label)