In [2]:
import torch
import transformers
import numpy as np
import tqdm
import wandb
import math
import random

## Some parameters


In [43]:
config = {
  # dataset
  "digits": 4,
  "highest_number": 33,
  "train_size": 10000,
  "test_size": 1000,
  # learning
  "learning_rate": 5e-5,
  "num_warmup_steps": 100,
  "num_epochs": 150,
  "batch_size": 64,
  "num_warmup_epochs": 40,
  # loops
  "eval_freq": 50,
  # model
  "num_attention_heads": 1,
  "num_hidden_layers": 6,
  "alignments2": [
        [[-1,1], "w"],
        [[-1,2], "x"],
        [[-1,3], "y"],
        [[2,0], "S1"],
        [[2,1], "C1"],
        [[4,0], "S2"],
    ],
  "alignments1": [
        [[-1,1], "w"],
        [[-1,2], "x"],
        [[-1,3], "y"],
        [[-1,4], "z"],
        [[2,0], "S1"],
        [[2,1], "C1"],
        [[2,2], "C2"],
        [[4,0], "S2"],
        [[4,1], "C3"],
        [[5,0], "O"]
    ]
}

num_training_steps_per_epoch =  int(math.ceil(config['train_size']/config['batch_size']))

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")

print(device)


cuda


In [44]:
  model_name = 'bert-base-uncased'

  bertConfig = transformers.BertConfig.from_pretrained(model_name)
  bertConfig

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.16.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

## Model & optimizers

In [45]:
def prepare_training():
  model_name = 'bert-base-uncased'

  bertConfig = transformers.BertConfig.from_pretrained(model_name)
  bertConfig.num_labels = config['highest_number'] * config['digits'] + 1
  bertConfig.num_hidden_layers = config['num_hidden_layers']
  bertConfig.num_attention_heads = config['num_attention_heads']

  tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
  # NOTE: did it initialize the weights from the model_name?
  model = transformers.BertForSequenceClassification(bertConfig)

  model.to(device);

  optimizer = transformers.AdamW(model.parameters(), lr=config['learning_rate'])
  lr_scheduler = transformers.get_scheduler("linear", optimizer=optimizer,num_warmup_steps=config['num_warmup_steps'], num_training_steps=num_training_steps_per_epoch*config['num_epochs'])

  return model, tokenizer, optimizer, lr_scheduler

def prepare_training2():
  model_name = 'bert-base-uncased'

  bertConfig = transformers.BertConfig.from_pretrained(model_name)
  bertConfig.num_labels = config['highest_number'] * config['digits'] + 1
  bertConfig.T2_num_labels = config['highest_number'] * (config['digits']-1) + 1
  bertConfig.num_hidden_layers = config['num_hidden_layers']
  bertConfig.num_attention_heads = config['num_attention_heads']

  tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
  # NOTE: did it initialize the weights from the model_name?
  model = BertForMultiSequenceClassification(bertConfig)

  model.to(device);

  optimizer = transformers.AdamW(model.parameters(), lr=config['learning_rate'])
  lr_scheduler = transformers.get_scheduler("linear", optimizer=optimizer,num_warmup_steps=config['num_warmup_steps'], num_training_steps=num_training_steps_per_epoch*config['num_epochs'])

  return model, tokenizer, optimizer, lr_scheduler

def prepare_training3():
  model_name = 'bert-base-uncased'

  bertConfig = transformers.BertConfig.from_pretrained(model_name)
  bertConfig.num_labels = config['highest_number'] * config['digits'] + 1
  bertConfig.T2_num_labels = config['highest_number'] * (config['digits']-1) + 1
  bertConfig.num_hidden_layers = config['num_hidden_layers']
  bertConfig.num_attention_heads = config['num_attention_heads']

  tokenizer = transformers.BertTokenizer.from_pretrained(model_name)
  # NOTE: did it initialize the weights from the model_name?
  neural_model = InterventionableTransformer(BertForMultiSequenceClassification(bertConfig))
  causal_model = Interventionable2(CausalArithmetic2(config))

  neural_model.model.to(device);
  causal_model.model.to(device);

  optimizer = transformers.AdamW(neural_model.model.parameters(), lr=config['learning_rate'])
  lr_scheduler = transformers.get_scheduler("linear", optimizer=optimizer,num_warmup_steps=config['num_warmup_steps'], num_training_steps=num_training_steps_per_epoch*config['num_epochs'])

  return neural_model, causal_model, tokenizer, optimizer, lr_scheduler

## Define dataset and helpers


