<a href="https://colab.research.google.com/github/Ruphai/UBS/blob/main/Techniques_for_dealing_with_imbalanced_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Sample code for dealing with imbalanced datasets using the PyTorch API and the Sklearn Sampler.

### Data Augmentation

The WiDS datasets is largely unbalanced, having more class labels with no palm plantations (0) than those with palm plantations. To bulwark for the possible challenges that this imbalanced sets could cause in the model training, we employ a commonly used technique in common vision for improve model performance known as Data Augmentation. <a href="https://vinesmsuic.github.io/2020/08/11/cnn-dataaug/#reference">Data Augmentation </a> is a technique that is used to increase the diversity of data available for trainig models, without collecting new data. The most used data augmentation techniques include cropping, padding, horizontal and vertical flipping and other affine transformations. PyTorch Transforms implements most of these functions, and are usually employed with the Albumentations framework.
When faced with the challenge of imbalanced datasets, we can employ one or a combination of the following strategies:
- Undersampling, 
- Oversampling, 
- Class weighting, 
- Focal loss, 

Here, we will employ the oversampling technique. Oversampling is simply increasing the number of samples in the minor samples so as to reach a near equal or equal number of samples in the datasets. This approach is more suited for deep learning approaches since having more datasets could possibily increase the feature learning rather than using less data. 

https://medium.com/analytics-vidhya/handling-imbalanced-dataset-in-image-classification-dc6f1e13aeee


In [1]:
from torch.utils.data import WeightedRandomSampler
import torch.nn as nn

# loss_fn = nn.CrossEntropyLoss(weights=torch.tensor([1, 8]))
 

In [None]:
IMG_DIR = os.path.join(project_path, "train")
IMG_DIR

In [None]:
IMG_DIR = os.path.join(project_path, "train")
# define transform function for normalizing the data.
img_dir = '../assignment/train'

def update_loader(IMG_DIR, dataframe, batch_size):
    aug_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((256,256)),
        transforms.RandomHorizontalFlip(p=0.25), 
        transforms.RandomVerticalFlip(p=0.25),                     
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])# normalize data
        ])
    
    dataset = ImageDataset(train_df, IMG_DIR, transform=aug_transforms)

    class_weights = []
    sample_weights = [0]*len(dataset)

    for df_len in dataframe["has_oilpalm"].value_counts():
        class_weights.append( (len(dataframe) - df_len))

    for idx, (data, label) in enumerate(dataset):
        class_weight = class_weights[label]
        sample_weights[idx] = class_weight

    sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

    return loader

def main():
    loader = update_loader(IMG_DIR=IMG_DIR, dataframe=train_df, batch_size=16)

    num_hasoilpalm = 0
    num_nooilpalm = 0
    
    for epoch in range(2):
        for data, labels in loader:
            num_hasoilpalm += torch.sum(labels ==1)
            num_nooilpalm += torch.sum(labels == 0)
    
    print(num_hasoilpalm)
    print(num_nooilpalm)

if __name__ == "__main__":
    main()

In [None]:
# Using the Oversampling as a sampling technique to balance our datasets.
from sklearn.utils import resample

RANDOM_SEED = 0
has_oilpalm_hp = train_df[(train_df['has_oilpalm'] == 1)].reset_index(drop=True)
no_oilpalm_hp = train_df[(train_df['has_oilpalm'] == 0)].reset_index(drop=True)

has_oilpalm_upsampled = resample(has_oilpalm_hp,
                          replace=True,
                          n_samples=len(no_oilpalm_hp),
                          random_state=RANDOM_SEED)

# balanced dataset
balanced_data=pd.concat([has_oilpalm_upsampled, no_oilpalm_hp]).reset_index(drop=True)
balanced_data["has_oilpalm"].value_counts().plot(kind="bar")

In [None]:
len(balanced_data)
print(torch.tensor(1/np.bincount(train_df['has_oilpalm'])) )

len(np.bincount(train_df['has_oilpalm']))