In [2]:
#!/usr/bin/env python
# coding=utf-8
"""
Author: Kitiro
Date: 2020-11-03 17:32:57
LastEditTime: 2020-11-06 00:04:17
LastEditors: Kitiro
Description: visual -> semantic 
FilePath: /zzc/exp/Hierarchically_Learning_The_Discriminative_Features_For_Zero_Shot_Learning/main.py
"""

import numpy as np
import torch
import os
import argparse
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
from torch.optim import lr_scheduler
from torchsummary import summary
import json
import random
import scipy.io as sio
import torch.nn.functional as F
# Author Defined
from model import MyModel, APYModel
from utils import (
    compute_accuracy,
    CenterLoss,
    set_seed,
    cal_mean_feature,
    weights_init,
    export_log,
    plot_img,
    EarlyStopping,
)
from dataset import DataSet
from sklearn import preprocessing as P

In [3]:

def compute_accuracy(model, test_att, test_visual, test_id, test_label, center=None):
    """[summary]

    Args:
        model ([type]): [description]
        test_att ([type]): semantic feature
        test_visual ([type]): visual feature
        test_id ([type]): att2label
        test_label ([type]): x2label
    """
    model.eval()
    with torch.no_grad():
        test_att = torch.tensor(test_att).unsqueeze(1).unsqueeze(1).float().cuda()
        attr_feature = model(test_att).cpu().detach().numpy()
        if center: centers = center.centers.cpu().detach().numpy()[test_id]
        outpred = np.array([0] * test_visual.shape[0])
        test_label = test_label.astype("float32")
        # 将类的属性映射到visual space. 得到att_pred。然后每一张图片去找距离最近的attr of unseen
        for i in range(test_visual.shape[0]):
            target = test_visual[i, :]
            subtract_feat = np.tile(target, (len(test_id), 1))
            dist = np.sum((subtract_feat - attr_feature) ** 2, axis=1) ** 0.5
            if center:
                dist_with_centers = np.sum((subtract_feat - centers) ** 2, axis=1) ** 0.5
                dist += 0.5*dist_with_centers

            outpred[i] = test_id[np.argsort(dist)[0]]

        acc = np.equal(outpred, test_label).mean()
        
    return acc

# visual -> att
def compute_accuracy2(model, test_att, test_visual, test_id, test_label, center=None):
    """[summary]

    Args:
        model ([type]): [description]
        test_att ([type]): semantic feature
        test_visual ([type]): visual feature
        test_id ([type]): att2label
        test_label ([type]): x2label
    """
    model.eval()
    with torch.no_grad():
        test_visual = torch.tensor(test_visual).view(-1, 1, 1, 2048).float().cuda()
        out_visual = model(test_visual).cpu().detach().numpy()
        if center: centers = center.centers.cpu().detach().numpy()[test_id]
        outpred = np.array([0] * test_visual.shape[0])
        test_label = test_label.astype("float32")
        
        for i in range(test_visual.shape[0]):
            target = out_visual[i, :]
            subtract_feat = np.tile(target, (len(test_id), 1))
            dist = np.sum((subtract_feat - test_att) ** 2, axis=1) ** 0.5
            if center:
                dist_with_centers = np.sum((subtract_feat - centers) ** 2, axis=1) ** 0.5
                dist += 0.5*dist_with_centers

            outpred[i] = test_id[np.argsort(dist)[0]]

        acc = np.equal(outpred, test_label).mean()
        
    return acc

In [4]:
# load train data        
data_dir = 'data/AWA2_data/'
mat_visual = sio.loadmat(os.path.join(data_dir, "res101.mat"))
features, labels = mat_visual[
    'features'].T, mat_visual['labels'].astype(int).squeeze() - 1

mat_semantic = sio.loadmat(
    os.path.join(data_dir, "att_splits.mat"))
trainval_loc = mat_semantic['trainval_loc'].squeeze() - 1
test_seen_loc = mat_semantic['test_seen_loc'].squeeze() - 1
test_unseen_loc = mat_semantic['test_unseen_loc'].squeeze() - 1

attribute1 = mat_semantic['att'].T  # manual labeld attributes

attribute = attribute1

train_feature = features[trainval_loc]  # feature
train_label = labels[trainval_loc].astype(
    int)  # 23527 training samples。
train_att = attribute[train_label]  # 23527*85
train_label_unique = np.unique(train_label)

test_feature_unseen = features[test_unseen_loc]  # 7913 测试集中的未见类
test_label_unseen = labels[test_unseen_loc].astype(int)

