In [1]:
import sys
from collections import OrderedDict
from typing import Tuple


import numpy as np; np.random.seed(1)
import torch; torch.manual_seed(1)
from torch import nn, Tensor
import torch.nn.functional as F
from time import time
from tqdm import tqdm
from scipy import io
from torch.utils.data import DataLoader, TensorDataset


# Golbal Variables

In [2]:
DATASET = 'AWA2' 
DATA_DIR = f'../../Datasets/{DATASET}'
DEVICE = 'cuda:0'
USE_ATTRIBUTE_REFINEMENT = True
USE_CIRCLE_LOSS = True

# Attribute Refinement Network

In [3]:
def exist(x):
    return x is not None

class Residual(nn.Module):
    def __init__(self,fn):
        super().__init__()
        self.fn=fn
    
    def forward(self,x):
        return self.fn(x)+x

class GatingUnit(nn.Module):
    def __init__(self,dim,len_sen):
        super().__init__()
        self.ln=nn.LayerNorm(dim)
        self.proj=nn.Conv1d(len_sen,len_sen,1)

        nn.init.zeros_(self.proj.weight)
        nn.init.ones_(self.proj.bias)
    
    def forward(self,x):
        res,gate=torch.chunk(x,2,-1)
        gate=self.ln(gate) 
        gate=self.proj(gate.unsqueeze(-1)) 
        gate = gate.squeeze(-1)
        return res*gate

class ARN(nn.Module):
    def __init__(self,num_tokens=None,len_sen=49,dim=512,d_ff=1024,num_layers=6):
        super().__init__()
        self.num_layers=num_layers
        self.embedding=nn.Embedding(num_tokens,dim) if exist(num_tokens) else nn.Identity()

        self.arn=nn.ModuleList([Residual(nn.Sequential(OrderedDict([
            ('ln1_%d'%i,nn.LayerNorm(dim)),
            ('fc1_%d'%i,nn.Linear(dim,d_ff*2)),
            ('gelu_%d'%i,nn.GELU()),
            ('sgu_%d'%i,GatingUnit(d_ff,len_sen)),
            ('fc2_%d'%i,nn.Linear(d_ff,dim)),
        ])))  for i in range(num_layers)])



    def forward(self,x):
        #embedding
        embeded=self.embedding(x)

        #ARN
        y=nn.Sequential(*self.arn)(embeded)
                
        return y

# Circle Loss

In [4]:
class CircleLoss(nn.Module):
    def __init__(self, m: float, gamma: float) -> None:
        super(CircleLoss, self).__init__()
        self.m = m
        self.gamma = gamma
        self.soft_plus = nn.Softplus()

    def forward(self, sp: Tensor, sn: Tensor, logits=None, N=None, C=None) -> Tensor:
        sp_logit = logits[sp].view(N,1)
        sn_logit = logits[sn].view(N,C-1)

        margin = self.m

        ap = torch.clamp_min(- sp_logit.detach() + 1 + margin, min=0.)
        an = torch.clamp_min(sn_logit.detach() + margin, min=0.)

        delta_p = 1 - margin
        delta_n = margin

        logit_p = - ap * (sp_logit - delta_p) * self.gamma
        logit_n = an * (sn_logit - delta_n) * self.gamma

        loss = self.soft_plus(torch.logsumexp(logit_n, dim=1) + torch.logsumexp(logit_p, dim=1))
        loss = torch.ones(loss.size(0)).to(DEVICE).view(-1, 1) * loss
        return loss

        

# Evaluation Function

