In [None]:
!pip install pyyaml
!pip install torch torchvision

In [2]:
import yaml
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class DomainAdaptationDataset(Dataset):
    def __init__(self, data_file, domain, label_range):
        self.data = []
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        with open(data_file, 'r') as f:
            for line in f:
                path, label = line.strip().split('\t')
                label = int(label)
                if domain in path and label in label_range:
                    self.data.append((path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path, label = self.data[idx]
        image = Image.open(path).convert('RGB')
        image = self.transform(image)
        return image, label

# YAMLファイルを読み込む関数
def load_yaml(file_path):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

In [12]:
os.path.exists("config")

True

In [20]:
# YAMLファイルを読み込む
config_path = 'config/office.yaml'
config = load_yaml(config_path)

# ソースドメインとターゲットドメインを取得
source_domain = config['data']['dataset']['souce']
target_domain = config['data']['dataset']['target']
batch_size = config['data']['dataloader']['batch_size']

# ラベルセットを定義
n_source_private = config['data']['dataset']['n_source_private']
n_share = config['data']['dataset']['n_share']
n_target_private = config['data']['dataset']['n_target_private']

source_private_labels = set(range(n_source_private))
shared_labels = set(range(n_source_private, n_source_private + n_share))
target_private_labels = set(range(n_source_private + n_share, n_source_private + n_share + n_target_private))

In [21]:
# データセットを作成
source_dataset = DomainAdaptationDataset('data/office/images_and_labels.txt', source_domain, source_private_labels.union(shared_labels))
target_dataset = DomainAdaptationDataset('data/office/images_and_labels.txt', target_domain, target_private_labels.union(shared_labels))

# データを8:2の割合で訓練データとテストデータに分割
target_train_size = int(0.8 * len(target_dataset))
target_test_size = len(target_dataset) - target_train_size
target_train_dataset, target_test_dataset = random_split(target_dataset, [target_train_size, target_test_size])

# データローダを作成
source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
target_train_loader = DataLoader(target_train_dataset, batch_size=batch_size, shuffle=True)
target_test_loader = DataLoader(target_test_dataset, batch_size=batch_size, shuffle=False)

# データローダの動作確認
# データローダの動作確認
for images, labels in source_loader:
    print(images.size(), labels.size())
    break

for images, labels in target_loader:
    print(images.size(), labels.size())
    break

torch.Size([36, 3, 224, 224]) torch.Size([36])
torch.Size([36, 3, 224, 224]) torch.Size([36])


In [42]:
#source_dataset[0][0].size()
source_dataset[0][0].sum().item()
#(source_dataset[0][0]!= 1).sum().item()


88272.703125