In [46]:
def get_dataset(digits=4, highest_number=33, size=30):
  x_numbers = np.random.randint(low=0, high=highest_number, size=(size,digits))
  y = np.sum(x_numbers, axis=1)

  def list_to_sentence(ls):
    return " ".join([str(i) for i in ls])

  x_sentences = [list_to_sentence(ls) for ls in x_numbers]
  
  return x_sentences, y

def get_dataset2(digits=4, highest_number=33, size=30):
  x_numbers = np.random.randint(low=0, high=highest_number, size=(size,digits))
  y = np.sum(x_numbers, axis=1)
  y2 = np.sum(x_numbers[:,:3], axis=1)

  def list_to_sentence(ls):
    return " ".join([str(i) for i in ls])

  x_sentences = [list_to_sentence(ls) for ls in x_numbers]
  
  return x_sentences, y, y2

def get_dataset3(digits=4, highest_number=33, size=30):
  x_numbers = np.random.randint(low=0, high=highest_number, size=(size,digits))
  y = np.sum(x_numbers, axis=1)
  y2 = np.sum(x_numbers[:,:3], axis=1)

  def list_to_sentence(ls):
    return " ".join([str(i) for i in ls])

  x_sentences = [list_to_sentence(ls) for ls in x_numbers]
  
  return x_numbers, x_sentences, y, y2

In [47]:
def tokenize_sample(x, y, tokenizer):
  x = tokenizer(list(x), return_tensors='pt').to(device)
  y = y.long().to(device)

  return x, y

def tokenize_sample2(x, y, y2, tokenizer):
  x = tokenizer(list(x), return_tensors='pt').to(device)
  y = y.long().to(device)
  y2 = y2.long().to(device)

  return x, y, y2

In [48]:
class ArithmeticDataset(torch.utils.data.Dataset):
  def __init__(self, digits=4, highest_number=33, size=30):
    super().__init__()

    self.size = size
    self.x, self.y = get_dataset(digits=digits, highest_number=highest_number, size=size)

  def __getitem__(self, index):
    return self.x[index], self.y[index]

  def __len__(self):
    return self.size

class ArithmeticDataset2(torch.utils.data.Dataset):
  def __init__(self, digits=4, highest_number=33, size=30):
    super().__init__()

    self.size = size
    self.x, self.y, self.y2 = get_dataset2(digits=digits, highest_number=highest_number, size=size)

  def __getitem__(self, index):
    return self.x[index], self.y[index], self.y2[index]

  def __len__(self):
    return self.size

class ArithmeticDataset3(torch.utils.data.Dataset):
  def __init__(self, digits=4, highest_number=33, size=30):
    super().__init__()

    self.size = size
    self.x_numbers, self.x, self.y, self.y2 = get_dataset3(digits=digits, highest_number=highest_number, size=size)

  def __getitem__(self, index):
    return  self.x_numbers[index], self.x[index], self.y[index], self.y2[index]

  def __len__(self):
    return self.size

In [49]:
def get_dataloader(ds):
  dl = iter(torch.utils.data.DataLoader(ds, batch_size=config['batch_size'], shuffle=True))
  return dl

## Training stuff

In [50]:
def get_accuracy(logits, y):
  y_hat = torch.argmax(logits, dim=1)
  correct = sum(y_hat == y)
  acc = correct / len(y_hat) * 100
  return acc

def train_step(loss, optimizer, lr_scheduler):
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  lr_scheduler.step()

def train_log(epoch, step, loss, lr, pbar, logits, y):
  # print(f'{step:04}: {loss.item()}')
  wandb.log({"train loss": loss})
  wandb.log({"learning rate": lr})

  acc = get_accuracy(logits, y)
  wandb.log({"train acc": acc})


  pbar.set_postfix(loss=loss.item(), step=step)

def train_log2(epoch, step, loss1, loss2, lr, pbar, logits1, logits2, y1, y2):
  # print(f'{step:04}: {loss.item()}')
  wandb.log({"T1 train loss": loss1})
  wandb.log({"T2 train loss": loss2})
  wandb.log({"learning rate": lr})

  acc1 = get_accuracy(logits1, y1)
  acc2 = get_accuracy(logits2, y2)
  wandb.log({"T1 train acc": acc1})
  wandb.log({"T2 train acc": acc2})


  pbar.set_postfix(loss=loss1.item(), step=step)

