In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import math
import os
import time

from AliasMethod import AliasMethod

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0
start_epoch = 0
low_dim = 128
nce_k = 1024 # defult 4096
nce_t = 0.5 # 温度
nce_m = 0.5 # SGD 动量参数
learning_rate = 0.001
batch_size = 64
num_workers = 0
num_epochs = 200

In [3]:
class CIFAR10Instance(torchvision.datasets.CIFAR10):
    """CIFAR10Instance Dataset."""
    def __getitem__(self, index):
        if self.train:
            img, target = self.data[index], self.targets[index]
        else:
            img, target = self.data[index], self.targets[index]

        img = Image.fromarray(img) # 使用 Image.fromarray(img) 将从数据集中获取的 NumPy 数组（代表图像数据）转换为 PIL.Image.Image 对象，以便后续的图像处理操作

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None: # 如果 self.target_transform 不为 None，则对标签应用相应的转换。这可以用于标签的转换，例如将标签从一个格式转换到另一个格式
            target = self.target_transform(target)

        return img, target, index

In [4]:
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    # transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# d2l-zh/pytorch/MyExercises/data
trainset = CIFAR10Instance(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

testset = CIFAR10Instance(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=num_workers)

ndata = len(trainset)

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [5]:
import torch
from torch.autograd import Function
from torch import nn
# from .alias_multinomial import AliasMethod
import math

class NCEFunction(Function):
    @staticmethod
    def forward(self, x, y, memory, idx, params): # features.shape=(batch_size, feature_size), indexes.shape=torch.Size([128]), memory bank, idx.shape(128,4097), params
        K = int(params[0].item())  # 负样本的数量
        T = params[1].item()  # 温度参数
        Z = params[2].item()  # 归一化常数
        momentum = params[3].item()  # 动量参数
        batchSize = x.size(0)  # 当前批次的大小
        outputSize = memory.size(0)  # memory bank 的大小
        inputSize = memory.size(1)  # 输入特征的维度

        # sample positives & negatives
        idx.select(1,0).copy_(y.data) # 将正样本索引放入 idx 的第一列，这样idx的第一列就都是正样本，后面4096个都是负样本

        # 采样相应的特征向量（正样本和负样本）
        weight = torch.index_select(memory, 0, idx.view(-1)) # 从memory中提取出所有第0维（行），再按照idx.view(-1)的索引提取出指定的行 - (len(idx.view(-1)), inputSize)
        # 感觉也不能说是weights吧，许多images的特征，用features似乎更合适一点
        weight.resize_(batchSize, K+1, inputSize) # resize成跟idx一样的shape - (batchSize, K+1, inputSize) - (128, 4097, 128)

        # inner product
        out = torch.bmm(weight, x.data.resize_(batchSize, inputSize, 1)) # x.shape=(128,128), 所以这里的out就是(128, 4097, 1)，每个正样本和4097个样本（一正4096副）的内积
        # 非参数softmax
        # print("1", out)
        out.div_(T).exp_() # batchSize * self.K+1
        # print("2", out)
        x.data.resize_(batchSize, inputSize)

        if Z < 0: # Z 的设置: 如果 Z 小于 0，表示还没有初始化，因此在第一次计算时设置 Z 的值 - Z就是非参数softmax的分母
            params[2] = out.mean() * outputSize # out.mean()生成单个数
            Z = params[2].item() 
            print("normalization constant Z is set to {:.1f}".format(Z))

        out.div_(Z).resize_(batchSize, K+1) # 从(batchSize, K+1, 1) -> (batchSize, K+1)

        # 保存用于反向传播的张量
        self.save_for_backward(x, memory, y, weight, out, params)

        return out

    @staticmethod
    def backward(self, gradOutput):
        x, memory, y, weight, out, params = self.saved_tensors
        K = int(params[0].item())
        T = params[1].item()
        Z = params[2].item()
        momentum = params[3].item()
        batchSize = gradOutput.size(0)
        
        # gradients d Pm / d linear = exp(linear) / Z????????? # 此时开始更新特征v，原文中equation2
        gradOutput.data.mul_(out.data) # out.shape=(batchSize, K+1) # 应该是有一个近似，具体看平板
        # add temperature # d exp(linear) / d (v_i)^T v = (1/T) * exp((v_i)^T v)
        gradOutput.data.div_(T)

        gradOutput.data.resize_(batchSize, 1, K+1)
        
        # gradient of linear 
        # print(gradOutput.shape, weight.shape) # torch.Size([128, 4097]) torch.Size([128, 4097, 128])
        gradOutput = gradOutput.reshape(batchSize, 1, K+1) # 这一步源代码中没有
        # d exp((v_i)^T v) / d v_i = exp((v_i)^T v) * v 吗？？？？？？？？？
        # \exp((v_i)^T v) \) 对 \( v_i \) 的导数是 exp(v_i^T v) \cdot v - GPT
        # v就是weight - 果然有近似
        gradInput = torch.bmm(gradOutput.data, weight) # (batchSize, 1, K+1) mm (batchSize, K+1, inputSize) -> (batchSize, 1, inputSize)
        gradInput.resize_as_(x) # x.shape=(batch_size, feature_size=inputSize)

        # update the non-parametric data - 更新memory bank中的特征v
        # weight.shape =(batchSize, K+1, inputSize)
        weight_pos = weight.select(1, 0).resize_as_(x) # 见下面的test - (batchSize, inputSize) 及所有批次的第一个样本的特征，及所有正样本的特征
        weight_pos.mul_(momentum)
        weight_pos.add_(torch.mul(x.data, 1-momentum)) # v_i' = \text{momentum} \cdot v_i + (1 - \text{momentum}) \cdot x
        w_norm = weight_pos.pow(2).sum(1, keepdim=True).pow(0.5)
        updated_weight = weight_pos.div(w_norm) # 将更新后的向量归一化，确保它的模为 1
        memory.index_copy_(0, y, updated_weight)
        
        return gradInput, None, None, None, None

In [6]:
class NCEAverage(nn.Module):

    def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5, Z=None):
        super(NCEAverage, self).__init__()
        self.nLem = outputSize
        self.unigrams = torch.ones(self.nLem)
        self.multinomial = AliasMethod(self.unigrams)
        self.multinomial.cuda()
        self.K = K

        self.register_buffer('params',torch.tensor([K, T, -1, momentum]));
        stdv = 1. / math.sqrt(inputSize/3)
        self.register_buffer('memory', torch.rand(outputSize, inputSize).mul_(2*stdv).add_(-stdv))
 
    def forward(self, x, y): # lemniscate(features, indexes); x - (batch_size, feature_size); y - (batch_size)
        batchSize = x.size(0)
        idx = self.multinomial.draw(batchSize * (self.K+1)).view(batchSize, -1)
        # print("!!!", idx.shape, idx) # torch.Size([128, 4097])
        out = NCEFunction.apply(x, y, self.memory, idx, self.params)
        return out

