Trying a causual intervention strategy to determine how much the representations built for CCG supertagging are used in language modeling

In [1]:
# constants

model_path = "../CCGMultitask/models/augment/augment_.50_0_sgd_continue"
cuda = True

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import sys
import pickle
import tqdm

sys.path.insert(0, "../CCGMultitask/")
from model import MultiTaskModel
from train_augment import evaluate_lm, evaluate_ccg
from train_joint import evaluate_joint
from data import joint_tag_lm, augment_tag_lm, AugmentDataset, BatchSampler

In [6]:
with open(model_path + ".w2idx", "rb") as w2idx_f:
    w2idx = pickle.load(w2idx_f)
    
vocab = {i:w for (w, i) in w2idx.items()}
    
with open(model_path + ".c2idx", "rb") as c2idx_f:
    c2idx = pickle.load(c2idx_f)
    
categories = {i:c for (c, i) in c2idx.items()}
    
model = MultiTaskModel(len(w2idx.keys()), 650, 650, [len(w2idx.keys()), len(c2idx.keys())], 2)
model.load_state_dict(torch.load(model_path + ".pt", map_location = torch.device("cpu")))

model = model.eval()

if cuda: model.cuda()

In [7]:
test_data = joint_tag_lm("../CCGMultitask/data/ccg_supertags/ccg.23.common", 
                         "../CCGMultitask/data/ccg_supertags/categories", 35, w2idx=w2idx)
test_sampler = BatchSampler(test_data, 10)
test_loader = DataLoader(test_data, batch_sampler=test_sampler)

In [5]:
input = torch.tensor([w2idx[w.lower()] for w in "<eos> The key to the cabinet".split()])

if cuda: input = input.cuda()
    
hidden = model.init_hidden(1)

out, hidden = model.lstm(input.view(-1, 1), hidden)


In [6]:
state = out.squeeze()[-1]

# Generate outputs

logits = [decoder(state) for decoder in model.decoders]

lm_logits, ccg_logits = logits

lm_top5 = lm_logits.squeeze().topk(5)
for idx, p in zip(lm_top5.indices, lm_top5.values):
    print("{}:\t{}".format(vocab[idx.item()], p))
    
print("---")
ccg_top5 = ccg_logits.squeeze().topk(5)
for idx, p in zip(ccg_top5.indices, ccg_top5.values):
    print("{}:\t{}".format(categories[idx.item()], p))

's:	-1.6922204494476318
is:	-1.7930352687835693
was:	-2.329763174057007
of:	-2.6500298976898193
,:	-3.3813416957855225
---
N:	-0.01167889591306448
N/N:	-4.456457138061523
N/S[em]:	-11.797123908996582
(S[dcl]\NP)/(S[b]\NP):	-18.4254207611084
NP\NP:	-18.518659591674805


Ablate individual "neurons" in the final layer. (finding highly localized syntactic neurons used for one task but not the other).

In [22]:
lm_losses = {}
ccg_losses = {}

loss_f = nn.NLLLoss()

hidden = model.init_hidden(10)
with torch.no_grad():
    for (input, target_lm, target_ccg) in tqdm.tqdm(test_loader):
        if cuda:
            input = input.cuda()
            target_lm = target_lm.cuda()
            target_ccg = target_ccg.cuda()
                
        input = input.transpose(0,1).contiguous()
        target_lm = target_lm.transpose(0,1).contiguous()
        target_ccg = target_ccg.transpose(0,1).contiguous()
        
        state, hidden = model.lstm(input, hidden)
        for unit_idx in range(state.shape[-1]):
            state_ = state.clone().detach()
            state_[:, :, unit_idx] = 0 #ablate
            
            lm_probs, ccg_probs = [decoder(state_) for decoder in model.decoders]
            
            lm_losses[unit_idx] = lm_losses.get(unit_idx, 0) + loss_f(lm_probs, target_lm.view(-1)).item()
            ccg_losses[unit_idx] = ccg_losses.get(unit_idx, 0) + loss_f(ccg_probs, target_ccg.view(-1)).item()
        
        # unablated
        lm_probs, ccg_probs = [decoder(state) for decoder in model.decoders]
            
        lm_losses[-1] = lm_losses.get(-1, 0) + loss_f(lm_probs, target_lm.view(-1)).item()
        ccg_losses[-1] = ccg_losses.get(-1, 0) + loss_f(ccg_probs, target_ccg.view(-1)).item()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 165/165 [2:07:37<00:00, 46.41s/it]


In [23]:
with open("./ablated_rep_lm.loss", "wb") as lm_loss_f:
    pickle.dump(lm_losses, lm_loss_f)
    
with open("./ablated_rep_ccg.loss", "wb") as ccg_loss_f:
    pickle.dump(ccg_losses, ccg_loss_f)
    
print(len(lm_losses.items()))

loss_by_ablation = {i:(lm_losses[i]/len(test_loader), ccg_losses[i]/len(test_loader)) for i in range(len(lm_losses.items()))}

with open("./ablated_rep.loss", "wb") as loss_f:
    pickle.dump(loss_by_ablation, loss_f)

651


KeyError: 650

