In [1]:
from wilds.common.data_loaders import get_eval_loader, get_train_loader
from wilds.common.grouper import CombinatorialGrouper
from models.initializer import get_dataset
import torchvision.transforms as transforms

In [2]:
model_transforms = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.Lambda(lambda image: image.convert('RGB')),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

full_dataset = get_dataset(dataset='waterbirds',
                           root_dir='/media/SSD2/Dataset',
                           download=True,
                           split_scheme='official',
                           seed=11111111)

train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=['generic-spurious', 'y']
    )

data = full_dataset.get_subset('train',frac=1.0,transform=model_transforms)
# loader = get_train_loader(loader='standard', 
#                           dataset=data, 
#                           batch_size=10,
#                           uniform_over_groups=True, 
#                           grouper=train_grouper,
#                           n_groups_per_batch=4)

# print(len(data[0]))
                        
# for batch in loader:
#     print(batch[0].shape)
#     print(batch[1]) 
#     print(batch[2])
#     print() 
#     break                  

In [4]:
groups, group_counts = train_grouper.metadata_to_group(
                data.metadata_array,
                return_counts=True)
group_weights = 1 / group_counts
print(group_weights), print(group_counts)

tensor([0.0003, 0.0054, 0.0179, 0.0009])
tensor([3498.,  184.,   56., 1057.])


(None, None)

In [9]:
full_dataset = get_dataset(dataset='waterbirds_robust',
                           root_dir='/media/SSD2/Dataset',
                           download=True,
                           split_scheme='official',
                           seed=11111111)

train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=['generic-spurious', 'y']
    )

data = full_dataset.get_subset('train',frac=1.0,transform=model_transforms)

groups, group_counts = train_grouper.metadata_to_group(
                data.metadata_array,
                return_counts=True)
group_weights = 1 / group_counts
print(group_weights)

tensor([0.0007, 0.0007, 0.0009, 0.0009])


In [9]:
full_dataset.metadata_array

tensor([[1, 1, 0],
        [1, 1, 1],
        [0, 1, 0],
        ...,
        [0, 0, 0],
        [0, 0, 1],
        [1, 0, 0]])

In [8]:
loader = get_train_loader(loader='standard', 
                          dataset=data, 
                          batch_size=10,
                          uniform_over_groups=False, 
                          grouper=train_grouper,
                          n_groups_per_batch=4)

print(len(data[0]))
                        
for batch in loader:
    print(batch[0].shape)
    print(batch[1]) 
    print(batch[2])
    break      

3
torch.Size([10, 3, 224, 224])
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([[1, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1],
        [0, 0, 1]])


In [14]:
groups, group_counts = train_grouper.metadata_to_group(
                data.metadata_array,
                return_counts=True)
print(groups[:-10])
print(len(group_counts))

tensor([3, 3, 3,  ..., 0, 0, 1])
4


In [16]:
group_weights = 1 / group_counts
weights = group_weights[groups]
weights

tensor([0.0009, 0.0009, 0.0009,  ..., 0.0003, 0.0003, 0.0003])

In [17]:
len(groups),len(weights)

(4795, 4795)

In [1]:
import cv2
import numpy as np
import os
from glob import glob
import random


In [3]:
background_path = '/media/SSD2/Dataset/Places365'
land_backgrounds_cls = ["forest-broadleaf","bamboo_forest"]#["botanical_garden","desert-vegetation","topiary_garden"]
water_backgrounds_cls = ["ocean","lake-natural"]#["beach","canal_natural","river"]
water_backgrounds, land_backgrounds = [],[]
for i in range(2):
    water_backgrounds.extend(glob(os.path.join(background_path,water_backgrounds_cls[i],'*.jpg')))
    land_backgrounds.extend(glob(os.path.join(background_path,land_backgrounds_cls[i],'*.jpg')))
random.shuffle(water_backgrounds)
random.shuffle(land_backgrounds)
water_backgrounds[:5],land_backgrounds[:5]

(['/media/SSD2/Dataset/Places365/ocean/Places365_val_00017396.jpg',
  '/media/SSD2/Dataset/Places365/lake-natural/Places365_val_00019078.jpg',
  '/media/SSD2/Dataset/Places365/ocean/Places365_val_00006739.jpg',
  '/media/SSD2/Dataset/Places365/ocean/Places365_val_00005456.jpg',
  '/media/SSD2/Dataset/Places365/ocean/Places365_val_00030862.jpg'],
 ['/media/SSD2/Dataset/Places365/forest-broadleaf/Places365_val_00027587.jpg',
  '/media/SSD2/Dataset/Places365/forest-broadleaf/Places365_val_00013329.jpg',
  '/media/SSD2/Dataset/Places365/bamboo_forest/Places365_val_00007602.jpg',
  '/media/SSD2/Dataset/Places365/forest-broadleaf/Places365_val_00000009.jpg',
  '/media/SSD2/Dataset/Places365/forest-broadleaf/Places365_val_00008121.jpg'])

