# Group Study Season 1-2: Learning Graph Embedding for Compositional Zero-shot Learning

#### The materaials for Group Study in Perception Lab Durham University
#### This is the simplified code for paper 'Learning Graph Embedding for Compositional Zero-shot Learning'. In this season, we will see what a graph looks like and how it pass through GNN.










> Today's takeaway (3/06/2022)


*   A view of a simple graph
*   Basic GCN theory
*   Compositional zero-shot in GNN

> Source code: https://github.com/ExplainableML/czsl

####Prerequisites: Please download ***'utzappos-graph.t7'*** from Teams -> Research Repositories -> file/Documents/tools/coding, and drag it into the files manager on the left.


In [9]:
# Torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn
import torch.optim as optim
cudnn.benchmark = True

# Python imports
import tqdm
from tqdm import tqdm
import os
from os.path import join as ospj
import csv
import numpy as np
import scipy.sparse as sp

## Train
### Data preparation


We define that there are 16 attritutes and 12 objects in the data, compisiting totally 116 potential pairs, 83 of which are available in the training session.

In [2]:
# random generate 1024 image features and labels
img_feat = torch.rand(1024, 300)
attr = torch.randint(0,16,(1024,))
obj = torch.randint(0,12,(1024,))
pair = torch.randint(0,83,(1024,))
print('total attribute numbers:', 16)
print('seen attribute numbers:', len(set(attr.numpy())))
print('\ntotal object numbers:', 12)
print('seen object numbers:', len(set(obj.numpy())))
print('\ntotal pairs:', 116)
print('seen object numbers:', len(set(pair.numpy())))

total attribute numbers: 16
seen attribute numbers: 16

total object numbers: 12
seen object numbers: 12

total pairs: 116
seen object numbers: 83


In [3]:
# load graph
graph = torch.load('/content/utzappos-graph.t7')
embeddings = graph['embeddings']  # (144, 300)
adj = graph['adj']  # (144, 144)
print(type(adj))
print(type(embeddings))
print(adj)

<class 'scipy.sparse.coo.coo_matrix'>
<class 'torch.Tensor'>
  (0, 0)	1.0
  (0, 16)	1.0
  (0, 18)	1.0
  (0, 19)	1.0
  (0, 20)	1.0
  (0, 21)	1.0
  (0, 22)	1.0
  (0, 24)	1.0
  (0, 25)	1.0
  (0, 26)	1.0
  (1, 1)	1.0
  (1, 18)	1.0
  (1, 19)	1.0
  (1, 22)	1.0
  (1, 23)	1.0
  (1, 24)	1.0
  (1, 26)	1.0
  (2, 2)	1.0
  (2, 16)	1.0
  (2, 18)	1.0
  (2, 27)	1.0
  (3, 3)	1.0
  (3, 16)	1.0
  (3, 17)	1.0
  (3, 18)	1.0
  :	:
  (24, 137)	1.0
  (138, 14)	1.0
  (138, 25)	1.0
  (14, 138)	1.0
  (25, 138)	1.0
  (139, 14)	1.0
  (139, 26)	1.0
  (14, 139)	1.0
  (26, 139)	1.0
  (140, 14)	1.0
  (140, 27)	1.0
  (14, 140)	1.0
  (27, 140)	1.0
  (141, 15)	1.0
  (141, 21)	1.0
  (15, 141)	1.0
  (21, 141)	1.0
  (142, 15)	1.0
  (142, 26)	1.0
  (15, 142)	1.0
  (26, 142)	1.0
  (143, 15)	1.0
  (143, 27)	1.0
  (15, 143)	1.0
  (27, 143)	1.0


transfer sparse matrix to tenaor

In [4]:
def normt_spm(mx, method='in'):
    if method == 'in':
        mx = mx.transpose()
        rowsum = np.array(mx.sum(1))
        r_inv = np.power(rowsum, -1).flatten()
        r_inv[np.isinf(r_inv)] = 0.
        r_mat_inv = sp.diags(r_inv)
        mx = r_mat_inv.dot(mx)
        return mx

In [5]:
def spm_to_tensor(sparse_mx):
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(np.vstack(
            (sparse_mx.row, sparse_mx.col))).long()
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

In [6]:
class GraphConv(nn.Module):

    def __init__(self, in_channels, out_channels, dropout=True, relu=True):
        super().__init__()

        self.dropout = nn.Dropout(p=0.5)
        self.layer = nn.Linear(in_channels, out_channels)
        self.relu = nn.ReLU()

    def forward(self, inputs, adj):
        if self.dropout is not None:
            inputs = self.dropout(inputs)

        outputs = torch.mm(adj, torch.mm(inputs, self.layer.weight.T)) + self.layer.bias  # △

        if self.relu is not None:
            outputs = self.relu(outputs)
        return outputs

Define GCN model. Two gnn layers.

In [7]:
class GCN(nn.Module):

    def __init__(self, adj):
        super().__init__()

        adj = normt_spm(adj, method='in')
        adj = spm_to_tensor(adj)
        self.train_adj = adj

        layers = []

        conv = GraphConv(300, 4096)
        self.add_module('conv0', conv)
        layers.append(conv)

        conv = GraphConv(4096, 300, relu=False)
        self.add_module('conv-last', conv)
        layers.append(conv)

        self.layers = layers

    def forward(self, x):
        for conv in self.layers:
            x = conv(x, self.train_adj)

        return F.normalize(x)

In [10]:
gcn = GCN(adj)
optimizer = optim.Adam(gcn.parameters(), lr=0.0001, weight_decay=0.0001)

In [11]:
training = True
current_embeddings = gcn(embeddings)  # (144, 300)
print(current_embeddings.shape)

torch.Size([144, 300])


the index of 83 seen pairs in 144 attr + obj + pairs 

In [12]:
train_idx = [ 28,  29,  30,  32,  35,  36,  37,  40,  41,  42,  43,  45,  46,  49,
        50,  51,  52,  53,  55,  56,  57,  58,  60,  61,  63,  64,  65,  66,
        67,  69,  70,  71,  73,  76,  78,  79,  72,  75,  77,  80,  82,  83,
        86,  87,  88,  89,  91,  93,  94,  95,  96,  98,  99, 101, 103, 104,
        105, 107, 108, 110, 111, 113, 114, 116, 117, 119, 121, 122, 124, 125,
        126, 118, 128, 131, 134, 135, 136, 137, 138, 139, 140, 141, 143]
print("number of seen pairs: ", len(train_idx))

number of seen pairs:  83


pick out the seen embeddings

do dot product with image feature

calculate the loss

In [13]:
pair_embed = current_embeddings[train_idx]  # (83,300)
pair_embed = pair_embed.permute(1,0)  # (300, 83)
pair_pred = torch.matmul(img_feat, pair_embed)  # (batch_size, 300) * (300, 83) -> (batch_size, 83)
loss = F.cross_entropy(pair_pred, pair)
print("loss = ", loss.data)

loss =  tensor(4.4417)


In [14]:
optimizer.zero_grad()
loss.backward()
optimizer.step()

## Test
### data preparation




In [15]:
test_feat = torch.rand(32, 300)

In [16]:
training = False
gcn.eval()
with torch.no_grad():
  current_embeddings = gcn(embeddings)
  pair_embeds = current_embeddings[28:144,:].permute(1,0)  # (300, 116)
  score = torch.matmul(img_feat, pair_embeds)

In [17]:
print(np.argmax(score, axis=1))

tensor([ 95,  89, 101,  ...,  30, 101, 101])