test_feature_seen = features[test_seen_loc]  # 5882  测试集中的已见类
test_label_seen = labels[test_seen_loc].astype(int)

test_id_unseen = np.unique(test_label_unseen)
test_att_map_unseen = attribute[test_id_unseen]
class_num = 50

In [9]:
# 生成属性为200时得到的在视觉空间的映射
attr_num = 200
g_a = np.load('generate_attributes/generated_attributes_glove/class_attribute_map_AWA2.npy')
g_attribute = np.hstack((P.scale(attribute, axis=1), P.scale(g_a[:,:attr_num], axis=1)))


In [5]:
from sklearn import preprocessing as P

dataset = DataSet(
    name='AWA2',
    generative_attribute_num=0,
    norm=True,
)
    
dataset_train = TensorDataset(
    torch.from_numpy(dataset.train_feature),
    torch.from_numpy(dataset.train_att),
    torch.from_numpy(dataset.train_label),
)

train_loader = DataLoader(
    dataset=dataset_train, batch_size=100, shuffle=True, num_workers=0
)


In [6]:
class DemModel(nn.Module):
    def __init__(self, attr_dim, output_dim):
        super(MyModel, self).__init__()
        self.attr_dim = attr_dim
        self.output_dim = output_dim
        
        channel_num = 5
        self.conv1 = nn.Conv2d(1, 50, kernel_size=1)
        self.conv2 = nn.Conv2d(1, channel_num, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(channel_num)
        self.conv_out_dim = 50 * (attr_dim)* channel_num
        self.fc = nn.Linear(self.conv_out_dim, self.output_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # 
        x = x.view(-1, 1, 50, self.attr_dim)  # batch, channel, height , width
        x = F.relu(self.bn1(self.conv2(x)))
        #x = x.view(-1, 50, 50, self.attr_dim)  
        #x = torch.mean(x, dim=1) # average the feature map

        x = x.view(-1, self.conv_out_dim)  # flatten

        x = self.fc(x)

        return x    

# MED Model : visual -> att. 2048 -> 85 on AwA2
class MedModel(nn.Module):
    def __init__(self, attr_dim, visual_dim):
        super(MedModel, self).__init__()
        self.in_dim = visual_dim
        self.output_dim = attr_dim
        
        self.expand_num = 5
        self.conv1 = nn.Conv2d(1, 5, kernel_size=1)
        self.conv2 = nn.Conv2d(1, 1, kernel_size=3, stride=1)

        self.conv_out_dim = 6138
        self.fc1 = nn.Linear(self.conv_out_dim, 1024)
        self.fc2 = nn.Linear(1024, self.output_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))  # ret:10*2048
        x = x.view(-1, 1, self.expand_num, self.in_dim)  # batch, channel, height , width
        x = F.relu(self.conv2(x))
        #x = x.view(-1, 50, 50, self.attr_dim)  
        #x = torch.mean(x, dim=1) # average the feature map

        x = x.view(-1, self.conv_out_dim)  # flatten

        x = self.fc2(F.relu(self.fc1(x)))

        return x    
    

In [7]:
model = MedModel(attribute.shape[-1], features.shape[-1])

model.cuda()
model.apply(weights_init)

summary(model, input_size=(1, 1, features.shape[-1]))

mse_loss = nn.MSELoss()

best_zsl = 0.0
best_h = 0.0
h_line = ""
patience = 7
delta = 0.1

# hyperparameters
#beta = 1e-5 # for center_loss. 0.003 in paper
alpha = 0.5 # for linear transferred
beta = 1e-5
lr = 1e-5

# opt = torch.optim.Adam(
#     model.parameters(), lr=lr, weight_decay=1e-2
# )  
# optimizer_centloss = torch.optim.Adam(center_loss.parameters(), lr=0.5)

opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-04)

scheduler = lr_scheduler.MultiStepLR(
    opt, milestones=list(range(20, 300, 20)), gamma=0.9
)

history = {
    "Acc_ZSL": [],
    "Loss_attr": [],
    "Loss_center": [],
    "Loss_Total": [],
}


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 10, 1, 2048]              20
            Conv2d-2           [-1, 5, 8, 2046]              50
       BatchNorm2d-3           [-1, 5, 8, 2046]              10
            Linear-4                 [-1, 1024]      83,805,184
            Linear-5                   [-1, 85]          87,125
Total params: 83,892,389
Trainable params: 83,892,389
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1.41
Params size (MB): 320.02
Estimated Total Size (MB): 321.45
----------------------------------------------------------------