In [5]:
def test():
    model.eval()
    
    reg = 0.95
    task_seen_class = seen_classes
    task_unseen_class = unseen_classes
    
    seen_unseen=np.hstack([seen_classes, unseen_classes])
        
        
    seen_dict = {}
    seen_dict_no = {}
    seen_acc_list = []    
    
    for jj in range(len(task_seen_class)):
        mapped_class = task_seen_class[jj]
        seen_dict[mapped_class] = 0
        seen_dict_no[mapped_class] = 0
        
        
    
    unseen_dict = {}
    unseen_dict_no = {}
    unseen_acc_list = []    
        
    for jj in range(len(task_unseen_class)):
        mapped_class = task_unseen_class[jj]
        unseen_dict[mapped_class] = 0
        unseen_dict_no[mapped_class] = 0
        
    
        
    with torch.no_grad():
        # Compute Seen Accuracy
        correct=0
        Nsamples=0
        feats, targets = torch.clone(test_seen_feat), torch.clone(test_seen_labels)
        feats = feats.to(DEVICE)
        targets = targets.to(DEVICE)
        logits = model(feats, attrs.to(DEVICE))
        logits[:, seen_mask] *= reg
        pred = logits.max(1, keepdim=True)[1]
        
        for jj in range(targets.shape[0]):
            if pred[jj] == targets[jj]:
               seen_dict[targets[jj].cpu().numpy().item()] += 1
            
            seen_dict_no[targets[jj].cpu().numpy().item()] += 1           
        
        
        for jj in range(len(task_seen_class)):
            mapped_class = task_seen_class[jj]
            seen_acc_list.append(seen_dict[mapped_class]/(seen_dict_no[mapped_class] * 1.0))
        
        seen_acc = np.mean(np.array(seen_acc_list))
        
                
        # Compute Unseen Accuracy
        correct=0
        Nsamples=0
        feats, targets = torch.clone(test_unseen_feat), torch.clone(test_unseen_labels)
        feats = feats.to(DEVICE)
        targets = targets.to(DEVICE)
        logits = model(feats, attrs.to(DEVICE))
        logits[:, seen_mask] *= reg
        pred = logits.max(1, keepdim=True)[1]
        
        for jj in range(targets.shape[0]):
            if pred[jj] == targets[jj]:
               unseen_dict[targets[jj].cpu().numpy().item()] += 1
            
            unseen_dict_no[targets[jj].cpu().numpy().item()] += 1           
        
        for jj in range(len(task_unseen_class)):
            mapped_class = task_unseen_class[jj]
            unseen_acc_list.append(unseen_dict[mapped_class]/(unseen_dict_no[mapped_class] * 1.0))
        
        unseen_acc = np.mean(np.array(unseen_acc_list))        
        
        # Compute Harmonic Mean
        h_mean_class_wise = (2*seen_acc*unseen_acc)/(seen_acc+unseen_acc)     
        
        return seen_acc, unseen_acc,h_mean_class_wise

# Initialize the Attribute Refinement Network

In [6]:
arn = ARN(len_sen=256,dim=85,d_ff=256, num_layers=2)
arn = arn.to(DEVICE)

np.random.seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x7feb706e4350>

# Prepare the Dataset

In [7]:
print(f'<=============== Loading data for {DATASET} ===============>')
data = io.loadmat(f'{DATA_DIR}/res101.mat')
attrs_mat = io.loadmat(f'{DATA_DIR}/att_splits.mat')
feats = data['features'].T.astype(np.float32)
labels = data['labels'].squeeze() - 1 # Using "-1" here and for idx to normalize to 0-index
train_idx = attrs_mat['trainval_loc'].squeeze() - 1
test_seen_idx = attrs_mat['test_seen_loc'].squeeze() - 1
test_unseen_idx = attrs_mat['test_unseen_loc'].squeeze() - 1
test_idx = np.array(test_seen_idx.tolist() + test_unseen_idx.tolist())
seen_classes = sorted(np.unique(labels[test_seen_idx]))
unseen_classes = sorted(np.unique(labels[test_unseen_idx]))



In [8]:
print(f'<=============== Preprocessing ===============>')
num_classes = len(seen_classes) + len(unseen_classes)
seen_mask = np.array([(c in seen_classes) for c in range(num_classes)])
unseen_mask = np.array([(c in unseen_classes) for c in range(num_classes)])
attrs = attrs_mat['att'].T
attrs = torch.from_numpy(attrs).to(DEVICE).float()
attrs = attrs / attrs.norm(dim=1, keepdim=True) * np.sqrt(attrs.shape[1])
attrs_seen = attrs[seen_mask]
attrs_unseen = attrs[unseen_mask]
train_labels = labels[train_idx]
labels_remapped_to_seen = [(seen_classes.index(t) if t in seen_classes else -1) for t in labels]
ds_train = [(feats[i], labels_remapped_to_seen[i]) for i in train_idx]



# Train Dataset

In [9]:
total_feat = [ds_train[ii][0] for ii in range(len(ds_train))]
total_feat = np.array(total_feat)

total_labels = [ds_train[ii][1] for ii in range(len(ds_train))]
total_labels = np.array(total_labels)

print("the total_feat shape is...", np.array(total_feat).shape)
print("the total_labels shape is...", np.array(total_labels).shape)  

train_dataset = TensorDataset(torch.from_numpy(total_feat), 
                              torch.from_numpy(total_labels).long())

