In [1]:
import sys
from collections import OrderedDict
from typing import Tuple


import numpy as np; np.random.seed(1)
import torch; torch.manual_seed(1)
from torch import nn, Tensor
import torch.nn.functional as F
from time import time
from tqdm import tqdm
from scipy import io
from torch.utils.data import DataLoader, TensorDataset


# Golbal Variables

In [2]:
DATASET = 'AWA2' 
DATA_DIR = f'../../Datasets/{DATASET}'
DEVICE = 'cuda:0'
USE_ATTRIBUTE_REFINEMENT = True
USE_CIRCLE_LOSS = True

# Attribute Refinement Network

In [3]:
def exist(x):
    return x is not None

class Residual(nn.Module):
    def __init__(self,fn):
        super().__init__()
        self.fn=fn
    
    def forward(self,x):
        return self.fn(x)+x

class GatingUnit(nn.Module):
    def __init__(self,dim,len_sen):
        super().__init__()
        self.ln=nn.LayerNorm(dim)
        self.proj=nn.Conv1d(len_sen,len_sen,1)

        nn.init.zeros_(self.proj.weight)
        nn.init.ones_(self.proj.bias)
    
    def forward(self,x):
        res,gate=torch.chunk(x,2,-1)
        gate=self.ln(gate) 
        gate=self.proj(gate.unsqueeze(-1)) 
        gate = gate.squeeze(-1)
        return res*gate

class ARN(nn.Module):
    def __init__(self,num_tokens=None,len_sen=49,dim=512,d_ff=1024,num_layers=6):
        super().__init__()
        self.num_layers=num_layers
        self.embedding=nn.Embedding(num_tokens,dim) if exist(num_tokens) else nn.Identity()

        self.arn=nn.ModuleList([Residual(nn.Sequential(OrderedDict([
            ('ln1_%d'%i,nn.LayerNorm(dim)),
            ('fc1_%d'%i,nn.Linear(dim,d_ff*2)),
            ('gelu_%d'%i,nn.GELU()),
            ('sgu_%d'%i,GatingUnit(d_ff,len_sen)),
            ('fc2_%d'%i,nn.Linear(d_ff,dim)),
        ])))  for i in range(num_layers)])



    def forward(self,x):
        #embedding
        embeded=self.embedding(x)

        #ARN
        y=nn.Sequential(*self.arn)(embeded)
                
        return y

# Circle Loss

In [4]:
class CircleLoss(nn.Module):
    def __init__(self, m: float, gamma: float) -> None:
        super(CircleLoss, self).__init__()
        self.m = m
        self.gamma = gamma
        self.soft_plus = nn.Softplus()

    def forward(self, sp: Tensor, sn: Tensor, logits=None, N=None, C=None) -> Tensor:
        sp_logit = logits[sp].view(N,1)
        sn_logit = logits[sn].view(N,C-1)

        margin = self.m

        ap = torch.clamp_min(- sp_logit.detach() + 1 + margin, min=0.)
        an = torch.clamp_min(sn_logit.detach() + margin, min=0.)

        delta_p = 1 - margin
        delta_n = margin

        logit_p = - ap * (sp_logit - delta_p) * self.gamma
        logit_n = an * (sn_logit - delta_n) * self.gamma

        loss = self.soft_plus(torch.logsumexp(logit_n, dim=1) + torch.logsumexp(logit_p, dim=1))
        loss = torch.ones(loss.size(0)).to(DEVICE).view(-1, 1) * loss
        return loss

        

# Evaluation Function

