In [220]:
import pickle
import numpy as np
import torch
import pytorch_lightning as pl
from pytorch_lightning import Trainer
import loss
import inlp
import scipy
from scipy import linalg
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
with open("data.sample.25k.bert-large.pickle", "rb") as f:
    data = pickle.load(f)

In [216]:
class Dataset(torch.utils.data.Dataset):
    
    """Simple torch dataset class"""
    def __init__(self, data, device):

        self.data = data
        self.device = device

    def __len__(self):

        return len(self.data)

    def __getitem__(self, index):

        with torch.no_grad():
             
            vec1_np, vec2_np, str1, str2, _ = self.data[index]
            
            vec1, vec2, str1, str2, pair_id = self.data[index]            
            vec1, vec2 = torch.from_numpy(vec1_np).float(), torch.from_numpy(vec2_np).float()
            
            vec1 = vec1.to(self.device)
            vec2 = vec2.to(self.device)
            
            return (vec1, vec2, str1, str2, pair_id)
        
        
def get_nullspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix
    """
    nullspace_basis = scipy.linalg.null_space(W)  # orthogonal basis

    nullspace_basis = nullspace_basis * np.sign(nullspace_basis[0][0])  # handle sign ambiguity
    projection_matrix = nullspace_basis.dot(nullspace_basis.T)

    return projection_matrix
    
    
def get_rowspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix
    """

    w_basis = scipy.linalg.orth(W.T) # orthogonal basis
    w_basis * np.sign(w_basis[0][0]) # handle sign ambiguity
    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace
    
    return P_W