In [15]:
transfer_test_feature_seen = linear_model(torch.tensor(dataset.test_feature_seen).float().cuda()).detach().cpu().numpy()
transfer_test_feature_unseen = linear_model(torch.tensor(dataset.test_feature_unseen).float().cuda()).detach().cpu().numpy()

In [None]:
best_model = None
beta = 0
with torch.set_grad_enabled(True):
    for epoch in range(3000):
        model.train()
        loss_center = 0
        for visual_batch, attr_batch, label_batch in train_loader:
            visual_batch = visual_batch.float().view(-1, 1, 1, 2048).cuda()
            
            attr_batch = attr_batch.float().cuda()
            
            out_visual = model(visual_batch)  # semantic -> visual space

            loss = mse_loss(out_visual, attr_batch)

            opt.zero_grad()
            
            loss.backward()

            opt.step()

        
        scheduler.step()
        model.eval()
    
        # early_stopping(eval(model, eval_loader), model)

        # if early_stopping.early_stop:
        #     print("Early stopping")
        #     break

        acc_zsl = compute_accuracy2(
            model,
            dataset.test_att_map_unseen,
            dataset.test_feature_unseen,
            # transfer_test_feature_unseen, 
            dataset.test_id_unseen,
            dataset.test_label_unseen,
        )
        acc_seen_gzsl = compute_accuracy2(
            model,
            dataset.attribute,
            dataset.test_feature_seen,
            #transfer_test_feature_seen, 
            np.arange(class_num),
            dataset.test_label_seen,
        )
        acc_unseen_gzsl = compute_accuracy2(
            model,
            dataset.attribute,
            dataset.test_feature_unseen,
            # transfer_test_feature_unseen, 
            np.arange(class_num),
            dataset.test_label_unseen,
        )
        H = 2 * acc_seen_gzsl * acc_unseen_gzsl / (acc_seen_gzsl + acc_unseen_gzsl)

        if acc_zsl > best_zsl:
            best_zsl = acc_zsl
            best_model = model
            torch.save(
                model.state_dict(),
                os.path.join(
                    "output/AWA2/MedModel.pth"
                ),
            )


        if H > best_h:
            best_h = H
            h_line = "gzsl: seen=%.4f, unseen=%.4f, h=%.4f" % (
                acc_seen_gzsl,
                acc_unseen_gzsl,
                H,
            )

        print("Epoch:", epoch, "--------")
        print("zsl:", acc_zsl)
        print("best_zsl:", best_zsl)

        print(
                "Total-loss:{}".format(
                    loss.item()
                )
            )

        print("lr:", opt.param_groups[0]["lr"])
        print(
            "gzsl: seen=%.4f, unseen=%.4f, h=%.4f"
            % (acc_seen_gzsl, acc_unseen_gzsl, H)
        )
        print(h_line)

        history["Acc_ZSL"].append(acc_zsl)
        history["Loss_Total"].append(loss.item())

        #export_log(log_info, log_path)

        # plot_img(history, os.path.join(args.output, args.dataset), postfix)
        if acc_zsl < best_zsl-delta:
            patience -= 1
            if patience == 0: 
                print('Early Stopping')
                break


Epoch: 0 --------
zsl: 0.4046505750031594
best_zsl: 0.4524200682421332
Total-loss:0.005509004462510347
lr: 7.290000000000001e-06
gzsl: seen=0.9430, unseen=0.0062, h=0.0123
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 1 --------
zsl: 0.40275496019208895
best_zsl: 0.4524200682421332
Total-loss:0.007405819371342659
lr: 7.290000000000001e-06
gzsl: seen=0.9442, unseen=0.0070, h=0.0138
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 2 --------
zsl: 0.40705168709718187
best_zsl: 0.4524200682421332
Total-loss:0.006299541797488928
lr: 7.290000000000001e-06
gzsl: seen=0.9434, unseen=0.0063, h=0.0126
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 3 --------
zsl: 0.40780993302161
best_zsl: 0.4524200682421332
Total-loss:0.006057158578187227
lr: 7.290000000000001e-06
gzsl: seen=0.9429, unseen=0.0059, h=0.0118
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 4 --------
zsl: 0.40995829647415644
best_zsl: 0.4524200682421332
Total-loss:0.005957984831184149
lr: 7.290000000000001e-06
gzsl: seen