the total_feat shape is... (23527, 2048)
the total_labels shape is... (23527,)


# Test Dataset

In [10]:
test_seen_idx = attrs_mat['test_seen_loc'].squeeze() - 1
test_unseen_idx = attrs_mat['test_unseen_loc'].squeeze() - 1

test_seen_feat = torch.from_numpy(feats[test_seen_idx])
test_seen_labels = torch.from_numpy(labels[test_seen_idx]).long()

test_unseen_feat = torch.from_numpy(feats[test_unseen_idx])
test_unseen_labels = torch.from_numpy(labels[test_unseen_idx]).long()


In [11]:
class ClassNormalization(nn.Module):
    
    def __init__(self, feat_dim: int):
        super().__init__()
        
        self.running_mean = nn.Parameter(torch.zeros(feat_dim), requires_grad=False)
        self.running_var = nn.Parameter(torch.ones(feat_dim), requires_grad=False)
    
    def forward(self, class_feats):
        
        if self.training:
            batch_mean = class_feats.mean(dim=0)
            batch_var = class_feats.var(dim=0)
            
            # Normalizing the batch
            result = (class_feats - batch_mean.unsqueeze(0)) / (batch_var.unsqueeze(0) + 1e-5)
            
            # Updating the running mean/std
            self.running_mean.data = 0.9 * self.running_mean.data + 0.1 * batch_mean.detach()
            self.running_var.data = 0.9 * self.running_var.data + 0.1 * batch_var.detach()
        else:
            # Using accumulated statistics
            # Attention! For the test inference, we cant use batch-wise statistics,
            # only the accumulated ones. Otherwise, it will be quite transductive
            result = (class_feats - self.running_mean.unsqueeze(0)) / (self.running_var.unsqueeze(0) + 1e-5)
        
        return result


# ZSL Model

In [12]:
class ZSLModel(nn.Module):
    def __init__(self, attr_dim: int, hid_dim: int, proto_dim: int):
        super().__init__()
        
        # Attribute to Feature Embedding
        self.model = nn.Sequential(
            nn.Linear(attr_dim, hid_dim),
            nn.ReLU(),
            
            nn.Linear(hid_dim, hid_dim),
            ClassNormalization(hid_dim),
            nn.ReLU(),
            
            ClassNormalization(hid_dim),
            nn.Linear(hid_dim, proto_dim),
            nn.ReLU(),
        )
        
        # Weight Initialization
        weight_var = 1 / (hid_dim * proto_dim)
        b = np.sqrt(3 * weight_var)
        self.model[-2].weight.data.uniform_(-b, b)
        
    def forward(self, x, attrs, train=False):
        attrs = arn(attrs) if USE_ATTRIBUTE_REFINEMENT else attrs
        protos = self.model(attrs)
        
        # Feature-Prototype Combiner 
        x_ns = 5 * x / x.norm(dim=1, keepdim=True)
        protos_ns = 5 * protos / protos.norm(dim=1, keepdim=True)
        logits = x_ns @ protos_ns.t()
        
        x_nss = 1.0 * x / x.norm(dim=1, keepdim=True)
        protos_nss = 1.0 * protos / protos.norm(dim=1, keepdim=True)
        logits_nn = x_nss @ protos_nss.t()
        
        if train:
            return logits, protos, logits_nn
        else:
            return logits

In [13]:
print(f'\n<=============== Starting training ===============>')
start_time = time()
model = ZSLModel(attrs.shape[1], 1024, feats.shape[1]).to(DEVICE)

params = set(list(model.parameters()) + list(arn.parameters()))

optim = torch.optim.Adam(params, lr=0.0005, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optim, gamma=0.1, step_size=25)





# Training