def train_log3(epoch, step, T1_loss, T2_loss2, iit_loss, lr, pbar, logits1, logits2, y1, y2):
  # print(f'{step:04}: {loss.item()}')
  wandb.log({"T1 train loss": T1_loss})
  wandb.log({"T2 train loss": T2_loss2})
  wandb.log({"learning rate": lr})

  acc1 = get_accuracy(logits1, y1)
  acc2 = get_accuracy(logits2, y2)
  wandb.log({"T1 train acc": acc1})
  wandb.log({"T2 train acc": acc2})

  wandb.log({"iit loss": iit_loss})

  pbar.set_postfix(loss=T1_loss.item(), step=step)

In [51]:
def split_input_dict(input_dict, halfway_point):
  first_half = {}
  second_half = {}
  for k,v in input_dict.items():
    first_half[k] = v[:halfway_point]
    second_half[k] = v[halfway_point:2*halfway_point]
  return first_half, second_half

In [52]:
def eval(model, tokenizer, test_ds):
  model.eval()

  eval_loss = 0.0
  eval_acc = 0.0
  test_dl = get_dataloader(test_ds)

  for step in range(len(test_dl)):
      x , y = test_dl.next()
      x, y = tokenize_sample(x,y,tokenizer)

      output = model(**x, labels=y)

      eval_loss += output.loss.item()
      eval_acc += get_accuracy(output.logits, y)

  eval_loss /= len(test_dl)
  eval_acc /= len(test_dl)

  wandb.log({"test loss": eval_loss})
  wandb.log({"test acc": eval_acc})

  model.train()

def eval2(model, tokenizer, test_ds):
  model.eval()

  eval_loss1 = 0.0
  eval_acc1 = 0.0
  eval_loss2 = 0.0
  eval_acc2 = 0.0
  test_dl = get_dataloader(test_ds)

  for step in range(len(test_dl)):
      x , y, y2 = test_dl.next()
      x, y, y2 = tokenize_sample2(x,y,y2,tokenizer)

      output1, output2 = model(**x, labels=y, labels2=y2)

      eval_loss1 += output1.loss.item()
      eval_acc1 += get_accuracy(output1.logits, y)
      eval_loss2 += output2.loss.item()
      eval_acc2 += get_accuracy(output2.logits, y2)

  eval_loss1 /= len(test_dl)
  eval_acc1 /= len(test_dl)
  eval_loss2 /= len(test_dl)
  eval_acc2 /= len(test_dl)

  wandb.log({"T1 test loss": eval_loss1})
  wandb.log({"T1 test acc": eval_acc1})
  wandb.log({"T2 test loss": eval_loss2})
  wandb.log({"T2 test acc": eval_acc2})

  model.train()

def eval3(model, tokenizer, test_ds):
  model.eval()

  eval_loss1 = 0.0
  eval_acc1 = 0.0
  eval_loss2 = 0.0
  eval_acc2 = 0.0
  test_dl = get_dataloader(test_ds)

  for step in range(len(test_dl)):
      _, x , y, y2 = test_dl.next()
      x, y, y2 = tokenize_sample2(x,y,y2,tokenizer)

      output1, output2 = model(**x, labels=y, labels2=y2)

      eval_loss1 += output1.loss.item()
      eval_acc1 += get_accuracy(output1.logits, y)
      eval_loss2 += output2.loss.item()
      eval_acc2 += get_accuracy(output2.logits, y2)

  eval_loss1 /= len(test_dl)
  eval_acc1 /= len(test_dl)
  eval_loss2 /= len(test_dl)
  eval_acc2 /= len(test_dl)

  wandb.log({"T1 test loss": eval_loss1})
  wandb.log({"T1 test acc": eval_acc1})
  wandb.log({"T2 test loss": eval_loss2})
  wandb.log({"T2 test acc": eval_acc2})

  model.train()

def ii_accuracy(neural_model, causal_model, tokenizer, test_ds, alignment, task="1"):
  neural_model.model.eval()

  neural_node, causal_node = alignment
  test_dl = get_dataloader(test_ds)

  n = []
  correct = []
  for i in range(len(test_dl)):
    x_numbers, x, y, y2 = test_dl.next()
    x , y, y2= tokenize_sample2(x, y, y2,tokenizer)

    # split in source and base input
    halfway_point = math.floor(x['input_ids'].shape[0]/2)

    x_numbers_base, x_numbers_source = x_numbers[:halfway_point], x_numbers[halfway_point:2*halfway_point]
    x_base, x_source = split_input_dict(x, halfway_point)
    y_base, y_source = y[:halfway_point], y[halfway_point:2*halfway_point] 
    y2_base, y2_source = y2[:halfway_point], y2[halfway_point:2*halfway_point] 


    with torch.no_grad():
        if task == '1':
            _, _, predict_intervention, _, _, _ = neural_model.forward(
                x_source, x_base, neural_node)
            _, _, target_intervention, _, _, _ = causal_model.forward(
                x_numbers_source, x_numbers_base, causal_node)
        if task == '2':
            _, _, _, _, _, predict_intervention = neural_model.forward(
                x_source, x_base, neural_node)
            _, _, _, _, _, target_intervention = causal_model.forward(
                x_numbers_source, x_numbers_base, causal_node)

        predict_labels = torch.argmax(predict_intervention.logits, dim=1).cpu()

        correct.append(sum(predict_labels == target_intervention))
        n.append(halfway_point)
    
    correct = np.sum(correct)
    acc = 100 * correct / np.sum(n)

    neural_model.model.train()

    return correct, acc
    
