In [1]:
#!/usr/bin/env python
# coding=utf-8
"""
Author: Kitiro
Date: 2020-11-03 17:32:57
LastEditTime: 2020-11-06 00:04:17
LastEditors: Kitiro
Description: visual -> semantic 
FilePath: /zzc/exp/Hierarchically_Learning_The_Discriminative_Features_For_Zero_Shot_Learning/main.py
"""

import numpy as np
import torch
import os
import argparse
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
from torch.optim import lr_scheduler
from torchsummary import summary
import json
import random
import scipy.io as sio
import torch.nn.functional as F
# Author Defined
from model import MyModel, APYModel
from utils import (
    compute_accuracy,
    CenterLoss,
    set_seed,
    cal_mean_feature,
    weights_init,
    export_log,
    plot_img,
    EarlyStopping,
)
from dataset import DataSet
from sklearn import preprocessing as P

In [2]:

def compute_accuracy(model, test_att, test_visual, test_id, test_label, center=None):
    """[summary]

    Args:
        model ([type]): [description]
        test_att ([type]): semantic feature
        test_visual ([type]): visual feature
        test_id ([type]): att2label
        test_label ([type]): x2label
    """
    model.eval()
    with torch.no_grad():
        test_att = torch.tensor(test_att).unsqueeze(1).unsqueeze(1).float().cuda()
        attr_feature = model(test_att).cpu().detach().numpy()
        if center: centers = center.centers.cpu().detach().numpy()[test_id]
        outpred = np.array([0] * test_visual.shape[0])
        test_label = test_label.astype("float32")
        # 将类的属性映射到visual space. 得到att_pred。然后每一张图片去找距离最近的attr of unseen
        for i in range(test_visual.shape[0]):
            target = test_visual[i, :]
            subtract_feat = np.tile(target, (len(test_id), 1))
            dist = np.sum((subtract_feat - attr_feature) ** 2, axis=1) ** 0.5
            if center:
                dist_with_centers = np.sum((subtract_feat - centers) ** 2, axis=1) ** 0.5
                dist += 0.5*dist_with_centers

            outpred[i] = test_id[np.argsort(dist)[0]]

        acc = np.equal(outpred, test_label).mean()
        
    return acc

# visual -> att
def compute_accuracy2(model, test_att, test_visual, test_id, test_label, center=None):
    """[summary]

    Args:
        model ([type]): [description]
        test_att ([type]): semantic feature
        test_visual ([type]): visual feature
        test_id ([type]): att2label
        test_label ([type]): x2label
    """
    model.eval()
    with torch.no_grad():
        test_visual = torch.tensor(test_visual).view(-1, 1, 1, 2048).float().cuda()
        out_visual = model(test_visual).cpu().detach().numpy()
        if center: centers = center.centers.cpu().detach().numpy()[test_id]
        outpred = np.array([0] * test_visual.shape[0])
        test_label = test_label.astype("float32")
        
        for i in range(test_visual.shape[0]):
            target = out_visual[i, :]
            subtract_feat = np.tile(target, (len(test_id), 1))
            dist = np.sum((subtract_feat - test_att) ** 2, axis=1) ** 0.5
            if center:
                dist_with_centers = np.sum((subtract_feat - centers) ** 2, axis=1) ** 0.5
                dist += 0.5*dist_with_centers

            outpred[i] = test_id[np.argsort(dist)[0]]

        acc = np.equal(outpred, test_label).mean()
        
    return acc

In [3]:
# load train data        
data_dir = 'data/AWA2_data/'
mat_visual = sio.loadmat(os.path.join(data_dir, "res101.mat"))
features, labels = mat_visual[
    'features'].T, mat_visual['labels'].astype(int).squeeze() - 1

mat_semantic = sio.loadmat(
    os.path.join(data_dir, "att_splits.mat"))
trainval_loc = mat_semantic['trainval_loc'].squeeze() - 1
test_seen_loc = mat_semantic['test_seen_loc'].squeeze() - 1
test_unseen_loc = mat_semantic['test_unseen_loc'].squeeze() - 1

attribute1 = mat_semantic['att'].T  # manual labeld attributes

attribute = attribute1

train_feature = features[trainval_loc]  # feature
train_label = labels[trainval_loc].astype(
    int)  # 23527 training samples。
train_att = attribute[train_label]  # 23527*85
train_label_unique = np.unique(train_label)

test_feature_unseen = features[test_unseen_loc]  # 7913 测试集中的未见类
test_label_unseen = labels[test_unseen_loc].astype(int)

test_feature_seen = features[test_seen_loc]  # 5882  测试集中的已见类
test_label_seen = labels[test_seen_loc].astype(int)

