# Import Libraries

In [1]:
%load_ext autoreload
%autoreload 2  
%autosave 10

Autosaving every 10 seconds


In [2]:
import torch, torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import pandas as pd
from PIL import Image
import random, os, sys, argparse
from pathlib import Path
from tqdm import tqdm
import pickle
from PIL import Image

In [3]:
sys.path.insert(0,'/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/experiments/toy_expts/')
from models import *

# Train Model

In [76]:
# user hyperparams
model = 'resnet18' # resnet18, vgg16, densenet121
train_csv = '/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/experiments/toy_expts/domino_expts/top_svhn_bot_fmnist_2class_ro_1p0.csv'
test_csv = '/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/experiments/toy_expts/domino_expts/top_blank_bot_fmnist_2class_ro_1p0.csv'
expt_name = f'temp_{int(random.random()*100000)}'# f'resnet18_top_mnist_bot_kmnist_{int(random.random()*100000)}'
save_dir = '/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/experiments/toy_expts/domino_expts/output/'

num_epochs = 50
lr = 0.1
seed = 10
num_ch = 3 # num of channels in image
num_embs = 2000 # 1500 for mnist, 10k for cifar10
K = 29 # K neighbours
num_test_imgs = 20 # num of test images for plotting PD
lp_norm = 1 # for computing KNN
knn_pos_thresh = 0.5
knn_neg_thresh = 0.5

In [77]:
# Setting the seed
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [78]:
df_train = pd.read_csv(train_csv)
df_test = pd.read_csv(test_csv).sample(500)
# df_test = df_test[df_test['test_split']==0].sample(500)

In [79]:
class ToyDataset(torch.utils.data.Dataset):
    def __init__(self, class_field, transform, csv_path=None, df=None):
        self.class_field = class_field        
        self.transform = transform
        if df is not None:
            self.df = df
        else:
            self.csvpath = csvpath
            self.df = pd.read_csv(self.csvpath)
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = str(self.df.iloc[idx]["path"])
        lab = int(self.df.iloc[idx][self.class_field])
        img = Image.open(img_path)
        if self.transform is not None:
            img = self.transform(img)  

        return {"img":img, "lab":lab, "idx":idx, "file_name" : img_path}

In [80]:
trans = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((64,32)),
])
trainset = ToyDataset(class_field=['bottom'], transform=trans, df=df_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=2)
testset = ToyDataset(class_field=['bottom'], transform=trans, df=df_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=True, num_workers=2)

In [81]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy


# Model
print('==> Building model..')
if model=='resnet18':
    net = ResNet18(num_channels=num_ch)
elif model=='vgg16':
    net = VGG('VGG16',num_channels=num_ch)
elif model=='densenet121':
    net = DenseNet121()
net.linear = nn.Linear(in_features=1024,out_features=10,bias=True)
net = net.to(device)
    
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True


criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, batch in enumerate(tqdm(trainloader)):
        inputs = batch['img']
        targets = batch['lab']
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                 % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(testloader)):
            inputs = batch['img']
            targets = batch['lab']
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                 % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
#     if acc > best_acc:
#         print('Saving..')
#         state = {
#             'net': net.state_dict(),
#             'acc': acc,
#             'epoch': epoch,
#         }
# #         if not os.path.isdir(f'{args['expt_name']}_checkpoint'):
# #             os.mkdir('checkpoint')
#         torch.save(state, os.path.join(save_dir,f'{expt_name}_ep{epoch}.pt'))
#         best_acc = acc
    print('Saving..')
    state = {
        'net': net.state_dict(),
        'acc': acc,
        'epoch': epoch,
    }
    torch.save(state, os.path.join(save_dir,f'{expt_name}.pt'))


for epoch in range(num_epochs):
    train(epoch)
    test(epoch)
    scheduler.step()

==> Building model..

Epoch: 0


100%|██████████| 38/38 [00:08<00:00,  4.58it/s]


37 38 Loss: 4.022 | Acc: 70.885% (6805/9600)


100%|██████████| 2/2 [00:03<00:00,  1.86s/it]


1 2 Loss: 0.532 | Acc: 89.400% (447/500)
Saving..

Epoch: 1


100%|██████████| 38/38 [00:08<00:00,  4.52it/s]


37 38 Loss: 0.228 | Acc: 93.625% (8988/9600)


