In [4]:
import torch
from dataset import MNIST
from clip import CLIP
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os
import multiprocessing

def main():
    DEVICE='cuda' if torch.cuda.is_available() else 'cpu'   # 设备

    dataset=MNIST() # 数据集

    model=CLIP().to(DEVICE) # 模型

    optimzer=torch.optim.Adam(model.parameters(),lr=1e-3)   # 优化器

    '''
        训练模型
    '''
    ITER_BATCH_COUNT=100000    # 迭代次数
    BATCH_SIZE=64   # 从batch内选出10个不一样的数字
    TARGET_COUNT=10 # 共10种数字

    # 修改worker数量为0可以避免多进程问题
    dataloader=DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=0)    # 数据加载器

    for i in range(ITER_BATCH_COUNT):
        while True:
            imgs,labels=next(iter(dataloader))
            if torch.unique(labels).shape[0]<TARGET_COUNT:  # 未覆盖10种数字
                continue
            # 挑选出10个数字
            target=set()
            indexes=[]
            for j in range(BATCH_SIZE):
                if labels[j].item() in target:
                    continue
                target.add(labels[j].item())
                indexes.append(j)
                if len(target)==TARGET_COUNT:
                    break
            imgs=imgs[indexes]
            labels=labels[indexes]
            break

        ### ====== TODO: TASK2: 完成模型损失函数计算的代码（BEGIN）

        # 通过模型获取相似度矩阵
        logits = model(imgs.to(DEVICE), labels.to(DEVICE))

        # 对比学习中，第i个图像应该与第i个文本匹配
        labels_one_hot = torch.arange(TARGET_COUNT, device=DEVICE)

        # 计算双向对比损失（图像到文本 + 文本到图像）
        loss = (F.cross_entropy(logits, labels_one_hot) +
                F.cross_entropy(logits.t(), labels_one_hot)) / 2

        ### ====== TODO: TASK 2: 完成模型损失函数计算的代码（END）

        optimzer.zero_grad()
        loss.backward()
        optimzer.step()
        if i%1000==0:
            print('iter:{},loss:{}'.format(i,loss))
            torch.save(model.state_dict(),'.model.pth')
            os.replace('.model.pth','model.pth')

# 重要！添加程序入口保护
if __name__ == '__main__':
    main()

iter:0,loss:3.831463098526001
iter:1000,loss:0.5947991013526917
iter:2000,loss:0.10980270802974701
iter:3000,loss:0.016202563419938087
iter:4000,loss:0.14844641089439392
iter:5000,loss:0.05808998644351959
iter:6000,loss:0.030240148305892944
iter:7000,loss:0.06617588549852371
iter:8000,loss:0.010835248976945877
iter:9000,loss:0.2623428702354431
iter:10000,loss:0.009947315789759159
iter:11000,loss:0.05500946193933487
iter:12000,loss:0.02327805757522583
iter:13000,loss:0.02404194325208664
iter:14000,loss:0.010703436098992825
iter:15000,loss:0.0014321808703243732
iter:16000,loss:0.0544450618326664
iter:17000,loss:0.00264459615573287
iter:18000,loss:0.026783445850014687
iter:19000,loss:0.0451720654964447
iter:20000,loss:0.009238595142960548
iter:21000,loss:0.001791806542314589
iter:22000,loss:0.12152861058712006
iter:23000,loss:0.011750077828764915
iter:24000,loss:0.01743306964635849
iter:25000,loss:0.0005077047972008586
iter:26000,loss:0.0022241934202611446
iter:27000,loss:0.00489481631666

In [3]:
'''
CLIP能力演示

1、对图片做分类
2、对图片求相图片

'''

from dataset import MNIST
import matplotlib.pyplot as plt
import torch
from clip import CLIP
import torch.nn.functional as F

DEVICE='cuda' if torch.cuda.is_available() else 'cpu'   # 设备

dataset=MNIST() # 数据集

model=CLIP().to(DEVICE) # 模型
model.load_state_dict(torch.load('./model.pth'))

model.eval()    # 预测模式

'''
1、对图片分类
'''
image,label=dataset[0]
print('正确分类:',label)
plt.imshow(image.permute(1,2,0))
plt.show()


### TODO: TASK 3: 完成CLIP模型进行预测的代码 (BEGIN)