In [7]:
class NCECriterion(nn.Module):
    def __init__(self, nLem):
        super(NCECriterion, self).__init__()
        self.nLem = nLem

    def forward(self, x, targets):
        batchSize = x.size(0)
        K = x.size(1) - 1
        Pnt = 1 / float(self.nLem)
        Pns = 1 / float(self.nLem)

        Pmt = x.select(1, 0)
        Pmt_div = Pmt.add(K * Pnt + 1e-7)
        lnPmt = torch.div(Pmt, Pmt_div)

        Pon_div = x.narrow(1, 1, K).add(K * Pns + 1e-7)
        # print("###########################", Pon_div)
        Pon = Pon_div.clone().fill_(K * Pns)
        lnPon = torch.div(Pon, Pon_div)

        lnPmt.log_()
        lnPon.log_()

        lnPmtsum = lnPmt.sum(0)
        lnPonsum = lnPon.view(-1, 1).sum(0)

        ######################################################
        # print("###", lnPmtsum.item(), "###", lnPonsum.item(), "###",)
        loss = - (lnPmtsum + lnPonsum) / batchSize
        return loss

In [8]:
print('==> Building model..')
net = torchvision.models.resnet18(num_classes=low_dim)  # Using ResNet18 with the low_dim output
lemniscate = NCEAverage(low_dim, ndata, nce_k, nce_t, nce_m)

