![Practicum AI Logo image](https://github.com/PracticumAI/practicumai.github.io/blob/main/images/logo/PracticumAI_logo_250x50.png?raw=true) <img src="https://github.com/PracticumAI/practicumai.github.io/blob/84b04be083ca02e5c7e92850f9afd391fc48ae2a/images/icons/practicumai_computer_vision.png?raw=true" alt="Practicum AI: Computer Vision icon" align="right" width=50>
***

# Handling Data Imbalance in Computer Vision

In this notebook, we will explore techniques to handle data imbalance using PyTorch.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
import helpers_01
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from PIL import Image
import numpy as np


In [3]:
# Define a custom dataset class
class BeeWaspDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = os.listdir(data_path)
        for idx, class_name in enumerate(self.class_names):
            class_path = os.path.join(data_path, class_name)
            for img_name in os.listdir(class_path):
                self.images.append(os.path.join(class_path, img_name))
                self.labels.append(idx)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transformations
transform = transforms.Compose([
    transforms.Resize((80, 80)),
    transforms.ToTensor()
])

# Load the dataset
data_path = 'data/bee_vs_wasp'
dataset = BeeWaspDataset(data_path, transform=transform)
train_data, val_data = train_test_split(dataset, test_size=0.2, stratify=dataset.labels)

# Calculate class weights
class_weights = compute_class_weight('balanced', classes=np.unique(dataset.labels), y=dataset.labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)

# Create a weighted sampler
sample_weights = [class_weights[label] for label in train_data.labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

# Create data loaders
train_loader = DataLoader(train_data, batch_size=32, sampler=sampler)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
class BeeWaspDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = os.listdir(data_path)
        for idx, class_name in enumerate(self.class_names):
            class_path = os.path.join(data_path, class_name)
            for img_name in os.listdir(class_path):
                self.images.append(os.path.join(class_path, img_name))
                self.labels.append(idx)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Define transformations
transform = transforms.Compose([
    transforms.Resize((80, 80)),
    transforms.ToTensor()
])

# Load the dataset
data_path = 'data/bee_vs_wasp'
dataset = BeeWaspDataset(data_path, transform=transform)
train_data, val_data = train_test_split(dataset, test_size=0.2, stratify=dataset.labels)

# Calculate class weights
class_weights = compute_class_weight('balanced', classes=np.unique(dataset.labels), y=dataset.labels)
class_weights = torch.tensor(class_weights, dtype=torch.float)

# Create a weighted sampler
sample_weights = [class_weights[label] for label in train_data.labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

# Create data loaders
train_loader = DataLoader(train_data, batch_size=32, sampler=sampler)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)


AttributeError: 'list' object has no attribute 'labels'

In [3]:
# Create the model
model = helpers_01.make_model()
loss_fn = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the model
model = helpers_01.compile_train_model(train_loader, val_loader, model, loss_fn, optimizer, num_epochs=10)


In [4]:
# Evaluate the model
helpers_01.evaluate_model(val_loader, model)
