<a href="https://colab.research.google.com/github/universome/class-norm-for-czsl/blob/master/class-norm-for-czsl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### 1. Defining the hyperparams

In [25]:
DATASET = 'AWA2' # One of ["AWA1", "AWA2", "APY", "CUB", "SUN"]
USE_CLASS_STANDARTIZATION = True # i.e. equation (9) from the paper
USE_PROPER_INIT = True # i.e. equation (10) from the paper

#### 2. Downloading GBU data from [the official GBU website](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/zero-shot-learning/zero-shot-learning-the-good-the-bad-and-the-ugly) (takes 1-2 minutes for the first time)

In [2]:
%%bash
if [ -d "./data" ] 
then
    echo "Files are already there."
else
    wget -q "http://datasets.d2.mpi-inf.mpg.de/xian/xlsa17.zip"
    unzip -q xlsa17.zip -d ./data
fi

#### 3. Running the code

In [20]:
DATASET = 'AWA2' # One of ["AWA1", "AWA2", "APY", "CUB", "SUN"]
USE_CLASS_STANDARTIZATION = True # i.e. equation (9) from the paper
USE_PROPER_INIT = True # i.e. equation (10) from the paper

import numpy as np; np.random.seed(1)
import torch; torch.manual_seed(1)
import torch.nn as nn
import torch.nn.functional as F
from time import time
from tqdm import tqdm
from scipy import io
from torch.utils.data import DataLoader


print(f'<=============== Loading data for {DATASET} ===============>')
DEVICE = 'cuda:1' # Set to 'cpu' if a GPU is not available
DATA_DIR = f'./IAB-GZSL/data/{DATASET}'
data = io.loadmat(f'{DATA_DIR}/vitL14.mat')
attrs_mat = io.loadmat(f'{DATA_DIR}/att_splits.mat')
feats = data['features'].T.astype(np.float32)
labels = data['labels'].squeeze() - 1 # Using "-1" here and for idx to normalize to 0-index
train_idx = attrs_mat['trainval_loc'].squeeze() - 1
test_seen_idx = attrs_mat['test_seen_loc'].squeeze() - 1
test_unseen_idx = attrs_mat['test_unseen_loc'].squeeze() - 1
test_idx = np.array(test_seen_idx.tolist() + test_unseen_idx.tolist())
seen_classes = sorted(np.unique(labels[test_seen_idx]))
unseen_classes = sorted(np.unique(labels[test_unseen_idx]))


print(f'<=============== Preprocessing ===============>')
num_classes = len(seen_classes) + len(unseen_classes)
seen_mask = np.array([(c in seen_classes) for c in range(num_classes)])
unseen_mask = np.array([(c in unseen_classes) for c in range(num_classes)])
attrs = attrs_mat['att'].T
attrs = torch.from_numpy(attrs).to(DEVICE).float()
attrs = attrs / attrs.norm(dim=1, keepdim=True) * np.sqrt(attrs.shape[1])
attrs_seen = attrs[seen_mask]
attrs_unseen = attrs[unseen_mask]
train_labels = labels[train_idx]
test_labels = labels[test_idx]
test_seen_idx = [i for i, y in enumerate(test_labels) if y in seen_classes]
test_unseen_idx = [i for i, y in enumerate(test_labels) if y in unseen_classes]
labels_remapped_to_seen = [(seen_classes.index(t) if t in seen_classes else -1) for t in labels]
test_labels_remapped_seen = [(seen_classes.index(t) if t in seen_classes else -1) for t in test_labels]
test_labels_remapped_unseen = [(unseen_classes.index(t) if t in unseen_classes else -1) for t in test_labels]
ds_train = [(feats[i], labels_remapped_to_seen[i]) for i in train_idx]
ds_test = [(feats[i], int(labels[i])) for i in test_idx]
train_dataloader = DataLoader(ds_train, batch_size=256, shuffle=True)
test_dataloader = DataLoader(ds_test, batch_size=2048)