In [4]:
import pandas as pd


metadata_df = pd.read_csv(
            os.path.join('/media/SSD2/Dataset/waterbirds_v1.0', 'metadata.csv'))
metadata_df.head()

Unnamed: 0,img_id,img_filename,y,split,place,place_filename
0,1,001.Black_footed_Albatross/Black_Footed_Albatr...,1,2,1,/o/ocean/00002178.jpg
1,2,001.Black_footed_Albatross/Black_Footed_Albatr...,1,0,1,/l/lake/natural/00000065.jpg
2,3,001.Black_footed_Albatross/Black_Footed_Albatr...,1,2,0,/b/bamboo_forest/00000131.jpg
3,4,001.Black_footed_Albatross/Black_Footed_Albatr...,1,0,1,/o/ocean/00001268.jpg
4,5,001.Black_footed_Albatross/Black_Footed_Albatr...,1,0,1,/o/ocean/00003147.jpg


In [4]:
set([" ".join(i.split("/")[:-1]) for i in metadata_df["place_filename"]])

{' b bamboo_forest', ' f forest broadleaf', ' l lake natural', ' o ocean'}

In [5]:
list(metadata_df["split"]).count(0),list(metadata_df["split"]).count(1),list(metadata_df["split"]).count(2)

(4795, 1199, 5794)

In [16]:
filtered_df = metadata_df[metadata_df["split"] == 0]
len(list(filtered_df[(filtered_df["y"] == 0) & (filtered_df["place"] == 0)]["img_id"])),len(list(filtered_df[(filtered_df["y"] == 1) & (filtered_df["place"] == 0)]["img_id"])),len(list(filtered_df[(filtered_df["y"] == 0) & (filtered_df["place"] == 1)]["img_id"])),len(list(filtered_df[(filtered_df["y"] == 1) & (filtered_df["place"] == 1)]["img_id"]))

3498

In [17]:
len(list(filtered_df[(filtered_df["y"] == 1) & (filtered_df["place"] == 0)]["img_id"])),len(list(filtered_df[(filtered_df["y"] == 0) & (filtered_df["place"] == 1)]["img_id"])),len(list(filtered_df[(filtered_df["y"] == 1) & (filtered_df["place"] == 1)]["img_id"]))

(56, 184, 1057)

In [5]:
new_df = pd.DataFrame(columns=metadata_df.columns)


In [6]:
vf = metadata_df.sample(frac=1).reset_index(drop=True)
vf.head()

Unnamed: 0,img_id,img_filename,y,split,place,place_filename
0,11132,190.Red_cockaded_Woodpecker/Red_Cockaded_Woodp...,0,0,0,/b/bamboo_forest/00002811.jpg
1,9612,164.Cerulean_Warbler/Cerulean_Warbler_0045_797...,0,0,0,/f/forest/broadleaf/00004521.jpg
2,11363,193.Bewick_Wren/Bewick_Wren_0124_184771.jpg,0,0,0,/f/forest/broadleaf/00001156.jpg
3,10598,180.Wilson_Warbler/Wilson_Warbler_0050_175573.jpg,0,1,1,/l/lake/natural/00004877.jpg
4,9185,157.Yellow_throated_Vireo/Yellow_Throated_Vire...,0,0,0,/f/forest/broadleaf/00000067.jpg


In [7]:
vf["img_id"].max()

11788

In [8]:
def generate_image(file,y):
    bird_img = cv2.imread(os.path.join('/media/SSD2/Dataset/waterbirds_v1.0',file))
    mask = cv2.imread(os.path.join('/media/SSD2/Dataset/waterbird_segmentation_mask/segmentations',file[:-4]+'.png'), cv2.IMREAD_GRAYSCALE)
    
    if y == 0:
        bg_file = np.random.choice(water_backgrounds)
    else:
        bg_file = np.random.choice(land_backgrounds)  # Randomly select a background
    
    

    background = cv2.imread(bg_file)
    background = cv2.resize(background, (bird_img.shape[1], bird_img.shape[0]))

    _, binary_mask = cv2.threshold(mask, 128, 255, cv2.THRESH_BINARY)

    # Extract bird part from the bird image
    bird_segment = cv2.bitwise_and(bird_img, bird_img, mask=binary_mask)

    # Invert mask for the background
    inverse_mask = cv2.bitwise_not(binary_mask)

    # Apply inverted mask to the background
    background_segment = cv2.bitwise_and(background, background, mask=inverse_mask)

    # Combine bird segment and background segment
    final_image = cv2.add(bird_segment, background_segment)
    return final_image, bg_file

