In [111]:
import pickle
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import loss
import scipy
from scipy import linalg
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
with open("data.sample.25k.bert-large.pickle", "rb") as f:
    data = pickle.load(f)

In [127]:
class Dataset(torch.utils.data.Dataset):
    
    """Simple torch dataset class"""
    def __init__(self, data):

        self.data = data

    def __len__(self):

        return len(self.data)

    def __getitem__(self, index):

        with torch.no_grad():
             
            vec1_np, vec2_np, str1, str2, _ = self.data[index]
            
            vec1, vec2, str1, str2, pair_id = self.data[index]            
            vec1, vec2 = torch.from_numpy(vec1_np).float(), torch.from_numpy(vec2_np).float()
            return (vec1, vec2, str1, str2, pair_id)
        
        
def get_nullspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix
    """
    nullspace_basis = scipy.linalg.null_space(W)  # orthogonal basis

    nullspace_basis = nullspace_basis * np.sign(nullspace_basis[0][0])  # handle sign ambiguity
    projection_matrix = nullspace_basis.dot(nullspace_basis.T)

    return projection_matrix
    
    
def get_rowspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix
    """

    w_basis = scipy.linalg.orth(W.T) # orthogonal basis
    w_basis * np.sign(w_basis[0][0]) # handle sign ambiguity
    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace
    
    return P_W