In [14]:
mse = nn.MSELoss()                              
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
hm_best, seen_best, unseen_best = 0, 0, 0
for epoch in tqdm(range(50)):
    model.train()
    
    for i, (feats, targets) in enumerate(train_dataloader):
        feats = feats.to(DEVICE)
        targets = targets.to(DEVICE)
       
        
        attrs_s = attrs[seen_mask]
        
                
        N, D = feats.shape
        C, K = attrs_s.cpu().numpy().shape
                
        logits, proto_s, logits_nn = model(feats, attrs[seen_mask], train=True)
        
                
        loss_circle = CircleLoss(m=0.4, gamma=0.5)
        
        targets_one_hot = F.one_hot(targets, 40)
        
        sp_bool = targets_one_hot[:] == 1
        sn_bool = targets_one_hot[:] != 1
        
                
        loss = F.cross_entropy(logits, targets)
        
        loss += torch.mean(loss_circle(sp_bool, sn_bool, logits=logits_nn, N=N, C=C))*0.5 if USE_CIRCLE_LOSS else 0
                        
        optim.zero_grad()
        loss.backward()
        optim.step()
    
    with torch.no_grad():
        test_seen_acc, test_unseen_acc, test_h_mean_class_wise = test()
        if hm_best < test_h_mean_class_wise:
            hm_best, seen_best, unseen_best = test_h_mean_class_wise, test_seen_acc, test_unseen_acc
        
        print("test result (u) : ", unseen_best)
        print("test result (s) : ", seen_best)
        print("test result (h) : ", hm_best)
    
    scheduler.step()

print(f'Training is done! Took time: {(time() - start_time): .1f} seconds')


  2%|▉                                           | 1/50 [00:03<02:33,  3.13s/it]

test result (u) :  0.4273653498842642
test result (s) :  0.8066752482271984
test result (h) :  0.5587256209061183


  4%|█▊                                          | 2/50 [00:05<02:09,  2.70s/it]

test result (u) :  0.5035173960087662
test result (s) :  0.7954187128560287
test result (h) :  0.6166695287021242


  6%|██▋                                         | 3/50 [00:07<02:00,  2.56s/it]

test result (u) :  0.5440895545643316
test result (s) :  0.796268863083431
test result (h) :  0.646456299038014


  8%|███▌                                        | 4/50 [00:10<01:55,  2.50s/it]

test result (u) :  0.5440895545643316
test result (s) :  0.796268863083431
test result (h) :  0.646456299038014


 10%|████▍                                       | 5/50 [00:12<01:51,  2.47s/it]

test result (u) :  0.5440895545643316
test result (s) :  0.796268863083431
test result (h) :  0.646456299038014


 12%|█████▎                                      | 6/50 [00:15<01:48,  2.47s/it]

test result (u) :  0.5616141956248955
test result (s) :  0.79337735571692
test result (h) :  0.6576749279605846


 14%|██████▏                                     | 7/50 [00:17<01:45,  2.45s/it]

test result (u) :  0.5880174863594708
test result (s) :  0.7812953776627969
test result (h) :  0.6710158885501863


 16%|███████                                     | 8/50 [00:20<01:42,  2.44s/it]

test result (u) :  0.6236923740969492
test result (s) :  0.7878867636397195
test result (h) :  0.6962400520057692


 18%|███████▉                                    | 9/50 [00:22<01:40,  2.44s/it]

test result (u) :  0.6236923740969492
test result (s) :  0.7878867636397195
test result (h) :  0.6962400520057692


 20%|████████▌                                  | 10/50 [00:24<01:37,  2.43s/it]

test result (u) :  0.6236923740969492
test result (s) :  0.7878867636397195
test result (h) :  0.6962400520057692


 22%|█████████▍                                 | 11/50 [00:27<01:35,  2.44s/it]

test result (u) :  0.6236923740969492
test result (s) :  0.7878867636397195
test result (h) :  0.6962400520057692


 24%|██████████▎                                | 12/50 [00:29<01:32,  2.43s/it]

test result (u) :  0.6236923740969492
test result (s) :  0.7878867636397195
test result (h) :  0.6962400520057692


 26%|███████████▏                               | 13/50 [00:32<01:30,  2.45s/it]

test result (u) :  0.6236923740969492
test result (s) :  0.7878867636397195
test result (h) :  0.6962400520057692


 28%|████████████                               | 14/50 [00:34<01:27,  2.44s/it]

test result (u) :  0.6365513174336928
test result (s) :  0.78237314079781
test result (h) :  0.701969228327063


 30%|████████████▉                              | 15/50 [00:37<01:25,  2.45s/it]

test result (u) :  0.6365513174336928
test result (s) :  0.78237314079781
test result (h) :  0.701969228327063


 32%|█████████████▊                             | 16/50 [00:39<01:23,  2.44s/it]

test result (u) :  0.6365513174336928
test result (s) :  0.78237314079781
test result (h) :  0.701969228327063


 34%|██████████████▌                            | 17/50 [00:42<01:20,  2.45s/it]

test result (u) :  0.6365513174336928
test result (s) :  0.78237314079781
test result (h) :  0.701969228327063


 36%|███████████████▍                           | 18/50 [00:44<01:18,  2.46s/it]