In [9]:
def save_image_with_folders(file_path, image):
    """
    Saves an image to the specified file path, creating any missing directories in the path.

    Args:
        file_path (str): The full path (including filename) where the image should be saved.
        image (numpy.ndarray): The image to save (OpenCV format).
    """
    # Extract the directory path from the file path
    directory = os.path.dirname(file_path)
    
    # Create the directory if it doesn't exist
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    # Save the image to the file path
    cv2.imwrite(file_path, image)
   

In [10]:
import shutil
def copy_paste(source,dest):
    # Extract the directory path from the file path
    directory = os.path.dirname(dest)
    
    # Create the directory if it doesn't exist
    if not os.path.exists(directory):
        os.makedirs(directory)
    shutil.copy(source,dest)

In [11]:
def new_entry(id,img_file,y,split,place,place_filename):
    global new_df
    new_entry = {
    "img_id": id,
    "img_filename": img_file,
    "y": y,
    "split": split,
    "place": place,
    "place_filename": place_filename,
    }
    new_df.loc[len(new_df)] = new_entry


In [12]:
from matplotlib import pyplot as plt

final = '/media/SSD2/Dataset/waterbirds_robust.3'
final_2 = '/media/SSD2/Dataset/waterbirds_robust.4'

land_b_l,land_b_w, water_b_l, water_b_w, idx = 0,0,0,0,0
for file, y, split, place, place_file in zip(vf["img_filename"],vf["y"],vf["split"],vf["place"],vf["place_filename"]):
    
    if split == 0 and y==0:

        if place == 0 and random.random() > 0.5 and land_b_l < 1341:
            land_b_l += 1
            source = os.path.join('/media/SSD2/Dataset/waterbirds_v1.0',file)
            dest = os.path.join(final,file)
            copy_paste(source,dest)
            new_entry(idx,file,y,split,place,place_file)
            idx += 1
        elif place == 1 :
            if land_b_w < 1341:
                land_b_w += 1
                source = os.path.join('/media/SSD2/Dataset/waterbirds_v1.0',file)
                dest = os.path.join(final,file)
                place = 1
                copy_paste(source,dest)
                new_entry(idx,file,y,split,place,place_file)
                idx += 1

        if land_b_w < 1341:
            land_b_w += 1
            place = 1
            img, bg_file = generate_image(file,y)
            new_path = os.path.join(final,file.split('/')[-2],'n_'+file.split('/')[-1])
            n_place_file = os.path.join(bg_file.split('/')[-2],bg_file.split('/')[-1])
            save_image_with_folders(new_path,img)
            save_image_with_folders(os.path.join(final_2,file.split('/')[-1]),img)
            new_entry(idx,os.path.join(file.split('/')[-2],'n_'+file.split('/')[-1]),y,split,place,n_place_file)
            idx += 1
    
    elif split == 0 and y==1:

        if place == 1 and water_b_w < 1057:
            water_b_w += 1
            source = os.path.join('/media/SSD2/Dataset/waterbirds_v1.0',file)
            dest = os.path.join(final,file)
            copy_paste(source,dest)
            new_entry(idx,file,y,split,place,place_file)
            idx += 1
        else:
            if water_b_l < 1057:
                water_b_l += 1
                place = 0
                source = os.path.join('/media/SSD2/Dataset/waterbirds_v1.0',file)
                dest = os.path.join(final,file)
                copy_paste(source,dest)
                new_entry(idx,file,y,split,place,place_file)
                idx += 1

        if water_b_l < 1057:
            water_b_l += 1
            place = 0
            img, bg_file = generate_image(file,y)
            new_path = os.path.join(final,file.split('/')[-2],'n_'+file.split('/')[-1])
            n_place_file = os.path.join(bg_file.split('/')[-2],bg_file.split('/')[-1])
            save_image_with_folders(new_path,img)
            save_image_with_folders(os.path.join(final_2,file.split('/')[-1]),img)
            new_entry(idx,os.path.join(file.split('/')[-2],'n_'+file.split('/')[-1]),y,split,place,n_place_file)
            idx += 1

    elif split != 0:
        source = os.path.join('/media/SSD2/Dataset/waterbirds_v1.0',file)
        dest = os.path.join(final,file)
        copy_paste(source,dest)
        new_entry(idx,file,y,split,place,place_file)
        idx += 1

    else:
        print("Error")
        break
    
    if idx % 200 == 0:
        print(f"Land Birds: {land_b_l} {land_b_w} Water Birds: {water_b_l} {water_b_w}")
    



    
        