test_id_unseen = np.unique(test_label_unseen)
test_att_map_unseen = attribute[test_id_unseen]
class_num = 50

In [24]:
# 生成属性为200时得到的在视觉空间的映射
attr_num = 200
g_a = np.load('generate_attributes/generated_attributes_glove/class_attribute_map_AWA2.npy')
g_attribute = np.hstack((attribute, P.scale(g_a[:,:attr_num], axis=0)))


In [25]:
from sklearn import preprocessing as P

dataset = DataSet(
    name='AWA2',
    generative_attribute_num=attr_num,
    norm=True,
)
    
dataset_train = TensorDataset(
    torch.from_numpy(dataset.train_feature),
    torch.from_numpy(dataset.train_att),
    torch.from_numpy(dataset.train_label),
)

train_loader = DataLoader(
    dataset=dataset_train, batch_size=100, shuffle=True, num_workers=0
)


In [28]:
# MED Model : visual -> att. 2048 -> 85 on AwA2
class MedModel(nn.Module):
    def __init__(self, attr_dim, visual_dim):
        super(MedModel, self).__init__()
        self.in_dim = visual_dim
        self.output_dim = attr_dim
        
        self.expand_num = 5
        self.conv1 = nn.Conv2d(1, 5, kernel_size=1)
        self.conv2 = nn.Conv2d(1, 5, kernel_size=3, stride=2, padding=1)

        self.conv_out_dim = 15360
        self.fc = nn.Sequential(
                nn.Linear(self.conv_out_dim, 2048),
                nn.ReLU(),
                nn.Linear(2048, self.output_dim))
    def forward(self, x):
        x = F.relu(self.conv1(x))  # ret:10*2048
        x = x.view(-1, 1, self.expand_num, self.in_dim)  # batch, channel, height , width
        x = F.relu(self.conv2(x))
        #x = x.view(-1, 50, 50, self.attr_dim)  
        #x = torch.mean(x, dim=1) # average the feature map

        x = x.view(-1, self.conv_out_dim)  # flatten

        x = self.fc(x)

        return x    
    

In [29]:
model = MedModel(285, 2048).cuda()

summary(model, input_size=(1, 1, features.shape[-1]))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 5, 1, 2048]              10
            Conv2d-2           [-1, 5, 3, 1024]              50
            Linear-3                 [-1, 2048]      31,459,328
              ReLU-4                 [-1, 2048]               0
            Linear-5                  [-1, 285]         583,965
Total params: 32,043,353
Trainable params: 32,043,353
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.23
Params size (MB): 122.24
Estimated Total Size (MB): 122.47
----------------------------------------------------------------


In [30]:
model = MedModel(dataset.attribute.shape[-1], features.shape[-1])

model.cuda()
model.apply(weights_init)

summary(model, input_size=(1, 1, features.shape[-1]))

mse_loss = nn.MSELoss()

best_zsl = 0.0
best_h = 0.0
h_line = ""
delta = 0.1

# hyperparameters
#beta = 1e-5 # for center_loss. 0.003 in paper
alpha = 0.5 # for linear transferred
beta = 1e-5
lr = 1e-5

# opt = torch.optim.Adam(
#     model.parameters(), lr=lr, weight_decay=1e-2
# )  
# optimizer_centloss = torch.optim.Adam(center_loss.parameters(), lr=0.5)

opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-04)

scheduler = lr_scheduler.MultiStepLR(
    opt, milestones=list(range(20, 300, 20)), gamma=0.9
)

history = {
    "Acc_ZSL": [],
    "Loss_attr": [],
    "Loss_center": [],
    "Loss_Total": [],
}


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 5, 1, 2048]              10
            Conv2d-2           [-1, 5, 3, 1024]              50
            Linear-3                 [-1, 2048]      31,459,328
              ReLU-4                 [-1, 2048]               0
            Linear-5                  [-1, 285]         583,965
Total params: 32,043,353
Trainable params: 32,043,353
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.23
Params size (MB): 122.24
Estimated Total Size (MB): 122.47
----------------------------------------------------------------


