In [1]:
import logging
import os
import copy
import math
import pathlib
from datetime import datetime
import csv

# 3rd Party
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms


In [2]:
# Japanese compatible font
plt.rcParams['font.sans-serif'] = "Microsoft YaHei" 

# Computation device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
if torch.cuda.is_available():
    print(f"Using device: {torch.cuda.get_device_name(0)}")
    os.environ['CUDA_LAUNCH_BLOCKING'] = "0"
else:
    print(f"Using device: CPU")

Using device: NVIDIA GeForce GTX 1060 6GB


In [3]:
DATASET_DIR = ''
DATASET_FILE = 'ETL8B2C1.csv'
TEST_EXAMPLES_PER_CLASS = 16
TRAIN_SET_FILE = 'ETL8B2C1_train.csv'
TEST_SET_FILE = 'ETL8B2C1_test.csv'

train_rows = []
test_rows = []
label_count = {}
with open(os.path.join(DATASET_DIR, DATASET_FILE), 'r') as csv_file:
    csvreader = csv.reader(csv_file)
    for data in csvreader:
        l = data[0]
        # img = [int(a) for a in data[1:]]
        if l not in label_count.keys():
            label_count[l] = 0
        if label_count[l] < TEST_EXAMPLES_PER_CLASS:
            label_count[l] += 1
            test_rows.append(data)
        else:
            train_rows.append(data)

with open(os.path.join(DATASET_DIR, TRAIN_SET_FILE), 'w', encoding="utf-8", newline='') as f:
    writer = csv.writer(f)
    writer.writerows(train_rows)

with open(os.path.join(DATASET_DIR, TEST_SET_FILE), 'w', encoding="utf-8", newline='') as f:
    writer = csv.writer(f)
    writer.writerows(test_rows)