class_indices_inside_test = {c: [i for i in range(len(test_idx)) if labels[test_idx[i]] == c] for c in range(num_classes)}


class ClassStandardization(nn.Module):
    """
    Class Standardization procedure from the paper.
    Conceptually, it is equivalent to nn.BatchNorm1d with affine=False,
    but for some reason nn.BatchNorm1d performs slightly worse.
    """
    def __init__(self, feat_dim: int):
        super().__init__()
        
        self.running_mean = nn.Parameter(torch.zeros(feat_dim), requires_grad=False)
        self.running_var = nn.Parameter(torch.ones(feat_dim), requires_grad=False)
    
    def forward(self, class_feats):
        """
        Input: class_feats of shape [num_classes, feat_dim]
        Output: class_feats (standardized) of shape [num_classes, feat_dim]
        """
        if self.training:
            batch_mean = class_feats.mean(dim=0)
            batch_var = class_feats.var(dim=0)
            
            # Normalizing the batch
            result = (class_feats - batch_mean.unsqueeze(0)) / (batch_var.unsqueeze(0) + 1e-5)
            
            # Updating the running mean/std
            self.running_mean.data = 0.9 * self.running_mean.data + 0.1 * batch_mean.detach()
            self.running_var.data = 0.9 * self.running_var.data + 0.1 * batch_var.detach()
        else:
            # Using accumulated statistics
            # Attention! For the test inference, we cant use batch-wise statistics,
            # only the accumulated ones. Otherwise, it will be quite transductive
            result = (class_feats - self.running_mean.unsqueeze(0)) / (self.running_var.unsqueeze(0) + 1e-5)
        
        return result


class CNZSLModel(nn.Module):
    def __init__(self, attr_dim: int, hid_dim: int, proto_dim: int):
        super().__init__()
        
        self.model = nn.Sequential(
            nn.Linear(attr_dim, hid_dim),
            nn.ReLU(),
            
            nn.Linear(hid_dim, hid_dim),
            ClassStandardization(hid_dim) if USE_CLASS_STANDARTIZATION else nn.Identity(),
            nn.ReLU(),
            
            ClassStandardization(hid_dim) if USE_CLASS_STANDARTIZATION else nn.Identity(),
            nn.Linear(hid_dim, proto_dim),
            nn.ReLU(),
        )
        
        if USE_PROPER_INIT:
            weight_var = 1 / (hid_dim * proto_dim)
            b = np.sqrt(3 * weight_var)
            self.model[-2].weight.data.uniform_(-b, b)
        
    def forward(self, x, attrs):
        protos = self.model(attrs)
        x_ns = 5 * x / x.norm(dim=1, keepdim=True) # [batch_size, x_dim]
        protos_ns = 5 * protos / protos.norm(dim=1, keepdim=True) # [num_classes, x_dim]
        logits = x_ns @ protos_ns.t() # [batch_size, num_classes]
        
        return logits
    

print(f'\n<=============== Starting training ===============>')
start_time = time()
model = CNZSLModel(attrs.shape[1],1024, feats.shape[1]).to(DEVICE)
optim = torch.optim.Adam(model.model.parameters(), lr=0.0005, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optim, gamma=0.1, step_size=25)


for epoch in tqdm(range(100)):
    model.train()
    
    for i, batch in enumerate(train_dataloader):
        feats = torch.from_numpy(np.array(batch[0])).to(DEVICE)
        targets = torch.from_numpy(np.array(batch[1])).to(DEVICE)
        logits = model(feats, attrs[seen_mask])
        loss = F.cross_entropy(logits, targets)
        optim.zero_grad()
        loss.backward()
        optim.step()
    
    scheduler.step()

print(f'Training is done! Took time: {(time() - start_time): .1f} seconds')