100%|██████████| 2/2 [00:00<00:00,  3.49it/s]


1 2 Loss: 0.430 | Acc: 89.200% (446/500)
Saving..

Epoch: 2


100%|██████████| 38/38 [00:08<00:00,  4.56it/s]


37 38 Loss: 0.123 | Acc: 95.938% (9210/9600)


100%|██████████| 2/2 [00:00<00:00,  3.96it/s]


1 2 Loss: 0.240 | Acc: 94.000% (470/500)
Saving..

Epoch: 3


100%|██████████| 38/38 [00:08<00:00,  4.52it/s]


37 38 Loss: 0.092 | Acc: 96.979% (9310/9600)


100%|██████████| 2/2 [00:00<00:00,  3.41it/s]


1 2 Loss: 0.427 | Acc: 85.400% (427/500)
Saving..

Epoch: 4


100%|██████████| 38/38 [00:08<00:00,  4.55it/s]


37 38 Loss: 0.079 | Acc: 97.438% (9354/9600)


100%|██████████| 2/2 [00:00<00:00,  3.52it/s]


1 2 Loss: 0.218 | Acc: 93.400% (467/500)
Saving..

Epoch: 5


100%|██████████| 38/38 [00:08<00:00,  4.54it/s]


37 38 Loss: 0.069 | Acc: 97.781% (9387/9600)


100%|██████████| 2/2 [00:00<00:00,  3.63it/s]


1 2 Loss: 0.124 | Acc: 95.800% (479/500)
Saving..

Epoch: 6


 13%|█▎        | 5/38 [00:11<01:17,  2.36s/it]


KeyboardInterrupt: 

In [83]:
train_csv

'/jet/home/nmurali/asc170022p/nmurali/projects/shortcut_detection_and_mitigation/experiments/toy_expts/domino_expts/top_svhn_bot_fmnist_2class_ro_1p0.csv'

In [None]:
# loop the whole code

### Obtain Train Subset Embeddings

In [None]:
# load checkpoint
# expt_name = 'resnet18_top_cifar10_bot_fmnist_64x32_60744_ep6'
if model=='resnet18':
    net = ResNet18(num_channels=num_ch)
    net.linear = nn.Linear(in_features=1024,out_features=10,bias=True)
    net = nn.DataParallel(net)
elif model=='vgg16':
    net = nn.DataParallel(VGG('VGG16',num_channels=num_ch))
net.load_state_dict(torch.load(os.path.join(save_dir,f'{expt_name}.pt'))['net'])
net.eval()

In [None]:
# adding hook function for resnet18
# def add_resnet18_hooks(net, hook):
#     net.module.bn1.register_forward_hook(hook)
    
#     net.module.layer1[0].bn1.register_forward_hook(hook)
#     net.module.layer1[0].shortcut.register_forward_hook(hook)
    
#     net.module.layer1[1].bn1.register_forward_hook(hook)
#     net.module.layer1[1].shortcut.register_forward_hook(hook)

#     net.module.layer2[0].bn1.register_forward_hook(hook)
#     net.module.layer2[0].shortcut.register_forward_hook(hook)
    
#     net.module.layer2[1].bn1.register_forward_hook(hook)
#     net.module.layer2[1].shortcut.register_forward_hook(hook)
    
#     net.module.layer3[0].bn1.register_forward_hook(hook)
#     net.module.layer3[0].shortcut.register_forward_hook(hook)
    
#     net.module.layer3[1].bn1.register_forward_hook(hook)
#     net.module.layer3[1].shortcut.register_forward_hook(hook)
    
#     net.module.layer4[0].bn1.register_forward_hook(hook)
#     net.module.layer4[0].shortcut.register_forward_hook(hook)
    
#     net.module.layer4[1].bn1.register_forward_hook(hook)
#     net.module.layer4[1].shortcut.register_forward_hook(hook)
        
#     return net