def eval_ii(neural_model, causal_model, tokenizer, test_ds, config):
  for alignment in config['alignments1']:
    correct, acc = ii_accuracy(neural_model, causal_model, tokenizer, test_ds, alignment, task="1")
    wandb.log({f"T1 ii  accuracy {alignment}": acc})
  for alignment in config['alignments2']:
    correct, acc = ii_accuracy(neural_model, causal_model, tokenizer, test_ds, alignment, task="2")
    wandb.log({f"T2 ii  accuracy {alignment}": acc})

In [53]:
def train(model, tokenizer, optimizer, lr_scheduler, train_ds, test_ds):
  model.train()

  global_steps = 0
  for epoch in range(config['num_epochs']):

    train_dl = get_dataloader(train_ds)

    pbar = tqdm.trange(num_training_steps_per_epoch, unit="steps", position=0, leave=True)
    pbar.set_description(f"Epoch {epoch}")

    for step in range(len(train_dl)):
      x , y = train_dl.next()
      x, y = tokenize_sample(x,y,tokenizer)

      output = model(**x, labels=y)

      train_step(output.loss, optimizer, lr_scheduler)

      train_log(epoch, step, output.loss, lr_scheduler.get_last_lr()[0], pbar, output.logits, y)

      pbar.update(1)
      
      if global_steps % config['eval_freq'] == 0:
        eval(model, tokenizer, test_ds)
      
      global_steps += 1


    pbar.close()

def train2(model, tokenizer, optimizer, lr_scheduler, train_ds, test_ds):
  model.train()

  global_steps = 0
  for epoch in range(config['num_epochs']):

    train_dl = get_dataloader(train_ds)

    pbar = tqdm.trange(num_training_steps_per_epoch, unit="steps", position=0, leave=True)
    pbar.set_description(f"Epoch {epoch}")

    for step in range(len(train_dl)):
      x , y, y2 = train_dl.next()
      x , y, y2= tokenize_sample2(x , y, y2,tokenizer)

      output1, output2 = model(**x, labels=y, labels2=y2)

      train_step(output1.loss + output2.loss, optimizer, lr_scheduler)

      train_log2(epoch, step, output1.loss, output2.loss, lr_scheduler.get_last_lr()[0], pbar, output1.logits, output2.logits, y, y2)

      pbar.update(1)
      
      if global_steps % config['eval_freq'] == 0:
        eval2(model, tokenizer, test_ds)
      
      global_steps += 1


    pbar.close()