model.eval() # Important! Otherwise we would use unseen batch statistics
logits = [model(x.to(DEVICE), attrs).cpu() for x, _ in test_dataloader]
logits = torch.cat(logits, dim=0)
logits[:, seen_mask] *= (0.95 if DATASET != "CUB" else 1.0) # Trading a bit of gzsl-s for a bit of gzsl-u
preds_gzsl = logits.argmax(dim=1).numpy()
preds_zsl_s = logits[:, seen_mask].argmax(dim=1).numpy()
preds_zsl_u = logits[:, ~seen_mask].argmax(dim=1).numpy()
guessed_zsl_u = (preds_zsl_u == test_labels_remapped_unseen)
guessed_gzsl = (preds_gzsl == test_labels)
zsl_unseen_acc = np.mean([guessed_zsl_u[cls_idx].mean().item() for cls_idx in [class_indices_inside_test[c] for c in unseen_classes]]) 
gzsl_seen_acc = np.mean([guessed_gzsl[cls_idx].mean().item() for cls_idx in [class_indices_inside_test[c] for c in seen_classes]])
gzsl_unseen_acc = np.mean([guessed_gzsl[cls_idx].mean().item() for cls_idx in [class_indices_inside_test[c] for c in unseen_classes]])
gzsl_harmonic = 2 * (gzsl_seen_acc * gzsl_unseen_acc) / (gzsl_seen_acc + gzsl_unseen_acc)

print(f'ZSL-U: {zsl_unseen_acc * 100:.02f}')
# print(f'GZSL-U: {gzsl_unseen_acc * 100:.02f}')
# print(f'GZSL-S: {gzsl_seen_acc * 100:.02f}')
# print(f'GZSL-H: {gzsl_harmonic * 100:.02f}')

# Best Result on hidden dim=1024, and 100 epochs acc= 32.76 and lr=0.0005 for vitL14 features
# Best Result on hidden dim=1024, and 100 epochs acc= 17.09 and lr=0.0005 for vitG14 features
# Best Result on hidden dim=512, and 100 epochs acc= 28.50 and lr=0.0005  for vitL14 features




  feats = torch.from_numpy(np.array(batch[0])).to(DEVICE)
  targets = torch.from_numpy(np.array(batch[1])).to(DEVICE)
100%|██████████| 100/100 [01:03<00:00,  1.58it/s]

Training is done! Took time:  63.3 seconds
ZSL-U: 32.76





# Trying Different Normalizations

In [28]:
DATASET = 'AWA2' # One of ["AWA1", "AWA2", "APY", "CUB", "SUN"]
USE_CLASS_STANDARTIZATION = True # i.e. equation (9) from the paper
USE_PROPER_INIT = True # i.e. equation (10) from the paper

import numpy as np; np.random.seed(1)
import torch; torch.manual_seed(1)
import torch.nn as nn
import torch.nn.functional as F
from time import time
from tqdm import tqdm
from scipy import io
from torch.utils.data import DataLoader


print(f'<=============== Loading data for {DATASET} ===============>')
DEVICE = 'cuda:1' # Set to 'cpu' if a GPU is not available
DATA_DIR = f'./IAB-GZSL/data/{DATASET}'
data = io.loadmat(f'{DATA_DIR}/vitL14.mat')
attrs_mat = io.loadmat(f'{DATA_DIR}/att_splits.mat')
feats = data['features'].T.astype(np.float32)
labels = data['labels'].squeeze() - 1 # Using "-1" here and for idx to normalize to 0-index
train_idx = attrs_mat['trainval_loc'].squeeze() - 1
test_seen_idx = attrs_mat['test_seen_loc'].squeeze() - 1
test_unseen_idx = attrs_mat['test_unseen_loc'].squeeze() - 1
test_idx = np.array(test_seen_idx.tolist() + test_unseen_idx.tolist())
seen_classes = sorted(np.unique(labels[test_seen_idx]))
unseen_classes = sorted(np.unique(labels[test_unseen_idx]))