test result (u) :  0.6365513174336928
test result (s) :  0.78237314079781
test result (h) :  0.701969228327063


 38%|████████████████▎                          | 19/50 [00:46<01:15,  2.44s/it]

test result (u) :  0.6365513174336928
test result (s) :  0.78237314079781
test result (h) :  0.701969228327063


 40%|█████████████████▏                         | 20/50 [00:49<01:14,  2.48s/it]

test result (u) :  0.6365513174336928
test result (s) :  0.78237314079781
test result (h) :  0.701969228327063


 42%|██████████████████                         | 21/50 [00:52<01:14,  2.58s/it]

test result (u) :  0.6365513174336928
test result (s) :  0.78237314079781
test result (h) :  0.701969228327063


 44%|██████████████████▉                        | 22/50 [00:54<01:12,  2.60s/it]

test result (u) :  0.6432803025807803
test result (s) :  0.8003648049762301
test result (h) :  0.7132762910011585


 46%|███████████████████▊                       | 23/50 [00:57<01:09,  2.59s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 48%|████████████████████▋                      | 24/50 [00:59<01:06,  2.54s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 50%|█████████████████████▌                     | 25/50 [01:02<01:03,  2.54s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 52%|██████████████████████▎                    | 26/50 [01:05<01:01,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 54%|███████████████████████▏                   | 27/50 [01:07<00:59,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 56%|████████████████████████                   | 28/50 [01:10<00:56,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 58%|████████████████████████▉                  | 29/50 [01:12<00:53,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 60%|█████████████████████████▊                 | 30/50 [01:15<00:51,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 62%|██████████████████████████▋                | 31/50 [01:18<00:48,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 64%|███████████████████████████▌               | 32/50 [01:20<00:46,  2.61s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 66%|████████████████████████████▍              | 33/50 [01:23<00:44,  2.59s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 68%|█████████████████████████████▏             | 34/50 [01:25<00:41,  2.58s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 70%|██████████████████████████████             | 35/50 [01:28<00:38,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 72%|██████████████████████████████▉            | 36/50 [01:30<00:35,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 74%|███████████████████████████████▊           | 37/50 [01:33<00:33,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 76%|████████████████████████████████▋          | 38/50 [01:36<00:30,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 78%|█████████████████████████████████▌         | 39/50 [01:38<00:28,  2.58s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 80%|██████████████████████████████████▍        | 40/50 [01:41<00:25,  2.58s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 82%|███████████████████████████████████▎       | 41/50 [01:43<00:23,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 84%|████████████████████████████████████       | 42/50 [01:46<00:20,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 86%|████████████████████████████████████▉      | 43/50 [01:48<00:17,  2.57s/it]

test result (u) :  0.6517592557628448
test result (s) :  0.7947091670838946
test result (h) :  0.7161705670243862


 88%|█████████████████████████████████████▊     | 44/50 [01:51<00:15,  2.56s/it]

test result (u) :  0.6522038739881286
test result (s) :  0.794426106839085
test result (h) :  0.71632385799371


 90%|██████████████████████████████████████▋    | 45/50 [01:54<00:12,  2.58s/it]

test result (u) :  0.6541470633487174
test result (s) :  0.792555156847705
test result (h) :  0.7167302588689155


 92%|███████████████████████████████████████▌   | 46/50 [01:56<00:10,  2.57s/it]

test result (u) :  0.6541470633487174
test result (s) :  0.792555156847705
test result (h) :  0.7167302588689155


 94%|████████████████████████████████████████▍  | 47/50 [01:59<00:07,  2.57s/it]

test result (u) :  0.6541470633487174
test result (s) :  0.792555156847705
test result (h) :  0.7167302588689155


 96%|█████████████████████████████████████████▎ | 48/50 [02:01<00:05,  2.57s/it]

test result (u) :  0.6541470633487174
test result (s) :  0.792555156847705
test result (h) :  0.7167302588689155


 98%|██████████████████████████████████████████▏| 49/50 [02:04<00:02,  2.57s/it]

test result (u) :  0.6541470633487174
test result (s) :  0.792555156847705
test result (h) :  0.7167302588689155


100%|███████████████████████████████████████████| 50/50 [02:06<00:00,  2.54s/it]

test result (u) :  0.6541470633487174
test result (s) :  0.792555156847705
test result (h) :  0.7167302588689155
Training is done! Took time:  126.9 seconds



