In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import hash_model as image_hash_model
import pickle

top_k = 1000
batch_size = 10
epochs = 150
learning_rate = 0.001 #0.05
weight_decay = 10 ** -5

alpha = 0.05
beta = 0.01
lamda = 0.01 #50
gamma = 0.2
sigma = 0.2
code_length = 64
# 1. 加载CIFAR-10数据集
transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)


hash_bits = 64 # 假设哈希编码的位数
model_name = "vgg11"
device = torch.device("mps")
model = image_hash_model.HASH_Net(model_name, hash_bits)
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.3, last_epoch=-1)

# 3. 加载哈希编码
with open('../labels/64_cifar10_10_class.pkl', 'rb') as f:
    label_code = torch.load(f)
label_code.to(device)

Files already downloaded and verified
Files already downloaded and verified


tensor([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,
          1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,
          1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,
          1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,
          1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,  1., -1.,
          1., -1.,  1., -1.,  1., -1.,  1., -1.],
        [ 1.,  1., -1., -1.,  1.,  1., -1., -1.,  1.,  1., -1., -1.,  1.,  1.,
         -1., -1.,  1.,  1., -1., -1.,  1.,  1., -1., -1.,  1.,  1., -1., -1.,
          1.,  1., -1., -1.,  1.,  1., -1., -1.,  1.,  1., -1., -1.,  1.,  1.,
         -1., -1.,  1.,  1., -1

In [30]:
# 5. 训练网络
for epoch in range(10):  # 循环遍历数据集多次
    scheduler.step()
    epoch_loss = 0.0
    epoch_loss_r = 0.0
    epoch_loss_e = 0.0
    for iter, traindata in enumerate(trainloader, 0):
        inputs, labels= traindata
        labels = torch.squeeze(labels)
        inputs =inputs.to(device)
        labels = labels.type(torch.FloatTensor).to(device)
        the_batch = 3
        hash_out = model(inputs)
        logit = hash_out.mm(label_code.t().to(device))
        print(label_code.size())
        print(logit.size())
        print(hash_out)

        print(logit)
        
        our_logit = torch.exp((logit - sigma * code_length) * gamma) * labels
        mu_logit = (torch.exp(logit * gamma) * (1 - labels)).sum(1).view(-1, 1).expand(the_batch, labels.size()[1]) + our_logit
        loss = - ((torch.log(our_logit / mu_logit + 1 - labels)).sum(1) / labels.sum(1)).sum()

        Bbatch = torch.sign(hash_out)
        regterm = (Bbatch - hash_out).pow(2).sum()
        loss_all = loss / the_batch + regterm * lamda / the_batch

        optimizer.zero_grad()
        loss_all.backward()
        optimizer.step()
        epoch_loss += loss_all.item()
        epoch_loss_e += loss.item() / the_batch
        epoch_loss_r += regterm.item() / the_batch
    print('[Train Phase][Epoch: %3d/%3d][Loss_i: %3.5f, Loss_e: %3.5f, Loss_r: %3.5f]' %
            (epoch + 1, epochs, epoch_loss / len(trainloader), epoch_loss_e / len(trainloader),
            epoch_loss_r / len(trainloader)))


print('Finished Training')

torch.Size([10, 64])
torch.Size([10, 10])
tensor([[ 5.6237e-01, -3.9726e-01,  5.2532e-01,  1.7990e-01,  6.6133e-01,
          3.5797e-01,  3.8393e-01, -2.1520e-01, -4.8338e-01,  3.0614e-01,
         -8.8840e-02,  2.9929e-01, -3.4418e-01,  2.8141e-01, -2.1158e-02,
          3.3469e-01,  2.8631e-01, -4.8569e-01, -8.2876e-01,  1.9018e-01,
          2.0906e-01, -4.7060e-01,  1.2424e-01,  2.0507e-01, -9.0063e-02,
          2.0617e-01, -2.6634e-01, -3.0706e-01, -2.7210e-01, -1.7099e-01,
          3.7904e-01,  5.7409e-01, -5.4631e-01,  3.4046e-02,  2.4665e-01,
         -4.1593e-01,  1.5301e-01, -6.8084e-02, -7.7951e-01,  2.0936e-02,
         -1.2844e-01, -1.2406e-01, -3.2081e-01, -2.4034e-01, -3.8700e-01,
         -1.5146e-03,  1.3396e-01,  5.2997e-01, -5.4221e-01,  1.8010e-01,
          3.5549e-01,  2.2619e-01, -4.5382e-01,  5.9925e-01, -2.0092e-01,
         -1.6110e-01, -5.1936e-01, -4.0850e-01,  1.8207e-01, -7.8246e-01,
         -2.3313e-01,  7.5020e-01, -1.1187e-01, -3.7288e-02],
        

IndexError: tuple index out of range