## Process Data to Torch Tensor

In [None]:
import json
from collections import Counter
import torch
from PIL import Image
import torchvision.transforms as transforms
import time
from get_label import get_label

In [None]:
data_dir = "../data/medium/"
with open(data_dir + "train_all.json") as f:
    train_data = json.load(f)
with open(data_dir + "test_all.json") as f:
    test_data = json.load(f)

json_dir = "./json/"
with open(json_dir + "word2color.json") as f:
    word2color = json.load(f)
with open(json_dir + "color2label.json") as f:
    color2label = json.load(f)
# with open(json_dir + "friend_label.json") as f:
#     friend_label = json.load(f)

In [None]:
train_data_packed = {}
for subdict in train_data.values():
    for subsubdict in subdict['imgs_tags']:
        train_data_packed.update(subsubdict)
train_data_packed = list(train_data_packed.items())

In [None]:
transform = transforms.Compose([
    transforms.PILToTensor()
])

img_tensor_all = None
img_tensor_sub = None
label_tensor_all = None
first = True
sub_first = True
label_first = True

st = time.time()

minibatch = 500

In [None]:
######## takes about 30min to process 156403 labels ########
for i, (picdir, word) in enumerate(train_data_packed):
    label = get_label(word, word2color, color2label)
    if label==-1:
        continue
    label = torch.tensor([label], dtype=torch.int64)

    if label_first:
        label_tensor_all = label
    else:
        label_tensor_all = torch.cat((label_tensor_all, label), 0)
    label_first = False

    img = Image.open(data_dir + 'train/' + picdir[0:12] + '/' + picdir)
    img_tensor = transform(img.resize((224,224))).reshape(1,3,224,224)

    if sub_first:
        img_tensor_sub = img_tensor
    else:
        img_tensor_sub = torch.cat((img_tensor_sub, img_tensor), 0)
    sub_first = False
    
    if i == len(train_data_packed)-1 or (i % minibatch) == minibatch-1:
        if first:
            img_tensor_all = img_tensor_sub
        else:
            img_tensor_all = torch.cat((img_tensor_all, img_tensor_sub), 0)
        first = False
        sub_first = True
        et = time.time()
        print('complete%d, %d s'%(i+1, int(et-st)))

In [None]:
torch.save(img_tensor_all, '../tensor_data/train_img.pt')
torch.save(label_tensor_all, '../tensor_data/train_label.pt')

## Balance Data

In [None]:
target_samples = {
    0: 17000,
    12: 14000,
    1: 13000,
    6: 13000,
    3: 12000,
    4: 11000,
    8: 10000,
    2: 10000,
    10: 8000,
    7: 7000,
    13: 7000,
    5: 7000,
    11: 5000,
    9: 3800,
}

def sample(all_idx, num_samples):
    samples = torch.randperm(len(all_idx))[:num_samples]
    return all_idx[samples]


img_tensor_all_balanced = None
img_tensor_sub_balanced = None
label_tensor_all_balanced = None
label_tensor_sub_balanced = None
first = True

In [2]:
for label in range(13):
    all_idx = (label_tensor_all==label).nonzero()
    if (all_idx.shape[0] >= target_samples[label]): # downsample
        sample_idx = sample(all_idx, target_samples[label])
        img_tensor_sub_balanced = img_tensor_all[sample_idx]
        label_tensor_sub_balanced = label_tensor_all[sample_idx]
    else: # duplicate more samples
        sample_idx = sample(all_idx, target_samples[label]-all_idx.shape[0])
        img_tensor_sub_balanced = torch.cat((img_tensor_all[all_idx], img_tensor_all[sample_idx]), 0)
        label_tensor_sub_balanced = torch.cat((label_tensor_all[all_idx], label_tensor_all[sample_idx]), 0)
    if first:
        img_tensor_all_balanced = img_tensor_sub_balanced
        label_tensor_all_balanced = label_tensor_sub_balanced
    else:
        img_tensor_all_balanced = torch.cat((img_tensor_all_balanced, img_tensor_sub_balanced), 0)
        label_tensor_all_balanced = torch.cat((label_tensor_all_balanced, label_tensor_sub_balanced), 0)
    first = False

In [3]:
torch.save(img_tensor_all_balanced, '../tensor_data/train_img_balanced.pt')
torch.save(label_tensor_all_balanced, '../tensor_data/train_label_balanced.pt')