In [74]:
class Siamese(pl.LightningModule):

    def __init__(self, X_train, X_dev, dim, batch_size):
        super(Siamese, self).__init__()
        self.l = torch.nn.Linear(1024, dim)
        
        self.train_data = Dataset(X_train)
        self.dev_data = Dataset(X_dev)
        self.train_gen = torch.utils.data.DataLoader(self.train_data, batch_size = batch_size, drop_last = False, shuffle=True)
        self.dev_gen = torch.utils.data.DataLoader(self.dev_data, batch_size = batch_size, drop_last = False, shuffle=True)
        self.loss_fn = loss.BatchHardTripletLoss(final = "softmax")
        
        self.acc = None
        self.optimizer = torch.optim.Adam(self.parameters(), weight_decay = 1e-6)
        
    def forward(self, x1, x2):

          h1 = self.l(x1)
          h2 = self.l(x2)
         
          return h1, h2
 
    def train_network(self, num_epochs):
    
      trainer = Trainer(max_nb_epochs = num_epochs, min_nb_epochs = num_epochs, show_progress_bar = True)
      trainer.fit(self)

      return self.acc   
      
    def get_weights(self):
    
        return self.l.weight.data.numpy()
    
    def training_step(self, batch, batch_nb):
        # REQUIRED
        x1, x2, str1, str2, ids = batch
        h1, h2 = self.forward(x1, x2)
        loss_val =  self.loss_fn(h1, h2, str1, str2, ids, index=0, evaluation = False)
        
        return {'loss': loss_val[0]}
        

    def validation_step(self, batch, batch_nb):
    
        # OPTIONAL
        x1, x2, str1, str2, ids = batch
        h1, h2 = self.forward(x1, x2)
        loss_val =  self.loss_fn(h1, h2, str1, str2, ids, index=batch_nb, evaluation = True)
        return {'val_loss': loss_val[0]}

    def validation_end(self, outputs):
        # OPTIONAL    
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        print("Loss is {}".format(avg_loss))
        return {'avg_val_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        return torch.optim.Adam(self.parameters(), weight_decay = 1e-4)

    @pl.data_loader
    def train_dataloader(self):
        return self.train_gen

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        # can also return a list of val dataloaders
        return self.dev_gen

In [82]:
net = Siamese(data[:20000], data[20000:], batch_size = 1024, dim = 256)
net.train_network(num_epochs = 15)

INFO:root:      Name                  Type Params
0        l                Linear  262 K
1  loss_fn  BatchHardTripletLoss    0  
Epoch 1:   0%|          | 0/25 [00:00<?, ?batch/s]                       

Loss is 0.3076070249080658


Epoch 1:  80%|████████  | 20/25 [00:04<00:01,  4.80batch/s, batch_nb=19, loss=0.236, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 1:  84%|████████▍ | 21/25 [00:04<00:00,  5.70batch/s, batch_nb=19, loss=0.236, v_nb=24]
Epoch 1:  88%|████████▊ | 22/25 [00:04<00:00,  6.21batch/s, batch_nb=19, loss=0.236, v_nb=24]
Epoch 1:  92%|█████████▏| 23/25 [00:04<00:00,  6.63batch/s, batch_nb=19, loss=0.236, v_nb=24]
Epoch 1:  96%|█████████▌| 24/25 [00:04<00:00,  7.00batch/s, batch_nb=19, loss=0.236, v_nb=24]
Epoch 1: 100%|██████████| 25/25 [00:04<00:00,  7.58batch/s, batch_nb=19, loss=0.236, v_nb=24]
Epoch 2:   0%|          | 0/25 [00:00<00:03,  7.58batch/s, batch_nb=19, loss=0.236, v_nb=24] 

Loss is 0.21157684922218323


Epoch 2:  80%|████████  | 20/25 [00:04<00:01,  4.88batch/s, batch_nb=19, loss=0.215, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 2:  84%|████████▍ | 21/25 [00:04<00:00,  5.63batch/s, batch_nb=19, loss=0.215, v_nb=24]
Epoch 2:  88%|████████▊ | 22/25 [00:04<00:00,  6.18batch/s, batch_nb=19, loss=0.215, v_nb=24]
Epoch 2:  92%|█████████▏| 23/25 [00:04<00:00,  6.22batch/s, batch_nb=19, loss=0.215, v_nb=24]
Epoch 2: 100%|██████████| 25/25 [00:04<00:00,  6.27batch/s, batch_nb=19, loss=0.215, v_nb=24]
Epoch 3:   0%|          | 0/25 [00:00<00:03,  6.27batch/s, batch_nb=19, loss=0.215, v_nb=24] 

Loss is 0.19035322964191437


Epoch 3:  80%|████████  | 20/25 [00:04<00:01,  4.88batch/s, batch_nb=19, loss=0.201, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 3:  84%|████████▍ | 21/25 [00:04<00:00,  5.61batch/s, batch_nb=19, loss=0.201, v_nb=24]
Epoch 3:  88%|████████▊ | 22/25 [00:04<00:00,  6.23batch/s, batch_nb=19, loss=0.201, v_nb=24]
Epoch 3:  92%|█████████▏| 23/25 [00:04<00:00,  6.27batch/s, batch_nb=19, loss=0.201, v_nb=24]
Epoch 3: 100%|██████████| 25/25 [00:04<00:00,  6.78batch/s, batch_nb=19, loss=0.201, v_nb=24]
Epoch 4:   0%|          | 0/25 [00:00<00:03,  6.78batch/s, batch_nb=19, loss=0.201, v_nb=24] 

Loss is 0.17495617270469666


Epoch 4:  80%|████████  | 20/25 [00:04<00:01,  4.73batch/s, batch_nb=19, loss=0.189, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 4:  84%|████████▍ | 21/25 [00:04<00:00,  5.44batch/s, batch_nb=19, loss=0.189, v_nb=24]
Epoch 4:  88%|████████▊ | 22/25 [00:04<00:00,  6.03batch/s, batch_nb=19, loss=0.189, v_nb=24]
Epoch 4:  92%|█████████▏| 23/25 [00:04<00:00,  6.13batch/s, batch_nb=19, loss=0.189, v_nb=24]
Epoch 4: 100%|██████████| 25/25 [00:04<00:00,  6.52batch/s, batch_nb=19, loss=0.189, v_nb=24]
Epoch 5:   0%|          | 0/25 [00:00<00:03,  6.52batch/s, batch_nb=19, loss=0.189, v_nb=24] 

Loss is 0.16759291291236877


Epoch 5:  80%|████████  | 20/25 [00:04<00:00,  5.20batch/s, batch_nb=19, loss=0.180, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 5:  84%|████████▍ | 21/25 [00:04<00:00,  5.93batch/s, batch_nb=19, loss=0.180, v_nb=24]
Epoch 5:  88%|████████▊ | 22/25 [00:04<00:00,  6.49batch/s, batch_nb=19, loss=0.180, v_nb=24]
Epoch 5:  92%|█████████▏| 23/25 [00:04<00:00,  6.43batch/s, batch_nb=19, loss=0.180, v_nb=24]
Epoch 5: 100%|██████████| 25/25 [00:04<00:00,  6.40batch/s, batch_nb=19, loss=0.180, v_nb=24]
Epoch 6:   0%|          | 0/25 [00:00<00:03,  6.40batch/s, batch_nb=19, loss=0.180, v_nb=24] 

Loss is 0.16256919503211975


Epoch 6:  80%|████████  | 20/25 [00:04<00:01,  4.52batch/s, batch_nb=19, loss=0.159, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 6:  84%|████████▍ | 21/25 [00:04<00:00,  5.32batch/s, batch_nb=19, loss=0.159, v_nb=24]
Epoch 6:  88%|████████▊ | 22/25 [00:04<00:00,  5.87batch/s, batch_nb=19, loss=0.159, v_nb=24]
Epoch 6:  92%|█████████▏| 23/25 [00:04<00:00,  5.90batch/s, batch_nb=19, loss=0.159, v_nb=24]
Epoch 6: 100%|██████████| 25/25 [00:04<00:00,  5.85batch/s, batch_nb=19, loss=0.159, v_nb=24]
Epoch 7:   0%|          | 0/25 [00:00<00:04,  5.85batch/s, batch_nb=19, loss=0.159, v_nb=24] 

Loss is 0.1599402278661728


Epoch 7:  80%|████████  | 20/25 [00:04<00:01,  4.66batch/s, batch_nb=19, loss=0.146, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 7:  84%|████████▍ | 21/25 [00:04<00:00,  5.58batch/s, batch_nb=19, loss=0.146, v_nb=24]
Epoch 7:  88%|████████▊ | 22/25 [00:04<00:00,  5.74batch/s, batch_nb=19, loss=0.146, v_nb=24]
Epoch 7:  92%|█████████▏| 23/25 [00:04<00:00,  5.87batch/s, batch_nb=19, loss=0.146, v_nb=24]
Epoch 7: 100%|██████████| 25/25 [00:04<00:00,  6.44batch/s, batch_nb=19, loss=0.146, v_nb=24]
Epoch 8:   0%|          | 0/25 [00:00<00:03,  6.44batch/s, batch_nb=19, loss=0.146, v_nb=24] 

Loss is 0.15479394793510437


Epoch 8:  80%|████████  | 20/25 [00:03<00:01,  4.96batch/s, batch_nb=19, loss=0.136, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 8:  84%|████████▍ | 21/25 [00:04<00:00,  5.70batch/s, batch_nb=19, loss=0.136, v_nb=24]
Epoch 8:  88%|████████▊ | 22/25 [00:04<00:00,  5.88batch/s, batch_nb=19, loss=0.136, v_nb=24]
Epoch 8:  92%|█████████▏| 23/25 [00:04<00:00,  5.99batch/s, batch_nb=19, loss=0.136, v_nb=24]
Epoch 8:  96%|█████████▌| 24/25 [00:04<00:00,  6.02batch/s, batch_nb=19, loss=0.136, v_nb=24]
Epoch 8: 100%|██████████| 25/25 [00:04<00:00,  6.82batch/s, batch_nb=19, loss=0.136, v_nb=24]
Epoch 9:   0%|          | 0/25 [00:00<00:03,  6.82batch/s, batch_nb=19, loss=0.136, v_nb=24] 

Loss is 0.1537805050611496


Epoch 9:  80%|████████  | 20/25 [00:04<00:01,  4.82batch/s, batch_nb=19, loss=0.129, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 9:  84%|████████▍ | 21/25 [00:04<00:00,  5.45batch/s, batch_nb=19, loss=0.129, v_nb=24]
Epoch 9:  88%|████████▊ | 22/25 [00:04<00:00,  5.85batch/s, batch_nb=19, loss=0.129, v_nb=24]
Epoch 9:  92%|█████████▏| 23/25 [00:04<00:00,  5.88batch/s, batch_nb=19, loss=0.129, v_nb=24]
Epoch 9: 100%|██████████| 25/25 [00:04<00:00,  6.39batch/s, batch_nb=19, loss=0.129, v_nb=24]
Epoch 10:   0%|          | 0/25 [00:00<00:03,  6.39batch/s, batch_nb=19, loss=0.129, v_nb=24]

Loss is 0.15229171514511108


Epoch 10:  80%|████████  | 20/25 [00:04<00:01,  4.45batch/s, batch_nb=19, loss=0.123, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 10:  84%|████████▍ | 21/25 [00:04<00:00,  5.36batch/s, batch_nb=19, loss=0.123, v_nb=24]
Epoch 10:  88%|████████▊ | 22/25 [00:04<00:00,  5.92batch/s, batch_nb=19, loss=0.123, v_nb=24]
Epoch 10:  92%|█████████▏| 23/25 [00:04<00:00,  6.00batch/s, batch_nb=19, loss=0.123, v_nb=24]
Epoch 10:  96%|█████████▌| 24/25 [00:04<00:00,  6.37batch/s, batch_nb=19, loss=0.123, v_nb=24]
Epoch 10: 100%|██████████| 25/25 [00:04<00:00,  7.07batch/s, batch_nb=19, loss=0.123, v_nb=24]
Epoch 11:   0%|          | 0/25 [00:00<00:03,  7.07batch/s, batch_nb=19, loss=0.123, v_nb=24] 

Loss is 0.1506134569644928


Epoch 11:  80%|████████  | 20/25 [00:04<00:01,  4.56batch/s, batch_nb=19, loss=0.118, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 11:  84%|████████▍ | 21/25 [00:04<00:00,  5.36batch/s, batch_nb=19, loss=0.118, v_nb=24]
Epoch 11:  88%|████████▊ | 22/25 [00:04<00:00,  5.65batch/s, batch_nb=19, loss=0.118, v_nb=24]
Epoch 11:  92%|█████████▏| 23/25 [00:04<00:00,  5.87batch/s, batch_nb=19, loss=0.118, v_nb=24]
Epoch 11: 100%|██████████| 25/25 [00:04<00:00,  6.37batch/s, batch_nb=19, loss=0.118, v_nb=24]
Epoch 12:   0%|          | 0/25 [00:00<00:03,  6.37batch/s, batch_nb=19, loss=0.118, v_nb=24] 

Loss is 0.15026356279850006


Epoch 12:  80%|████████  | 20/25 [00:04<00:01,  4.61batch/s, batch_nb=19, loss=0.114, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 12:  84%|████████▍ | 21/25 [00:04<00:00,  5.44batch/s, batch_nb=19, loss=0.114, v_nb=24]
Epoch 12:  88%|████████▊ | 22/25 [00:04<00:00,  6.10batch/s, batch_nb=19, loss=0.114, v_nb=24]
Epoch 12:  92%|█████████▏| 23/25 [00:04<00:00,  6.19batch/s, batch_nb=19, loss=0.114, v_nb=24]
Epoch 12: 100%|██████████| 25/25 [00:04<00:00,  6.71batch/s, batch_nb=19, loss=0.114, v_nb=24]
Epoch 13:   0%|          | 0/25 [00:00<00:03,  6.71batch/s, batch_nb=19, loss=0.114, v_nb=24] 

Loss is 0.1508251428604126


Epoch 13:  80%|████████  | 20/25 [00:04<00:01,  4.41batch/s, batch_nb=19, loss=0.110, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 13:  84%|████████▍ | 21/25 [00:04<00:00,  5.27batch/s, batch_nb=19, loss=0.110, v_nb=24]
Epoch 13:  88%|████████▊ | 22/25 [00:04<00:00,  5.89batch/s, batch_nb=19, loss=0.110, v_nb=24]
Epoch 13:  92%|█████████▏| 23/25 [00:04<00:00,  5.93batch/s, batch_nb=19, loss=0.110, v_nb=24]
Epoch 13: 100%|██████████| 25/25 [00:04<00:00,  5.97batch/s, batch_nb=19, loss=0.110, v_nb=24]
Epoch 14:   0%|          | 0/25 [00:00<00:04,  5.97batch/s, batch_nb=19, loss=0.110, v_nb=24] 

Loss is 0.1509326994419098


Epoch 14:  80%|████████  | 20/25 [00:04<00:01,  4.60batch/s, batch_nb=19, loss=0.107, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 14:  84%|████████▍ | 21/25 [00:04<00:00,  5.38batch/s, batch_nb=19, loss=0.107, v_nb=24]
Epoch 14:  88%|████████▊ | 22/25 [00:04<00:00,  6.08batch/s, batch_nb=19, loss=0.107, v_nb=24]
Epoch 14:  92%|█████████▏| 23/25 [00:04<00:00,  6.21batch/s, batch_nb=19, loss=0.107, v_nb=24]
Epoch 14: 100%|██████████| 25/25 [00:04<00:00,  6.70batch/s, batch_nb=19, loss=0.107, v_nb=24]
Epoch 15:   0%|          | 0/25 [00:00<00:03,  6.70batch/s, batch_nb=19, loss=0.107, v_nb=24] 

Loss is 0.15037532150745392


Epoch 15:  80%|████████  | 20/25 [00:04<00:01,  4.55batch/s, batch_nb=19, loss=0.104, v_nb=24]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 15:  84%|████████▍ | 21/25 [00:04<00:00,  5.49batch/s, batch_nb=19, loss=0.104, v_nb=24]
Epoch 15:  88%|████████▊ | 22/25 [00:04<00:00,  5.74batch/s, batch_nb=19, loss=0.104, v_nb=24]
Epoch 15:  92%|█████████▏| 23/25 [00:04<00:00,  6.33batch/s, batch_nb=19, loss=0.104, v_nb=24]
Epoch 15: 100%|██████████| 25/25 [00:04<00:00,  6.36batch/s, batch_nb=19, loss=0.104, v_nb=24]
Epoch 15: 100%|██████████| 25/25 [00:04<00:00,  5.11batch/s, batch_nb=19, loss=0.104, v_nb=24]

Loss is 0.15040801465511322





In [175]:
W = net.l.weight.detach().cpu().numpy()
P_Rw = get_rowspace_projection(W)
I = np.eye(P_Rw.shape[0])
P_Nw = I - P_Rw
P_Nw2 = get_nullspace_projection(W)
vecs = np.array([d[0] for d in data[20000:]])
strings = [d[2] for d in data[20000:]]
vecs_transformed = W.dot(vecs.T).T
vecs_transformed_nullspace = P_Nw2.dot(vecs.T).T
print(P_Nw2-P_Nw)

[[-1.49011612e-07  3.25962901e-09 -6.42612576e-08 ... -9.48784873e-09
  -1.86264515e-09 -4.59840521e-09]
 [ 3.25962901e-09 -5.96046448e-08  4.65661287e-08 ... -2.04890966e-08
  -1.07102096e-08 -2.09547579e-08]
 [-6.42612576e-08  4.65661287e-08  8.94069672e-08 ...  2.60770321e-08
  -2.79396772e-09 -1.25728548e-08]
 ...
 [-9.48784873e-09 -2.04890966e-08  2.60770321e-08 ... -1.19209290e-07
   9.31322575e-10 -2.98023224e-08]
 [-1.86264515e-09 -1.07102096e-08 -2.79396772e-09 ...  9.31322575e-10
   1.49011612e-07  2.79396772e-09]
 [-4.59840521e-09 -2.09547579e-08 -1.25728548e-08 ... -2.98023224e-08
   2.79396772e-09  0.00000000e+00]]


In [173]:
W_r_basis = scipy.linalg.orth(W.T)
print(W.shape, W_r_basis.shape)

(256, 1024) (1024, 256)


In [176]:

def to_string(np_array):
        return "\t".join(["%0.4f" % float(x) for x in np_array])
    
with open("vecs.transformed.nullspace.tsv", "w") as f:
    for v in vecs_transformed_nullspace:
        f.write(to_string(v) + "\n")

with open("labels.transformed.nullspace.tsv", "w") as f:
    for s in strings:
        f.write(s + "\n")
        
with open("vecs.transformed.tsv", "w") as f:
    for v in vecs_transformed:
        f.write(to_string(v) + "\n")

with open("labels.transformed.tsv", "w") as f:
    for s in strings:
        f.write(s + "\n")

In [119]:
np.linalg.matrix_rank(W)
print(np.linalg.matrix_rank(vecs))
print(np.linalg.matrix_rank(vecs_transformed_nullspace))
print(np.linalg.matrix_rank(P_Rw.dot(vecs.T).T))

1023
768
256


In [97]:
I.shape

(1024, 1024)

array([[ 0.2558343 ,  0.00257222, -0.00626908, ...,  0.00080806,
        -0.00819083,  0.00097386],
       [ 0.00257222,  0.25096208,  0.00496867, ...,  0.01734442,
        -0.00780143,  0.00727562],
       [-0.00626908,  0.00496867,  0.25201002, ..., -0.01009724,
         0.00422214, -0.00510023],
       ...,
       [ 0.00080806,  0.01734442, -0.01009724, ...,  0.26606488,
         0.01120324, -0.02924984],
       [-0.00819083, -0.00780143,  0.00422214, ...,  0.01120324,
         0.2691891 ,  0.00959322],
       [ 0.00097386,  0.00727562, -0.00510023, ..., -0.02924984,
         0.00959322,  0.2728414 ]], dtype=float32)

In [105]:
P_Rw.dot(P_Rw) - P_Rw
P_Nw.dot(P_Nw) - P_Nw

array([[-5.96046448e-08,  3.21306288e-08, -3.21306288e-08, ...,
         7.39237294e-09, -4.65661287e-09,  4.13274392e-09],
       [ 3.21306288e-08, -2.98023224e-08,  3.49245965e-08, ...,
        -2.79396772e-08,  0.00000000e+00,  1.25728548e-08],
       [-3.21306288e-08,  3.49245965e-08, -5.96046448e-08, ...,
         1.21071935e-08,  3.72529030e-09, -5.58793545e-09],
       ...,
       [ 7.39237294e-09, -2.79396772e-08,  1.21071935e-08, ...,
        -2.98023224e-08,  2.79396772e-09, -2.23517418e-08],
       [-4.65661287e-09,  0.00000000e+00,  3.72529030e-09, ...,
         2.79396772e-09,  0.00000000e+00,  9.31322575e-10],
       [ 4.13274392e-09,  1.25728548e-08, -5.58793545e-09, ...,
        -2.23517418e-08,  9.31322575e-10,  1.19209290e-07]], dtype=float32)

In [162]:
w = np.random.rand(10,100) - 0.5
p1 = get_nullspace_projection(w)
p2 = np.eye(100) - get_rowspace_projection(w)

In [163]:
p1.shape, p2.shape

((100, 100), (100, 100))

In [164]:
p1

array([[ 0.82945399,  0.0125107 , -0.0786709 , ..., -0.05030625,
         0.02868598,  0.01301795],
       [ 0.0125107 ,  0.89848579,  0.01230274, ...,  0.00552447,
        -0.01376449,  0.02839011],
       [-0.0786709 ,  0.01230274,  0.94018492, ..., -0.02453937,
        -0.0141308 ,  0.02401205],
       ...,
       [-0.05030625,  0.00552447, -0.02453937, ...,  0.87408413,
        -0.00841796, -0.00286106],
       [ 0.02868598, -0.01376449, -0.0141308 , ..., -0.00841796,
         0.86127804, -0.00816   ],
       [ 0.01301795,  0.02839011,  0.02401205, ..., -0.00286106,
        -0.00816   ,  0.94351214]])

In [165]:
p2

array([[ 0.82945399,  0.0125107 , -0.0786709 , ..., -0.05030625,
         0.02868598,  0.01301795],
       [ 0.0125107 ,  0.89848579,  0.01230274, ...,  0.00552447,
        -0.01376449,  0.02839011],
       [-0.0786709 ,  0.01230274,  0.94018492, ..., -0.02453937,
        -0.0141308 ,  0.02401205],
       ...,
       [-0.05030625,  0.00552447, -0.02453937, ...,  0.87408413,
        -0.00841796, -0.00286106],
       [ 0.02868598, -0.01376449, -0.0141308 , ..., -0.00841796,
         0.86127804, -0.00816   ],
       [ 0.01301795,  0.02839011,  0.02401205, ..., -0.00286106,
        -0.00816   ,  0.94351214]])

In [168]:
p1 = get_nullspace_projection(W)
p2 = np.eye(1024) - get_rowspace_projection(W)

In [170]:
p1

array([[ 0.74416554, -0.00257222,  0.00626902, ..., -0.00080807,
         0.00819083, -0.00097387],
       [-0.00257222,  0.74903786, -0.00496862, ..., -0.01734444,
         0.00780141, -0.00727564],
       [ 0.00626902, -0.00496862,  0.7479901 , ...,  0.01009727,
        -0.00422214,  0.00510022],
       ...,
       [-0.00080807, -0.01734444,  0.01009727, ...,  0.733935  ,
        -0.01120324,  0.02924981],
       [ 0.00819083,  0.00780141, -0.00422214, ..., -0.01120324,
         0.73081106, -0.00959321],
       [-0.00097387, -0.00727564,  0.00510022, ...,  0.02924981,
        -0.00959321,  0.7271586 ]], dtype=float32)

In [171]:
p2

array([[ 0.74416569, -0.00257222,  0.00626908, ..., -0.00080806,
         0.00819083, -0.00097386],
       [-0.00257222,  0.74903792, -0.00496867, ..., -0.01734442,
         0.00780143, -0.00727562],
       [ 0.00626908, -0.00496867,  0.74798998, ...,  0.01009724,
        -0.00422214,  0.00510023],
       ...,
       [-0.00080806, -0.01734442,  0.01009724, ...,  0.73393512,
        -0.01120324,  0.02924984],
       [ 0.00819083,  0.00780143, -0.00422214, ..., -0.01120324,
         0.73081091, -0.00959322],
       [-0.00097386, -0.00727562,  0.00510023, ...,  0.02924984,
        -0.00959322,  0.72715861]])