In [None]:
import torch
import torch.nn as nn
import clip
from torchvision import datasets, transforms
import config

# 1. 加载 CLIP 模型（只用视觉部分）
clip_model, preprocess = clip.load("ViT-B/32", device="cuda")
image_encoder = clip_model.visual  # 视觉 backbone

# 2. 分类模型定义：CLIP视觉模块 + 分类头
class CLIPClassifier(nn.Module):
    def __init__(self, image_encoder, num_classes):
        super().__init__()
        self.encoder = image_encoder
        self.classifier = nn.Sequential(
            nn.Linear(image_encoder.proj.shape[1], 512),  # 512是CLIP的输出维度
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        with torch.no_grad():  # 可选：不训练 encoder
            x = self.encoder(x)
        x = x.to(torch.float32)
        return self.classifier(x)

model = CLIPClassifier(image_encoder, num_classes=config.NUM_CLASSES).to("cuda")


In [2]:
torch.backends.cudnn.benchmark = True

In [3]:
from stanford_cars_dataset import StanfordCarsDataset,get_transforms
train_transform_pretrain = get_transforms(train = True)
test_transform_pretrain = get_transforms(train = False)

In [4]:
train_dataset = StanfordCarsDataset(data_dir="standford_cars", split='train', transform=train_transform_pretrain)
test_dataset = StanfordCarsDataset(data_dir="standford_cars", split='test', transform=test_transform_pretrain)

In [5]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True,num_workers=8,pin_memory=True,prefetch_factor=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False,num_workers=8,pin_memory=True,prefetch_factor=4)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-4)

In [7]:
def test(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    top1_correct = 0
    top5_correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device).half()
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            total += labels.size(0)

            # Top-1
            _, predicted = torch.max(outputs, 1)
            top1_correct += (predicted == labels).sum().item()

            # Top-5
            _, top5 = torch.topk(outputs, k=5, dim=1)
            top5_correct += (top5 == labels.unsqueeze(1)).any(dim=1).sum().item()

    epoch_loss = running_loss / len(val_loader.dataset)
    top1_acc = top1_correct / total
    top5_acc = top5_correct / total
    return epoch_loss, top1_acc, top5_acc


In [8]:
def train(model, train_loader, test_loader, criterion, optimizer, num_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images = images.to(device).half()
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # runing_loss加上每个batch的loss，images.size(0)是batch_size
            running_loss += loss.item() * images.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")
        
        if epoch % 5 == 0:
        #在测试集上评估模型
            test_loss, t_top1_acc, t_top5_acc = test(model,test_loader, criterion, "cuda")
            print(f"test Loss: {test_loss:.4f}, Top-1 Accuracy: {t_top1_acc:.4f}, Top-5 Accuracy: {t_top5_acc:.4f}")
            # 保存模型
            torch.save(model.state_dict(), f"accu{t_top1_acc} top5 {t_top5_acc}Pretain-epoch_{epoch}.pth")

In [9]:
PRETRAIN_EPOCHES =1000

In [10]:
train(model, train_loader, test_loader, criterion, optimizer, PRETRAIN_EPOCHES)
    

Epoch 1/1000
Epoch [1/1000], Loss: 5.2793
test Loss: 5.2666, Top-1 Accuracy: 0.0066, Top-5 Accuracy: 0.0341
Epoch 2/1000
Epoch [2/1000], Loss: 5.2612
Epoch 3/1000
Epoch [3/1000], Loss: 5.2437
Epoch 4/1000
Epoch [4/1000], Loss: 5.2253
Epoch 5/1000
Epoch [5/1000], Loss: 5.2072
Epoch 6/1000
Epoch [6/1000], Loss: 5.1860
test Loss: 5.1686, Top-1 Accuracy: 0.0598, Top-5 Accuracy: 0.1846
Epoch 7/1000
Epoch [7/1000], Loss: 5.1641
Epoch 8/1000
Epoch [8/1000], Loss: 5.1393
Epoch 9/1000
Epoch [9/1000], Loss: 5.1118
Epoch 10/1000
Epoch [10/1000], Loss: 5.0824
Epoch 11/1000
Epoch [11/1000], Loss: 5.0508
test Loss: 5.0212, Top-1 Accuracy: 0.1249, Top-5 Accuracy: 0.3684
Epoch 12/1000
Epoch [12/1000], Loss: 5.0156
Epoch 13/1000
Epoch [13/1000], Loss: 4.9769
Epoch 14/1000
Epoch [14/1000], Loss: 4.9370
Epoch 15/1000
Epoch [15/1000], Loss: 4.8935
Epoch 16/1000
Epoch [16/1000], Loss: 4.8490
test Loss: 4.8034, Top-1 Accuracy: 0.1971, Top-5 Accuracy: 0.5374
Epoch 17/1000
Epoch [17/1000], Loss: 4.7997
Epoch 

KeyboardInterrupt: 