def add_resnet18_hooks(net, hook):
    net.module.bn1.register_forward_hook(hook)
    
    net.module.layer1[0].conv1.register_forward_hook(hook)
    net.module.layer1[0].conv2.register_forward_hook(hook)
    
    net.module.layer1[1].conv1.register_forward_hook(hook)
    net.module.layer1[1].conv2.register_forward_hook(hook)

    net.module.layer2[0].conv1.register_forward_hook(hook)
    net.module.layer2[0].conv2.register_forward_hook(hook)
    
    net.module.layer2[1].conv1.register_forward_hook(hook)
    net.module.layer2[1].conv2.register_forward_hook(hook)
    
    net.module.layer3[0].conv1.register_forward_hook(hook)
    net.module.layer3[0].conv2.register_forward_hook(hook)
    
    net.module.layer3[1].conv1.register_forward_hook(hook)
    net.module.layer3[1].conv2.register_forward_hook(hook)
    
    net.module.layer4[0].conv1.register_forward_hook(hook)
    net.module.layer4[0].conv2.register_forward_hook(hook)
    
    net.module.layer4[1].conv1.register_forward_hook(hook)
    net.module.layer4[1].conv2.register_forward_hook(hook)
        
    return net

def add_vgg16_hooks(net, hook):
    net.module.features[0].register_forward_hook(hook)
    net.module.features[3].register_forward_hook(hook)
    net.module.features[7].register_forward_hook(hook)
    net.module.features[10].register_forward_hook(hook)
    net.module.features[14].register_forward_hook(hook)
    net.module.features[17].register_forward_hook(hook)
    net.module.features[20].register_forward_hook(hook)
    net.module.features[24].register_forward_hook(hook)
    net.module.features[27].register_forward_hook(hook)
    net.module.features[30].register_forward_hook(hook)
    net.module.features[34].register_forward_hook(hook)
    net.module.features[37].register_forward_hook(hook)
    net.module.features[40].register_forward_hook(hook)        
    return net

def add_densenet121_hooks(net, hook):
    
    for idx,layer in enumerate(net.module.dense1):
        if idx%2==0:
            layer.register_forward_hook(hook)
        
    for idx,layer in enumerate(net.module.dense2):
        if idx%2==0:
            layer.register_forward_hook(hook)
        
    for idx,layer in enumerate(net.module.dense3):
        if idx%2==0:
            layer.register_forward_hook(hook)
        
    for idx,layer in enumerate(net.module.dense4):
        if idx%2==0:
            layer.register_forward_hook(hook)
        
    return net

In [None]:
feature_maps = []  # This will be a list of Tensors, each representing a feature map
def hook_feat_map(mod, inp, out):
#     out = torch.nn.functional.interpolate(out,(24,12))
#     feature_maps.append(torch.mean(out,dim=[2,3]))
    feature_maps.append(torch.reshape(out, (out.shape[0],-1)))

if model=='resnet18':
    net = add_resnet18_hooks(net, hook_feat_map)
elif model=='vgg16':
    net = add_vgg16_hooks(net, hook_feat_map)
elif model=='densenet121':
    net = add_densenet121_hooks(net, hook_feat_map)

In [None]:
def to_cpu(arr):
    for idx,x in enumerate(arr):
        arr[idx] = x.to('cpu')
    return arr

def print_memory_profile(s):
    # print GPU memory
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    print(s)
    print(t/1024**3,r/1024**3,a/1024**3)
    print('\n')


In [None]:
perm = torch.randperm(len(trainset))
inds = perm[:num_embs]

labs = []
samples = torch.empty((0,3,64,32))
for i in tqdm(inds):
    i = int(i)
    labs.append(trainset[i]['lab'])
    samples = torch.cat((samples,trainset[i]['img'].unsqueeze(0)))
labs = torch.tensor(labs)

In [None]:
samples_resized = samples
train_subset = torch.utils.data.TensorDataset(samples_resized,labs)
trainloader2 = torch.utils.data.DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
# code for saving pkl file of layer embeddings
save_path = os.path.join(save_dir,f'{expt_name}.pkl')
trainloader2 = torch.utils.data.DataLoader(train_subset, batch_size=30000, shuffle=True, num_workers=2)

handle = open(save_path, "wb")

with torch.no_grad():
    net.eval()
    for b_idx,batch in enumerate(tqdm(trainloader2)):        
        # print GPU memory
        print_memory_profile('Initial')
        