Land Birds: 21 57 Water Birds: 15 15
Land Birds: 67 159 Water Birds: 46 44
Land Birds: 111 270 Water Birds: 80 74
Land Birds: 131 322 Water Birds: 95 89
Land Birds: 163 374 Water Birds: 108 102
Land Birds: 190 440 Water Birds: 114 108
Land Birds: 207 489 Water Birds: 134 126
Land Birds: 229 535 Water Birds: 153 145
Land Birds: 248 575 Water Birds: 176 164
Land Birds: 294 676 Water Birds: 210 196
Land Birds: 319 729 Water Birds: 234 216
Land Birds: 346 788 Water Birds: 245 227
Land Birds: 394 897 Water Birds: 276 254
Land Birds: 411 938 Water Birds: 299 275
Land Birds: 435 989 Water Birds: 310 284
Land Birds: 455 1032 Water Birds: 335 305
Land Birds: 482 1096 Water Birds: 348 318
Land Birds: 507 1145 Water Birds: 362 330
Land Birds: 529 1197 Water Birds: 378 346
Land Birds: 547 1239 Water Birds: 401 365
Land Birds: 571 1294 Water Birds: 412 376
Land Birds: 589 1340 Water Birds: 427 389
Land Birds: 619 1341 Water Birds: 449 405
Land Birds: 649 1341 Water Birds: 472 426
Land Birds: 649 13

In [13]:
new_df["img_id"].max()

11788

In [14]:
filtered_df = new_df[new_df["split"] == 0]
len(list(filtered_df[(filtered_df["y"] == 0) & (filtered_df["place"] == 0)]["img_id"])),len(list(filtered_df[(filtered_df["y"] == 1) & (filtered_df["place"] == 0)]["img_id"])),len(list(filtered_df[(filtered_df["y"] == 0) & (filtered_df["place"] == 1)]["img_id"])),len(list(filtered_df[(filtered_df["y"] == 1) & (filtered_df["place"] == 1)]["img_id"]))

(1341, 1057, 1341, 1057)

In [15]:
new_df.head()

Unnamed: 0,img_id,img_filename,y,split,place,place_filename
0,0,190.Red_cockaded_Woodpecker/Red_Cockaded_Woodp...,0,0,0,/b/bamboo_forest/00002811.jpg
1,1,190.Red_cockaded_Woodpecker/n_Red_Cockaded_Woo...,0,0,1,lake-natural/Places365_val_00015415.jpg
2,2,164.Cerulean_Warbler/n_Cerulean_Warbler_0045_7...,0,0,1,lake-natural/Places365_val_00033412.jpg
3,3,193.Bewick_Wren/Bewick_Wren_0124_184771.jpg,0,0,0,/f/forest/broadleaf/00001156.jpg
4,4,193.Bewick_Wren/n_Bewick_Wren_0124_184771.jpg,0,0,1,lake-natural/Places365_val_00003818.jpg


In [17]:
new_df.to_csv('/media/SSD2/Dataset/waterbirds_robust.3/metadata.csv',index=False)

In [18]:
from waterbirds_robust_dataset_similar import WaterbirdsRobustSimilarDataset
import torchvision.transforms as transforms
model_transforms = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.Lambda(lambda image: image.convert('RGB')),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
full_dataset = WaterbirdsRobustSimilarDataset(root_dir='/media/SSD2/Dataset',split_scheme='official')

In [19]:
train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=['generic-spurious', 'y']
    )

data = full_dataset.get_subset('test',frac=1.0,transform=model_transforms)
loader = get_train_loader(loader='standard', 
                          dataset=data, 
                          batch_size=10,
                          uniform_over_groups=True, 
                          grouper=train_grouper,
                          n_groups_per_batch=4)

print(len(data[0]))

NameError: name 'CombinatorialGrouper' is not defined