Epoch: 40 --------
zsl: 0.4074308100593959
best_zsl: 0.4524200682421332
Total-loss:0.00423052441328764
lr: 5.904900000000001e-06
gzsl: seen=0.9449, unseen=0.0059, h=0.0118
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 41 --------
zsl: 0.4004802224188045
best_zsl: 0.4524200682421332
Total-loss:0.004171972628682852
lr: 5.904900000000001e-06
gzsl: seen=0.9436, unseen=0.0056, h=0.0111
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 42 --------
zsl: 0.3992164792114242
best_zsl: 0.4524200682421332
Total-loss:0.004949692636728287
lr: 5.904900000000001e-06
gzsl: seen=0.9434, unseen=0.0045, h=0.0091
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 43 --------
zsl: 0.40490332364463544
best_zsl: 0.4524200682421332
Total-loss:0.003964265342801809
lr: 5.904900000000001e-06
gzsl: seen=0.9437, unseen=0.0047, h=0.0093
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 44 --------
zsl: 0.4097055478326804
best_zsl: 0.4524200682421332
Total-loss:0.005289625376462936
lr: 5.904900000000001e-06
gzsl: 

Epoch: 78 --------
zsl: 0.4026285858713509
best_zsl: 0.4524200682421332
Total-loss:0.002994883805513382
lr: 4.782969000000001e-06
gzsl: seen=0.9451, unseen=0.0042, h=0.0083
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 79 --------
zsl: 0.4050296979653734
best_zsl: 0.4524200682421332
Total-loss:0.0024726244155317545
lr: 4.782969000000001e-06
gzsl: seen=0.9430, unseen=0.0038, h=0.0076
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 80 --------
zsl: 0.4066725641349678
best_zsl: 0.4524200682421332
Total-loss:0.003030942752957344
lr: 4.782969000000001e-06
gzsl: seen=0.9442, unseen=0.0049, h=0.0098
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 81 --------
zsl: 0.40844180462530016
best_zsl: 0.4524200682421332
Total-loss:0.0037769395858049393
lr: 4.782969000000001e-06
gzsl: seen=0.9449, unseen=0.0058, h=0.0116
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 82 --------
zsl: 0.3985846076077341
best_zsl: 0.4524200682421332
Total-loss:0.004355264827609062
lr: 4.782969000000001e-06
gzs

Epoch: 116 --------
zsl: 0.40515607228611145
best_zsl: 0.4524200682421332
Total-loss:0.0030146194621920586
lr: 3.8742048900000015e-06
gzsl: seen=0.9434, unseen=0.0054, h=0.0108
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 117 --------
zsl: 0.3998483508151144
best_zsl: 0.4524200682421332
Total-loss:0.0018829682376235723
lr: 3.8742048900000015e-06
gzsl: seen=0.9447, unseen=0.0040, h=0.0081
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 118 --------
zsl: 0.4036395804372552
best_zsl: 0.4524200682421332
Total-loss:0.0030180015601217747
lr: 3.8742048900000015e-06
gzsl: seen=0.9454, unseen=0.0034, h=0.0068
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 119 --------
zsl: 0.4038923290787312
best_zsl: 0.4524200682421332
Total-loss:0.002401673002168536
lr: 3.8742048900000015e-06
gzsl: seen=0.9446, unseen=0.0042, h=0.0083
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 120 --------
zsl: 0.4062934411727537
best_zsl: 0.4524200682421332
Total-loss:0.0036210473626852036
lr: 3.874204890000

Epoch: 154 --------
zsl: 0.4004802224188045
best_zsl: 0.4524200682421332
Total-loss:0.0019077197648584843
lr: 3.138105960900001e-06
gzsl: seen=0.9456, unseen=0.0044, h=0.0088
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 155 --------
zsl: 0.4050296979653734
best_zsl: 0.4524200682421332
Total-loss:0.0012892462546005845
lr: 3.138105960900001e-06
gzsl: seen=0.9463, unseen=0.0039, h=0.0078
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 156 --------
zsl: 0.4019967142676608
best_zsl: 0.4524200682421332
Total-loss:0.003924292977899313
lr: 3.138105960900001e-06
gzsl: seen=0.9447, unseen=0.0037, h=0.0073
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 157 --------
zsl: 0.40212308858839885
best_zsl: 0.4524200682421332
Total-loss:0.0025831758975982666
lr: 3.138105960900001e-06
gzsl: seen=0.9451, unseen=0.0042, h=0.0083
gzsl: seen=0.8767, unseen=0.0224, h=0.0436
Epoch: 158 --------
zsl: 0.39871098192847215
best_zsl: 0.4524200682421332
Total-loss:0.0020496139768511057
lr: 3.138105960900001