def train3(neural_model, causal_model, tokenizer, optimizer, lr_scheduler, train_ds, test_ds):
  neural_model.model.train()

  loss_fct = torch.nn.CrossEntropyLoss()

  global_steps = 0
  for epoch in range(config['num_epochs']):

    train_dl = get_dataloader(train_ds)

    pbar = tqdm.trange(num_training_steps_per_epoch, unit="steps", position=0, leave=True)
    pbar.set_description(f"Epoch {epoch}")

    for step in range(len(train_dl)):
      x_numbers, x , y, y2 = train_dl.next()
      x_numbers = x_numbers.to(device)
      x , y, y2= tokenize_sample2(x, y, y2,tokenizer)

      # split in source and base input
      halfway_point = math.floor(x['input_ids'].shape[0]/2)

      x_numbers_base, x_numbers_source = x_numbers[:halfway_point], x_numbers[halfway_point:2*halfway_point]
      x_base, x_source = split_input_dict(x, halfway_point)
      y_base, y_source = y[:halfway_point], y[halfway_point:2*halfway_point] 
      y2_base, y2_source = y2[:halfway_point], y2[halfway_point:2*halfway_point] 

      # sample alignment
      neural_node, causal_node = random.choice(config['alignments2'])

      # run intervention
      source_logits_T1, base_logits_T1, _, source_logits_T2, base_logits_T2, counterfactual_logits_T2 = neural_model.forward(x_source, x_base, neural_node)
      with torch.no_grad():
          _, _, _, _, _, counterfactual_target_T2 = causal_model.forward(x_numbers_source, x_numbers_base, causal_node)
      # source_logits_T1, source_logits_T2 = neural_model.model(**x_source, labels=y_source, labels2=y2_source)
      # base_logits_T1, base_logits_T2 = neural_model.model(**x_base, labels=y_base, labels2=y2_base)

      # task loss on all seen examples
      T1_logits_all = torch.cat((source_logits_T1.logits, base_logits_T1.logits), dim=0)
      y_all = torch.cat((y_source, y_base), dim=0)

      T2_logits_all = torch.cat((source_logits_T2.logits, base_logits_T2.logits), dim=0)
      y2_all = torch.cat((y2_source, y2_base), dim=0)

      # T1_loss = loss_fct(source_logits_T1.logits.view(-1, neural_model.model.config.num_labels), y_source.view(-1)) + \
      #         loss_fct(base_logits_T1.logits.view(-1, neural_model.model.config.num_labels), y_base.view(-1))

      # T2_loss = loss_fct(source_logits_T2.logits.view(-1, neural_model.model.T2_num_labels), y2_source.view(-1)) + \
      #         loss_fct(base_logits_T2.logits.view(-1, neural_model.model.T2_num_labels), y2_base.view(-1))

      T1_loss = loss_fct(T1_logits_all.view(-1, neural_model.model.config.num_labels), y_all.view(-1)) 

      T2_loss = loss_fct(T2_logits_all.view(-1, neural_model.model.T2_num_labels), y2_all.view(-1))

      # T1_loss = (source_logits_T1.loss + base_logits_T1.loss ) / 2 
      # T2_loss = (source_logits_T2.loss + base_logits_T2.loss ) / 2 

      # T1_loss = source_logits_T1.loss 
      # T2_loss = source_logits_T2.loss 
      
      # iit loss
      iit_loss = loss_fct(counterfactual_logits_T2.logits, counterfactual_target_T2)
      # iit_loss = torch.zeros_like(T1_loss, device=T1_loss.device)
    
      # TODO: reimplement warmup training better
      alpha = min((step+epoch*len(train_dl))/(config['num_warmup_epochs']*len(train_dl)),1.0)
      wandb.log({"alpha":alpha})
      T1_loss = alpha * T1_loss
    
      train_step(T1_loss + T2_loss + iit_loss, optimizer, lr_scheduler)

      # train_log3(epoch, step, T1_loss, T2_loss, iit_loss, lr_scheduler.get_last_lr()[0], pbar, base_logits_T1.logits, base_logits_T2.logits, y_base, y2_base)
      train_log3(epoch, step, T1_loss, T2_loss, iit_loss, lr_scheduler.get_last_lr()[0], pbar, source_logits_T1.logits, source_logits_T2.logits, y_source, y2_source)

      pbar.update(1)
      
      if global_steps % config['eval_freq'] == 0:
        eval3(neural_model.model, tokenizer, test_ds)
        eval_ii(neural_model, causal_model, tokenizer, test_ds, config)
      
      global_steps += 1


    pbar.close()


## Run on the single task

In [54]:
# wandb.init(project="transformer-arithmetic-multiIIT", entity="stanford-causality", config=config, mode="online")
# wandb.config = config

# model, tokenizer, optimizer, lr_scheduler = prepare_training()

# train_ds = ArithmeticDataset(digits=config['digits'], highest_number=config['highest_number'], size=config['train_size'])
# test_ds = ArithmeticDataset(digits=config['digits'], highest_number=config['highest_number'], size=config['test_size'])



# train(model, tokenizer, optimizer, lr_scheduler, train_ds, test_ds)

# wandb.finish()

## Custom multitask model

In [55]:
# option 1: crude
# wrap bertforsequenceclassification
# add second classification head (no pooling) - TODO:consider removing pooling for sequence output for fairer comparison
# https://github.com/frankaging/Causal-Distill-XXS/blob/a704dbcb40440a24ae4415348eddbb9f377d49b4/src/modeling.py#L11 -> BertNOPooler
# add new output, define loss

# option 2: cleaner
# make an analogous class to bertforsequence classification
# have a new output class
# handle everything where it should be

