In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import terra
import meerkat as mk
import numpy as np
from meerkat.contrib.gqa import read_gqa_dps
from domino.data.visual_genome import ATTRIBUTE_GROUPS

## Create DataPanels for the Visual Genome

`image_dp` - each row corresponds to an image in the dataset  
`object_dp`– each row corresponds to an object (e.g. 'car') in one of the images  
`attr_dp`- each row maps an attribute (e.g. 'red') to an object

In [3]:
dps = read_gqa_dps("/home/common/datasets/gqa")
object_dp, image_dp, attr_dp = dps["objects"], dps["images"], dps["attributes"]

## Create a classification task
We're going to use the visual genome to craft a binary image classification task.  
`target` = 1 if the object is a car and `target` = 0 otherwise. 

We're also going to keep track of a potential confounder: "color". 

In [None]:
# filter to only include objects with an annotated color 
colored_objects = attr_dp["object_id"][
    attr_dp["attribute"].isin(ATTRIBUTE_GROUPS["colors"])
]
dp = object_dp#.lz[np.isin(object_dp["object_id"], colored_objects)]

# set the target column to "is car?"
dp["target"] = dp["name"].isin(["car"]).values.astype(int)

# set the correlate column to "is red?"
red_objects = attr_dp["object_id"][
    attr_dp["attribute"].isin(["red"])
]
dp["correlate"] = np.isin(dp["object_id"], red_objects).astype(int)

In [None]:
dp.lz[(dp["target"] & dp["correlate"]) == 1]["object_image"]

## Induce correlation
Let's induce an artificial correlation between cars and the color red.

In [None]:
from domino.evaluate.linear import induce_correlation

In [None]:
indices = induce_correlation(
    dp, 
    corr=0.8, 
    mu_a=0.05,
    mu_b=0.05,
    attr_a="target", 
    attr_b="correlate", 
    n=2e4
)
dataset_dp = dp.lz[indices]

In [None]:
from scipy.stats import pearsonr
pearsonr(dataset_dp["target"], dataset_dp["correlate"])

## Train a Model
Train a model on this subsampled dataset

In [None]:
from domino.utils import split_dp

In [None]:
dataset_dp = split_dp(dataset_dp, split_on="image_id")

In [None]:
import terra 
from domino.vision import train
from torchvision import transforms
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
preprocessing = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

dataset_dp["input"] = dataset_dp["object_image"].to_lambda(preprocessing)

@terra.Task.make_task
def train_vg(dp, run_dir: str = None):

    model = train(
        config={"pretrained": False},
        dp=dp,
        input_column="input",
        id_column="object_id",
        target_column="target",
        ckpt_monitor="valid_auroc",
        batch_size=128, 
        run_dir=run_dir,
        val_check_interval=10,
        num_workers=6
    )
    return model
train_vg(dp=dataset_dp)

In [4]:
run_id = 4675
dataset_dp = terra.inp(run_id)["dp"].load()
model = terra.get_artifacts(run_id, "best_chkpt")["model"]



## Score model
Let's score the model and evaluate on subgroups.

In [5]:
from domino.vision import score
score_dp = score(
    model=model.load(), 
    dp=dataset_dp.lz[dataset_dp["split"] == "test"], 
    input_column="input",
    batch_size=128
)



HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))




In [6]:
score_dp["prob"] = score_dp["output"].probabilities().data[:, -1]

In [7]:
from sklearn.metrics import roc_auc_score
print("Overall AUROC: {}".format(
    roc_auc_score(score_dp["target"], score_dp["prob"])
))

curr_dp = score_dp.lz[score_dp["target"] == score_dp["correlate"]]
print("AUROC where correlation holds: {}".format(
    roc_auc_score(curr_dp["target"], curr_dp["prob"])
))

curr_dp = score_dp.lz[score_dp["target"] != score_dp["correlate"]]
print("AUROC where correlation does not hold: {}".format(
    roc_auc_score(curr_dp["target"], curr_dp["prob"])
))

Overall AUROC: 0.916423677103131
AUROC where correlation holds: 0.9337805442892193
AUROC where correlation does not hold: 0.6737804878048781


## Embed in CLIP space to recover spurious correlate