In [5]:
def test():
    model.eval()
    
    reg = 0.95
    task_seen_class = seen_classes
    task_unseen_class = unseen_classes
    
    seen_unseen=np.hstack([seen_classes, unseen_classes])
        
        
    seen_dict = {}
    seen_dict_no = {}
    seen_acc_list = []    
    
    for jj in range(len(task_seen_class)):
        mapped_class = task_seen_class[jj]
        seen_dict[mapped_class] = 0
        seen_dict_no[mapped_class] = 0
        
        
    
    unseen_dict = {}
    unseen_dict_no = {}
    unseen_acc_list = []    
        
    for jj in range(len(task_unseen_class)):
        mapped_class = task_unseen_class[jj]
        unseen_dict[mapped_class] = 0
        unseen_dict_no[mapped_class] = 0
        
    
        
    with torch.no_grad():
        # Compute Seen Accuracy
        correct=0
        Nsamples=0
        feats, targets = torch.clone(test_seen_feat), torch.clone(test_seen_labels)
        feats = feats.to(DEVICE)
        targets = targets.to(DEVICE)
        logits = model(feats, attrs.to(DEVICE))
        logits[:, seen_mask] *= reg
        pred = logits.max(1, keepdim=True)[1]
        
        for jj in range(targets.shape[0]):
            if pred[jj] == targets[jj]:
               seen_dict[targets[jj].cpu().numpy().item()] += 1
            
            seen_dict_no[targets[jj].cpu().numpy().item()] += 1           
        
        
        for jj in range(len(task_seen_class)):
            mapped_class = task_seen_class[jj]
            seen_acc_list.append(seen_dict[mapped_class]/(seen_dict_no[mapped_class] * 1.0))
        
        seen_acc = np.mean(np.array(seen_acc_list))
        
                
        # Compute Unseen Accuracy
        correct=0
        Nsamples=0
        feats, targets = torch.clone(test_unseen_feat), torch.clone(test_unseen_labels)
        feats = feats.to(DEVICE)
        targets = targets.to(DEVICE)
        logits = model(feats, attrs.to(DEVICE))
        logits[:, seen_mask] *= reg
        pred = logits.max(1, keepdim=True)[1]
        
        for jj in range(targets.shape[0]):
            if pred[jj] == targets[jj]:
               unseen_dict[targets[jj].cpu().numpy().item()] += 1
            
            unseen_dict_no[targets[jj].cpu().numpy().item()] += 1           
        
        for jj in range(len(task_unseen_class)):
            mapped_class = task_unseen_class[jj]
            unseen_acc_list.append(unseen_dict[mapped_class]/(unseen_dict_no[mapped_class] * 1.0))
        
        unseen_acc = np.mean(np.array(unseen_acc_list))        
        
        # Compute Harmonic Mean
        h_mean_class_wise = (2*seen_acc*unseen_acc)/(seen_acc+unseen_acc)     
        
        return seen_acc, unseen_acc,h_mean_class_wise

# Initialize the Attribute Refinement Network

In [6]:
arn = ARN(len_sen=256,dim=85,d_ff=256, num_layers=2)
arn = arn.to(DEVICE)

np.random.seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x7fcea86f2350>

# Prepare the Dataset

In [7]:
print(f'<=============== Loading data for {DATASET} ===============>')
data = io.loadmat(f'{DATA_DIR}/res101.mat')
attrs_mat = io.loadmat(f'{DATA_DIR}/att_splits.mat')
feats = data['features'].T.astype(np.float32)
labels = data['labels'].squeeze() - 1 # Using "-1" here and for idx to normalize to 0-index
train_idx = attrs_mat['trainval_loc'].squeeze() - 1
test_seen_idx = attrs_mat['test_seen_loc'].squeeze() - 1
test_unseen_idx = attrs_mat['test_unseen_loc'].squeeze() - 1
test_idx = np.array(test_seen_idx.tolist() + test_unseen_idx.tolist())
seen_classes = sorted(np.unique(labels[test_seen_idx]))
unseen_classes = sorted(np.unique(labels[test_unseen_idx]))



In [8]:
print(f'<=============== Preprocessing ===============>')
num_classes = len(seen_classes) + len(unseen_classes)
seen_mask = np.array([(c in seen_classes) for c in range(num_classes)])
unseen_mask = np.array([(c in unseen_classes) for c in range(num_classes)])
attrs = attrs_mat['att'].T
attrs = torch.from_numpy(attrs).to(DEVICE).float()
attrs = attrs / attrs.norm(dim=1, keepdim=True) * np.sqrt(attrs.shape[1])
attrs_seen = attrs[seen_mask]
attrs_unseen = attrs[unseen_mask]
train_labels = labels[train_idx]
labels_remapped_to_seen = [(seen_classes.index(t) if t in seen_classes else -1) for t in labels]
ds_train = [(feats[i], labels_remapped_to_seen[i]) for i in train_idx]



# Train Dataset

In [9]:
total_feat = [ds_train[ii][0] for ii in range(len(ds_train))]
total_feat = np.array(total_feat)

total_labels = [ds_train[ii][1] for ii in range(len(ds_train))]
total_labels = np.array(total_labels)