In [56]:
# option 1:
class BertForMultiSequenceClassification(transformers.BertPreTrainedModel):
  def __init__(self, config):
    super().__init__(config)

    self.bert = transformers.BertForSequenceClassification(config)

    self.T2_num_labels = config.T2_num_labels
    self.config = config

    classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    )
    self.dropout2 = torch.nn.Dropout(classifier_dropout)
    self.classifier2 = torch.nn.Linear(config.hidden_size, config.T2_num_labels)

    # Initialize weights and apply final processing
    self.post_init()

  def forward(
    self,
    input_ids=None,
    attention_mask=None,
    token_type_ids=None,
    position_ids=None,
    head_mask=None,
    inputs_embeds=None,
    labels=None,
    labels2=None,
    output_attentions=None,
    output_hidden_states=None,
    return_dict=None,
  ):
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    outputs = self.bert(
        input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        output_attentions=output_attentions,
        labels=labels,
        output_hidden_states=True,
        return_dict=return_dict,
    )


    hidden_states = outputs.hidden_states
    layers = hidden_states[1:]

    position = (4,0)
    intermediate_output = layers[position[0]][:,position[1]]

    intermediate_output = self.dropout2(intermediate_output)
    logits = self.classifier2(intermediate_output)

    loss = None
    if labels2 is not None:
        if self.config.problem_type is None:
            if self.T2_num_labels == 1:
                self.config.problem_type = "regression"
            elif self.T2_num_labels > 1 and (labels2.dtype == torch.long or labels2.dtype == torch.int):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"

        if self.config.problem_type == "regression":
            loss_fct = torch.nn.MSELoss()
            if self.T2_num_labels == 1:
                loss = loss_fct(logits.squeeze(), labels2.squeeze())
            else:
                loss = loss_fct(logits, labels2)
        elif self.config.problem_type == "single_label_classification":
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.T2_num_labels), labels2.view(-1))
        elif self.config.problem_type == "multi_label_classification":
            loss_fct = torch.nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels2)

    return outputs, transformers.modeling_outputs.SequenceClassifierOutput(
        loss=loss,
        logits=logits,
    )

## Run on multi task

In [57]:
# wandb.init(project="transformer-arithmetic-multiIIT", entity="stanford-causality", config=config, mode="disabled")
# wandb.config = config

# model, tokenizer, optimizer, lr_scheduler = prepare_training2()

# train_ds = ArithmeticDataset2(digits=config['digits'], highest_number=config['highest_number'], size=config['train_size'])
# test_ds = ArithmeticDataset2(digits=config['digits'], highest_number=config['highest_number'], size=config['test_size'])

# train2(model, tokenizer, optimizer, lr_scheduler, train_ds, test_ds)

# wandb.finish()

## Interventionable transformer

In [58]:
class Interventionable2():
    # NOTE: can probably be merged with Interventionable1
    def __init__(self, model):
        self.activation = {}
        self.model = model

        self.names_to_layers = dict(self.model.named_children())

    def _get_activation(self, name):
        def hook(model, input, output):
            self.activation[name] = output
        return hook

    def _set_activation(self, name):
        def hook(model, input, output):
            return self.activation[name]
        return hook

    def forward(self, source, base, layer_name):
        # NOTE: other ways that do not require constantly adding / removing hooks should exist
        assert layer_name in self.names_to_layers

        # set hook to get activation
        get_handler = self.names_to_layers[layer_name].register_forward_hook(
            self._get_activation(layer_name))

        # get output on source examples (and also capture the activations)
        # with torch.no_grad():
        #     source_logits = self.model(source)
        source_logits_T1, source_logits_T2 = self.model(source)

        # remove the handler (don't store activations on base)
        get_handler.remove()

        # get base logits
        base_logits_T1, base_logits_T2 = self.model(base)

        # set hook to do the intervention
        set_handler = self.names_to_layers[layer_name].register_forward_hook(
            self._set_activation(layer_name))

        # get counterfactual output on base examples
        counterfactual_logits_T1, counterfactual_logits_T2 = self.model(base)

        # remove the handler
        set_handler.remove()

        return source_logits_T1, base_logits_T1, counterfactual_logits_T1, source_logits_T2, base_logits_T2, counterfactual_logits_T2
        
