In [2]:
import os
import time
import os.path as osp

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import CIFAR10
from torchvision import datasets
from torchvision import transforms
import torchvision

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from PIL import Image
from clip import clip

In [3]:
# # random seed
# SEED = 1 
# NUM_CLASS = 10

# Training
BATCH_SIZE = 128
# NUM_EPOCHS = 30
# EVAL_INTERVAL=1
# SAVE_DIR = './log'

# # Optimizer
# LEARNING_RATE = 1e-1
# MOMENTUM = 0.9
# STEP=5
# GAMMA=0.5

# CLIP
VISUAL_BACKBONE = 'RN50' # RN50, ViT-B/32, ViT-B/16


In [None]:
import torch
import clip
from torchvision import datasets
from torchvision.transforms import Compose, Resize
from torch.utils.data import DataLoader
from tqdm import tqdm

transform_cifar10_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


# Load the CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("RN50", device=device)

transform_FM_test = transforms.Compose([
    transforms.Resize(size=224),
    transforms.CenterCrop(size=(224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load Fashion-MNIST dataset
test_dataset = datasets.FashionMNIST(root='./data', test=True, download=True, transform=transform_FM_test)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=2)

# Define text labels for Fashion-MNIST classes
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
text_inputs = torch.cat([clip.tokenize(f"This is a photo of a {c}") for c in class_names]).to(device)
dataset_name = 'FashionMNIST'



 72%|████████████████████████████▊           | 176M/244M [13:50<04:10, 285kiB/s]

In [7]:
from torch.optim.lr_scheduler import StepLR

# Optimizer and learning rate scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  # Slightly higher starting learning rate
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)  # Reduce learning rate over epochs

num_epochs = 10  # Increase number of epochs

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        image_features = model.encode_image(images)
        text_features = model.encode_text(text_inputs)

        # Calculate similarity
        image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
        similarity = torch.matmul(image_features_norm, text_features_norm.T)

        # Loss
        loss = torch.nn.functional.cross_entropy(similarity, labels)
        total_loss += loss.item()

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    scheduler.step()  # Update learning rate
    print(f"Epoch {epoch+1} finished with average loss: {total_loss / len(train_loader)}")


Epoch 1: 100%|██████████| 1875/1875 [04:39<00:00,  6.72it/s]


Epoch 1 finished with average loss: nan


Epoch 2: 100%|██████████| 1875/1875 [07:03<00:00,  4.43it/s]


Epoch 2 finished with average loss: nan


Epoch 3: 100%|██████████| 1875/1875 [06:15<00:00,  4.99it/s]


Epoch 3 finished with average loss: nan


Epoch 4: 100%|██████████| 1875/1875 [06:01<00:00,  5.19it/s]


Epoch 4 finished with average loss: nan


Epoch 5:   8%|▊         | 156/1875 [01:43<18:58,  1.51it/s]  


KeyboardInterrupt: 

In [1]:
# 加载测试数据集
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

model.eval()  # 设置模型为评估模式
total_correct = 0
total_images = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        # 预测
        image_features = model.encode_image(images)
        text_features = model.encode_text(text_inputs)

        # 计算相似性
        image_features_norm = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
        similarity = torch.matmul(image_features_norm, text_features_norm.T)

        # 取得最高相似性的类别
        preds = similarity.argmax(dim=1)
        total_correct += preds.eq(labels).sum().item()
        total_images += labels.size(0)

print(f"测试集上的准确率: {total_correct / total_images:.2%}")


NameError: name 'datasets' is not defined

In [9]:
# 进一步降低学习率
optimizer = torch.optim.Adam(model.parameters(), lr=5e-7)

# 添加小的正值以增加数值稳定性
epsilon = 1e-9

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        image_features = model.encode_image(images)
        text_features = model.encode_text(text_inputs)

        # Calculate similarity with numerical stability
        image_features_norm = image_features / (image_features.norm(dim=-1, keepdim=True) + epsilon)
        text_features_norm = text_features / (text_features.norm(dim=-1, keepdim=True) + epsilon)
        similarity = torch.matmul(image_features_norm, text_features_norm.T)

        # Loss
        loss = torch.nn.functional.cross_entropy(similarity, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)

        optimizer.step()

        # Check for NaN
        if torch.isnan(loss).any():
            raise ValueError("Loss is NaN")

        total_loss += loss.item()

    scheduler.step()
    print(f"Epoch {epoch+1} finished with average loss: {total_loss / len(train_loader)}")


Epoch 1:   0%|          | 0/1875 [00:03<?, ?it/s]


ValueError: Loss is NaN