#         if dataset=='mnist'  or dataset=='kmnist' or dataset=='fmnist':
#             imgs = batch[0].unsqueeze(1).to('cuda')
#         elif dataset=='svhn':
#             imgs = batch[0].permute(0,3,1,2).to('cuda')
#         else:
#             imgs = batch[0].permute(0,3,1,2).to('cuda')
        imgs = batch[0].to('cuda')
        labels = batch[1]
        
        feature_maps = []
        out = net(imgs.float())
        
        info_dict = {'batch_idx':b_idx,'num_batches':len(trainloader2),'feats':feature_maps,'labels':labels}
        pickle.dump(info_dict, handle)  
        
        # print GPU memory
        print_memory_profile('After processing Batch')
        
        # free up GPU memory
        del feature_maps, info_dict
        torch.cuda.empty_cache()     
        
        # print GPU memory
        print_memory_profile('After freeing GPU memory')
        
handle.close()
        

### Compute PD

In [None]:
# compute_pd.py

def compute_pred_depth(arr):
    last = arr[-1]

    p_depth = 1
    for i in range(len(arr)-1):
        ele = arr[-1-(i+1)]
        if ele!=last:
            p_depth = (len(arr)-(i+1)) + 1
            break
    
    return p_depth

In [None]:
perm = torch.randperm(len(testset))
inds = perm[:num_test_imgs]

labs = []
samples = torch.empty((0,3,64,32))
for i in tqdm(inds):
    i = int(i)
    labs.append(testset[i]['lab'])
    samples = torch.cat((samples,testset[i]['img'].unsqueeze(0)))
labs = torch.tensor(labs)