In [8]:
with open("./ablated_rep_lm.loss", "rb") as lm_loss_f:
    lm_losses = pickle.load(lm_loss_f)
    
with open("./ablated_rep_ccg.loss", "rb") as ccg_loss_f:
    ccg_losses = pickle.load(ccg_loss_f)

loss_by_ablation = {(i-1):(lm_losses[i-1]/len(test_loader), ccg_losses[i-1]/len(test_loader)) for i in range(len(lm_losses.items()))}

print(loss_by_ablation)

csv_losses = []

for i, (lm_loss, ccg_loss) in loss_by_ablation.items():
    csv_losses.append({"ablated_dim": i,
                       "lm_loss": lm_loss,
                       "ccg_loss": ccg_loss})
    
import csv

with open("ablated_rep.csv", "w") as csv_out_f:
    writer = csv.DictWriter(csv_out_f, fieldnames = csv_losses[0].keys())
    writer.writeheader()
    writer.writerows(csv_losses)

{-1: (4.920966253858624, 1.2028521180152894), 0: (4.92141254309452, 1.2025727788607279), 1: (4.920021029674646, 1.1998506853074737), 2: (4.920931123964714, 1.2011630336443584), 3: (4.920268097790805, 1.2018451770146688), 4: (4.921880266883156, 1.200346716967496), 5: (4.921704280737675, 1.2013453443845112), 6: (4.921463424509222, 1.2015910350915158), 7: (4.921231868050315, 1.2020735614227527), 8: (4.920827848261053, 1.1993766112761064), 9: (4.922034499139497, 1.2021556124542698), 10: (4.920778635776404, 1.201355913552371), 11: (4.921153397993608, 1.2022711865829698), 12: (4.922376184752493, 1.2024511969450749), 13: (4.921757426406398, 1.2023827509446579), 14: (4.921260865529378, 1.2019453207651773), 15: (4.922083253571482, 1.2025107661883037), 16: (4.9258659955227015, 1.2001202016165762), 17: (4.921839412053426, 1.2033789587743355), 18: (4.921325328133323, 1.202247925238176), 19: (4.920799788561735, 1.197259164940227), 20: (4.920733034249508, 1.2018260988322171), 21: (4.921380095048384,

Use a non-iterated version of iNLP 

In [14]:
model_lmccg_path = "../CCGMultitask/models/augment/augment_.50_0_sgd_continue"
model_lmonly_path = "../CCGMultitask/models/augment/augment_1.00_0_sgd_continue"
model_ccgprobe_path = "./models/augment_1.00_0_sgd_continue_ccgfrozen"

In [15]:
def load_model(model_path):
    with open(model_path + ".w2idx", "rb") as w2idx_f:
        w2idx = pickle.load(w2idx_f)
    
    vocab = {i:w for (w, i) in w2idx.items()}
    
    with open(model_path + ".c2idx", "rb") as c2idx_f:
        c2idx = pickle.load(c2idx_f)
    
    categories = {i:c for (c, i) in c2idx.items()}
    
    model = MultiTaskModel(len(w2idx.keys()), 650, 650, [len(w2idx.keys()), len(c2idx.keys())], 2)
    model.load_state_dict(torch.load(model_path + ".pt", map_location = torch.device("cpu")))

    model = model.eval()

    if cuda: model.cuda()
        
    return model, w2idx, vocab, c2idx, categories

model_lmccg, w2idx_lmccg, vocab_lmccg, c2idx_lmccg, categories_lmccg = load_model(model_lmccg_path)
model_lmonly, w2idx_lmonly, vocab_lmonly, c2idx_lmonly, categories_lmonly = load_model(model_lmonly_path)
model_ccgprobe, w2idx_ccgprobe, vocab_ccgprobe, c2idx_ccgprobe, categories_ccgprobe = load_model(model_ccgprobe_path)

In [18]:
test_data_lmccg = joint_tag_lm("../CCGMultitask/data/ccg_supertags/ccg.23.common", 
                         "../CCGMultitask/data/ccg_supertags/categories", 35, w2idx=w2idx_lmccg)
test_sampler_lmccg = BatchSampler(test_data_lmccg, 10)
test_loader_lmccg = DataLoader(test_data_lmccg, batch_sampler=test_sampler_lmccg)

test_data_lmonly = joint_tag_lm("../CCGMultitask/data/ccg_supertags/ccg.23.common", 
                         "../CCGMultitask/data/ccg_supertags/categories", 35, w2idx=w2idx_lmonly)
test_sampler_lmonly = BatchSampler(test_data_lmonly, 10)
test_loader_lmonly = DataLoader(test_data_lmonly, batch_sampler=test_sampler_lmonly)

test_data_ccgprobe = joint_tag_lm("../CCGMultitask/data/ccg_supertags/ccg.23.common", 
                         "../CCGMultitask/data/ccg_supertags/categories", 35, w2idx=w2idx_ccgprobe)
test_sampler_ccgprobe = BatchSampler(test_data_ccgprobe, 10)
test_loader_ccgprobe = DataLoader(test_data_ccgprobe, batch_sampler=test_sampler_ccgprobe)

In [7]:
import numpy as np
import scipy.linalg as linalg

ccg_decoder = model.decoders[1]

W = ccg_decoder.linear.weight.detach().cpu().numpy()

basis = linalg.null_space(W)

P = basis.dot(basis.T)

sum(abs(W.dot(P.dot(np.random.rand(650)))) < 1e-5) == 427

P = torch.Tensor(P)

if cuda: P = P.cuda()

P.shape

NameError: name 'model' is not defined

In [8]:
def get_P(weight):
    W = weight.detach().cpu().numpy()
    basis = linalg.null_space(W)
    P = basis.dot(basis.T)

    P = torch.Tensor(P)
    if cuda: P = P.cuda()
        
    return P

def nsp_losses(model, P, test_loader):
    lm_loss = 0
    ccg_loss = 0
    ccg_correct = 0

    nsp_lm_loss = 0
    nsp_ccg_loss = 0
    nsp_ccg_correct = 0

    total_examples = 0
    
    loss_f = nn.NLLLoss()

    hidden = model.init_hidden(10)
    with torch.no_grad():
        for (input, target_lm, target_ccg) in tqdm.tqdm(test_loader):
            if cuda:
                input = input.cuda()
                target_lm = target_lm.cuda()
                target_ccg = target_ccg.cuda()
                
            input = input.transpose(0,1).contiguous()
            target_lm = target_lm.transpose(0,1).contiguous()
            target_ccg = target_ccg.transpose(0,1).contiguous()
        
            state, hidden = model.lstm(input, hidden)
            state_ = state.clone().detach()
            state_ = torch.matmul(P, state_.view(state_.shape[0], state_.shape[1], state_.shape[2], 1))
            state_ = state_.squeeze()

            num_examples = len(target_lm.view(-1))
            total_examples += num_examples
        
            nsp_lm_probs, nsp_ccg_probs = [decoder(state_) for decoder in model.decoders]
            nsp_lm_loss += num_examples * loss_f(nsp_lm_probs, target_lm.view(-1)).item()
            nsp_ccg_loss += num_examples * loss_f(nsp_ccg_probs, target_ccg.view(-1)).item()
            nsp_ccg_correct += (nsp_ccg_probs.argmax(dim=1) == target_ccg.view(-1)).sum().item()
        
            lm_probs, ccg_probs = [decoder(state) for decoder in model.decoders]
            lm_loss += num_examples * loss_f(lm_probs, target_lm.view(-1)).item()
            ccg_loss += num_examples * loss_f(ccg_probs, target_ccg.view(-1)).item()
            ccg_correct += (ccg_probs.argmax(dim=1) == target_ccg.view(-1)).sum().item()
        
        
    lm_loss = lm_loss/total_examples
    ccg_loss = ccg_loss/total_examples
    ccg_accuracy = ccg_correct/total_examples
    
    nsp_lm_loss = nsp_lm_loss/total_examples
    nsp_ccg_loss = nsp_ccg_loss/total_examples
    nsp_ccg_accuracy = nsp_ccg_correct/total_examples
    
    return lm_loss, ccg_loss, ccg_accuracy, nsp_lm_loss, nsp_ccg_loss, nsp_ccg_accuracy    

In [9]:
_, _, _, rand_lm_loss, rand_ccg_loss, rand_ccg_accuracy = nsp_losses(model_lmccg, get_P(torch.rand((650, 650))).cuda(), test_loader_lmccg)
print("no intervention: lm {}, ccg {}/{}\nintervention: lm {}, ccg {}/{}\nrandom proj: lm {}, ccg {}/{}".format(
    *nsp_losses(model_lmccg, get_P(model_lmccg.decoders[1].linear.weight), test_loader_lmccg), 
    rand_lm_loss, rand_ccg_loss, rand_ccg_accuracy))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 165/165 [00:25<00:00,  6.39it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 165/165 [00:26<00:00,  6.26it/s]

no intervention: lm 4.920966253858624, ccg 1.2028521180152894/0.8448658008658009
intervention: lm 7.180473087773178, ccg 6.509100890882087/0.21229437229437229
random proj: lm 7.83287026087443, ccg 6.464027540611498/0.21229437229437229





In [16]:
print("no intervention: lm {}, ccg {}/{}\nintervention: lm {}, ccg {}/{}".format(
    *nsp_losses(model_lmonly, get_P(model_lmonly.decoders[1].linear.weight), test_loader_lmonly)))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 165/165 [00:25<00:00,  6.40it/s]

no intervention: lm 4.325432572220311, ccg 6.098769101229581/0.002718614718614719
intervention: lm 5.966211801586729, ccg 6.056785089319402/0.07601731601731601





In [19]:
print("no intervention: lm {}, ccg {}/{}\nintervention: lm {}, ccg {}/{}".format(
    *nsp_losses(model_ccgprobe, get_P(model_ccgprobe.decoders[1].linear.weight), test_loader_ccgprobe)))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 165/165 [00:25<00:00,  6.43it/s]

no intervention: lm 4.325432572220311, ccg 0.5187777609536142/0.8429783549783549
intervention: lm 7.181911335569439, ccg 4.455898727070202/0.21229437229437229