In [26]:
# back up
best_model = None
with torch.set_grad_enabled(True):
    for epoch in range(300):
        model.train()
        for visual_batch, attr_batch, label_batch in train_map_loader:
            visual_batch = visual_batch.float().cuda()

            attr_batch = (
                attr_batch.float()
                .reshape(visual_batch.shape[0], 1, 1, semantic_feature_dim)
                .cuda()
            )

            out_visual = model(attr_batch)  # semantic -> visual space

            loss_attr = mse_loss(out_visual, visual_batch)
            loss = loss_attr

            opt.zero_grad()

            loss.backward()

            opt.step()

        scheduler.step()
        model.eval()
 
        acc_zsl = compute_accuracy(
            model,
            P.scale(dataset.test_att_map_unseen, axis=1),
            P.scale(dataset.test_feature_unseen, axis=1),
            # dataset.test_att_map_unseen, 
            #dataset.test_feature_unseen,
            dataset.test_id_unseen,
            dataset.test_label_unseen,
        )
        acc_seen_gzsl = compute_accuracy(
            model,
            # dataset.attribute,
            # dataset.test_feature_seen,
            P.scale(dataset.attribute, axis=1),
            P.scale(dataset.test_feature_seen, axis=1),
            np.arange(class_num),
            dataset.test_label_seen,
        )
        acc_unseen_gzsl = compute_accuracy(
            model,
            #dataset.attribute,
            #dataset.test_feature_unseen,
            P.scale(dataset.attribute, axis=1),
            P.scale(dataset.test_feature_unseen, axis=1),
            np.arange(class_num),
            dataset.test_label_unseen,
        )
        H = 2 * acc_seen_gzsl * acc_unseen_gzsl / (acc_seen_gzsl + acc_unseen_gzsl)

        if acc_zsl > best_zsl:
            best_zsl = acc_zsl
            best_model = model
#             torch.save(
#                 model.state_dict(),
#                 os.path.join(
#                     args.output, "{}/Model_{}.pth".format(args.dataset, postfix)
#                 ),
#             )
#         if best_center > loss_center.item():
#             best_center = loss_center.item()
#             torch.save(
#                 center_loss.state_dict(),
#                 os.path.join(
#                     args.output,
#                     "{}/CenterLoss_{}.pth".format(args.dataset, postfix),
#                 ),
#             )

        if H > best_h:
            best_h = H
            h_line = "gzsl: seen=%.4f, unseen=%.4f, h=%.4f" % (
                acc_seen_gzsl,
                acc_unseen_gzsl,
                H,
            )

        print("Epoch:", epoch, "--------")
        print("zsl:", acc_zsl)
        print("best_zsl:", best_zsl)

        print("loss:{}".format(loss.item()))

        print("lr:", opt.param_groups[0]["lr"])
        print(
            "gzsl: seen=%.4f, unseen=%.4f, h=%.4f"
            % (acc_seen_gzsl, acc_unseen_gzsl, H)
        )
        print(h_line)

        history["Acc_ZSL"].append(acc_zsl)
        history["Loss_Total"].append(loss.item())

        #export_log(log_info, log_path)

        # plot_img(history, os.path.join(args.output, args.dataset), postfix)
        if acc_zsl < best_zsl-delta:
            patience -= 1
            if patience == 0: 
                print('Early Stopping')
                break


Epoch: 0 --------
zsl: 0.39112852268419057
best_zsl: 0.39112852268419057
loss:0.2824324667453766
lr: 1e-05
gzsl: seen=0.3574, unseen=0.1860, h=0.2447
gzsl: seen=0.3574, unseen=0.1860, h=0.2447
Epoch: 1 --------
zsl: 0.3820295715910527
best_zsl: 0.39112852268419057
loss:0.1497512012720108
lr: 1e-05
gzsl: seen=0.2837, unseen=0.2225, h=0.2494
gzsl: seen=0.2837, unseen=0.2225, h=0.2494
Epoch: 2 --------
zsl: 0.3635789207633009
best_zsl: 0.39112852268419057
loss:0.15182745456695557
lr: 1e-05
gzsl: seen=0.2564, unseen=0.2164, h=0.2347
gzsl: seen=0.2837, unseen=0.2225, h=0.2494
Epoch: 3 --------
zsl: 0.354353595349425
best_zsl: 0.39112852268419057
loss:0.16483403742313385
lr: 1e-05
gzsl: seen=0.2367, unseen=0.2075, h=0.2211
gzsl: seen=0.2837, unseen=0.2225, h=0.2494
Epoch: 4 --------
zsl: 0.34891949955768986
best_zsl: 0.39112852268419057
loss:0.12380704283714294
lr: 1e-05
gzsl: seen=0.2338, unseen=0.2030, h=0.2173
gzsl: seen=0.2837, unseen=0.2225, h=0.2494
Epoch: 5 --------
zsl: 0.33969417414