In [None]:
samples_resized = samples
test_subset = torch.utils.data.TensorDataset(samples_resized,labs)
testloader2 = torch.utils.data.DataLoader(test_subset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
# ===================== Storing Batch Statistics =====================

batch_info = {}
train_embs_pkl_path = os.path.join(save_dir,f'{expt_name}.pkl')
ckpt_path = os.path.join(save_dir,f'{expt_name}.pt')
batch_info['readme'] = f'---- K={K} ---- ckpt_path={ckpt_path} ---- pkl_path={train_embs_pkl_path} ----'
batch_info['imgs'] = [] # test images
batch_info['preds'] = [] # corresponding model predictions
batch_info['labels'] = [] # labels of the test images
batch_info['pd'] = [] # corresponding prediction depths
batch_info['layers_knn_prob'] = [] # for each test image we have a list of knn means for every layer
batch_info['layers_knn_mode'] = [] # for each test image we have a list of knn mode for every layer

print_memory_profile('Initial')

In [None]:
# loop over test images
invalid_counter = 0 # for invalid predictions (last layer mode != model output)
for test_id, (img,lab) in enumerate(tqdm(test_subset)):

    batch_info['imgs'].append(img)
    with torch.no_grad():
        to_pil_trans = transforms.ToPILImage()
        img = to_pil_trans(img.to('cuda'))
        img = trans(img).unsqueeze(0)
        lab = int(lab)
        if img.shape[1]==4:
            img = img[:,0,:,:].unsqueeze(0)
        feature_maps = []
        out = net(img)
        print(f'Model output: {out}')
        batch_info['preds'].append(int(out.argmax()))
        batch_info['labels'].append(lab)

        print_memory_profile('Model forward pass')
        with open(train_embs_pkl_path, 'rb') as handle:            
            info_dict = pickle.load(handle)
            print_memory_profile('Pickle load')

            # loop over layers in densenet, and compute KNN for this test image
            knn_preds_mode = []  # layer-wise final KNN classification preds   
            knn_preds_prob = []
            for layer_id,feat in tqdm(enumerate(feature_maps)):
                X_i = feat.unsqueeze(1)  # (10000, 1, 784) test set
                X_j = info_dict['feats'][layer_id].unsqueeze(0)  # (1, 60000, 784) train set
                if lp_norm==2:
                    D_ij = ((X_i - X_j) ** 2).sum(-1)  # (10000, 60000) symbolic matrix of squared L2 distances
                elif lp_norm==1:
                    D_ij = (abs(X_i - X_j)).sum(-1)  # (10000, 60000) symbolic matrix of squared L2 distances
                else:
                    raise('Invalid lp_norm in arguments!')

                ind_knn = torch.topk(-D_ij,K,dim=1)  # Samples <-> Dataset, (N_test, K)
                lab_knn = info_dict['labels'][ind_knn[1]]  # (N_test, K) array of integers in [0,9]
#                 print(f'!!!!!!test_img:{test_id}, nbrs in layer {layer_id}: {ind_knn[1]}')
                mode = int(lab_knn.squeeze().mode()[0])
                knn_preds_mode.append(mode)
                knn_preds_prob.append(float((lab_knn==mode).float().mean()))

            print_memory_profile('Pickle batch processed')

            # free GPU memory
            del info_dict
            torch.cuda.empty_cache()
            print_memory_profile('After GPU memory freed') 

            print('Test Image: %d' %(test_id))
            print(f'knn_preds_mode: {knn_preds_mode}')
            print(f'knn_preds_prob: {knn_preds_prob}')
            print(f'label: {lab}')
            print(f'pred: {int(out.argmax())}')
            print('\n')
            batch_info['layers_knn_prob'].append(knn_preds_prob)
            batch_info['layers_knn_mode'].append(knn_preds_mode)
            if int(out.argmax())==knn_preds_mode[-1]: # PD accurate
                if knn_pos_thresh==0.5 and knn_neg_thresh==0.5:
                    batch_info['pd'].append(compute_pred_depth(knn_preds_mode))
                else:
                    raise('Code not ready yet! Compute pred arr function also has to be updated!')
            else: # PD inaccurate, KNN pred doesn't match model pred
                print('Invalid datapoint: last_layer_mode != model_output')
                invalid_counter += 1
                batch_info['pd'].append(-99)
print(f'Invalid Counts Ratio: {invalid_counter}/{num_test_imgs}')

In [None]:
# ===================== Save results =====================

with open(os.path.join(save_dir,expt_name+'_testPDinfo.pkl'), 'wb') as handle:
    pickle.dump(batch_info, handle)

In [None]:
# plot PDs histogram
with open(os.path.join(save_dir,expt_name+'_testPDinfo.pkl'), 'rb') as handle:
    batch_info = pickle.load(handle)

batch_info['pd'] = np.array(batch_info['pd'])
batch_info['labels'] = np.array(batch_info['labels'])
batch_info['preds'] = np.array(batch_info['preds'])
correct_preds_arr = (batch_info['preds']==batch_info['labels'])


plt.figure(figsize=(7,7))
plt.title(expt_name)
plt.ylabel('# of Images')
plt.xlabel('Layer')
if model=='resnet18':
    plt.xlim((0,18))
elif model=='vgg16':
    plt.xlim((0,16))
# plt.ylim((0,80))
plt.hist(batch_info['pd'][correct_preds_arr],bins=10,color='g',alpha=0.55)
# plt.hist(batch_info['pd'][~correct_preds_arr ],bins=100,color='r',alpha=0.55)
plt.savefig(os.path.join(save_dir,expt_name+'_PDplot.png'))

In [None]:
plt.hist(batch_info['pd'][correct_preds_arr])

# Test Model on Particular Images

In [None]:
counter = 0
total = 0
for top in tqdm([0,1,2,4,5,6,7,8,9]):
    for bot in [3]:
        for i in range(20):
            img_path = f'/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/data/domino_datasets/top_cifar10_bot_fmnist/train/top{top}_bot{bot}/{int(random.random()*1000)}.jpg'
            ckpt_path = '/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/experiments/toy_expts/domino_expts/output/resnet18_top_cifar10_bot_fmnist_64x32_65168.pt'

            net = ResNet18(num_channels=num_ch)
            net = net.to(device)
            net.linear = nn.Linear(in_features=1024,out_features=10,bias=True)
            net = torch.nn.DataParallel(net)
            net.load_state_dict(torch.load(ckpt_path)['net'])
            net.eval()

            T = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Resize((64,32)),
            ])

            with Image.open(img_path) as im:
                out = net(T(im).unsqueeze(0))
                total+=1
                if int(out.argmax())==bot:
                    counter+=1
print(f'counter:{counter},total:{total},accuracy:{counter*100/total}')

In [None]:
print(f'counter:{counter},total:{total}')

In [None]:
0:48/180, 26%,21%,80%,4.5,

# Create Domino CSV Files

In [None]:
# user hyperparams
top_dset_name = 'cifar10'
bot_dset_name = 'fmnist'
top_inds = [0,1]
bot_inds = [0,1]

In [None]:
df = pd.DataFrame(columns=['path','bottom_label','top_label','val_train_split','test_split','all_zeros'])

