<a href="https://colab.research.google.com/github/Howl06/practice/blob/main/imbalance_class_weights.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchsampler

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchsampler
  Downloading torchsampler-0.1.2-py3-none-any.whl (5.6 kB)
Installing collected packages: torchsampler
Successfully installed torchsampler-0.1.2


In [None]:
import os
import sys
import json

from torch import Tensor
from typing import List
from torchsampler import ImbalancedDatasetSampler
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm




hyper_params = {
    "learning_rate": 0.0001,
    "epochs": 10,
    "batch_size": 2,
    "num_classes": 27,
    "alpha": 1,
    "input_size": 224,
    "loss_function_name": nn.CrossEntropyLoss(),
    "optimizer_name": optim.Adam
    }
    

# path
classjson_path = "/content/drive/MyDrive/data_set/class_indices.json"
root_path = "/content/drive/MyDrive"
data_set_dir = "data_set"
flower_dir = "project_data"
weight_path = ""




# ToTensor H，W，C ——> C，H，W C/255 [0~1]
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.ToTensor(),
                                  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    "val": transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    "test": transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

data_root = os.path.abspath(os.path.join(os.getcwd(), root_path))  # get data root path
image_path = os.path.join(data_root, data_set_dir, flower_dir)  # flower data set path

assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                      transform=data_transform["train"])
train_num = len(train_dataset)

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())


# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open(classjson_path, 'w') as json_file:
    json_file.write(json_str)

nw = min([os.cpu_count(), hyper_params["batch_size"] if hyper_params["batch_size"] > 1 else 0, 8])  # number of workers
print('Using {} dataloader workers every process'.format(nw))

train_loader = torch.utils.data.DataLoader(train_dataset,
                                        #sampler=ImbalancedDatasetSampler(train_dataset),
                                        shuffle=True,
                                        batch_size=hyper_params["batch_size"],
                                        num_workers=nw)

validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=hyper_params["batch_size"], shuffle=False,
                                              num_workers=nw)


Using 2 dataloader workers every process


In [None]:
train_dataset

Dataset ImageFolder
    Number of datapoints: 68
    Root location: /content/drive/MyDrive/data_set/project_data/train
    StandardTransform
Transform: Compose(
               RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear, antialias=warn)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [None]:

from collections import Counter 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))
label_nums_dic = Counter([j for i, j in train_dataset.imgs])
sample_num_list = label_nums_dic.values()
max_sample_num = max(label_nums_dic.values())
weights = [max_sample_num/label_nums for label_nums in sample_num_list]
class_weights = torch.FloatTensor(weights).to(device)
loss_function = hyper_params["loss_function_name"](weight=class_weights)

using cpu device.


TypeError: ignored

In [None]:
reback_img = transforms.Compose([
                transforms.Normalize(
                mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
                std=[1/0.229, 1/0.224, 1/0.225])]
                )

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
train_bar = tqdm(train_loader, file=sys.stdout)
for step, data in enumerate(train_bar):
    train_images, labels = data
    train_images = reback_img(train_images)
    for pre_index, image in enumerate(train_images):
      img = transforms.ToPILImage()(image)
      img.show()
      img = np.array(img, dtype=np.uint8)
      print(img.shape)