with torch.no_grad():
    # 为每个数字(0-9)生成文本嵌入
    all_text_embeddings = model.text_enc(torch.arange(10).to(DEVICE))
    # 获取图像嵌入
    image_embedding = model.img_enc(image.unsqueeze(0).to(DEVICE))

    # 规范化嵌入
    image_embedding = image_embedding / image_embedding.norm(dim=1, keepdim=True)
    all_text_embeddings = all_text_embeddings / all_text_embeddings.norm(dim=1, keepdim=True)

    # 计算相似度
    similarity = torch.mm(image_embedding, all_text_embeddings.t())
    print(similarity)

    # 选择相似度最高的作为预测结果
    predicted_label = similarity.argmax(dim=1).item()

### TODO: TASK 3: 完成CLIP模型进行预测的代码 (END)

print('CLIP分类:', predicted_label)

'''
2、图像相似度
'''
other_images=[]
other_labels=[]
for i in range(1,101):
    other_image,other_label=dataset[i]
    other_images.append(other_image)
    other_labels.append(other_label)

### TODO: TASK 4: 使用CLIP的image encoder，从other_images里检索和image最相似的5张图像 (BEGIN)

with torch.no_grad():
    # 将其他图像转换为张量
    other_images_tensor = torch.stack(other_images).to(DEVICE)

    # 使用图像编码器获取图像嵌入
    query_embedding = model.img_enc(image.unsqueeze(0).to(DEVICE))
    other_embeddings = model.img_enc(other_images_tensor)

    # 规范化嵌入
    query_embedding = query_embedding / query_embedding.norm(dim=1, keepdim=True)
    other_embeddings = other_embeddings / other_embeddings.norm(dim=1, keepdim=True)

    # 计算相似度
    similarities = torch.mm(query_embedding, other_embeddings.t())

    # 获取相似度最高的5个索引
    indexs = similarities[0].topk(5).indices.cpu().numpy().tolist()

### TODO: TASK 4: 使用CLIP的image encoder，从other_images里检索和image最相似的5张图像 (END)

plt.figure(figsize=(15,15))
for i,img_idx in enumerate(indexs):
    plt.subplot(1,5,i+1)
    plt.imshow(other_images[img_idx].permute(1,2,0))
    plt.title(other_labels[img_idx])
    plt.axis('off')
plt.savefig(f"output/similarity{label}.pdf")
plt.show()

# 在文件末尾添加

'''
3、在整个MNIST数据集上评估CLIP模型性能
'''
print("\n在整个MNIST数据集上评估CLIP模型性能:")

# 使用tqdm创建进度条(如果没有安装可以使用pip install tqdm安装)
from tqdm import tqdm

correct = 0
total = 0
class_correct = [0] * 10
class_total = [0] * 10

with torch.no_grad():
    # 为每个数字(0-9)生成文本嵌入(只需计算一次)
    all_text_embeddings = model.text_enc(torch.arange(10).to(DEVICE))
    all_text_embeddings = all_text_embeddings / all_text_embeddings.norm(dim=1, keepdim=True)
    
    # 遍历数据集中的所有样本
    for i in tqdm(range(len(dataset))):
        image, label = dataset[i]
        
        # 获取图像嵌入
        image_embedding = model.img_enc(image.unsqueeze(0).to(DEVICE))
        image_embedding = image_embedding / image_embedding.norm(dim=1, keepdim=True)
        
        # 计算相似度
        similarity = torch.mm(image_embedding, all_text_embeddings.t())
        
        # 选择相似度最高的作为预测结果
        predicted_label = similarity.argmax(dim=1).item()
        
        # 统计正确预测的数量
        total += 1
        if predicted_label == label:
            correct += 1
            class_correct[label] += 1
        class_total[label] += 1

# 计算并打印总体准确率
accuracy = 100 * correct / total
print(f'模型在MNIST数据集上的总体准确率: {accuracy:.2f}%')
print(f'正确预测: {correct}/{total}')

# 打印每个类别的准确率
print("\n各数字类别的准确率:")
for i in range(10):
    class_acc = 100 * class_correct[i] / class_total[i]
    print(f'数字 {i}: {class_acc:.2f}% ({class_correct[i]}/{class_total[i]})')

  model.load_state_dict(torch.load('./model.pth'))


RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.