存储对应的图像和表格数据，加快读取速度

In [1]:
import os
import torch
from PIL import Image
import pandas as pd
import numpy as np
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


blood

In [2]:
transform_img = transforms.Compose([
    transforms.Resize([240, 240]),
    transforms.ToTensor(),
])

In [3]:
def save_images(mode, opt_dict, save_dir):
    img_data = {}
    base_path = opt_dict['dataset_config']['dataset']
    img_dir = os.path.join(base_path, '建模组' if mode == 'train' else '验证组')

    for class_dir in os.listdir(img_dir):
        class_path = os.path.join(img_dir, class_dir)
        if not os.path.isdir(class_path):
            continue

        for barcode in os.listdir(class_path):
            barcode_path = os.path.join(class_path, barcode)
            wdf_path = os.path.join(barcode_path, 'WDF.png')

            if os.path.exists(wdf_path):
                image = Image.open(wdf_path).convert('RGB')
                image_tensor = transform_img(image) 
                img_data[barcode] = {
                    'image': image_tensor,
                    'label': int(class_dir) if int(class_dir) != 7 else 6
                }

    
    torch.save(img_data, os.path.join(save_dir, f'{mode}_images.pt'))

In [4]:
def save_tabular(mode, opt_dict, save_dir):
    df = pd.read_csv(opt_dict['dataset_config']['dataset_tabular'])
    df = df[df['train_val'] == ('j' if mode == 'train' else 'y')]
    
    cols_to_drop = ['train_val']
    if 'Unnamed: 0' in df.columns:
        cols_to_drop.append('Unnamed: 0')
    df = df.drop(columns=cols_to_drop, errors='ignore')

    tab_data = {}
    for _, row in df.iterrows():
        try:
            # import ipdb;ipdb.set_trace();
            target = row['target']
            row = row.drop('target')
            barcode = row['barcode']
            features = torch.tensor(row.drop('barcode').values.astype(np.float32)) 
            label = torch.tensor(target, dtype=torch.long)
            tab_data[barcode] = {'features': features, 'label': label}
        except Exception as e:
            print(f"处理条码{barcode}时出错: {str(e)}")
            continue

    torch.save(tab_data, os.path.join(save_dir, f'{mode}_tabular.pt'))
    print(f"{mode}表格数据保存成功，样本数: {len(tab_data)}")

In [5]:
def save_dataset_files(opt_dict, save_dir):
    os.makedirs(save_dir, exist_ok=True)

    save_images('train', opt_dict, save_dir)
    save_images('test', opt_dict, save_dir)
    print("image finish")

    save_tabular('train', opt_dict, save_dir)
    save_tabular('test', opt_dict, save_dir)
    print("table finish")

In [6]:
if __name__ == '__main__':
    opt_dict = {
        'dataset_config': {
            'dataset': '/data/blood_dvm/data/blood/dataset/图像数据/',
            'dataset_tabular': '/data/blood_dvm/data/blood/blood_imputation_result/orign/mostfreq_mean.csv',
        }
    }
    save_dir = "/data/blood_dvm/data/blood/"
    save_dataset_files(opt_dict, save_dir)

image finish
train表格数据保存成功，样本数: 655
test表格数据保存成功，样本数: 717
table finish


In [7]:
train_img = torch.load("/data/blood_dvm/data/blood/train_tabular.pt")

In [8]:
train_img['11815997300']['features']

tensor([ 0.0000,  1.0000,  0.0000,  0.0000,  1.0000,  1.0000,  0.0000,  0.0000,
         1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.0000,  2.0000,
         1.0000,  1.0000,  0.0000,  5.6413,  0.2296, -0.1397, -0.3466, -0.9947,
        -0.7427, -1.5324,  0.7217,  0.9054,  1.5941,  1.2783,  1.2911,  1.1036,
         0.8196,  0.3292,  0.4868, -0.8531, -0.3064, -0.0658, -0.8670, -0.2937,
        -0.0788])

dvm

In [1]:
import os
import torch
from PIL import Image
import pandas as pd
import numpy as np
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform_img_train = transforms.Compose([
    transforms.RandomResizedCrop(size=240, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8)], p=0.8),     
    transforms.RandomGrayscale(p=0.2),     
    transforms.ToTensor(),
])
transform_img_test = transforms.Compose([
    transforms.Resize([240, 240]),
    transforms.ToTensor(),
])

In [3]:
def process_data(mode, save_dir):
    df = pd.read_csv(opt_dict['dataset_config']['dataset_tabular'])
    train_val_map = {'train': 'j', 'test': 'y', 'val': 'v'}
    df = df[df['train_val'] == train_val_map[mode]]
    
    tab_data = {}
    # img_data = {}
    need_col = [
        'Adv_year', 'Adv_month', 'Reg_year', 'Runned_Miles', 'Price',
        'Seat_num', 'Door_num', 'Entry_price', 'Engine_size', 'Color',
        'Bodytype', 'Gearbox', 'Fuel_type', 'Wheelbase',
        'Length', 'Width', 'Height'
    ]
    # if mode == 'train':
    #     transform_img = transform_img_train
    # else:
    #     transform_img = transform_img_test
    img_base_path = "/data/blood_dvm/data/dvm/dvm_img/"

    for _, row in df.iterrows():
        adv_id = row['Adv_ID']
        label = row['Genmodel_ID']
        features = torch.tensor(row[need_col].values.astype(np.float32))
        label = torch.tensor(label, dtype=torch.long)
        image_path = os.path.join(img_base_path, mode, str(label.item()) + '_' + row['Image_name'])
        tab_data[adv_id] = {
            'image_path': image_path, 
            'features': features,
            'label': label
        }
        # image = Image.open(image_path).convert('RGB')
        # img_tensor = transform_img(image)

        # img_data[adv_id] = {
        #     'image': img_tensor,
        #     'label': label
        # }

    # torch.save(img_data, os.path.join(save_dir, f'{mode}_image.pt'))
    torch.save(tab_data, os.path.join(save_dir, f'{mode}_tabular.pt'))
    print(f'Saved {len(tab_data)} {mode} tabular samples')

In [4]:
def save_dvm_dataset(opt_dict, save_dir):
    
    os.makedirs(save_dir, exist_ok=True)
    for mode in ['train', 'test', 'val']:
        process_data(mode, save_dir)

In [5]:
if __name__ == '__main__':
    opt_dict = {
        'dataset_config': {
            'dataset': '/data/blood_dvm/data/dvm/dvm_img/',
            'dataset_tabular': '/data/blood_dvm/data/dvm/dvm_table/dvm_data/dvm_orig_standard.csv'
        }
    }
    save_dvm_dataset(opt_dict, '/data/blood_dvm/data/dvm/')

Saved 70565 train tabular samples
Saved 88208 test tabular samples
Saved 17641 val tabular samples