In [31]:
best_model = None
beta = 0
with torch.set_grad_enabled(True):
    for epoch in range(3000):
        model.train()
        loss_center = 0
        for visual_batch, attr_batch, label_batch in train_loader:
            visual_batch = visual_batch.float().view(-1, 1, 1, 2048).cuda()
            
            attr_batch = attr_batch.float().cuda()
            
            out_visual = model(visual_batch)  # semantic -> visual space

            loss = mse_loss(out_visual, attr_batch)

            opt.zero_grad()
            
            loss.backward()

            opt.step()

        
        scheduler.step()
        model.eval()
    
        # early_stopping(eval(model, eval_loader), model)

        # if early_stopping.early_stop:
        #     print("Early stopping")
        #     break

        acc_zsl = compute_accuracy2(
            model,
            dataset.test_att_map_unseen,
            dataset.test_feature_unseen,
            # transfer_test_feature_unseen, 
            dataset.test_id_unseen,
            dataset.test_label_unseen,
        )
        acc_seen_gzsl = compute_accuracy2(
            model,
            dataset.attribute,
            dataset.test_feature_seen,
            #transfer_test_feature_seen, 
            np.arange(class_num),
            dataset.test_label_seen,
        )
        acc_unseen_gzsl = compute_accuracy2(
            model,
            dataset.attribute,
            dataset.test_feature_unseen,
            # transfer_test_feature_unseen, 
            np.arange(class_num),
            dataset.test_label_unseen,
        )
        H = 2 * acc_seen_gzsl * acc_unseen_gzsl / (acc_seen_gzsl + acc_unseen_gzsl)

        if acc_zsl > best_zsl:
            best_zsl = acc_zsl
            best_model = model
#             torch.save(
#                 model.state_dict(),
#                 os.path.join(
#                     "output/AWA2/MedModel.pth"
#                 ),
#             )


        if H > best_h:
            best_h = H
            h_line = "gzsl: seen=%.4f, unseen=%.4f, h=%.4f" % (
                acc_seen_gzsl,
                acc_unseen_gzsl,
                H,
            )

        print("Epoch:", epoch, "--------")
        print("zsl:", acc_zsl)
        print("best_zsl:", best_zsl)

        print(
                "Total-loss:{}".format(
                    loss.item()
                )
            )

        print("lr:", opt.param_groups[0]["lr"])
        print(
            "gzsl: seen=%.4f, unseen=%.4f, h=%.4f"
            % (acc_seen_gzsl, acc_unseen_gzsl, H)
        )
        print(h_line)

        history["Acc_ZSL"].append(acc_zsl)
        history["Loss_Total"].append(loss.item())

        #export_log(log_info, log_path)

        # plot_img(history, os.path.join(args.output, args.dataset), postfix)
        if acc_zsl < best_zsl-delta:
            patience -= 1
            if patience == 0: 
                print('Early Stopping')
                break


Epoch: 0 --------
zsl: 0.12372046000252748
best_zsl: 0.12372046000252748
Total-loss:0.006585697177797556
lr: 1e-05
gzsl: seen=0.1863, unseen=0.0538, h=0.0835
gzsl: seen=0.1863, unseen=0.0538, h=0.0835
Epoch: 1 --------
zsl: 0.1791987868065209
best_zsl: 0.1791987868065209
Total-loss:0.006156490184366703
lr: 1e-05
gzsl: seen=0.3631, unseen=0.0505, h=0.0887
gzsl: seen=0.3631, unseen=0.0505, h=0.0887
Epoch: 2 --------
zsl: 0.2169847087071907
best_zsl: 0.2169847087071907
Total-loss:0.003804041538387537
lr: 1e-05
gzsl: seen=0.4879, unseen=0.0379, h=0.0704
gzsl: seen=0.3631, unseen=0.0505, h=0.0887
Epoch: 3 --------
zsl: 0.23126500695058763
best_zsl: 0.23126500695058763
Total-loss:0.0029962894041091204
lr: 1e-05
gzsl: seen=0.5519, unseen=0.0279, h=0.0532
gzsl: seen=0.3631, unseen=0.0505, h=0.0887
Epoch: 4 --------
zsl: 0.23530898521420449
best_zsl: 0.23530898521420449
Total-loss:0.0031463350169360638
lr: 1e-05
gzsl: seen=0.5959, unseen=0.0262, h=0.0501
gzsl: seen=0.3631, unseen=0.0505, h=0.08

Epoch: 41 --------
zsl: 0.3093643371666877
best_zsl: 0.31340831543030456
Total-loss:0.002052977913990617
lr: 8.1e-06
gzsl: seen=0.7241, unseen=0.0061, h=0.0120
gzsl: seen=0.3631, unseen=0.0505, h=0.0887
Epoch: 42 --------
zsl: 0.3014027549601921
best_zsl: 0.31340831543030456
Total-loss:0.0024596625007689
lr: 8.1e-06
gzsl: seen=0.7375, unseen=0.0087, h=0.0172
gzsl: seen=0.3631, unseen=0.0505, h=0.0887
Epoch: 43 --------
zsl: 0.3031719954505245
best_zsl: 0.31340831543030456
Total-loss:0.0022073688451200724
lr: 8.1e-06
gzsl: seen=0.7435, unseen=0.0061, h=0.0120
gzsl: seen=0.3631, unseen=0.0505, h=0.0887


KeyboardInterrupt: 