# Implementation of the paper "Fine-grained generalized zero-shot learning via dense attribute-based attention"

In [1]:
import os,sys
import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models.resnet as models
from PIL import Image
import h5py
import numpy as np
import scipy.io as sio
import pickle
import pdb
import matplotlib.pyplot as plt
import pandas as pd
import gensim.downloader as api
import torch.optim as optim
import importlib


# SUN dataset
#images = 14340

#classes = 717, 645 seen classes and 72 unseen classes

Each class has 102 number of attributes that represent the class infromation

In [2]:
img_dir = 'C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/'
print(img_dir)

file_paths = 'C:/Sushree/Jio_Institute/Dataset/data/xlsa17/data/SUN/res101.mat'
print(file_paths)

#resNet101.mat includes the following fields:
#-features: columns correspond to image instances
#-labels: label number of a class is its row number in allclasses.txt
#-image_files: image sources  

C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/
C:/Sushree/Jio_Institute/Dataset/data/xlsa17/data/SUN/res101.mat


# Let's visualize the data

In [None]:
def visualize_data_distribution(file_paths):    
    matcontent = sio.loadmat(file_paths)
    print(matcontent)

    image_files = np.squeeze(matcontent['image_files'])
    print('image_files', image_files)

    labels = np.squeeze(matcontent['labels'])
    print('labels', labels)
    print('labels.size', labels.size)  # 14340 for SUN

    class_names = []
    for idx in range(len(image_files)):
        image_file = image_files[idx][0]
        class_name = image_file.split('/')[6:][3]
        class_names.append(class_name)

    print('len(class_names)', len(class_names))   
    print('class_names', class_names)
    
    num_bins = 717 # # for SUN
    
    plt.figure(figsize=(80,6))
    
    plt.title("Data Distribution: SUN")
    plt.xlabel("Categories")
    plt.ylabel("Number of Classes")
    
    plt.xticks(rotation = 90)
    plt.grid(color = 'red', linestyle = '--', linewidth = 0.3)
    plt.hist(class_names, num_bins, align="mid")

visualize_data_distribution(file_paths) 

# # Let's extract deep features (consider pre-trained ResNet 101 with no fine-tuning)

In [3]:
class CustomedDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, img_dir , file_paths, transform=None):
        self.matcontent = sio.loadmat(file_paths)
        self.image_files = np.squeeze(self.matcontent['image_files'])
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx][0]
        image_file = os.path.join(self.img_dir, '/'.join(image_file.split('/')[7:]))
        
        image = Image.open(image_file)
        #if image.mode == 'L':
        image=image.convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, image_file