# domino = 5000/800
# blank = 6000/1000
if top_dset_name=='blank':
    arr = np.arange(0,6000,1)
    np.random.shuffle(arr)
    val_arr = arr[:600]
    train_arr = arr[600:]
    num_test_imgs = 1000
else:
    arr = np.arange(0,5000,1)
    np.random.shuffle(arr)
    val_arr = arr[:500]
    train_arr = arr[500:]
    num_test_imgs = 800

for (top_idx,bot_idx) in zip(top_inds,bot_inds):
    for img_id in tqdm(train_arr):
        if top_dset_name=='blank':
            path = f'/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/data/domino_datasets/top_{top_dset_name}_bot_{bot_dset_name}/train/{bot_idx}/{img_id}.jpg'
        else:    
            path = f'/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/data/domino_datasets/top_{top_dset_name}_bot_{bot_dset_name}/train/top{top_idx}_bot{bot_idx}/{img_id}.jpg'
        df = df.append({'path':path, 'bot_lbl':bot_idx, 'top_lbl':top_idx, 'val_train_split':1, 'test_split':0, 'all_zeros':0}, ignore_index=True)
        
    for img_id in tqdm(val_arr):
        if top_dset_name=='blank':
            path = f'/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/data/domino_datasets/top_{top_dset_name}_bot_{bot_dset_name}/train/{bot_idx}/{img_id}.jpg'
        else:
            path = f'/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/data/domino_datasets/top_{top_dset_name}_bot_{bot_dset_name}/train/top{top_idx}_bot{bot_idx}/{img_id}.jpg'
        df = df.append({'path':path, 'bot_lbl':bot_idx, 'top_lbl':top_idx, 'val_train_split':0, 'test_split':0, 'all_zeros':0}, ignore_index=True)
        
    for img_id in tqdm(range(num_test_imgs)):
        if top_dset_name=='blank':            
            path = f'/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/data/domino_datasets/top_{top_dset_name}_bot_{bot_dset_name}/test/{bot_idx}/{img_id}.jpg'
        else:
            path = f'/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/data/domino_datasets/top_{top_dset_name}_bot_{bot_dset_name}/test/top{top_idx}_bot{bot_idx}/{img_id}.jpg'
        df = df.append({'path':path, 'bot_lbl':bot_idx, 'top_lbl':top_idx, 'val_train_split':2, 'test_split':1, 'all_zeros':0}, ignore_index=True)
     

In [None]:
# csv_save_path = '/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/experiments/toy_expts/domino_expts/top_cifar10_bot_fmnist_2class_ro1p0.csv'
# df.to_csv(csv_save_path, index=False)   

# Misc

In [None]:
df = pd.read_csv('/xxx/home/xxx/xxxp/shared/Projects/shortcut_learning/WaterBirds_MetaData.csv')

In [None]:
arr = (df['bird_type']==1) & (df['land_type']==0)
# arr = (df['test_split']==0) 
df[arr]

In [None]:
arr1 = (df['bird_type']==0) & (df['land_type']==0)
df1 = df[arr1].sample(5620)

arr2 = (df['bird_type']==0) & (df['land_type']==1)
df2 = df[arr2].sample(300)

arr3 = (df['bird_type']==1) & (df['land_type']==0)
df3 = df[arr3].sample(80)

arr4 = (df['bird_type']==1) & (df['land_type']==1)
df4 = df[arr4].sample(1652)

df_final = pd.concat([df1,df2,df3,df4]).sample(frac=1)
df_final.to_csv('/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/experiments/vision_expts/waterbirds/train.csv',index=False)

In [None]:
arr1 = (df['bird_type']==0) & (df['land_type']==0)
df1 = df[arr1].sample(600)

arr2 = (df['bird_type']==0) & (df['land_type']==1)
df2 = df[arr2].sample(2605)

arr3 = (df['bird_type']==1) & (df['land_type']==0)
df3 = df[arr3].sample(751)

arr4 = (df['bird_type']==1) & (df['land_type']==1)
df4 = df[arr4].sample(180)

df_final = pd.concat([df1,df2,df3,df4]).sample(frac=1)
df_final.to_csv('/xxx/home/xxx/xxxp/xxx/projects/shortcut_detection_and_mitigation/experiments/vision_expts/waterbirds/test.csv',index=False)