<a href="https://colab.research.google.com/github/TianyiRnj/dino-classification-extent/blob/Fine_tune/11785_TeamProject_PostTrainingDinoV2HAM10000_MidtermReport.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**（1）下载数据集**

In [None]:
!pip install torch torchvision pandas numpy matplotlib tqdm torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu118

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m848.7/848.7 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, nvidia-cusolver-cu11, nvidia-cudnn-cu11, torch


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install --upgrade kaggle==1.6.17 --force-reinstall --no-deps
!mkdir /root/.kaggle

with open("/root/.kaggle/kaggle.json", "w+") as f:

    f.write('{"username":"","key":""}')

!chmod 600 /root/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d kmader/skin-cancer-mnist-ham10000
!unzip -q skin-cancer-mnist-ham10000.zip -d /content/ham10000

In [None]:
# 查看解压后的目录结构
import os
data_dir = "/content/ham10000"
print("Files in dataset:", os.listdir(data_dir))

In [None]:
# 处理数据目录格式（ImageFolder 需要按类别存放）
import shutil
import pandas as pd

# 读取 metadata（图像-类别映射）
metadata = pd.read_csv(os.path.join(data_dir, "HAM10000_metadata.csv"))

# 创建新的目录结构
processed_data_dir = "/content/ham10000_processed"
os.makedirs(processed_data_dir, exist_ok=True)

# 获取所有类别名称
categories = metadata["dx"].unique()
for category in categories:
    os.makedirs(os.path.join(processed_data_dir, category), exist_ok=True)

# 归类图片
image_dir_1 = os.path.join(data_dir, "HAM10000_images_part_1")
image_dir_2 = os.path.join(data_dir, "HAM10000_images_part_2")

for _, row in metadata.iterrows():
    image_id = row["image_id"]
    category = row["dx"]
    src_path = None

    if os.path.exists(os.path.join(image_dir_1, f"{image_id}.jpg")):
        src_path = os.path.join(image_dir_1, f"{image_id}.jpg")
    elif os.path.exists(os.path.join(image_dir_2, f"{image_id}.jpg")):
        src_path = os.path.join(image_dir_2, f"{image_id}.jpg")

    if src_path:
        dst_path = os.path.join(processed_data_dir, category, f"{image_id}.jpg")
        shutil.copy(src_path, dst_path)

print("数据已整理完毕！")

**(2) 数据预处理**

In [None]:
# 使用 ImageFolder 读取数据
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # DINOv2 需要 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 指定数据存放路径
processed_data_dir = "/content/ham10000_processed"

# 读取数据
dataset = ImageFolder(root=processed_data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 输出数据集信息
print(f"Dataset loaded with {len(dataset)} images")
print(f"Classes: {dataset.classes}")



**（3）加载 DINOv2 进行 Feature Extraction**

In [None]:
import torch
import torch.nn as nn
from transformers import Dinov2Model

class DINOv2Classifier(nn.Module):
    def __init__(self, num_classes=8, pretrain_choice='frozen'):
        super().__init__()

        # 加载 DINOv2 预训练模型
        self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base")

        # 是否冻结 DINOv2 的参数
        if pretrain_choice == 'frozen':
            for param in self.dinov2.parameters():
                param.requires_grad = False  # 冻结 DINOv2 权重

        # 分类器
        self.classifier = nn.Sequential(
            nn.Linear(self.dinov2.config.hidden_size, 512),  # DINOv2 输出维度
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # 提取特征 (batch_size, num_patches, hidden_dim)
        features = self.dinov2(x).last_hidden_state

        # 取 [CLS] token 作为分类输入 (batch_size, hidden_dim)
        cls_token = features[:, 0, :]

        # 通过分类器
        out = self.classifier(cls_token)
        return out

# 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载模型
model = DINOv2Classifier(pretrain_choice='frozen').to(device)

# 打印模型摘要
from torchsummary import summary
summary(model, input_size=(3, 224, 224))


In [None]:
import torch

def extract_features(model, images, device="cuda"):
    images = images.to(device)
    with torch.no_grad():
        output = model(images)  # DINOv2 GitHub 版
        print("Model Output Shape:", output.shape)

        if output.dim() == 2:
            return output
        elif output.dim() == 3:
            return output[:, 0, :]
        else:
            raise ValueError(f"Unexpected model output shape: {output.shape}")



In [None]:
for batch in dataloader:
    print(type(batch), len(batch))  # 检查 batch 数据格式
    break


In [None]:
for batch in dataloader:
    print("Batch Type:", type(batch))
    print("Batch Length:", len(batch))
    print("Type of batch[0]:", type(batch[0]))  # 检查 images
    print("Type of batch[1]:", type(batch[1]))  # 检查 labels
    break


In [None]:
for batch in dataloader:
    images, labels = batch[0], batch[1]  # 解包


**Freeze**

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from transformers import AutoModel

# 加载 DINOv2 预训练模型
dino_model = AutoModel.from_pretrained("facebook/dino-v2-base")

# 冻结所有层
for param in dino_model.parameters():
    param.requires_grad = False

# 选择解冻
for i in range(-3, 0):  # 解冻最后 3 层
    for param in dino_model.encoder.layer[i].parameters():
        param.requires_grad = True

print("解冻的层：", [name for name, param in dino_model.named_parameters() if param.requires_grad])


In [None]:
# 定义 Fine-tuned DINOv2 模型
class FineTunedDINOv2(nn.Module):
    def __init__(self, base_model, num_classes):
        super(FineTunedDINOv2, self).__init__()
        self.dino = base_model
        self.fc = nn.Linear(768, num_classes)

    def forward(self, x):
        features = self.dino(x).last_hidden_state[:, 0, :]
        output = self.fc(features)
        return output

# 实例化模型
num_classes = 7
model = FineTunedDINOv2(dino_model, num_classes)

# 选择要训练的参数
trainable_params = [p for p in model.parameters() if p.requires_grad]
print(f"需要训练的参数量: {sum(p.numel() for p in trainable_params)}")


In [None]:
# 训练参数
learning_rate = 1e-5
batch_size = 32
num_epochs = 10

# 选择优化器
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate)

# 交叉熵损失
criterion = nn.CrossEntropyLoss()

# 训练数据
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = ImageFolder(root="/content/ham10000_processed", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}, Accuracy: {correct/total:.4f}")


In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {correct / total:.4f}")