print(f'<=============== Preprocessing ===============>')
num_classes = len(seen_classes) + len(unseen_classes)
seen_mask = np.array([(c in seen_classes) for c in range(num_classes)])
unseen_mask = np.array([(c in unseen_classes) for c in range(num_classes)])
attrs = attrs_mat['att'].T
attrs = torch.from_numpy(attrs).to(DEVICE).float()
attrs = attrs / attrs.norm(dim=1, keepdim=True) * np.sqrt(attrs.shape[1])
attrs_seen = attrs[seen_mask]
attrs_unseen = attrs[unseen_mask]
train_labels = labels[train_idx]
test_labels = labels[test_idx]
test_seen_idx = [i for i, y in enumerate(test_labels) if y in seen_classes]
test_unseen_idx = [i for i, y in enumerate(test_labels) if y in unseen_classes]
labels_remapped_to_seen = [(seen_classes.index(t) if t in seen_classes else -1) for t in labels]
test_labels_remapped_seen = [(seen_classes.index(t) if t in seen_classes else -1) for t in test_labels]
test_labels_remapped_unseen = [(unseen_classes.index(t) if t in unseen_classes else -1) for t in test_labels]
ds_train = [(feats[i], labels_remapped_to_seen[i]) for i in train_idx]
ds_test = [(feats[i], int(labels[i])) for i in test_idx]
train_dataloader = DataLoader(ds_train, batch_size=256, shuffle=True)
test_dataloader = DataLoader(ds_test, batch_size=2048)

class_indices_inside_test = {c: [i for i in range(len(test_idx)) if labels[test_idx[i]] == c] for c in range(num_classes)}


class ClassStandardization(nn.Module):
    """
    Class Standardization procedure from the paper.
    Conceptually, it is equivalent to nn.BatchNorm1d with affine=False,
    but for some reason nn.BatchNorm1d performs slightly worse.
    """
    def __init__(self, feat_dim: int):
        super().__init__()
        
        self.running_mean = nn.Parameter(torch.zeros(feat_dim), requires_grad=False)
        self.running_var = nn.Parameter(torch.ones(feat_dim), requires_grad=False)
    
    def forward(self, class_feats):
        """
        Input: class_feats of shape [num_classes, feat_dim]
        Output: class_feats (standardized) of shape [num_classes, feat_dim]
        """
        if self.training:
            batch_mean = class_feats.mean(dim=0)
            batch_var = class_feats.var(dim=0)
            
            # Normalizing the batch
            result = (class_feats - batch_mean.unsqueeze(0)) / (batch_var.unsqueeze(0) + 1e-5)
            
            # Updating the running mean/std
            self.running_mean.data = 0.9 * self.running_mean.data + 0.1 * batch_mean.detach()
            self.running_var.data = 0.9 * self.running_var.data + 0.1 * batch_var.detach()
        else:
            # Using accumulated statistics
            # Attention! For the test inference, we cant use batch-wise statistics,
            # only the accumulated ones. Otherwise, it will be quite transductive
            result = (class_feats - self.running_mean.unsqueeze(0)) / (self.running_var.unsqueeze(0) + 1e-5)
        
        return result


class CNZSLModel(nn.Module):
    def __init__(self, attr_dim: int, hid_dim: int, proto_dim: int):
        super().__init__()
        self.hid_dim = hid_dim
        self.model = nn.Sequential(
            nn.Linear(attr_dim, hid_dim),
            nn.ReLU(),
            nn.GroupNorm(8, hid_dim),
            
            nn.Linear(hid_dim, hid_dim),
            ClassStandardization(hid_dim) if USE_CLASS_STANDARTIZATION else nn.Identity(),
            nn.ReLU(),
            nn.GroupNorm(8, hid_dim),
            
            ClassStandardization(hid_dim) if USE_CLASS_STANDARTIZATION else nn.Identity(),
            nn.Linear(hid_dim, proto_dim),
            nn.ReLU(),
        )
        
        if USE_PROPER_INIT:
            weight_var = 1 / (hid_dim * proto_dim)
            b = np.sqrt(3 * weight_var)
            self.model[-2].weight.data.uniform_(-b, b)
        # Apply spectral normalization
        self.model[0] = torch.nn.utils.spectral_norm(self.model[0])
        self.model[3] = torch.nn.utils.spectral_norm(self.model[3])
        
    def forward(self, x, attrs):
        protos = self.model(attrs)
        protos = nn.LayerNorm(self.hid_dim).to(DEVICE)(protos)
        x = nn.InstanceNorm1d(1, affine=False).to(DEVICE)(x.unsqueeze(1)).squeeze(1)
        x_ns = 5 * x / x.norm(dim=1, keepdim=True) # [batch_size, x_dim]
        protos_ns = 5 * protos / protos.norm(dim=1, keepdim=True) # [num_classes, x_dim]
        logits = x_ns @ protos_ns.t() # [batch_size, num_classes]
        
        return logits
    