print("the total_feat shape is...", np.array(total_feat).shape)
print("the total_labels shape is...", np.array(total_labels).shape)  

train_dataset = TensorDataset(torch.from_numpy(total_feat), 
                              torch.from_numpy(total_labels).long())

the total_feat shape is... (23527, 2048)
the total_labels shape is... (23527,)


# Test Dataset

In [10]:
test_seen_idx = attrs_mat['test_seen_loc'].squeeze() - 1
test_unseen_idx = attrs_mat['test_unseen_loc'].squeeze() - 1

test_seen_feat = torch.from_numpy(feats[test_seen_idx])
test_seen_labels = torch.from_numpy(labels[test_seen_idx]).long()

test_unseen_feat = torch.from_numpy(feats[test_unseen_idx])
test_unseen_labels = torch.from_numpy(labels[test_unseen_idx]).long()


In [11]:
class ClassNormalization(nn.Module):
    
    def __init__(self, feat_dim: int):
        super().__init__()
        
        self.running_mean = nn.Parameter(torch.zeros(feat_dim), requires_grad=False)
        self.running_var = nn.Parameter(torch.ones(feat_dim), requires_grad=False)
    
    def forward(self, class_feats):
        
        if self.training:
            batch_mean = class_feats.mean(dim=0)
            batch_var = class_feats.var(dim=0)
            
            # Normalizing the batch
            result = (class_feats - batch_mean.unsqueeze(0)) / (batch_var.unsqueeze(0) + 1e-5)
            
            # Updating the running mean/std
            self.running_mean.data = 0.9 * self.running_mean.data + 0.1 * batch_mean.detach()
            self.running_var.data = 0.9 * self.running_var.data + 0.1 * batch_var.detach()
        else:
            # Using accumulated statistics
            # Attention! For the test inference, we cant use batch-wise statistics,
            # only the accumulated ones. Otherwise, it will be quite transductive
            result = (class_feats - self.running_mean.unsqueeze(0)) / (self.running_var.unsqueeze(0) + 1e-5)
        
        return result


# ZSL Model

In [12]:
class ZSLModel(nn.Module):
    def __init__(self, attr_dim: int, hid_dim: int, proto_dim: int):
        super().__init__()
        
        # Attribute to Feature Embedding
        self.model = nn.Sequential(
            nn.Linear(attr_dim, hid_dim),
            nn.ReLU(),
            
            nn.Linear(hid_dim, hid_dim),
            ClassNormalization(hid_dim),
            nn.ReLU(),
            
            ClassNormalization(hid_dim),
            nn.Linear(hid_dim, proto_dim),
            nn.ReLU(),
        )
        
        # Weight Initialization
        weight_var = 1 / (hid_dim * proto_dim)
        b = np.sqrt(3 * weight_var)
        self.model[-2].weight.data.uniform_(-b, b)
        
    def forward(self, x, attrs, train=False):
        attrs = arn(attrs) if USE_ATTRIBUTE_REFINEMENT else attrs
        protos = self.model(attrs)
        
        # Feature-Prototype Combiner 
        x_ns = 5 * x / x.norm(dim=1, keepdim=True)
        protos_ns = 5 * protos / protos.norm(dim=1, keepdim=True)
        logits = x_ns @ protos_ns.t()
        
        x_nss = 1.0 * x / x.norm(dim=1, keepdim=True)
        protos_nss = 1.0 * protos / protos.norm(dim=1, keepdim=True)
        logits_nn = x_nss @ protos_nss.t()
        
        if train:
            return logits, protos, logits_nn
        else:
            return logits

In [13]:
print(f'\n<=============== Starting training ===============>')
start_time = time()
model = ZSLModel(attrs.shape[1], 1024, feats.shape[1]).to(DEVICE)

params = set(list(model.parameters()) + list(arn.parameters()))

optim = torch.optim.Adam(params, lr=0.0005, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optim, gamma=0.1, step_size=25)





# Training