In [8]:
from domino.clip import embed_images
score_dp = embed_images(dp=score_dp, img_column="object_image")

HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))




In [9]:
from domino.clip import embed_words, get_wiki_words
#words_dp = get_wiki_words()
#words_dp = embed_words(words_dp).load()
words_dp = embed_words.out(4517).load()
words_dp = words_dp.lz[:int(1e4)]

In [10]:
from sklearn.metrics.pairwise import cosine_similarity
dp = score_dp
attr_emb = np.dot(dp["prob"].data.numpy(), dp["emb"].data.numpy()) / (np.sum(dp["prob"].data.numpy()))
ref_emb = dp.lz[dp["target"] == 1]["emb"].data.numpy().mean(axis=0)

scores = cosine_similarity( 
    words_dp["emb"].data.numpy(),   
    np.expand_dims(attr_emb - ref_emb, axis=0) 
).squeeze()
words_dp["score"] = scores
words_dp.lz[(-scores).argsort()[:20]][["word", "score", "frequency"]]

Unnamed: 0,word (PandasSeriesColumn),score (NumpyArrayColumn),frequency (PandasSeriesColumn)
0,seeds,0.039737,50767.0
1,merchandise,0.038270,14959.0
2,roger,0.037758,71403.0
3,costume,0.037489,27696.0
4,cerambycidae,0.037062,23164.0
...,...,...,...
15,antarctica,0.029363,19341.0
16,highest,0.029151,237608.0
17,clothing,0.028741,54617.0
18,wisdom,0.028613,20440.0


In [12]:
pos_indices = np.where(score_dp["target"] == 1)[0]
neg_indices = np.where(score_dp["target"] == 0)[0][:len(pos_indices)]
dp = score_dp.lz[np.concatenate([pos_indices, neg_indices])]

In [36]:
from sklearn.decomposition import PCA

cluster_dp = score_dp

X = cluster_dp["emb"].numpy()
y = cluster_dp["target"].data
y_oh = np.zeros((len(y), 2))
y_oh[y == 0, 0] = 1
y_oh[y == 1, 1] = 1
y = y_oh
y_hat = cluster_dp["output"].probs().numpy()

pca = PCA(n_components=128)
X = pca.fit_transform(X)

In [47]:
from domino.sdm.gmm import ErrorGMM
gmm = ErrorGMM(weight_y_log_likelihood=50, 
n_components = 15, max_iter = 100, tol = 1e-4, covariance_type="diag"
              )
out = gmm.fit_predict(X, y, y_hat)
cluster_dp["cluster"] = out

 64%|██████▍   | 64/100 [00:02<00:01, 24.29it/s]


In [48]:
gmm.y_probs[:, 1]

