# Experiment - do data augmentations make us learn the right feature?

Two features - one "advice", one "spurious"

# Setup

In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import transforms
# import pytorch_lightning as pl
import numpy as np
import torch.nn as nn

In [1]:
import pytorch_lightning as pl

In [3]:
class NN(pl.LightningModule):

    def __init__(self, input_dim, hidden_dim=8):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), 
            nn.ReLU(), 
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        out = self.model(x)
        return out

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, y = batch
        pred = self.model(x)
        loss = F.mse_loss(pred, y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
class SepHeadsNN(NN):
    def __init__(self, input_dim, hidden_dim=8):
        assert input_dim == 2
        self.input_preprocess = nn.Linear(1, 8)
        self.model = nn.Sequential(
            nn.Linear(9, hidden_dim), 
            nn.ReLU(), 
            nn.Linear(hidden_dim, 1)
        )
        SepHeadsNN
    def forward(self, x):
        input1 = self.input_preprocess(x[:, :1])
        full_input = torch.cat([input1, x[:, 1:]], dim=1)
        return self.model(full_input)
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, y = batch
        input1 = self.input_preprocess(x[:, :1])
        full_input = torch.cat([input1, x[:, 1:]], dim=1)
        pred = self.model(full_input)
        loss = F.mse_loss(pred, y)
        self.log('train_loss', loss)
        return loss

In [23]:
from torch.utils.data import DataLoader, random_split, Dataset

class BaseDataset(Dataset):
    
    def __len__(self):
        return 1000
    
    def get_x(self, y):
        raise NotImplementedError
        
    def get_x_dim(self):
        raise NotImplementedError
    
    def __getitem__(self, idx):
        y = torch.randn(1)
        x = self.get_x(y)
        return x.cuda(), y.cuda()

class BothEasy(BaseDataset):
    def __init__(self, noise=0):
        self.noise = noise 
        
    def get_x(self, y):
        x1 = y.clone() + torch.randn(1) * self.noise # advice
        x2 = -y.clone() + torch.randn(1) * self.noise # spurious
        return torch.cat([x1, x2])
    
    def get_x_dim(self):
        return 2

class AdviceSum(BaseDataset):
    def __init__(self, noise=0):
        self.noise = noise 
       
    def get_x(self, y):
        x2 = torch.randn(1) # useful non-advice
        x1 = y - x2 + torch.randn(1) * self.noise # advice
        x3 = -y.clone() + torch.randn(1) * self.noise # spurious
        return torch.cat([x1, x2, x3])
    
    def get_x_dim(self):
        return 3  
    
    
class BothEasyRandomized(BothEasy):
    def __init__(self, noise=0, random_rate=.2):
        self.noise = noise 
        self.random_rate = random_rate
        
    def get_x(self, y):
        x1 = y.clone() + torch.randn(1) * self.noise # advice
        if np.random.uniform() < self.random_rate:
            x2 = torch.randn(1)
        else:
            x2 = -y.clone() + torch.randn(1) * self.noise # spurious
        return torch.cat([x1, x2])
    
    def get_x_dim(self):
        return 2
    
    
class AdviceSumSpuriousRandomized(BaseDataset):
    def __init__(self, noise=0, random_rate=.2):
        self.noise = noise 
        self.random_rate = random_rate
       
    def get_x(self, y):
        x2 = torch.randn(1) # useful non-advice
        x1 = y - x2 + torch.randn(1) * self.noise # advice
        if np.random.uniform() < self.random_rate:
            x3 = torch.randn(1)
        else:
            x3 = -y.clone() + torch.randn(1) * self.noise # spurious
        return torch.cat([x1, x2, x3])
    
    def get_x_dim(self):
        return 3   
    
class AdviceSumFullRandomizedTogether(BaseDataset):
    def __init__(self, noise=0, random_rate=.2):
        self.noise = noise 
        self.random_rate = random_rate
       
    def get_x(self, y):
        x2 = torch.randn(1) # useful non-advice
        x1 = y - x2 + torch.randn(1) * self.noise # advice
        if np.random.uniform() < self.random_rate:
            x3 = torch.randn(1)
            x2 = torch.randn(1)
        else:
            x3 = -y.clone() + torch.randn(1) * self.noise # spurious
        return torch.cat([x1, x2, x3])
    
    def get_x_dim(self):
        return 3   
       
        
class AdviceSumFullRandomizedSolo(BaseDataset):
    def __init__(self, noise=0, random_rate=.2):
        self.noise = noise 
        self.random_rate = random_rate
       
    def get_x(self, y):
        x2 = torch.randn(1) # useful non-advice
        if np.random.uniform() < self.random_rate:
            x1 = torch.randn(1)
        else:
            x1 = y - x2 + torch.randn(1) * self.noise # advice
        if np.random.uniform() < self.random_rate:
            x3 = torch.randn(1)
        else:
            x3 = -y.clone() + torch.randn(1) * self.noise # spurious
        return torch.cat([x1, x2, x3])
    
    def get_x_dim(self):
        return 3   
       
    
   
    
def run_exp(class_name, exp_name=None, max_epochs=20, noise=0):
    if exp_name is None:
        exp_name = class_name.__name__
    print("running experiment", exp_name)
    dataset = class_name(noise=noise)
    train, val = random_split(dataset, [800, 200])

    model = NN(dataset.get_x_dim())
    logger = pl.loggers.TensorBoardLogger(f'logs/{exp_name}')
    trainer = pl.Trainer(max_epochs=max_epochs, logger=logger)
    temp = trainer.fit(model, DataLoader(train), DataLoader(val))
    return model


def check_gradients(model, num_trials, x_dim):
    x = torch.randn(num_trials, x_dim)
    y = torch.randn(num_trials, 1)
    x.requires_grad = True
    pred = model(x)
    err = pred - y
    err.sum().backward()
    print(np.round(x.grad.cpu().numpy(), 2))
    print("=" * 50)

# Default performance

Expect: agent will use both features, with and without noise, with both features

Result: both used!

In [11]:
num_trials = 10

In [18]:
# Both Easy
both_easy_n0 = run_exp(BothEasy, noise=0)
both_easy_n02 = run_exp(BothEasy, noise=.2)
check_gradients(both_easy_n0, num_trials, 2)
check_gradients(both_easy_n02, num_trials, 2)

[[ 0.91 -0.74]
 [ 0.71 -0.29]
 [ 0.37 -0.52]
 [ 0.53 -0.4 ]
 [ 0.9  -0.63]
 [ 0.71 -0.29]
 [ 0.41 -0.59]
 [ 0.28 -0.14]
 [ 0.41 -0.59]
 [ 0.71 -0.29]]
[[ 0.51 -0.48]
 [ 0.11 -0.21]
 [ 0.7  -0.63]
 [ 0.5  -0.49]
 [ 0.5  -0.49]
 [ 0.5  -0.49]
 [ 0.39 -0.58]
 [ 0.5  -0.49]
 [ 0.63 -0.68]
 [ 0.4  -0.57]]


In [20]:
# Advice Hard
advice_sum_n0 = run_exp(AdviceSum, noise=0)
advide_sum_n02 = run_exp(AdviceSum, noise=.2)
check_gradients(advice_sum_n0, num_trials, 3)
check_gradients(advice_sum_n02, num_trials, 3)

[[ 0.28  0.28 -0.72]
 [ 0.36  0.36 -0.64]
 [ 0.28  0.28 -0.72]
 [ 0.28  0.28 -0.72]
 [ 0.36  0.36 -0.64]
 [ 0.36  0.36 -0.64]
 [ 0.6   0.38 -0.69]
 [ 0.28  0.28 -0.72]
 [ 0.36  0.36 -0.64]
 [ 0.36  0.36 -0.64]]
[[ 0.53  0.51 -0.8 ]
 [ 0.46  0.48 -0.51]
 [ 0.46  0.49 -0.51]
 [ 0.46  0.48 -0.51]
 [ 0.46  0.48 -0.51]
 [ 0.46  0.49 -0.51]
 [ 0.46  0.49 -0.51]
 [ 0.38  0.47 -0.26]
 [ 0.49  0.51 -0.47]
 [ 0.44  0.35 -0.42]]


# With Dropout (spurious only)

    Expect: When you drop out spurious only, the agent will use the correct feature.
    
    Result: with no noise, agent only uses advice. With noise, agent uses advice more, but still both.

In [28]:
# Both Easy
both_easy_n0_dropout = run_exp(BothEasyRandomized, noise=0)
both_easy_n02_dropout = run_exp(BothEasyRandomized, noise=.2, max_epochs=50)
check_gradients(both_easy_n0_dropout, num_trials, 2)
check_gradients(both_easy_n02_dropout, num_trials, 2)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 33    
-------------------------------------
33        Trainable params
0         Non-trainable params
33        Total params
0.000     Total estimated model params size (MB)


running experiment BothEasyRandomized


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…


[[ 1.  0.]
 [ 1.  0.]
 [ 1.  0.]
 [ 1. -0.]
 [ 1. -0.]
 [ 1.  0.]
 [ 1.  0.]
 [ 1.  0.]
 [ 1.  0.]
 [ 1.  0.]]
[[ 1.01  0.07]
 [ 0.63 -0.32]
 [ 1.01  0.07]
 [ 1.01  0.07]
 [ 1.01  0.07]
 [ 1.03  0.1 ]
 [ 0.63 -0.32]
 [ 1.03  0.1 ]
 [ 0.63 -0.32]
 [ 1.01  0.07]]


In [27]:
# Advice Hard, only dropout spurious
advice_sum_n0_dropout_spurious = run_exp(AdviceSumSpuriousRandomized, noise=0, max_epochs=50)
advide_sum_n02_dropout_spurious = run_exp(AdviceSumSpuriousRandomized, noise=.2, max_epochs=50)
check_gradients(advice_sum_n0_dropout_spurious, num_trials, 3)
check_gradients(advide_sum_n02_dropout_spurious, num_trials, 3)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 41    
-------------------------------------
41        Trainable params
0         Non-trainable params
41        Total params
0.000     Total estimated model params size (MB)


running experiment AdviceSumSpuriousRandomized


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 41    
-------------------------------------
41        Trainable params
0         Non-trainable params
41        Total params
0.000     Total estimated model params size (MB)



running experiment AdviceSumSpuriousRandomized


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…


[[ 1.    1.   -0.  ]
 [ 1.    1.    0.  ]
 [ 1.    1.    0.  ]
 [ 1.    1.   -0.  ]
 [ 1.    1.   -0.  ]
 [ 1.    1.    0.  ]
 [ 1.    1.   -0.  ]
 [ 1.    1.    0.  ]
 [ 1.05  1.08  0.02]
 [ 1.    1.   -0.  ]]
[[ 0.74  0.62 -0.12]
 [ 0.62  0.69  0.05]
 [ 0.95  0.91 -0.1 ]
 [ 0.89  0.87 -0.11]
 [ 0.95  0.91 -0.1 ]
 [ 0.89  0.87 -0.1 ]
 [ 0.91  0.76 -0.15]
 [ 0.63  0.59 -0.1 ]
 [ 0.73  0.89  0.07]
 [ 0.95  0.91 -0.1 ]]


# With Dropout (spurious + helper)

Expect: When you drop out the helper feature too, the agent will still use the correct feature more, but not much more.

Result: When you drop out together, advice is only used slightly more. when you drop out solo, it's a bit worse.

In [25]:
# Advice Hard, dropout spurious and helper together
advice_sum_n0_dropout_full_together = run_exp(AdviceSumFullRandomizedTogether, noise=0, max_epochs=50)
advide_sum_n02_dropout_full_together = run_exp(AdviceSumFullRandomizedTogether, noise=.2, max_epochs=50)
check_gradients(advice_sum_n0_dropout_full_together, num_trials, 3)
check_gradients(advide_sum_n02_dropout_full_together, num_trials, 3)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 41    
-------------------------------------
41        Trainable params
0         Non-trainable params
41        Total params
0.000     Total estimated model params size (MB)


running experiment AdviceSumFullRandomizedTogether


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 41    
-------------------------------------
41        Trainable params
0         Non-trainable params
41        Total params
0.000     Total estimated model params size (MB)



running experiment AdviceSumFullRandomizedTogether


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…


[[ 0.66  0.61 -0.26]
 [ 0.55  0.14 -0.43]
 [ 0.91  0.75  0.07]
 [ 0.33  0.16 -0.48]
 [ 0.55  0.14 -0.43]
 [ 0.69  0.62 -0.25]
 [ 0.69  0.62 -0.25]
 [ 0.21  0.12 -0.22]
 [ 0.21  0.12 -0.22]
 [ 0.18 -0.03 -0.51]]
[[ 0.59  0.41 -0.37]
 [ 0.59  0.5  -0.21]
 [ 0.59  0.5  -0.21]
 [ 0.63  0.54 -0.33]
 [ 0.59  0.41 -0.37]
 [ 0.59  0.5  -0.21]
 [ 0.63  0.54 -0.33]
 [ 0.44  0.44 -0.22]
 [ 0.63  0.54 -0.33]
 [ 0.59  0.41 -0.37]]


In [26]:
# Advice Hard, dropout spurious and helper individually
advice_sum_n0_dropout_full_solo = run_exp(AdviceSumFullRandomizedSolo, noise=0, max_epochs=50)
advide_sum_n02_dropout_full_solo = run_exp(AdviceSumFullRandomizedSolo, noise=.2, max_epochs=50)
check_gradients(advice_sum_n0_dropout_full_solo, num_trials, 3)
check_gradients(advide_sum_n02_dropout_full_solo, num_trials, 3)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 41    
-------------------------------------
41        Trainable params
0         Non-trainable params
41        Total params
0.000     Total estimated model params size (MB)


running experiment AdviceSumFullRandomizedSolo


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 41    
-------------------------------------
41        Trainable params
0         Non-trainable params
41        Total params
0.000     Total estimated model params size (MB)



running experiment AdviceSumFullRandomizedSolo


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…


[[ 0.08 -0.27 -0.69]
 [ 0.04 -0.32 -0.66]
 [ 0.   -0.31 -0.76]
 [ 0.34  0.22 -0.67]
 [ 0.63  0.51 -0.35]
 [ 0.59  0.44 -0.33]
 [ 0.5   0.46 -0.46]
 [ 0.16  0.03 -0.61]
 [ 0.5   0.46 -0.46]
 [ 0.08 -0.27 -0.69]]
[[ 0.57  0.25 -0.56]
 [ 0.34  0.2  -0.19]
 [ 0.22  0.01 -0.35]
 [ 0.32  0.32 -0.61]
 [ 0.37  0.29 -0.63]
 [ 0.22  0.04 -0.7 ]
 [ 0.32  0.32 -0.61]
 [ 0.32  0.32 -0.61]
 [ 0.37  0.29 -0.63]
 [ 0.32  0.32 -0.61]]