print(f'\n<=============== Starting training ===============>')
start_time = time()
model = CNZSLModel(attrs.shape[1],1024, feats.shape[1]).to(DEVICE)
optim = torch.optim.AdamW(model.model.parameters(), lr=0.0005, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optim, gamma=0.1, step_size=25)


for epoch in tqdm(range(101)):
    model.train()
    
    for i, batch in enumerate(train_dataloader):
        feats = torch.from_numpy(np.array(batch[0])).to(DEVICE)
        targets = torch.from_numpy(np.array(batch[1])).to(DEVICE)
        logits = model(feats, attrs[seen_mask])
        loss = F.cross_entropy(logits, targets)
        optim.zero_grad()
        loss.backward()
        optim.step()
    
    scheduler.step()

print(f'Training is done! Took time: {(time() - start_time): .1f} seconds')
print(f"Epoch {epoch} Results")
model.eval() # Important! Otherwise we would use unseen batch statistics
logits = [model(x.to(DEVICE), attrs).cpu() for x, _ in test_dataloader]
logits = torch.cat(logits, dim=0)
logits[:, seen_mask] *= (0.95 if DATASET != "CUB" else 1.0) # Trading a bit of gzsl-s for a bit of gzsl-u
preds_gzsl = logits.argmax(dim=1).numpy()
preds_zsl_s = logits[:, seen_mask].argmax(dim=1).numpy()
preds_zsl_u = logits[:, ~seen_mask].argmax(dim=1).numpy()
guessed_zsl_u = (preds_zsl_u == test_labels_remapped_unseen)
guessed_gzsl = (preds_gzsl == test_labels)
zsl_unseen_acc = np.mean([guessed_zsl_u[cls_idx].mean().item() for cls_idx in [class_indices_inside_test[c] for c in unseen_classes]]) 
gzsl_seen_acc = np.mean([guessed_gzsl[cls_idx].mean().item() for cls_idx in [class_indices_inside_test[c] for c in seen_classes]])
gzsl_unseen_acc = np.mean([guessed_gzsl[cls_idx].mean().item() for cls_idx in [class_indices_inside_test[c] for c in unseen_classes]])
gzsl_harmonic = 2 * (gzsl_seen_acc * gzsl_unseen_acc) / (gzsl_seen_acc + gzsl_unseen_acc)
print(f'ZSL-U: {zsl_unseen_acc * 100:.02f}')


# Best Result on hidden dim=1024, and 100 epochs acc= 32.76 and lr=0.0005 for vitL14 features
# Best Result on hidden dim=1024, and 100 epochs acc= 17.09 and lr=0.0005 for vitG14 features
# Best Result on hidden dim=512, and 100 epochs acc= 28.50 and lr=0.0005  for vitL14 features




  feats = torch.from_numpy(np.array(batch[0])).to(DEVICE)
  targets = torch.from_numpy(np.array(batch[1])).to(DEVICE)
100%|██████████| 101/101 [01:18<00:00,  1.29it/s]

Training is done! Took time:  78.2 seconds
Epoch 100 Results
ZSL-U: 31.78