# if device == 'cuda':
#     net = nn.DataParallel(net).to(device)
#     cudnn.benchmark = True
net.to(device)
lemniscate.to(device)

criterion = NCECriterion(ndata).to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
# optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=5e-4)

==> Building model..


In [9]:
def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 80 epochs"""
    lr = learning_rate
    if epoch >= 80:
        lr *= 0.1 ** ((epoch - 80) // 40)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [10]:
from save_load_checkpoint import save_checkpoint

def train(epoch, model_path, optimizer_path, best_loss=None):
    print(f'\nEpoch: {epoch}')
    adjust_learning_rate(optimizer, epoch)
    net.train()
    
    train_loss = 0
    for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
        inputs, targets, indexes = inputs.to(device), targets.to(device), indexes.to(device)
        optimizer.zero_grad()
        
        features = net(inputs)
        outputs = lemniscate(features, indexes)
        loss = criterion(outputs, indexes)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        # 每 100 个 batch 打印一次损失
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}/{len(trainloader)}: Loss: {loss.item():.4f}')
    
    # 计算当前 epoch 的平均损失
    avg_loss = train_loss / len(trainloader)
    
    # 如果当前损失低于最小损失，则保存模型
    if best_loss is None or avg_loss < best_loss:
        best_loss = avg_loss
        save_checkpoint(net, optimizer, epoch, model_path, optimizer_path)
        print(f'New best loss: {best_loss:.4f} - Model saved.')
    else:
        print(f'Epoch {epoch} completed with loss: {avg_loss:.4f}, no improvement from best loss: {best_loss:.4f}.')

In [11]:
def test(epoch):
    global best_acc
    net.eval()
    acc = kNN(epoch, net, lemniscate, trainloader, testloader, 200, nce_t, 0)
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'lemniscate': lemniscate,
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

    print(f'Best Accuracy: {best_acc * 100:.2f}%')

In [12]:
best_loss = None
model_path = 'checkpoints/cifar10/cifar10.pth'
optimizer_path = 'checkpoints/cifar10/cifar10.pth'

for epoch in range(start_epoch, start_epoch + num_epochs):
    train(epoch, model_path, optimizer_path, best_loss)
    # test(epoch)


Epoch: 0
normalization constant Z is set to 296483.4


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Batch 0/782: Loss: 9.5734
Batch 100/782: Loss: 9.1923
Batch 200/782: Loss: 8.7931
Batch 300/782: Loss: 9.0037
Batch 400/782: Loss: 9.0685
Batch 500/782: Loss: 8.8031
Batch 600/782: Loss: 9.0971
Batch 700/782: Loss: 8.8185


  return F.conv2d(input, weight, bias, self.stride,


Model and optimizer states saved to checkpoints/cifar10/cifar10.pth and checkpoints/cifar10/cifar10.pth.
New best loss: 9.0421 - Model saved.

Epoch: 1
Batch 0/782: Loss: 8.7285
Batch 100/782: Loss: 8.1204
Batch 200/782: Loss: 8.2326
Batch 300/782: Loss: 7.8456
Batch 400/782: Loss: 8.3071
Batch 500/782: Loss: 7.9908
Batch 600/782: Loss: 8.0225
Batch 700/782: Loss: 8.1245
Model and optimizer states saved to checkpoints/cifar10/cifar10.pth and checkpoints/cifar10/cifar10.pth.
New best loss: 8.0599 - Model saved.

Epoch: 2
Batch 0/782: Loss: 6.6091
Batch 100/782: Loss: 5.7945
Batch 200/782: Loss: 6.2444
Batch 300/782: Loss: 6.6357
Batch 400/782: Loss: 6.4408
Batch 500/782: Loss: 5.9939
Batch 600/782: Loss: 6.5664
Batch 700/782: Loss: 6.6244
Model and optimizer states saved to checkpoints/cifar10/cifar10.pth and checkpoints/cifar10/cifar10.pth.
New best loss: 6.1298 - Model saved.

Epoch: 3
Batch 0/782: Loss: 5.0156
Batch 100/782: Loss: 5.8160



KeyboardInterrupt