array([0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [50]:
np.argsort(gmm.y_probs[:, 1] - gmm.y_hat_probs[:, 1])

array([ 4,  7,  5, 11,  8,  3,  0, 12, 14, 10,  2,  9, 13,  6,  1])

In [55]:
from sklearn.metrics.pairwise import cosine_similarity
dp = cluster_dp
cluster_idx = 4
in_emb = dp.lz[out == cluster_idx]["emb"].data.numpy().mean(axis=0)
out_emb = dp.lz[dp["cluster"] != cluster_idx]["emb"].data.numpy().mean(axis=0)

scores = cosine_similarity( 
    words_dp["emb"].data.numpy(),   
    np.expand_dims(in_emb - out_emb, axis=0) 
).squeeze()
words_dp["score"] = scores
words_dp.lz[(-scores).argsort()[:10]][["word", "score", "frequency"]]

Unnamed: 0,word (PandasSeriesColumn),score (NumpyArrayColumn),frequency (PandasSeriesColumn)
0,buses,0.119851,52234
1,bus,0.107675,149318
2,vehicles,0.106228,123799
3,vehicle,0.087376,115819
4,locomotives,0.082853,56369
5,automobile,0.082485,33268
6,drivers,0.081335,54109
7,ambulance,0.081285,17907
8,automotive,0.081184,21286
9,cars,0.078107,158080


In [54]:
cluster_dp.lz[cluster_dp["cluster"] == cluster_idx]

Unnamed: 0,image_id (NumpyArrayColumn),h (NumpyArrayColumn),name (PandasSeriesColumn),object_id (NumpyArrayColumn),w (NumpyArrayColumn),x (NumpyArrayColumn),y (NumpyArrayColumn),index (PandasSeriesColumn),image (ImageColumn),height (NumpyArrayColumn),width (NumpyArrayColumn),object_image (LambdaColumn),target (NumpyArrayColumn),correlate (NumpyArrayColumn),split_hash (NumpyArrayColumn),split (PandasSeriesColumn),input (LambdaColumn),output (ClassificationOutputColumn),prob (TensorColumn),__embed_images_input__ (LambdaColumn),emb (TensorColumn),clusters (NumpyArrayColumn),cluster (NumpyArrayColumn)
0,2344591.0,40.0,stairs,2841877.0,48.0,448.0,200.0,131711,,315.0,500.0,,0.0,0.0,0.80888,test,"LambdaCell(fn=Compose(  Resize(size=256, interpolation=bilinear)  CenterCrop(size=(224, 224))  ToTensor()  Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ))",torch.Tensor(shape=torch.Size([2])),tensor(0.9645),"LambdaCell(fn=Compose(  Resize(size=224, interpolation=bicubic)  CenterCrop(size=(224, 224))  . at 0x7f48c14dca60>  ToTensor()  Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ))",torch.Tensor(shape=torch.Size([512])),6.0,4.0
1,2321271.0,47.0,windows,3182838.0,33.0,221.0,42.0,402525,,333.0,500.0,,0.0,0.0,0.80436,test,"LambdaCell(fn=Compose(  Resize(size=256, interpolation=bilinear)  CenterCrop(size=(224, 224))  ToTensor()  Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ))",torch.Tensor(shape=torch.Size([2])),tensor(0.9788),"LambdaCell(fn=Compose(  Resize(size=224, interpolation=bicubic)  CenterCrop(size=(224, 224))  . at 0x7f48c14dca60>  ToTensor()  Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ))",torch.Tensor(shape=torch.Size([512])),6.0,4.0
2,2341792.0,71.0,motorcycle,2145153.0,55.0,75.0,181.0,1315163,,375.0,500.0,,0.0,0.0,0.96439,test,"LambdaCell(fn=Compose(  Resize(size=256, interpolation=bilinear)  CenterCrop(size=(224, 224))  ToTensor()  Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ))",torch.Tensor(shape=torch.Size([2])),tensor(0.6690),"LambdaCell(fn=Compose(  Resize(size=224, interpolation=bicubic)  CenterCrop(size=(224, 224))  . at 0x7f48c14dca60>  ToTensor()  Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ))",torch.Tensor(shape=torch.Size([512])),6.0,4.0
3,2386312.0,101.0,sign,1287422.0,441.0,28.0,115.0,979813,,374.0,500.0,,0.0,0.0,0.92777,test,"LambdaCell(fn=Compose(  Resize(size=256, interpolation=bilinear)  CenterCrop(size=(224, 224))  ToTensor()  Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ))",torch.Tensor(shape=torch.Size([2])),tensor(0.9825),"LambdaCell(fn=Compose(  Resize(size=224, interpolation=bicubic)  CenterCrop(size=(224, 224))  . at 0x7f48c14dca60>  ToTensor()  Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ))",torch.Tensor(shape=torch.Size([512])),6.0,4.0
4,2410781.0,33.0,face,1085378.0,26.0,117.0,83.0,182993,,375.0,500.0,,0.0,0.0,0.94415,test,"LambdaCell(fn=Compose(  Resize(size=256, interpolation=bilinear)  CenterCrop(size=(224, 224))  ToTensor()  Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ))",torch.Tensor(shape=torch.Size([2])),tensor(0.6487),"LambdaCell(fn=Compose(  Resize(size=224, interpolation=bicubic)  CenterCrop(size=(224, 224))  . at 0x7f48c14dca60>  ToTensor()  Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ))",torch.Tensor(shape=torch.Size([512])),6.0,4.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
50,2327616.0,39.0,train,2884367.0,184.0,95.0,215.0,257670,,333.0,500.0,,0.0,0.0,0.82574,test,"LambdaCell(fn=Compose(  Resize(size=256, interpolation=bilinear)  CenterCrop(size=(224, 224))  ToTensor()  Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ))",torch.Tensor(shape=torch.Size([2])),tensor(0.8739),"LambdaCell(fn=Compose(  Resize(size=224, interpolation=bicubic)  CenterCrop(size=(224, 224))  . at 0x7f48c14dca60>  ToTensor()  Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ))",torch.Tensor(shape=torch.Size([512])),6.0,4.0
51,2350868.0,21.0,outfit,864829.0,26.0,219.0,158.0,1098693,,243.0,500.0,,0.0,0.0,0.98816,test,"LambdaCell(fn=Compose(  Resize(size=256, interpolation=bilinear)  CenterCrop(size=(224, 224))  ToTensor()  Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ))",torch.Tensor(shape=torch.Size([2])),tensor(0.9993),"LambdaCell(fn=Compose(  Resize(size=224, interpolation=bicubic)  CenterCrop(size=(224, 224))  . at 0x7f48c14dca60>  ToTensor()  Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ))",torch.Tensor(shape=torch.Size([512])),6.0,4.0
52,2331885.0,64.0,pants,3162401.0,83.0,113.0,175.0,1043370,,333.0,500.0,,0.0,0.0,0.85784,test,"LambdaCell(fn=Compose(  Resize(size=256, interpolation=bilinear)  CenterCrop(size=(224, 224))  ToTensor()  Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ))",torch.Tensor(shape=torch.Size([2])),tensor(0.9892),"LambdaCell(fn=Compose(  Resize(size=224, interpolation=bicubic)  CenterCrop(size=(224, 224))  . at 0x7f48c14dca60>  ToTensor()  Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ))",torch.Tensor(shape=torch.Size([512])),6.0,4.0
53,2386431.0,34.0,ground,1286089.0,183.0,136.0,385.0,1143222,,500.0,333.0,,0.0,0.0,0.96835,test,"LambdaCell(fn=Compose(  Resize(size=256, interpolation=bilinear)  CenterCrop(size=(224, 224))  ToTensor()  Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ))",torch.Tensor(shape=torch.Size([2])),tensor(0.9967),"LambdaCell(fn=Compose(  Resize(size=224, interpolation=bicubic)  CenterCrop(size=(224, 224))  . at 0x7f48c14dca60>  ToTensor()  Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ))",torch.Tensor(shape=torch.Size([512])),6.0,4.0


In [12]:
dp = score_dp
attr_emb = np.dot(dp["prob"].data.numpy(), dp["emb"].data.numpy()) / (np.sum(dp["prob"].data.numpy()))
ref_emb = dp.lz[dp["target"] == 1]["emb"].data.numpy().mean(axis=0)

scores = np.dot( 
    words_dp["emb"].data.numpy(),   
    (attr_emb - ref_emb) 
)
words_dp["score"] = scores
words_dp.lz[(-scores).argsort()[:20]][["word", "score", "frequency"]]

Unnamed: 0,word (PandasSeriesColumn),score (NumpyArrayColumn),frequency (PandasSeriesColumn)
0,roger,0.361980,71403.0
1,seeds,0.352969,50767.0
2,costume,0.351968,27696.0
3,merchandise,0.351417,14959.0
4,islands,0.325014,230088.0
...,...,...,...
15,john,0.275752,1011147.0
16,clothing,0.274674,54617.0
17,cerambycidae,0.274526,23164.0
18,hiking,0.272910,18533.0


In [46]:
mk.config.DisplayOptions.max_rows = 91
error_dp = score_dp.lz[(score_dp["prob"] > 0.5).numpy() != (score_dp["target"] == 1)]
error_dp = error_dp.lz[(error_dp["target"] == 0)]
error_dp[["object_image", "prob", "target", "correlate"]]

Unnamed: 0,object_image (LambdaColumn),prob (TensorColumn),target (NumpyArrayColumn),correlate (NumpyArrayColumn)
0,,tensor(0.9645),0,0
1,,tensor(0.9788),0,0
2,,tensor(0.6690),0,0
3,,tensor(0.9825),0,0
4,,tensor(0.6487),0,0
5,,tensor(0.8852),0,0
6,,tensor(0.5159),0,0
7,,tensor(0.7781),0,0
8,,tensor(0.8611),0,0
9,,tensor(0.9998),0,0
