In [448]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [449]:
def load_data(label_names):
    dataset_path = './dataset/hand-drawn-image-dataset/'
    dataset_dict_in = {}
    dataset_dict_out = {}
    train_dataset_list_in, valid_dataset_list_in, test_dataset_list_in = [], [], []
    test_dataset_list_out = []
    for i, label_name in enumerate(label_names):
        data = np.load(dataset_path + label_name + '/' + label_name + '.npy')
        label = np.array([i] * len(data)).reshape(-1, 1)
        dataset  = np.hstack((label, data))
        if label_name in ['ambulance','apple','bear','bicycle','bird','bus','cat']:
            dataset_dict_in[label_name] = dataset
            train_dataset_list_in.append(dataset[:int(len(dataset)*0.7)])
            valid_dataset_list_in.append(dataset[int(len(dataset)*0.7):int(len(dataset)*0.9)])
            test_dataset_list_in.append(dataset[int(len(dataset)*0.9):])
        if label_name in ['foot','owl','pig']:
            dataset_dict_out[label_name] = dataset
            random_idx = np.random.randint(0, len(dataset), 100)
            test_dataset_list_out.append(dataset[random_idx])
    train_dataset_in = np.vstack(train_dataset_list_in)
    valid_dataset_in = np.vstack(valid_dataset_list_in)
    test_dataset_in = np.vstack(test_dataset_list_in)
    test_dataset_out = np.vstack(test_dataset_list_out)
    return dataset_dict_in, dataset_dict_out, train_dataset_in, valid_dataset_in, test_dataset_in, test_dataset_out

In [450]:
label_names=['ambulance','apple','bear','bicycle','bird','bus','cat','foot','owl','pig']
dataset_dict_in, dataset_dict_out, train_dataset_in, valid_dataset_in, test_dataset_in, test_dataset_out = load_data(label_names=label_names)
train_dataset_data_in = train_dataset_in[:, 1:].reshape(-1, 28, 28)
train_dataset_label_in = train_dataset_in[:, 0]
valid_dataset_data_in = valid_dataset_in[:, 1:].reshape(-1, 28, 28)
valid_dataset_label_in = valid_dataset_in[:, 0]
test_dataset_data_in = test_dataset_in[:, 1:].reshape(-1, 28, 28)
test_dataset_label_in = test_dataset_in[:, 0]
test_dataset_data_out = test_dataset_out[:, 1:].reshape(-1, 28, 28)
test_dataset_label_out = test_dataset_out[:, 0]
print(test_dataset_data_out.shape)

(300, 28, 28)


In [451]:
class Image_Dataset(Dataset):
    def __init__(self, data, label, transform=None):
        super(Image_Dataset, self).__init__()
        self.data = torch.tensor(data).float()
        self.label = torch.tensor(label)
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        data = self.data[index]
        label = self.label[index]
        if self.transform:
            data = self.transform(data)
        return data, label
    
class LeNet(torch.nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=0)
        self.s2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv3 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, padding=0)
        self.s4 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv5 = torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, padding=0)
        self.flatten = torch.nn.Flatten()
        self.f6 = torch.nn.Linear(120, 84)
        self.output = torch.nn.Linear(84, 10)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.conv1(x))
        x = self.s2(x)
        x = self.sigmoid(self.conv3(x))
        x = self.s4(x)
        x = self.sigmoid(self.conv5(x))
        x = self.flatten(x)
        x = self.f6(x)
        x = self.output(x)
        return x

In [452]:
batch_size = 128
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Pad(2),
    transforms.ToTensor()])
train_dataset_in = Image_Dataset(train_dataset_data_in, train_dataset_label_in, transform=transform)
valid_dataset_in = Image_Dataset(valid_dataset_data_in, valid_dataset_label_in, transform=transform)
test_dataset_in = Image_Dataset(test_dataset_data_in, test_dataset_label_in, transform=transform)
test_dataset_out = Image_Dataset(test_dataset_data_out, test_dataset_label_out, transform=transform)
train_dataloader_in = DataLoader(train_dataset_in, batch_size=batch_size, shuffle=True)
valid_dataloader_in = DataLoader(valid_dataset_in, batch_size=batch_size, shuffle=True)
test_dataloader_in = DataLoader(test_dataset_in, batch_size=batch_size, shuffle=True)
test_dataloader_out = DataLoader(test_dataset_out, batch_size=batch_size, shuffle=True)

In [453]:
def train(model, epochs, train_dataloader, valid_dataloader, optimizer, criterion):
    loss_list = []
    acc_list = []
    for epoch in range(1, epochs+1):
        loss = 0
        for data_batch, label_batch in train_dataloader:
            data_batch = data_batch.to(device)
            label_batch = label_batch.to(device)
            optimizer.zero_grad()
            output = model(data_batch)
            _loss = criterion(output, label_batch)
            loss += _loss.item()
            _loss.backward()
            optimizer.step()
        loss /= len(train_dataloader)
        loss_list.append(loss)
        accuracy = validate(model, valid_dataloader)
        acc_list.append(accuracy)
        if epoch % 2 == 0 or epoch ==  1:
            print('device: {}, epoch: {}, loss: {:.4f}, accuracy: {:.4f}'.format(device, epoch, loss, accuracy))

def validate(model, valid_dataloader):
    correct_pred_nums = 0
    with torch.no_grad():
        for data_batch, label_batch in valid_dataloader:
            data_batch = data_batch.to(device)
            label_batch = label_batch.to(device)
            output = model(data_batch)
            pred = torch.argmax(output, dim=1)
            correct_pred_nums += torch.sum(pred == label_batch)
        accuracy = correct_pred_nums / len(valid_dataloader.dataset)
        return accuracy

In [454]:
epochs = 1
lr = 0.1

lenet = LeNet().to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(lenet.parameters(), lr=lr)
train(model=lenet, epochs=epochs, train_dataloader=train_dataloader_in, valid_dataloader=valid_dataloader_in, optimizer=optimizer, criterion=criterion)
test_accuracy_in = validate(model=lenet, valid_dataloader=test_dataloader_in)
test_accuracy_out = validate(model=lenet, valid_dataloader=test_dataloader_out)
print(f"Test accuracy in: {test_accuracy_in}")
print(f"Test accuracy out: {test_accuracy_out}")

device: cpu, epoch: 1, loss: 1.9605, accuracy: 0.1429
Test accuracy in: 0.1428571492433548
Test accuracy out: 0.0