In [14]:
mse = nn.MSELoss()                              
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
hm_best, seen_best, unseen_best = 0, 0, 0
for epoch in tqdm(range(50)):
    model.train()
    
    for i, (feats, targets) in enumerate(train_dataloader):
        feats = feats.to(DEVICE)
        targets = targets.to(DEVICE)
       
        
        attrs_s = attrs[seen_mask]
        
                
        N, D = feats.shape
        C, K = attrs_s.cpu().numpy().shape
                
        logits, proto_s, logits_nn = model(feats, attrs[seen_mask], train=True)
        
                
        loss_circle = CircleLoss(m=0.7, gamma=0.3)
        
        targets_one_hot = F.one_hot(targets, 40)
        
        sp_bool = targets_one_hot[:] == 1
        sn_bool = targets_one_hot[:] != 1
        
                
        loss = F.cross_entropy(logits, targets)
        
        loss += torch.mean(loss_circle(sp_bool, sn_bool, logits=logits_nn, N=N, C=C))*0.8 if USE_CIRCLE_LOSS else 0
                        
        optim.zero_grad()
        loss.backward()
        optim.step()
    
    with torch.no_grad():
        test_seen_acc, test_unseen_acc, test_h_mean_class_wise = test()
        if hm_best < test_h_mean_class_wise:
            hm_best, seen_best, unseen_best = test_h_mean_class_wise, test_seen_acc, test_unseen_acc
        
        print("test result (u) : ", unseen_best)
        print("test result (s) : ", seen_best)
        print("test result (h) : ", hm_best)
    
    scheduler.step()

print(f'Training is done! Took time: {(time() - start_time): .1f} seconds')


  2%|▉                                           | 1/50 [00:03<02:46,  3.40s/it]

test result (u) :  0.4457662843003991
test result (s) :  0.8072499262167945
test result (h) :  0.5743657537565339


  4%|█▊                                          | 2/50 [00:06<02:21,  2.96s/it]

test result (u) :  0.5075924460050565
test result (s) :  0.7940384372493746
test result (h) :  0.619296788007539


  6%|██▋                                         | 3/50 [00:08<02:12,  2.82s/it]

test result (u) :  0.5326411189943291
test result (s) :  0.7937084644595249
test result (h) :  0.6374816563279007


  8%|███▌                                        | 4/50 [00:11<02:06,  2.75s/it]

test result (u) :  0.5315517726408716
test result (s) :  0.8020152572651821
test result (h) :  0.6393568858917648


 10%|████▍                                       | 5/50 [00:14<02:02,  2.72s/it]

test result (u) :  0.5538790205036689
test result (s) :  0.7933785758557044
test result (h) :  0.6523410959730619


 12%|█████▎                                      | 6/50 [00:16<01:59,  2.72s/it]

test result (u) :  0.5836305150114282
test result (s) :  0.799867910765464
test result (h) :  0.674850526756501


 14%|██████▏                                     | 7/50 [00:19<01:56,  2.70s/it]