class InterventionableTransformer():
    def __init__(self, model):
        self.activation = {}
        self.model = model

    # these functions are model dependent
    # they specify how the coordinate system works
    def _coordinate_to_getter(self, coord):
        layer, index = coord
        def hook(model, input, output):
            self.activation[f'{layer}-{index}'] = output[:,index]
        if layer == -1:
          handler = self.model.bert.bert.embeddings.register_forward_hook(hook)
        else:
          handler = self.model.bert.bert.encoder.layer[layer].output.register_forward_hook(hook)
        return handler

    def _coordinate_to_setter(self, coord):
        layer, index = coord
        def hook(model, input, output):
            # NOTE: This might lead to errors about inplace manipulations during the backprop.
            output[:,index] = self.activation[f'{layer}-{index}']
        if layer == -1:
          handler = self.model.bert.bert.embeddings.register_forward_hook(hook)
        else:
          handler = self.model.bert.bert.encoder.layer[layer].output.register_forward_hook(hook)
        return handler


    def forward(self, source, base, coord):
        # NOTE: other ways that do not require constantly adding / removing hooks should exist

        # set hook to get activation
        # get_handler = self.names_to_layers[layer_name].register_forward_hook(self._get_activation(layer_name))
        get_handler = self._coordinate_to_getter(coord)

        # get output on source examples (and also capture the activations)
        T1_source_output, T2_source_output = self.model(**source)

        # remove the handler (don't store activations on base) 
        get_handler.remove()

        # get base logits
        T1_base_output, T2_base_output = self.model(**base)
        
        # set hook to do the intervention
        set_handler = self._coordinate_to_setter(coord)

        # get counterfactual output on base examples
        assert base['input_ids'].shape == source['input_ids'].shape, f"shape mismatch! {base['input_ids'].shape}, {source['input_ids'].shape}"
        T1_counterfactual_output, T2_counterfactual_output = self.model(**base)

        # remove the handler
        set_handler.remove()

        return T1_source_output, T1_base_output, T1_counterfactual_output, T2_source_output,  T2_base_output,  T2_counterfactual_output


