In [None]:
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as util_data
import time
import numpy as np
from PIL import Image
from tqdm import tqdm
from torchvision import transforms, models, datasets
from tools import *
from make_dataset import *
from ST_tools import *

In [None]:
num_cls=xxxxxx
bit=xxxxx

In [None]:
class PENet(nn.Module):

    def __init__(self,
                 hash_bit,
                 patch_size=4,
                 in_chans=3,
                 num_classes=1000,
                 embed_dim=96,
                 depths=(2, 2, 6, 2),
                 num_heads=(3, 6, 12, 24),
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm,
                 patch_norm=True,
                 use_checkpoint=False,
                 **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2**(self.num_layers - 1))
        self.mlp_ratio = mlp_ratio
        self.patch_embed = PatchEmbed(
            patch_size=patch_size,
            in_c=in_chans,
            embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layers = BasicLayer(
                dim=int(embed_dim * 2**i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=self.mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if
                (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint)
            self.layers.append(layers)

        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.hash = nn.Linear(self.num_features, hash_bit)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        x, H, W = self.patch_embed(x)
        x = self.pos_drop(x)
        for layer in self.layers:
            x, H, W = layer(x, H, W)
        x = self.norm(x)
        x = self.avgpool(x.transpose(1, 2))
        x = torch.flatten(x, 1)
        x = self.hash(x)
        return x

In [None]:
class MyDataset(ImageFolder):

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        target = np.eye(num_cls, dtype=np.int8)[np.array(target)]

        return sample, target, index

In [None]:
def get_data(config):
    image_datasets ={}
    dataloaders = {}
    for data_set in ["train", "test", "valid"]:
        image_datasets[data_set] = MyDataset(os.path.join(config["data_dir"],data_set),
                                                 transform=data_transforms[data_set] 
                                                )
        dataloaders[data_set] = util_data.DataLoader(image_datasets[data_set],
                                                          batch_size=12,
                                                          shuffle=True)
    return dataloaders["train"], dataloaders["test"], dataloaders["valid"], \
               len(image_datasets["train"]), len(image_datasets["test"]), len(image_datasets["valid"])


In [None]:
def get_config():
    config = {
        "alpha": 0.1,
        "optimizer": {"type": optim.AdamW, "optim_params": {"lr": xxxx, "weight_decay": 5E-2}, "lr_type": "step"},
        "epoch": 50,
        "test_map": 1,
        "save_path": "xxxxxx",
        "device": torch.device("cuda:0"),
        "topK": -1,
        "n_class":num_cls,
        "data_dir" : "xxxxxx",
        " batch_size": 12,
    }
    return config


In [None]:
class DSHLoss(torch.nn.Module):
    def __init__(self, config, bit):
        super(DSHLoss, self).__init__()
        self.m = 2 * bit
        self.U = torch.zeros(config["num_train"], bit).float().to(config["device"])
        self.Y = torch.zeros(config["num_train"], config["n_class"]).float().to(config["device"])

    def forward(self, u, y, ind, config):
        self.U[ind, :] = u.data
        self.Y[ind, :] = y.float()

        dist = (u.unsqueeze(1) - self.U.unsqueeze(0)).pow(2).sum(dim=2)
        y = (y @ self.Y.t() == 0).float()

        loss = (1 - y) / 2 * dist + y / 2 * (self.m - dist).clamp(min=0)
        loss1 = loss.mean()
        loss2 = config["alpha"] * (1 - u.sign()).abs().mean()

        return loss1 + loss2

In [None]:
data_transforms = {
    'train': transforms.Compose([transforms.RandomRotation(0),
        transforms.RandomResizedCrop(224), 
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
    'valid': transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])       
]),
    'test': transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}

In [None]:
def get_model(hash_bit,num_classes, **kwargs):
    # trained ImageNet-1K
    # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth
    model = PENet(in_chans=3,
                            hash_bit=hash_bit,
                            patch_size=4,
                            window_size=7,
                            embed_dim=96,
                            depths=(2, 2, 6, 2),
                            num_heads=(3, 6, 12, 24),
                            num_classes=num_classes,
                            **kwargs)
    weights_dict = torch.load('./xxxxxxx.pth', map_location=device)['model']
    for k in list(weights_dict.keys()):
        del weights_dict[k]
    model.load_state_dict(weights_dict, strict=False)
    return model



In [None]:
config = get_config()
print(config)

In [None]:
device = config["device"]
net = get_model(hash_bit=bit,num_classes=num_cls).to(device)

## model training

In [None]:
optimizer = config["optimizer"]["type"](net.parameters(), **(config["optimizer"]["optim_params"]))
train_loader,  dataset_loader,test_loader, num_train, num_dataset, num_test = get_data(config)
config["num_train"] = num_train
criterion = DSHLoss(config, bit)
Best_mAP = 0
for epoch in range(config["epoch"]):
    current_time = time.strftime('%H:%M:%S', time.localtime(time.time()))
    net.train()
    train_loss = 0
    for image, label, ind in train_loader:
        image = image.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        u = net(image)

        loss = criterion(u, label.float(), ind, config)
        train_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_loss = train_loss / len(train_loader)

    print("\b\b\b\b\b\b\b loss:%.3f" % (train_loss))

    if (epoch + 1) % config["test_map"] == 0:
        tst_binary, tst_label = compute_result(test_loader, net, device=device)
        trn_binary, trn_label = compute_result(dataset_loader, net, device=device)
        mAP = CalcTopMap(trn_binary.numpy(), tst_binary.numpy(), trn_label.numpy(), tst_label.numpy(),
                             config["topK"])

        if mAP > Best_mAP:
            Best_mAP = mAP
            if "save_path" in config:
                if not os.path.exists(config["save_path"]):
                    os.makedirs(config["save_path"])
                print("save in ", config["save_path"])
                np.save(os.path.join(config["save_path"], "xxxxx"+str(bit) + str(mAP) + "-" + "trn_binary.npy"),
                         trn_binary.numpy())
                torch.save(net.state_dict(),
                            os.path.join(config["save_path"], "xxxxx"+str(bit) + "-" + str(mAP) + "-model.pt"))
        print(" epoch:%d, bit:%d, MAP:%.3f, Best MAP: %.3f" % (
                 epoch + 1, bit, mAP, Best_mAP))


## img_to_hash

In [None]:
image_dataset_test = MyDataset('./path',
                                                 transform=data_transforms["train"] 
                                                )
test_dataloaders = util_data.DataLoader(image_dataset_test,
                                                          batch_size=12,
                                                          shuffle=False)

In [None]:
def compute_result(dataloader, net, device):
    bs, clses = [], []
    net.eval()
    for img, cls, _ in tqdm(dataloader):
        clses.append(cls)
        bs.append((net(img.to(device))).data.cpu())
    hash_codes=torch.cat(bs).sign().numpy()
    hash_codes[hash_codes<0]=0
    return hash_codes, torch.cat(clses)

In [None]:
hash_codes, label = compute_result(test_dataloaders, net, device=device)