test result (u) :  0.5836305150114282
test result (s) :  0.799867910765464
test result (h) :  0.674850526756501


 16%|███████                                     | 8/50 [00:22<01:52,  2.69s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 18%|███████▉                                    | 9/50 [00:24<01:49,  2.67s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 20%|████████▌                                  | 10/50 [00:27<01:46,  2.66s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 22%|█████████▍                                 | 11/50 [00:29<01:43,  2.66s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 24%|██████████▎                                | 12/50 [00:32<01:41,  2.67s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 26%|███████████▏                               | 13/50 [00:35<01:43,  2.78s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 28%|████████████                               | 14/50 [00:38<01:41,  2.81s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 30%|████████████▉                              | 15/50 [00:41<01:38,  2.82s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 32%|█████████████▊                             | 16/50 [00:44<01:36,  2.83s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 34%|██████████████▌                            | 17/50 [00:47<01:33,  2.83s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 36%|███████████████▍                           | 18/50 [00:50<01:31,  2.87s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 38%|████████████████▎                          | 19/50 [00:53<01:30,  2.91s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 40%|█████████████████▏                         | 20/50 [00:56<01:27,  2.93s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 42%|██████████████████                         | 21/50 [00:57<01:15,  2.62s/it]

test result (u) :  0.6302360289515303
test result (s) :  0.792613553820883
test result (h) :  0.7021594196624711


 44%|██████████████████▉                        | 22/50 [00:59<01:06,  2.38s/it]

test result (u) :  0.6389622952386189
test result (s) :  0.7968852300869826
test result (h) :  0.7092391172142999


 46%|███████████████████▊                       | 23/50 [01:02<01:09,  2.58s/it]

test result (u) :  0.6496303428676387
test result (s) :  0.7978202476340237
test result (h) :  0.7161394584634563


 48%|████████████████████▋                      | 24/50 [01:05<01:10,  2.71s/it]

test result (u) :  0.6496303428676387
test result (s) :  0.7978202476340237
test result (h) :  0.7161394584634563


 50%|█████████████████████▌                     | 25/50 [01:08<01:10,  2.81s/it]

test result (u) :  0.6496303428676387
test result (s) :  0.7978202476340237
test result (h) :  0.7161394584634563


 52%|██████████████████████▎                    | 26/50 [01:10<01:00,  2.53s/it]

test result (u) :  0.6508963252312198
test result (s) :  0.7990342438845758
test result (h) :  0.7173977349764268


 54%|███████████████████████▏                   | 27/50 [01:13<01:00,  2.64s/it]

test result (u) :  0.6531574392393968
test result (s) :  0.7982444602891716
test result (h) :  0.718449256183087


 56%|████████████████████████                   | 28/50 [01:16<00:59,  2.71s/it]

test result (u) :  0.6531574392393968
test result (s) :  0.7982444602891716
test result (h) :  0.718449256183087


 58%|████████████████████████▉                  | 29/50 [01:19<00:57,  2.72s/it]

test result (u) :  0.6524376625855501
test result (s) :  0.8002541807951278
test result (h) :  0.7188254956773608


 60%|█████████████████████████▊                 | 30/50 [01:21<00:53,  2.68s/it]

test result (u) :  0.6524376625855501
test result (s) :  0.8002541807951278
test result (h) :  0.7188254956773608


 62%|██████████████████████████▋                | 31/50 [01:23<00:46,  2.46s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 64%|███████████████████████████▌               | 32/50 [01:25<00:41,  2.30s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 66%|████████████████████████████▍              | 33/50 [01:28<00:41,  2.41s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 68%|█████████████████████████████▏             | 34/50 [01:31<00:40,  2.56s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 70%|██████████████████████████████             | 35/50 [01:33<00:38,  2.59s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 72%|██████████████████████████████▉            | 36/50 [01:36<00:36,  2.62s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 74%|███████████████████████████████▊           | 37/50 [01:39<00:34,  2.66s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 76%|████████████████████████████████▋          | 38/50 [01:42<00:31,  2.67s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 78%|█████████████████████████████████▌         | 39/50 [01:44<00:29,  2.69s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 80%|██████████████████████████████████▍        | 40/50 [01:47<00:26,  2.69s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 82%|███████████████████████████████████▎       | 41/50 [01:50<00:24,  2.67s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 84%|████████████████████████████████████       | 42/50 [01:52<00:21,  2.70s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 86%|████████████████████████████████████▉      | 43/50 [01:55<00:19,  2.74s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 88%|█████████████████████████████████████▊     | 44/50 [01:58<00:16,  2.77s/it]

test result (u) :  0.654716640921807
test result (s) :  0.7999974218331264
test result (h) :  0.7201025111103597


 90%|██████████████████████████████████████▋    | 45/50 [02:01<00:14,  2.82s/it]

test result (u) :  0.6569628566913119
test result (s) :  0.796845988405741
test result (h) :  0.7201747583962458


 92%|███████████████████████████████████████▌   | 46/50 [02:04<00:11,  2.77s/it]

test result (u) :  0.6569628566913119
test result (s) :  0.796845988405741
test result (h) :  0.7201747583962458


 94%|████████████████████████████████████████▍  | 47/50 [02:07<00:08,  2.86s/it]

test result (u) :  0.6569628566913119
test result (s) :  0.796845988405741
test result (h) :  0.7201747583962458


 96%|█████████████████████████████████████████▎ | 48/50 [02:09<00:05,  2.80s/it]

test result (u) :  0.6569628566913119
test result (s) :  0.796845988405741
test result (h) :  0.7201747583962458


 98%|██████████████████████████████████████████▏| 49/50 [02:12<00:02,  2.82s/it]

test result (u) :  0.6569628566913119
test result (s) :  0.796845988405741
test result (h) :  0.7201747583962458


100%|███████████████████████████████████████████| 50/50 [02:15<00:00,  2.71s/it]

test result (u) :  0.6569628566913119
test result (s) :  0.796845988405741
test result (h) :  0.7201747583962458
Training is done! Took time:  135.6 seconds



