<a href="https://colab.research.google.com/github/BhardwajArjit/Research-Paper-Replication/blob/main/CaaM_Replication.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook replicates the research paper titled "Causal Attention for Unbiased Visual Recognition" with PyTorch.

The link to the paper: https://arxiv.org/abs/2108.08782

**Causal Attention Module** (CaaM) generates data partition iteratively and self-annotates the confounders progressively to overcome the over-adjustment problem.

The ultimate goal of CaaM is to improve causal inference by adjusting for confounders more accurately.

##0. Get Setup

In [17]:
try:
  import torch
  import torchvision
  import timm
  print(f"torch version: {torch.__version__}")
  print(f"torchvision version: {torchvision.__version__}")
  print(f"timm version: {timm.__version__}")
except:
  print(f"[INFO] Couldn't find timm... installing it.")
  !pip install timm
  import timm
  print(f"timm {timm.__version__} installed successfully...")

torch version: 2.1.0+cu118
torchvision version: 0.16.0+cu118
timm version: 0.9.8


In [18]:
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

try:
  from torchinfo import summary
except:
  print(f"[INFO] Couldn't find torchinfo... installing it.")
  !pip install -q torchinfo
  print("torchinfo installed successfully...")
  from torchinfo import summary

try:
    from pytorch_utils import download_data, set_seeds, plot_loss_curves
except ImportError:
    # If the import fails, clone the repository
    !git clone https://github.com/BhardwajArjit/Helper-Functions.git
    !mv Helper-Functions/pytorch_utils.py .
    !rm -rf Helper-Functions
    from pytorch_utils import download_data, set_seeds, plot_loss_curves

In [19]:
# Setup device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'