In [217]:
class Siamese(pl.LightningModule):

    def __init__(self, X_train, X_dev, dim, batch_size, dropout_rate, device):
        super(Siamese, self).__init__()
        self.l = torch.nn.Linear(1024, dim)
        
        self.train_data = Dataset(X_train, device)
        self.dev_data = Dataset(X_dev, device)
        self.train_gen = torch.utils.data.DataLoader(self.train_data, batch_size = batch_size, drop_last = False, shuffle=True)
        self.dev_gen = torch.utils.data.DataLoader(self.dev_data, batch_size = batch_size, drop_last = False, shuffle=True)
        self.loss_fn = loss.BatchHardTripletLoss(final = "softmax", device = device)
        self.dropout = torch.nn.Dropout(p = dropout_rate)
        self.acc = None
        self.optimizer = torch.optim.Adam(self.parameters(), weight_decay = 1e-6)
        
    def forward(self, x1, x2):

          h1 = self.l(x1)
          h2 = self.l(x2)
         
          return h1, h2
 
    def train_network(self, num_epochs):
    
      trainer = Trainer(max_nb_epochs = num_epochs, min_nb_epochs = num_epochs, show_progress_bar = True)
      trainer.fit(self)

      return self.acc   
      
    def get_weights(self):
    
        return self.l.weight.data.numpy()
    
    def training_step(self, batch, batch_nb):
        # REQUIRED
        x1, x2, str1, str2, ids = batch

        h1, h2 = self.forward(self.dropout(x1), self.dropout(x2))
        loss_val =  self.loss_fn(h1, h2, str1, str2, ids, index=0, evaluation = False)
        
        return {'loss': loss_val[0]}
        

    def validation_step(self, batch, batch_nb):
    
        # OPTIONAL
        x1, x2, str1, str2, ids = batch
        h1, h2 = self.forward(x1, x2)
        loss_val =  self.loss_fn(h1, h2, str1, str2, ids, index=batch_nb, evaluation = True)
        return {'val_loss': loss_val[0]}

    def validation_end(self, outputs):
        # OPTIONAL    
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        print("Loss is {}".format(avg_loss))
        return {'avg_val_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        return torch.optim.Adam(self.parameters(), weight_decay = 1e-4)

    @pl.data_loader
    def train_dataloader(self):
        return self.train_gen

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        # can also return a list of val dataloaders
        return self.dev_gen

In [218]:
device = "cuda"
net = Siamese(data[:20000], data[20000:], batch_size = 1024, dim = 256,dropout_rate = 0.1, device = device).to(device)
net.train_network(num_epochs = 1)

INFO:root:      Name                  Type Params
0        l                Linear  262 K
1  loss_fn  BatchHardTripletLoss    0  
2  dropout               Dropout    0  
Validation sanity check:  80%|████████  | 4/5 [00:00<00:00,  5.10batch/s]


[A[A[A                            
[A                                  

Epoch 1:   0%|          | 0/25 [00:00<?, ?batch/s]                          

Loss is 0.3180398643016815


Epoch 1:  80%|████████  | 20/25 [00:03<00:00,  6.00batch/s, batch_nb=19, loss=0.276, v_nb=41]
Validating:   0%|          | 0/5 [00:00<?, ?batch/s][A
Epoch 1:  84%|████████▍ | 21/25 [00:03<00:00,  6.75batch/s, batch_nb=19, loss=0.276, v_nb=41]
Epoch 1:  88%|████████▊ | 22/25 [00:03<00:00,  6.57batch/s, batch_nb=19, loss=0.276, v_nb=41]
Epoch 1:  92%|█████████▏| 23/25 [00:03<00:00,  6.44batch/s, batch_nb=19, loss=0.276, v_nb=41]
Epoch 1:  96%|█████████▌| 24/25 [00:04<00:00,  6.37batch/s, batch_nb=19, loss=0.276, v_nb=41]
Epoch 1: 100%|██████████| 25/25 [00:04<00:00,  7.00batch/s, batch_nb=19, loss=0.276, v_nb=41]

[A[A
                                                            [A
[A
Epoch 1: 100%|██████████| 25/25 [00:04<00:00,  6.04batch/s, batch_nb=19, loss=0.276, v_nb=41]

Loss is 0.21725492179393768





In [183]:
W = net.l.weight.detach().cpu().numpy()
P_Rw = get_rowspace_projection(W)
I = np.eye(P_Rw.shape[0])
P_Nw = I - P_Rw
P_Nw2 = get_nullspace_projection(W)
vecs = np.array([d[0] for d in data[20000:]])
strings = [d[2] for d in data[20000:]]
vecs_transformed = W.dot(vecs.T).T
vecs_transformed_nullspace = P_Nw2.dot(vecs.T).T
print(P_Nw2-P_Nw)

[[ 1.00582838e-06 -5.45755029e-07  1.57393515e-07 ... -1.42492354e-07
   3.40864062e-07 -1.39698386e-08]
 [-5.45755029e-07 -3.72529030e-08  7.79982656e-07 ...  3.34344804e-07
   7.39237294e-08 -9.26665962e-08]
 [ 1.57393515e-07  7.79982656e-07  2.08616257e-07 ...  3.58559191e-08
  -4.84287739e-08  2.05589458e-07]
 ...
 [-1.42492354e-07  3.34344804e-07  3.58559191e-08 ... -8.94069672e-08
   1.87195837e-07 -6.42612576e-08]
 [ 3.40864062e-07  7.39237294e-08 -4.84287739e-08 ...  1.87195837e-07
   1.78813934e-07  8.42846930e-08]
 [-1.39698386e-08 -9.26665962e-08  2.05589458e-07 ... -6.42612576e-08
   8.42846930e-08 -1.78813934e-07]]


In [184]:
W_r_basis = scipy.linalg.orth(W.T)
print(W.shape, W_r_basis.shape)

(900, 1024) (1024, 900)


In [185]:

def to_string(np_array):
        return "\t".join(["%0.4f" % float(x) for x in np_array])
    
with open("vecs.transformed.nullspace.tsv", "w") as f:
    for v in vecs_transformed_nullspace:
        f.write(to_string(v) + "\n")

with open("labels.transformed.nullspace.tsv", "w") as f:
    for s in strings:
        f.write(s + "\n")
        
with open("vecs.transformed.tsv", "w") as f:
    for v in vecs_transformed:
        f.write(to_string(v) + "\n")

with open("labels.transformed.tsv", "w") as f:
    for s in strings:
        f.write(s + "\n")

In [186]:
np.linalg.matrix_rank(W)
print(np.linalg.matrix_rank(vecs))
print(np.linalg.matrix_rank(vecs_transformed_nullspace))
print(np.linalg.matrix_rank(P_Rw.dot(vecs.T).T))

1023
124
900


In [None]:
P, rowspace_projections, Ws = inlp.get_debiasing_projection(6, 1024, is_autoregressive = False, X_train = data, X_dev = data, dropout_rate = 0.05, device = "cuda")





  0%|          | 0/6 [00:00<?, ?it/s][A[A[A[AINFO:root:      Name                  Type Params
0        l                Linear  102 K
1  loss_fn  BatchHardTripletLoss    0  
2  dropout               Dropout    0  


Loss is 0.28423216938972473
Loss is 0.22954709827899933
Loss is 0.20683249831199646
Loss is 0.18630759418010712
Loss is 0.17049802839756012
Loss is 0.16003009676933289
Loss is 0.15152859687805176
Loss is 0.14518693089485168
Loss is 0.14100699126720428
Loss is 0.13558314740657806
Loss is 0.13317079842090607
Loss is 0.12940062582492828
Loss is 0.12703098356723785
Loss is 0.12439966201782227
Loss is 0.12251582741737366
Loss is 0.12137825042009354
Loss is 0.11912421882152557
Loss is 0.11834836006164551
Loss is 0.1167420968413353
Loss is 0.1148914247751236






iteration: 0, accuracy: None:   0%|          | 0/6 [02:06<?, ?it/s][A[A[A[A



iteration: 0, accuracy: None:  17%|█▋        | 1/6 [02:06<10:32, 126.59s/it][A[A[A[AINFO:root:      Name                  Type Params
0        l                Linear  102 K
1  loss_fn  BatchHardTripletLoss    0  
2  dropout               Dropout    0  


Loss is 0.113502137362957
Loss is 0.2785147726535797
Loss is 0.22862054407596588
Loss is 0.20555132627487183
Loss is 0.1855544000864029
Loss is 0.1697065681219101
Loss is 0.1592424511909485
Loss is 0.15148137509822845
Loss is 0.1455279439687729
Loss is 0.1396777629852295
Loss is 0.13603438436985016
Loss is 0.1318574994802475
Loss is 0.12972413003444672
Loss is 0.127724289894104
Loss is 0.12453426420688629
Loss is 0.12258563935756683
Loss is 0.12045826017856598
Loss is 0.11883234977722168
Loss is 0.11782550811767578
Loss is 0.1165626123547554
Loss is 0.11476627737283707






iteration: 1, accuracy: None:  17%|█▋        | 1/6 [04:16<10:32, 126.59s/it][A[A[A[A



iteration: 1, accuracy: None:  33%|███▎      | 2/6 [04:16<08:30, 127.65s/it][A[A[A[AINFO:root:      Name                  Type Params
0        l                Linear  102 K
1  loss_fn  BatchHardTripletLoss    0  
2  dropout               Dropout    0  


Loss is 0.11292089521884918
Loss is 0.27768227458000183
Loss is 0.22878669202327728
Loss is 0.20566412806510925
Loss is 0.1848350316286087
Loss is 0.17024558782577515
Loss is 0.15844789147377014
Loss is 0.15101775527000427
Loss is 0.14527741074562073
Loss is 0.13967938721179962
Loss is 0.13666410744190216
Loss is 0.13273440301418304
Loss is 0.12939713895320892
Loss is 0.12783721089363098


In [197]:
P, _, _ = P

In [199]:
vecs_transformed_nullspace = P.dot(vecs.T).T

In [200]:
with open("vecs.transformed.nullspace.tsv", "w") as f:
    for v in vecs_transformed_nullspace:
        f.write(to_string(v) + "\n")

with open("labels.transformed.nullspace.tsv", "w") as f:
    for s in strings:
        f.write(s + "\n")
        
with open("vecs.transformed.tsv", "w") as f:
    for v in vecs_transformed:
        f.write(to_string(v) + "\n")

with open("labels.transformed.tsv", "w") as f:
    for s in strings:
        f.write(s + "\n")