In [4]:
input_size = 224
data_transforms = transforms.Compose([
        transforms.Resize(input_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

SUNDataset = CustomedDataset(img_dir, file_paths, data_transforms)

In [5]:
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_name = "resnet"

# Batch size for training (change depending on how much memory you have)
batch_size = 32

#%%

model_ref = models.resnet101(pretrained=True)
model_ref.eval()

model_f = nn.Sequential(*list(model_ref.children())[:-2])
model_f.eval()

for param in model_f.parameters():
    param.requires_grad = False
    
print(model_f)
        
from torchsummary import summary
summary(model_f, (3, 224, 224))    

dataset_loader = torch.utils.data.DataLoader(SUNDataset, batch_size=batch_size, shuffle=False, num_workers=0)    



Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

In [7]:
all_features = []
error_files = []
for i_batch, package in enumerate(dataset_loader):
    
    imgs, image_files=package
    #imgs=imgs.to(device)
    print(i_batch, imgs.size(1))
    #print(imgs.size(1))
    if imgs.size(1) != 3:
        print('Error')
        features = torch.zeros((1, 2048, 7, 7))   
        error_files.append(image_files)
    else:
        features = model_f(imgs)
    
    all_features.append(features.numpy())

    
print('err_counter {}'.format(error_files))
all_features = np.concatenate(all_features,axis=0)    

0 3
1 3
2 3
3 3
4 3
5 3
6 3
7 3
8 3
9 3
10 3
11 3
12 3
13 3
14 3
15 3
16 3
17 3
18 3
19 3
20 3
21 3
22 3
23 3
24 3
25 3
26 3
27 3
28 3
29 3
30 3
31 3
32 3
33 3
34 3
35 3
36 3
37 3
38 3
39 3
40 3
41 3
42 3
43 3
44 3
45 3
46 3
47 3
48 3
49 3
50 3
51 3
52 3
53 3
54 3
55 3
56 3
57 3
58 3
59 3
60 3
61 3
62 3
63 3
64 3
65 3
66 3
67 3
68 3
69 3
70 3
71 3
72 3
73 3
74 3
75 3
76 3
77 3
78 3
79 3
80 3
81 3
82 3
83 3
84 3
85 3
86 3
87 3
88 3
89 3
90 3
91 3
92 3
93 3
94 3
95 3
96 3
97 3
98 3
99 3
100 3
101 3
102 3
103 3
104 3
105 3
106 3
107 3
108 3
109 3
110 3
111 3
112 3
113 3
114 3
115 3
116 3
117 3
118 3
119 3
120 3
121 3
122 3
123 3
124 3
125 3
126 3
127 3
128 3
129 3
130 3
131 3
132 3
133 3
134 3
135 3
136 3
137 3
138 3
139 3
140 3
141 3
142 3
143 3
144 3
145 3
146 3
147 3
148 3
149 3
150 3
151 3
152 3
153 3
154 3
155 3
156 3
157 3
158 3
159 3
160 3
161 3
162 3
163 3
164 3
165 3
166 3
167 3
168 3
169 3
170 3
171 3
172 3
173 3
174 3
175 3
176 3
177 3
178 3
179 3
180 3
181 3
182 3
183 3
184 3


# Let's extract semantic attributes of each category (consider pre-trained word2vec model with no fine-tuning)

In [10]:
print('Load pretrain w2v model')

model_name = 'word2vec-google-news-300'#best model
model = api.load(model_name)

dim_w2v = 300

#%%
replace_word = [('rockstone','rock stone'),('dirtsoil','dirt soil'),('man-made','man-made'),('sunsunny','sun sunny'),
                ('electricindoor','electric indoor'),('semi-enclosed','semi enclosed'),('far-away','faraway')] # for SUN


file_path = 'C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/attribute/attributes.mat'
matcontent = sio.loadmat(file_path)
des = matcontent['attributes'].flatten()

#%%
df = pd.DataFrame()
new_des = [''.join(i.item().split('/')) for i in des]

#%% replace out of dictionary words
for pair in replace_word:
    for idx,s in enumerate(new_des):
        new_des[idx]=s.replace(pair[0],pair[1])
print('Done replace OOD words')

df['new_des']=new_des
df.to_csv('C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/attribute/new_des.csv')
print('Done preprocessing attribute des')

Load pretrain w2v model
Done replace OOD words
Done preprocessing attribute des


In [11]:

all_w2v = []
for s in new_des:
    print(s)
    words = s.split(' ')
    if words[-1] == '':     #remove empty element
        words = words[:-1]
    w2v = np.zeros(dim_w2v)
    for w in words:
        try:
            w2v += model[w]
        except Exception as e:
            print(e)
    all_w2v.append(w2v[np.newaxis,:])
    
#%%
all_w2v=np.concatenate(all_w2v,axis=0)
#pdb.set_trace()
#%%

with open('C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/w2v/SUN_attribute.pkl','wb') as f:
    pickle.dump(all_w2v,f)  

sailing boating
driving
biking
transporting things or people
sunbathing
vacationing touring
hiking
climbing
camping
reading
studying learning
teaching training
research
diving
swimming
bathing
eating
cleaning
socializing
congregating
waiting in line queuing
competing
sports
exercise
playing
gaming
spectating being in an audience
farming
constructing building
shopping
medical activity
working
using tools
digging
conducting business
praying
fencing
railing
wire
railroad
trees
grass
vegetation
shrubbery
foliage
leaves
flowers
asphalt
pavement
shingles
carpet
brick
tiles
concrete
metal
paper
wood (not part of a tree)
"Key '(not' not present"
"Key 'of' not present"
"Key 'a' not present"
"Key 'tree)' not present"
vinyl linoleum
rubber plastic
cloth
sand
rock stone
dirt soil
marble
glass
waves surf
ocean
running water
still water
ice
snow
clouds
smoke
fire
natural light
direct sun sunny
electric indoor lighting
aged worn
glossy
matte
sterile
moist damp
dry
dirty
rusty
warm
cold
natural
man-ma

# Read the attributes and save as "w2v_att"

In [12]:
attribute_path = 'C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/w2v/SUN_attribute.pkl'

with open(attribute_path,'rb') as f:
    w2v_att = pickle.load(f)
assert w2v_att.shape == (102,300) # for AWA2
print('save w2v_att')

print(w2v_att, w2v_att.shape)

save w2v_att
[[ 0.1484375   0.39794922 -0.20947266 ...  0.07128906  0.57324219
   0.06445312]
 [ 0.22070312  0.09863281  0.06738281 ...  0.03198242  0.27929688
   0.00640869]
 [-0.00921631 -0.15332031  0.06640625 ... -0.07910156  0.38476562
   0.24511719]
 ...
 [ 0.171875   -0.12402344  0.17480469 ... -0.15039062  0.15722656
  -0.00708008]
 [ 0.19726562  0.06542969  0.00866699 ... -0.07275391  0.22558594
   0.20410156]
 [ 0.02111816  0.00772095 -0.39257812 ...  0.13378906  0.01190186
   0.15136719]] (102, 300)


# Let's gather additional information (training, validation, and test indexes)

In [13]:
#%% get remaining metadata
matcontent = SUNDataset.matcontent
labels = matcontent['labels'].astype(int).squeeze() - 1

split_path = 'C:/Sushree/Jio_Institute/Dataset/data/xlsa17/data/SUN/att_splits.mat'
print(split_path)
    
#att_splits.mat includes the following fields:
#-att: columns correpond to class attribute vectors normalized to have unit l2 norm, following the classes order in allclasses.txt 
#-original_att: the original class attribute vectors without normalization
#-trainval_loc: instances indexes of train+val set features (for only seen classes) in resNet101.mat
#-test_seen_loc: instances indexes of test set features for seen classes
#-test_unseen_loc: instances indexes of test set features for unseen classes    


C:/Sushree/Jio_Institute/Dataset/data/xlsa17/data/SUN/att_splits.mat


In [14]:
def get_index_details(split_path):
    matcontent = sio.loadmat(split_path)
    print(matcontent)
    
    trainval_loc = matcontent['trainval_loc'].squeeze() - 1
    print(trainval_loc, len(trainval_loc))

    test_seen_loc = matcontent['test_seen_loc'].squeeze() - 1
    print(test_seen_loc, len(test_seen_loc))

    test_unseen_loc = matcontent['test_unseen_loc'].squeeze() - 1
    print(test_unseen_loc, len(test_unseen_loc))
    
    att = matcontent['att'].T
    print(att, att.shape)
    
    original_att = matcontent['original_att'].T
    print(original_att, original_att.shape)
    return trainval_loc, test_seen_loc, test_unseen_loc, att, original_att
    
trainval_loc, test_seen_loc, test_unseen_loc, att, original_att = get_index_details(split_path)    

{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Fri Aug 21 10:35:38 2020', '__version__': '1.0', '__globals__': [], 'allclasses_names': array([[array(['abbey'], dtype='<U5')],
       [array(['access_road'], dtype='<U11')],
       [array(['airfield'], dtype='<U8')],
       [array(['airlock'], dtype='<U7')],
       [array(['airplane_cabin'], dtype='<U14')],
       [array(['airport_airport'], dtype='<U15')],
       [array(['airport_entrance'], dtype='<U16')],
       [array(['airport_terminal'], dtype='<U16')],
       [array(['airport_ticket_counter'], dtype='<U22')],
       [array(['alcove'], dtype='<U6')],
       [array(['alley'], dtype='<U5')],
       [array(['amphitheater'], dtype='<U12')],
       [array(['amusement_arcade'], dtype='<U16')],
       [array(['amusement_park'], dtype='<U14')],
       [array(['anechoic_chamber'], dtype='<U16')],
       [array(['apartment_building_outdoor'], dtype='<U26')],
       [array(['apse_indoor'], dtype='<U11')],
       [array(['

# Save the feature map that includes ResNet50 features, labels, training and test (seen and unseen) data indexes, semantic attributes, and w2v attributes

In [15]:

save_path = 'C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/feature_map_ResNet_101_SUN.hdf5'

f = h5py.File(save_path, "w")
f.create_dataset('feature_map', data=all_features,compression="gzip")
f.create_dataset('labels', data=labels,compression="gzip")
f.create_dataset('trainval_loc', data=trainval_loc,compression="gzip")
#    f.create_dataset('train_loc', data=train_loc,compression="gzip")
#    f.create_dataset('val_unseen_loc', data=val_unseen_loc,compression="gzip")
f.create_dataset('test_seen_loc', data=test_seen_loc,compression="gzip")
f.create_dataset('test_unseen_loc', data=test_unseen_loc,compression="gzip")
f.create_dataset('att', data=att,compression="gzip")
f.create_dataset('original_att', data=original_att,compression="gzip")
f.create_dataset('w2v_att', data=w2v_att,compression="gzip")
f.close()

In [16]:
hf = h5py.File('C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/feature_map_ResNet_101_SUN.hdf5', 'r')
features = np.array(hf.get('feature_map'))
print(features.shape)

(14340, 2048, 7, 7)


In [17]:
print(features)
att = np.array(hf.get('att'))
print(att)


[[[[0.00000000e+00 1.96410865e-01 6.76221728e-01 ... 9.90564227e-01
    4.70250785e-01 2.34905377e-01]
   [2.21763458e-03 5.41731834e-01 1.18846798e+00 ... 4.94249791e-01
    2.96695948e-01 0.00000000e+00]
   [0.00000000e+00 1.74594030e-01 0.00000000e+00 ... 0.00000000e+00
    0.00000000e+00 0.00000000e+00]
   ...
   [8.81760865e-02 3.25672001e-01 4.37725484e-01 ... 4.11111265e-01
    0.00000000e+00 0.00000000e+00]
   [2.23469660e-01 3.81422281e-01 6.01043701e-01 ... 2.20551789e-01
    0.00000000e+00 0.00000000e+00]
   [6.44565463e-01 4.24261063e-01 5.55662751e-01 ... 0.00000000e+00
    2.94744670e-02 0.00000000e+00]]

  [[3.04151326e-01 7.70970523e-01 9.21629131e-01 ... 1.52987623e+00
    1.60744679e+00 1.11083770e+00]
   [6.79018438e-01 5.99694729e-01 9.15595651e-01 ... 1.38449216e+00
    1.43896592e+00 9.86185789e-01]
   [5.74687600e-01 6.54377818e-01 8.89177799e-01 ... 7.79943526e-01
    7.14233458e-01 6.23846650e-01]
   ...
   [5.19757032e-01 7.13034093e-01 1.18239522e+00 ... 1.27

# Train the DAZLE model for SUN

In [1]:
import os,sys
import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models.resnet as models
from PIL import Image
import h5py
import numpy as np
import scipy.io as sio
import pickle
import pdb
import matplotlib.pyplot as plt
import pandas as pd
import gensim.downloader as api
import torch.optim as optim
import importlib


In [2]:
from DAZLE import DAZLE
from SUNDataLoader import SUNDataLoader
from helper_func import eval_zs_gzsl,visualize_attention#,get_attribute_attention_stats

In [3]:
data_path = 'C:/Sushree/Jio_Institute/Dataset/'
feature_path = 'C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/'
dataloader = SUNDataLoader(data_path, feature_path, device = None, is_scale=False, is_balance = True)


C:/Sushree/Jio_Institute/Dataset/
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
SUN
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
Balance dataloader
_____
C:/Sushree/Jio_Institute/Dataset/SUNAttributeDB_Images/feature_map_ResNet_101_SUN.hdf5
Expert Attr


In [5]:

#%%
seed = 214
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

print('Randomize seed {}'.format(seed))
#%%
batch_size = 50
nepoches = 100
niters = dataloader.ntrain * nepoches//batch_size
dim_f = 2048
dim_v = 300
init_w2v_att = dataloader.w2v_att
att = dataloader.att#dataloader.normalize_att#
normalize_att = dataloader.normalize_att
#assert (att.min().item() == 0 and att.max().item() == 1)

device = None

trainable_w2v = True
lambda_ = 0.1
bias = 0.
prob_prune = 0
uniform_att_1 = False
uniform_att_2 = True

seenclass = dataloader.seenclasses
unseenclass = dataloader.unseenclasses
desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0))
report_interval = niters//nepoches
#%%
model = DAZLE(dim_f,dim_v,init_w2v_att,att,normalize_att,
            seenclass,unseenclass,
            lambda_,
            trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,
            uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,
            prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,
            is_bias=True,non_linear_act=False)
model.to(device)
#%%
params_to_update = []
for name,param in model.named_parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        print("\t",name)
#%%
lr = 0.0001
weight_decay = 0.0001#0.000#0.#
momentum = 0.9#0.#
optimizer  = optim.RMSprop( params_to_update ,lr=lr,weight_decay=weight_decay, momentum=momentum)
#%%
print('-'*30)
print('learing rate {}'.format(lr))
print('trainable V {}'.format(trainable_w2v))
print('lambda_ {}'.format(lambda_))
print('optimized seen only')
print('optimizer: RMSProp with momentum = {} and weight_decay = {}'.format(momentum,weight_decay))
print('-'*30)


Randomize seed 214
------------------------------
Configuration
loss_type CE
no constraint V
normalize F
training to exclude unseen class [seen upperbound]
Init word2vec
Linear model
loss_att BCEWithLogitsLoss()
Bilinear attention module
******************************
Measure w2v deviation
Compute Pruning loss Parameter containing:
tensor(0)
Add one smoothing
Second layer attenion conditioned on image features
------------------------------
No sigmoid on attr score
	 V
	 W_1
	 W_2
	 W_3
------------------------------
learing rate 0.0001
trainable V True
lambda_ 0.1
optimized seen only
optimizer: RMSProp with momentum = 0.9 and weight_decay = 0.0001
------------------------------


In [6]:
best_performance = [0,0,0,0]
for i in range(0,niters):
    model.train()
    optimizer.zero_grad()
    
    batch_label, batch_feature, batch_att = dataloader.next_batch(batch_size)
    out_package = model(batch_feature)
    
    in_package = out_package
    in_package['batch_label'] = batch_label
    
    out_package=model.compute_loss(in_package)
    loss,loss_CE,loss_cal = out_package['loss'],out_package['loss_CE'],out_package['loss_cal']
    
    loss.backward()
    optimizer.step()
    if i%report_interval==0:
        print('-'*30)
        acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(dataloader,model,device,bias_seen=-bias,bias_unseen=bias)
        
        if H > best_performance[2]:
            best_performance = [acc_seen, acc_novel, H, acc_zs]
        stats_package = {'iter':i, 'loss':loss.item(), 'loss_CE':loss_CE.item(),
                         'loss_cal': loss_cal.item(),
                         'acc_seen':best_performance[0], 'acc_novel':best_performance[1], 'H':best_performance[2], 'acc_zs':best_performance[3]}
        
        print(stats_package)


------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 0, 'loss': 6.955271244049072, 'loss_CE': 6.874741077423096, 'loss_cal': 0.8053020238876343, 'acc_seen': 0, 'acc_novel': 0, 'H': 0, 'acc_zs': 0}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 206, 'loss': 3.608905076980591, 'loss_CE': 3.5211551189422607, 'loss_cal': 0.8774999976158142, 'acc_seen': 0.08294573426246643, 'acc_novel': 0.5333333611488342, 'H': 0.1435639390547079, 'acc_zs': 0.5562499761581421}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 412, 'loss': 3.182274580001831, 'loss_CE': 3.0780272483825684, 'loss_cal': 1.0424728393554688, 'acc_seen': 0.12441860139369965, 'acc_novel': 0.53125, 'H': 0.20161826218277126, 'acc_zs': 0.5645833015441895}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 618, 'loss': 3.271792411804199, 'loss_CE': 3.158285140991211, 'loss_cal': 1.1350719928741455, 'acc_seen': 0.15155038237571716, 'acc_novel': 0.54861110

------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 5974, 'loss': 1.2781357765197754, 'loss_CE': 1.139394760131836, 'loss_cal': 1.3874105215072632, 'acc_seen': 0.25697675347328186, 'acc_novel': 0.5020833611488342, 'H': 0.33995661117094206, 'acc_zs': 0.6006944179534912}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 6180, 'loss': 1.0732533931732178, 'loss_CE': 0.9266790747642517, 'loss_cal': 1.4657433032989502, 'acc_seen': 0.25697675347328186, 'acc_novel': 0.5020833611488342, 'H': 0.33995661117094206, 'acc_zs': 0.6006944179534912}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 6386, 'loss': 1.1802022457122803, 'loss_CE': 1.0414183139801025, 'loss_cal': 1.3878393173217773, 'acc_seen': 0.25697675347328186, 'acc_novel': 0.5020833611488342, 'H': 0.33995661117094206, 'acc_zs': 0.6006944179534912}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 6592, 'loss': 1.0793715715408325, 'loss_CE': 0.9162287116050

------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 11948, 'loss': 0.5198255777359009, 'loss_CE': 0.3515304625034332, 'loss_cal': 1.6829513311386108, 'acc_seen': 0.25697675347328186, 'acc_novel': 0.5020833611488342, 'H': 0.33995661117094206, 'acc_zs': 0.6006944179534912}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 12154, 'loss': 0.6738852858543396, 'loss_CE': 0.5231189131736755, 'loss_cal': 1.5076634883880615, 'acc_seen': 0.25697675347328186, 'acc_novel': 0.5020833611488342, 'H': 0.33995661117094206, 'acc_zs': 0.6006944179534912}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 12360, 'loss': 0.7248835563659668, 'loss_CE': 0.5698680877685547, 'loss_cal': 1.550154447555542, 'acc_seen': 0.25697675347328186, 'acc_novel': 0.5020833611488342, 'H': 0.33995661117094206, 'acc_zs': 0.6006944179534912}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 12566, 'loss': 0.5048877596855164, 'loss_CE': 0.318034410

{'iter': 17716, 'loss': 0.5162560343742371, 'loss_CE': 0.3317318856716156, 'loss_cal': 1.8452415466308594, 'acc_seen': 0.25697675347328186, 'acc_novel': 0.5020833611488342, 'H': 0.33995661117094206, 'acc_zs': 0.6006944179534912}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 17922, 'loss': 0.49585211277008057, 'loss_CE': 0.32777515053749084, 'loss_cal': 1.6807695627212524, 'acc_seen': 0.25697675347328186, 'acc_novel': 0.5020833611488342, 'H': 0.33995661117094206, 'acc_zs': 0.6006944179534912}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 18128, 'loss': 0.534538984298706, 'loss_CE': 0.36025282740592957, 'loss_cal': 1.7428616285324097, 'acc_seen': 0.25697675347328186, 'acc_novel': 0.5020833611488342, 'H': 0.33995661117094206, 'acc_zs': 0.6006944179534912}
------------------------------
bias_seen -0.0 bias_unseen 0.0
{'iter': 18334, 'loss': 0.5499625205993652, 'loss_CE': 0.3793228268623352, 'loss_cal': 1.7063966989517212, 'acc_seen': 0.2569