In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt

import sys
sys.path.append('./src')
import warnings
warnings.filterwarnings("ignore")

from data import load_data, get_data_sl, get_data_cl
from visualize import plot_outcome_distribution, visualize_examples
from model import get_model
from causal import compute_ate
from train import train_model

## Load Data

In [None]:
supervised = load_data(environment='supervised')
unsupervised = load_data(environment='unsupervised')

### Sanity Check

In [None]:
encoder_name = "vit"
processor, model = get_model(encoder_name)

In [None]:
idx = 15000
img = supervised[idx]['image']
outcome = supervised[idx]['outcome']
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
print("Top 5 predicted labels with associated probabilities:")
top_5 = torch.topk(logits, 5)
probs = logits.softmax(-1)[0][top_5.indices][0]
for i, (idx, prob) in enumerate(zip(top_5.indices[0], probs), 1):
    print(f"    {i}. {model.config.id2label[idx.item()]}: {prob.item():.2%}")

img = img.permute(1, 2, 0)
plt.title(f"Y2F: {int(outcome[0])}, B2F: {int(outcome[1])}")
plt.imshow(img);

In [None]:
inputs = processor(images=img, return_tensors="pt")
outputs = model(**inputs, output_hidden_states=True)
outputs.hidden_states[-1][:,0].shape
# outputs.hidden_states[-1].mean(dim=[2,3]).shape

## Supervised Learning

In [None]:
X, y, split = get_data_sl(environment="supervised", 
                          encoder_name="dino",
                          task="or", 
                          split_criteria="experiment_easy", 
                          token="mean")

In [None]:
model = train_model(X, y,  
                    split=split,
                    batch_size=256, 
                    num_epochs=10, 
                    lr=0.01, 
                    verbose=True)
y_probs = model.probs(X.to(model.device)).to("cpu")
y_pred = model.pred(X.to(model.device)).to("cpu")

In [None]:
exp = (supervised["experiment"]==4)
pos = (supervised["position"]==1)
filter = (exp & pos).nonzero().squeeze()
y_probs[filter].detach()

plt.scatter(range(len(filter)), y_probs[filter].detach(), s=1, c="blue", alpha=0.5, label="y_probs")
plt.scatter(range(len(filter)), y_pred[filter].detach()-(-1)**y_pred[filter].detach()*0.04, s=1, c="red", alpha=0.5, label="y_pred")
plt.scatter(range(len(filter)), y[filter].detach()-(-1)**y[filter].detach()*0.02, s=1, c="green", alpha=0.5, label="y")
plt.legend()
plt.show()

In [None]:
frame = (supervised["frame"]==2220)
idx = (exp & pos & frame).nonzero().item()
img = supervised[idx]["image"]
outcome = supervised[idx]["outcome"]

img = img.permute(1, 2, 0)
plt.title(f"Y2F: {int(outcome[0])}, B2F: {int(outcome[1])}")
plt.imshow(img);

In [None]:
visualize_examples(n=36,
                   encoder_name="dino", 
                   model=model, 
                   save=True, 
                   data_dir="./data",
                   results_dir="./results")

#### Results Finetuning

- **Encoder**: mae, vit_large
- **Token**: not clear
- **Task**: all
- **Learning Rate**: not clear

In [None]:
results = pd.read_csv("results/head_training_results.csv", index_col=0)
results = results.dropna()
results = results.set_index(["encoder", "token", "task", "lr", "seed"])
results = results.groupby(["encoder", "token","task", "lr"]).mean()
results = results.sort_values("val_accuracy", ascending=False)
results = results.xs("all", level="task")
results.head(20)

## Causal Inference

In [None]:
supervised = load_data(environment='supervised')
plot_outcome_distribution(supervised, save=True)

In [None]:
Y, T, X = get_data_cl(aggregate=True)
methods = ["ead", "aipw", "slearner", "tlearner", "xlearner", "drlearner", "causalforest"]
results = pd.DataFrame(columns=["ATE_B_blue", "ATE_inf_blue", "ATE_B_yellow", "ATE_inf_yellow"], index=methods)
for method in methods:
    ATE_B_yellow, ATE_inf_yellow = compute_ate(Y, T, X, color="yellow", method=method, verbose=False)   
    ATE_B_blue, ATE_inf_blue = compute_ate(Y, T, X, color="blue", method=method, verbose=False)  
    results.loc[method] = [ATE_B_blue, ATE_inf_blue, ATE_B_yellow, ATE_inf_yellow]
results