In [68]:
import torch
from PIL import Image
from easydict import EasyDict
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms


def make_dataset(image_list, labels):
    if labels:
        len_ = len(image_list)
        images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
    else:
        if len(image_list[0].split()) > 2:
            images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
        else:
            images = [(val.split()[0], int(val.split()[1])) for val in image_list]
    return images


def rgb_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def l_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('L')

def image_train(resize_size=256, crop_size=224):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.RandomCrop(crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])


def image_test(resize_size=256, crop_size=224):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    return transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(),
        normalize
    ])


def data_load(args):
    ## prepare data
    dsets = {}
    dset_loaders = {}
    train_bs = args.batch_size

    all_txt_src = ImageFolder('data/office-home/Art', loader=lambda x: x)

    txt_src = list()
    for x in all_txt_src:
        path = x[0]
        cls = x[1]
        txt_src.append(f"{path} {cls}")

    all_txt_test = ImageFolder('data/office-home/Clipart', loader=lambda x: x)
    txt_test = list()
    for x in all_txt_test:
        path = x[0]
        cls = x[1]
        txt_test.append(f"{path} {cls}")

    if not args.da == 'uda':
        label_map_s = {}
        for i in range(len(args.src_classes)):
            label_map_s[args.src_classes[i]] = i

        new_src = []
        for i in range(len(txt_src)):
            rec = txt_src[i]
            reci = rec.strip().split(' ')
            if int(reci[1]) in args.src_classes:
                line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
                new_src.append(line)
        txt_src = new_src.copy()

        new_tar = []
        for i in range(len(txt_test)):
            rec = txt_test[i]
            reci = rec.strip().split(' ')
            if int(reci[1]) in args.tar_classes:
                if int(reci[1]) in args.src_classes:
                    line = reci[0] + ' ' + str(label_map_s[int(reci[1])]) + '\n'
                    new_tar.append(line)
                else:
                    line = reci[0] + ' ' + str(len(label_map_s)) + '\n'
                    new_tar.append(line)
        txt_test = new_tar.copy()
    if args.trte == "val":
        dsize = len(txt_src)
        tr_size = int(0.9 * dsize)
        tr_txt, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
    else:
        dsize = len(txt_src)
        tr_size = int(0.9 * dsize)
        _, te_txt = torch.utils.data.random_split(txt_src, [tr_size, dsize - tr_size])
        tr_txt = txt_src

    # dsets["source_tr"] = ImageList(tr_txt, transform=image_train())
    # dset_loaders["source_tr"] = DataLoader(dsets["source_tr"], batch_size=train_bs, shuffle=True,
    #                                        num_workers=args.worker, drop_last=False)
    #
    # dsets["source_te"] = ImageList(te_txt, transform=image_test())
    # dset_loaders["source_te"] = DataLoader(dsets["source_te"], batch_size=train_bs, shuffle=True,
    #                                        num_workers=args.worker, drop_last=False)
    #
    # dsets["test"] = ImageList(txt_test, transform=image_test())
    # dset_loaders["test"] = DataLoader(dsets["test"], batch_size=train_bs * 2, shuffle=True, num_workers=args.worker,
    #                                   drop_last=False)
    # print(len(dsets['source_tr']))
    # print(len(dsets['source_te']))
    # print(len(dsets['test']))
    # return dset_loaders


if __name__ == '__main__':
    args = EasyDict({
        "worker": 4,
        "da": "oda",
        "trte": "val",
        "batch_size": 64,
        "src_classes": [i for i in range(25)],
        "tar_classes": [i for i in range(65)],
    })
    dl = data_load(args)

    print(len(dl['source_tr']))
    print(len(dl['source_te']))
    print(len(dl['test']))

980
109
4365
16
2
35


In [77]:
import numpy as np
def make_dataset2(image_list, labels):
    if labels:
        len_ = len(image_list)
        images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
    else:
        if len(image_list[0].split()) > 2:
            images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
        else:
            images = [(val.split()[0], int(val.split()[1])) for val in image_list]
    return images

class ImageList(Dataset):
    def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
        print(image_list[0])
        imgs = make_dataset2(image_list, labels)
        print(imgs[0])

        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        if mode == 'RGB':
            self.loader = rgb_loader
        elif mode == 'L':
            self.loader = l_loader

    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imgs)

all_txt_src = ImageFolder('data/office-home/Art', loader=lambda x: x)

txt_src = list()
for x in all_txt_src:
    path = x[0]
    cls = x[1]
    txt_src.append(f"{path} {cls}")

ids = ImageList(txt_src, transform=image_train())
# print(ids[0])

data/office-home/Art/Alarm_Clock/00001.jpg 0
?
('data/office-home/Art/Alarm_Clock/00001.jpg', 0)
