In [12]:
from torch.utils.data import Dataset
import csv
import pandas as pd
from PIL import Image

from torch.utils.data import DataLoader

In [13]:
'''
加载csv文件函数

param：
- csv_filename: csv文件路径

return:
- image: 类别列表
- label: 标签列表
'''
def loadCsv(csv_filename):
    image, label = [], []
    with open(csv_filename) as f:
        reader = csv.reader(f)
        for row in reader:
            i, l = row
            image.append(i)
            label.append(l)
    return image, label

In [14]:
val_dataset = pd.read_csv('E:/GBCDL/data/val.csv')
img_pth = val_dataset['img_path']
label = val_dataset['label']


In [15]:
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import torch

In [16]:
class GBCValDataset(Dataset):
    def __init__(self, images, labels, resize):
        
        self.images = images
        self.labels = labels
        self.resize = resize

        # 数据预处理和增强
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomRotation(15),
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        image = self.transform(image)
        return image, label

In [17]:
def get_loaders(val_csv_dir, batch_size):
    val_dataset = pd.read_csv(val_csv_dir)
    val_db = GBCValDataset(val_dataset['img_path'], val_dataset['label'], 224)
    val_dataloader = DataLoader(dataset=val_db, batch_size=batch_size, shuffle=True)
    return val_dataloader


In [18]:
batch_size = 32
val_csv_dir = 'E:/GBCDL/data/val.csv'

val_dataloader = get_loaders(val_csv_dir, 32)

# 计算DataLoader中数据的数量
data_length = len(val_dataloader.dataset)
        
# 如果你想要获取DataLoader中加载的总批次数，可以使用以下代码
num_batches = len(val_dataloader)
        
print(f"数据集长度: {data_length}")
print(f"总批次数: {num_batches}")

数据集长度: 3167
总批次数: 99


In [19]:
batch_size = 32
test_csv_dir = 'E:/GBCDL/data/test.csv'

test_dataloader = get_loaders(test_csv_dir, 32)

# 计算DataLoader中数据的数量
data_length = len(test_dataloader.dataset)
        
# 如果你想要获取DataLoader中加载的总批次数，可以使用以下代码
num_batches = len(test_dataloader)
        
print(f"数据集长度: {data_length}")
print(f"总批次数: {num_batches}")

数据集长度: 2538
总批次数: 80


In [20]:
batch_size = 32
train_csv_dir = 'E:/GBCDL/data/train.csv'

train_dataloader = get_loaders(train_csv_dir, 32)

# 计算DataLoader中数据的数量
data_length = len(train_dataloader.dataset)
        
# 如果你想要获取DataLoader中加载的总批次数，可以使用以下代码
num_batches = len(train_dataloader)
        
print(f"数据集长度: {data_length}")
print(f"总批次数: {num_batches}")

数据集长度: 16201
总批次数: 507


In [21]:
import csv

'''
统计csv类别函数

param：
- csv_files: csv文件路径

return:
- label_counts: 类别总数
'''
def count_labels(csv_files):
    label_counts = {}

    for file_name in csv_files:
        with open(file_name, 'r', newline='') as file:
            reader = csv.DictReader(file)
            for row in reader:
                label = row['label']
                if label in label_counts:
                    label_counts[label] += 1
                else:
                    label_counts[label] = 1

    return label_counts


In [22]:
csv_files = ['E:/GBCDL/data/train.csv', 'E:/GBCDL/data/test.csv', 'E:/GBCDL/data/val.csv']  
label_counts = count_labels(csv_files)

print("Label Counts:")
sum = 0
for label, count in label_counts.items():
    sum = sum + count
    print(f"{label}: {count}")

print(f"sum:{sum}")
print("Label %:")
for label, count in label_counts.items():
    print(f"{label}: {count/sum}")

Label Counts:
0: 9235
1: 8107
2: 4564
sum:21906
Label %:
0: 0.42157399799141787
1: 0.37008125627681915
2: 0.20834474573176298