In [59]:
class CausalArithmetic2(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.w = torch.nn.Identity()
        self.x = torch.nn.Identity()
        self.y = torch.nn.Identity()
        self.z = torch.nn.Identity()

        self.S1 = torch.nn.Identity()
        self.S2 = torch.nn.Identity()

        self.C1 = torch.nn.Identity()
        self.C2 = torch.nn.Identity()
        self.C3 = torch.nn.Identity()

        self.O = torch.nn.Identity()

    def forward(self, input):
        w = torch.clone(input[:, 0])
        x = torch.clone(input[:, 1])
        y = torch.clone(input[:, 2])
        z = torch.clone(input[:, 3])

        w = self.w(w)
        x = self.x(x)
        y = self.y(y)
        z = self.z(z)

        S1 = self.S1(w + x)
        C1 = self.C1(y)
        C2 = self.C2(z)
        S2 = self.S2(S1 + C1)
        C3 = self.C3(C2)
        O = self.O(S2 + C3)
        return O, S2

In [60]:
with wandb.init(project="transformer-arithmetic-multiIIT", entity="stanford-causality", config=config, mode="online"):
  wandb.config = config

  neural_model, causal_model, tokenizer, optimizer, lr_scheduler = prepare_training3()

  train_ds = ArithmeticDataset3(digits=config['digits'], highest_number=config['highest_number'], size=config['train_size'])
  test_ds = ArithmeticDataset3(digits=config['digits'], highest_number=config['highest_number'], size=config['test_size'])

  train3(neural_model, causal_model, tokenizer, optimizer, lr_scheduler, train_ds, test_ds)

  torch.save(neural_model.model.state_dict(), wandb.run.name + '.pt')


Epoch 0: 100%|██████████| 157/157 [00:10<00:00, 14.69steps/s, loss=0.108, step=156] 
Epoch 1: 100%|██████████| 157/157 [00:10<00:00, 15.58steps/s, loss=0.176, step=156]
Epoch 2: 100%|██████████| 157/157 [00:10<00:00, 15.65steps/s, loss=0.251, step=156]
Epoch 3: 100%|██████████| 157/157 [00:10<00:00, 15.60steps/s, loss=0.292, step=156]
Epoch 4: 100%|██████████| 157/157 [00:10<00:00, 15.51steps/s, loss=0.359, step=156]
Epoch 5: 100%|██████████| 157/157 [00:10<00:00, 15.17steps/s, loss=0.479, step=156]
Epoch 6: 100%|██████████| 157/157 [00:10<00:00, 15.42steps/s, loss=0.495, step=156]
Epoch 7: 100%|██████████| 157/157 [00:10<00:00, 14.89steps/s, loss=0.516, step=156]
Epoch 8: 100%|██████████| 157/157 [00:10<00:00, 15.63steps/s, loss=0.608, step=156]
Epoch 9: 100%|██████████| 157/157 [00:10<00:00, 15.46steps/s, loss=0.605, step=156]
Epoch 10: 100%|██████████| 157/157 [00:09<00:00, 15.73steps/s, loss=0.635, step=156]
Epoch 11: 100%|██████████| 157/157 [00:10<00:00, 15.61steps/s, loss=0.7, s

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
"T1 ii accuracy [[-1, 1], 'w']",▁▁▃▂▃▂▄▃▄▃▃▅▆▄▄▆▆▃▅▅▆▇▇▇▆▆▆▇▇▇█████▇██▇▇
"T1 ii accuracy [[-1, 2], 'x']",▁▁▂▂▅▃▃▃▄▃▄▅▆▅▅▄▅▄▆▄▆▆▆▅▆▆▆▇▇▆▇▇▇▆▇▇█▇██
"T1 ii accuracy [[-1, 3], 'y']",▁▂▂▂▄▃▃▃▅▄▄▅▆▅▅▅▆▅▆▅▆▆▇▆▇▆▅▆█▆▇▇██████▇▇
"T1 ii accuracy [[-1, 4], 'z']",▁▂▁▂▂▃▄▃▃▄▃▄▅▄▅▅▅▃▆▅▆▇▆▆▆▆▅▇▇▆█▇▇▇▇▇▇███
"T1 ii accuracy [[2, 0], 'S1']",▁▁▁▂▃▃▄▃▃▅▄▄▆▅▅▅▆▄▇▅▅▅▄▅▆▆▅▆▇▆▇▇█▆▇▇▇█▇▇
"T1 ii accuracy [[2, 1], 'C1']",▁▁▂▃▃▃▅▃▄▅▄▃▆▇▅▂▆▅▄▆█▆▅▅▄▆▄▅▆▅▆▆▃▅▆▅▄▅▄▅
"T1 ii accuracy [[2, 2], 'C2']",█▆▆▃▆▆▃▆▃▃▁▁▃▆▃▁▃▃▃▁▆▁█▁▆█▁▃▃▃▁▁▃▃▁▃▁▃▃▆
"T1 ii accuracy [[4, 0], 'S2']",▁▂▂▂▄▂▃▄▄▂▄█▆▄▄▃▃▄▄▂▅▃▁▁▄▃▅▃▄▃▂▁▅▃▂▂▂▄▂▃
"T1 ii accuracy [[4, 1], 'C3']",▁█▁▆▃▆▆▃▁▆▁▁▆▃▁▁▁▆▁▁▃▆▁▁▃▃▁▁▃█▃▃▃▁▁▁▁▆▃▃
"T1 ii accuracy [[5, 0], 'O']",▁▁▁▃▃▄▅▄▄▅▄▆▆▅▆▆▆▃▇▆▇▅▆▇▇▇▇▇█▇▇█▇███████

0,1
"T1 ii accuracy [[-1, 1], 'w']",93.75
"T1 ii accuracy [[-1, 2], 'x']",96.875
"T1 ii accuracy [[-1, 3], 'y']",96.875
"T1 ii accuracy [[-1, 4], 'z']",84.375
"T1 ii accuracy [[2, 0], 'S1']",90.625
"T1 ii accuracy [[2, 1], 'C1']",43.75
"T1 ii accuracy [[2, 2], 'C2']",3.125
"T1 ii accuracy [[4, 0], 'S2']",6.25
"T1 ii accuracy [[4, 1], 'C3']",6.25
"T1 ii accuracy [[5, 0], 'O']",100.0


## Bert without a pooler

In [21]:
# class BertModelNoPooler(transformers.BertPreTrainedModel):
#     def __init__(self, config):
#         super(BertModelNoPooler, self).__init__(config)
#         self.embeddings = transformers.BertEmbeddings(config)
#         self.encoder = transformers.BertEncoder(config)
#         self.apply(self.init_bert_weights)

#     def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
#         if attention_mask is None:
#             attention_mask = torch.ones_like(input_ids)
#         if token_type_ids is None:
#             token_type_ids = torch.zeros_like(input_ids)

#         extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

#         extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
#         extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

#         embedding_output = self.embeddings(input_ids, token_type_ids)
#         encoded_layers = self.encoder(embedding_output,
#                                       extended_attention_mask,
#                                       output_all_encoded_layers=output_all_encoded_layers)
#         return encoded_layers

In [22]:
neural_model.model.bert.bert.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

AttributeError: 'NoneType' object has no attribute 'name'

In [None]:
neural_model.model

In [24]:
wandb.run