In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import hash_model as image_hash_model
import pickle
import time
import numpy as np



top_k = 1000
batch_size = 10
epochs = 150
lr = 0.001 #0.05
weight_decay = 10 ** -5

lambda1 = 0.01

# 1. 加载CIFAR-10数据集
transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)


hash_bits = 64 # 假设哈希编码的位数
model_name = "vgg11"
device = torch.device("cuda")

model = image_hash_model.HASH_Net(model_name, hash_bits)
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.3, last_epoch=-1)

# 3. 加载哈希编码
with open('../labels/64_cifar10_10_class.pkl', 'rb') as f:
    label_hash_codes= torch.load(f)
label_hash_codes.to(device)

ModuleNotFoundError: No module named 'torch'

In [None]:
def test_accuracy(model, test_loader, label_hash_codes, device=device):
    model.eval()  # 设置模型为评估模式
    correct = 0
    total = 0

    with torch.no_grad():  # 禁用梯度计算
        for data, labels in test_loader:
            data = data.to(device)
            outputs = torch.sign(model(data)).to(device)  # 模型输出的哈希值 后续可以继续修改
            
            # 通过计算输出和每个类别哈希码之间的相似度来简化汉明距离的计算
            # 计算相似度
            similarities = torch.mm(outputs, label_hash_codes.t().to(device))
            
            # 相似度最高的类别即为预测类别
            _, predicted = similarities.max(dim=1)
            
            total += labels.size(0)
            correct += (predicted.to(device) == labels.to(device)).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

In [None]:

start_time = time.time()
total_loss = []
for epoch in range(10):  # 循环遍历数据集多次
    scheduler.step()
    if epoch % 5 == 0:
        accuracy = test_accuracy(model, testloader, label_hash_codes)
        print(f'Epoch {epoch}, Test Accuracy: {accuracy}%')
    for iter, traindata in enumerate(trainloader, 0):
        inputs, labels= traindata
        inputs = inputs.to(device)
        inputs = model(inputs)
        #label_ind = (labels ==1).nonzero()[:1]
        cat_codes = label_hash_codes[labels].to(device) 
        criterion = nn.BCELoss().to(device)
        center_loss = criterion(0.5*(inputs+1),0.5*(cat_codes+1))

        Q_loss = torch.mean((torch.abs(inputs)-1.0)**2)
        loss = center_loss+lambda1*Q_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss.append(loss.data.to(device).numpy)
        if iter%100==0:
            end_time1 = time.time()
            print('epoch: %d, lr: %.5f iter_num: %d, time: %.3f, loss: %.3f' % (epoch, lr, iter,(end_time1-start_time), loss))

        
end_time = time.time()
epoch_loss = np.mean(total_loss)
print('Finished Training')

Epoch 0, Test Accuracy: 12.93%
epoch: 0, lr: 0.00100 iter_num: 0, time: 71.707, loss: 0.758
epoch: 0, lr: 0.00100 iter_num: 100, time: 94.185, loss: 0.527
epoch: 0, lr: 0.00100 iter_num: 200, time: 116.296, loss: 0.468


KeyboardInterrupt: 