In [1]:
import torchvision
from torchvision.datasets.vision import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.error import URLError
import os
import numpy as np
import torch
import codecs
from PIL import Image

class myMNIST(VisionDataset):

    mirrors = [
        'http://yann.lecun.com/exdb/mnist/',
        'https://ossci-datasets.s3.amazonaws.com/mnist/',
    ]

    resources = [
        ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
        ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
    ]

    training_file = 'training.pt'
    test_file = 'test.pt'
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

    def __init__(
            self,
            root: str,
            train: bool = True,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            download: bool = False,
    ) -> None:
        super(myMNIST, self).__init__(root, transform=transform, target_transform=target_transform)
        self.train = train  # training set or test set

        if self._check_legacy_exist():
            self.data, self.targets = self._load_legacy_data()
            return

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        self.data, self.targets = self._load_data()
        
        import random 
        imbalance_list = random.sample(range(0,10), 5)

        new_data = []
        new_target = []

        for i in range(len(self.targets)):
            if self.targets[i] in imbalance_list:
                if random.random() >= 0.9:
                    new_data.append(self.data[i])
                    new_target.append(self.targets[i])
            else:
                new_data.append(self.data[i])
                new_target.append(self.targets[i])

        self.data = new_data
        self.targets = new_target


    def _check_legacy_exist(self):
        processed_folder_exists = os.path.exists(self.processed_folder)
        if not processed_folder_exists:
            return False

        return all(
            check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file)
        )

    def _load_legacy_data(self):
        # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
        # directly.
        data_file = self.training_file if self.train else self.test_file
        return torch.load(os.path.join(self.processed_folder, data_file))

    def _load_data(self):
        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
        data = read_image_file(os.path.join(self.raw_folder, image_file))

        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
        targets = read_label_file(os.path.join(self.raw_folder, label_file))

        return data, targets

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.data)

    @property
    def raw_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def processed_folder(self) -> str:
        return os.path.join(self.root, self.__class__.__name__, 'processed')

    @property
    def class_to_idx(self) -> Dict[str, int]:
        return {_class: i for i, _class in enumerate(self.classes)}

    def _check_exists(self) -> bool:
        return all(
            check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]))
            for url, _ in self.resources
        )

    def download(self) -> None:
        """Download the MNIST data if it doesn't exist already."""

        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)

        # download files
        for filename, md5 in self.resources:
            for mirror in self.mirrors:
                url = "{}{}".format(mirror, filename)
                try:
                    print("Downloading {}".format(url))
                    download_and_extract_archive(
                        url, download_root=self.raw_folder,
                        filename=filename,
                        md5=md5
                    )
                except URLError as error:
                    print(
                        "Failed to download (trying next):\n{}".format(error)
                    )
                    continue
                finally:
                    print()
                break
            else:
                raise RuntimeError("Error downloading {}".format(filename))

    def extra_repr(self) -> str:
        return "Split: {}".format("Train" if self.train is True else "Test")
    

class myFashionMNIST(myMNIST):
    mirrors = [
        "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"
    ]

    resources = [
        ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
        ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"),
        ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"),
        ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310")
    ]
    classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

def get_int(b: bytes) -> int:
    return int(codecs.encode(b, 'hex'), 16)

    
SN3_PASCALVINCENT_TYPEMAP = {
    8: (torch.uint8, np.uint8, np.uint8),
    9: (torch.int8, np.int8, np.int8),
    11: (torch.int16, np.dtype('>i2'), 'i2'),
    12: (torch.int32, np.dtype('>i4'), 'i4'),
    13: (torch.float32, np.dtype('>f4'), 'f4'),
    14: (torch.float64, np.dtype('>f8'), 'f8')
}

    
def read_image_file(path: str) -> torch.Tensor:
    x = read_sn3_pascalvincent_tensor(path, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 3)
    return x


def read_label_file(path: str) -> torch.Tensor:
    x = read_sn3_pascalvincent_tensor(path, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 1)
    return x.long()


def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor:
    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
       Argument may be a filename, compressed filename, or file object.
    """
    # read
    with open(path, "rb") as f:
        data = f.read()
    # parse
    magic = get_int(data[0:4])
    nd = magic % 256
    ty = magic // 256
    assert 1 <= nd <= 3
    assert 8 <= ty <= 14
    m = SN3_PASCALVINCENT_TYPEMAP[ty]
    s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
    parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
    assert parsed.shape[0] == np.prod(s) or not strict
    return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)

In [13]:
import torchvision.transforms as transforms

batch_size = 128

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((64, 64))])

# mymnist_trainset = myMNIST(root='./data1', train=True, download=True)
# mymnist_trainloader = torch.utils.data.DataLoader(mymnist_trainset, batch_size=batch_size, shuffle=True, num_workers=2 ,pin_memory=True)

myfmnist_trainset = myFashionMNIST(root='./data1', train=True, download=True, transform=transform)
myfmnist_trainloader = torch.utils.data.DataLoader(myfmnist_trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

In [14]:
# print(len(mymnist_trainset.data))
print(len(myfmnist_trainset.data))

32961


In [15]:
# num_of_class_MNIST = [0,0,0,0,0,0,0,0,0,0]
num_of_class_fMNIST = [0,0,0,0,0,0,0,0,0,0]

# for i in mymnist_trainset.targets:
#     num_of_class_MNIST[i] += 1
for i in myfmnist_trainset.targets:
    num_of_class_fMNIST[i] += 1
    
# print(num_of_class_MNIST)
print(num_of_class_fMNIST)

[6000, 6000, 618, 6000, 6000, 592, 6000, 596, 534, 621]


In [16]:
# mymnist_testset = torchvision.datasets.MNIST(root='./data1', train=False, download=True)
# mymnist_testloader = torch.utils.data.DataLoader(mymnist_testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

myfmnist_testset = torchvision.datasets.FashionMNIST(root='./data1', train=False, download=True, transform=transform)
myfmnist_testloader = torch.utils.data.DataLoader(myfmnist_testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [17]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, 5)
        self.conv2 = nn.Conv2d(64, 64, 5)
        
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.conv4 = nn.Conv2d(128, 128, 3)
        
        self.conv5 = nn.Conv2d(128, 128, 3)
        self.conv6 = nn.Conv2d(128, 128, 3)
        
        self.pool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(128 * 4 * 4, 2048)
        self.fc2 = nn.Linear(2048, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc4 = nn.Linear(512, 128)
        self.fc5 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.5)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool(x)
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.pool(x)
        x = torch.flatten(x, 1) # 배치를 제외한 모든 차원을 평탄화(flatten)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.dropout(F.relu(self.fc2(x)))
        x = self.dropout(F.relu(self.fc3(x)))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x


net = Net()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

Net(
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (conv6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=2048, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=512, bias=True)
  (fc4): Linear(in_features=512, out_features=128, bias=True)
  (fc5): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [18]:
import torch.optim as optim

# criterion = nn.CrossEntropyLoss()
def FL(predict, label, alpha=.25, gamma=2.0):
    predict_sm = torch.softmax(predict, dim=1)
    label = F.one_hot(label, num_classes=len(predict[0]))
    factor = alpha * (1-predict_sm)**gamma
    return torch.mean((factor * (label * -torch.log(predict_sm))).sum(dim=1))
# def CEE(predict, label):
#     delta = 1e-7
#     predict_sm = torch.softmax(predict, dim=1)
#     label = F.one_hot(label, num_classes=len(predict[0]))
#     return torch.mean((label * -torch.log(predict_sm + delta)).sum(dim=1))
optimizer = optim.Adam(net.parameters(), lr=0.0001)

In [19]:
train_loss_history, train_acc_history = [], []
# valid_loss_history, valid_acc_history = [], []

In [20]:
epochs = 30

for epoch in range(epochs):   # 데이터셋을 수차례 반복합니다.
    train_loss = 0.0
    train_acc = 0.0
    valid_loss = 0.0
    valid_acc = 0.0
    
    train_samples = 0
    valid_samples = 0
    
    for inputs, labels in myfmnist_trainloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # 변화도(Gradient) 매개변수를 0으로 만들고
        optimizer.zero_grad()

        # 순전파 + 역전파 + 최적화를 한 후
        outputs = net(inputs)
        loss = FL(outputs, labels)
        loss.backward()
        optimizer.step()
#         scheduler.step()

        _, preds = torch.max(outputs, 1)
        train_loss += loss.item()
        train_acc += torch.sum(preds == labels.data)
        train_samples += len(inputs)
    
#     else:
#         # 훈련팔 필요가 없으므로 메모리 절약
#         with torch.no_grad():
#             for valid_input, valid_label in valid_loader:
#                 valid_input, valid_label = valid_input.to(device), valid_label.to(device)
#                 valid_outputs = net(valid_input)
#                 valid_loss = criterion(valid_outputs, valid_label)

#                 _, valid_preds = torch.max(valid_outputs, 1)
#                 valid_loss += valid_loss.item()
#                 valid_acc += torch.sum(valid_preds == valid_label.data)
#                 valid_samples += len(valid_input)
                
    epoch_loss = train_loss / len(myfmnist_trainloader)
    epoch_acc = train_acc.float() / train_samples * 100
    train_loss_history.append(epoch_loss)
    train_acc_history.append(epoch_acc)

#     valid_epoch_loss = valid_loss * 10 / len(valid_loader)
#     valid_epoch_acc = valid_acc.float() / valid_samples * 100
#     valid_loss_history.append(valid_epoch_loss)
#     valid_acc_history.append(valid_epoch_acc)

#     if (epoch + 1) % 5 == 0:
    print(f"epoch: {epoch + 1} || tl: {epoch_loss:.3f} | ta: {epoch_acc:.3f}")

print('Finished Training')

epoch: 1 || tl: 0.249 | ta: 39.434
epoch: 2 || tl: 0.090 | ta: 72.637
epoch: 3 || tl: 0.074 | ta: 75.526
epoch: 4 || tl: 0.066 | ta: 77.583
epoch: 5 || tl: 0.061 | ta: 78.948
epoch: 6 || tl: 0.057 | ta: 80.295
epoch: 7 || tl: 0.052 | ta: 81.454
epoch: 8 || tl: 0.048 | ta: 82.616
epoch: 9 || tl: 0.045 | ta: 83.650
epoch: 10 || tl: 0.042 | ta: 84.694
epoch: 11 || tl: 0.040 | ta: 85.161
epoch: 12 || tl: 0.037 | ta: 86.208
epoch: 13 || tl: 0.035 | ta: 86.593
epoch: 14 || tl: 0.034 | ta: 87.000
epoch: 15 || tl: 0.032 | ta: 87.461
epoch: 16 || tl: 0.030 | ta: 88.074
epoch: 17 || tl: 0.029 | ta: 88.495
epoch: 18 || tl: 0.027 | ta: 88.917
epoch: 19 || tl: 0.025 | ta: 89.303
epoch: 20 || tl: 0.024 | ta: 89.733
epoch: 21 || tl: 0.023 | ta: 89.982
epoch: 22 || tl: 0.021 | ta: 90.613
epoch: 23 || tl: 0.020 | ta: 90.880
epoch: 24 || tl: 0.019 | ta: 91.193
epoch: 25 || tl: 0.018 | ta: 91.630
epoch: 26 || tl: 0.017 | ta: 92.066
epoch: 27 || tl: 0.016 | ta: 92.327
epoch: 28 || tl: 0.015 | ta: 92.840
e

In [21]:
correct = 0
total = 0
# 학습 중이 아니므로, 출력에 대한 변화도를 계산할 필요가 없습니다
with torch.no_grad():
    for images, labels in myfmnist_testloader:
        images, labels = images.to(device), labels.to(device)
        
        # 신경망에 이미지를 통과시켜 출력을 계산합니다
        outputs = net(images)
        # 가장 높은 값(energy)를 갖는 분류(class)를 정답으로 선택하겠습니다
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

Accuracy of the network on the 10000 test images: 88 %


In [22]:
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# 변화도는 여전히 필요하지 않습니다
with torch.no_grad():
    for images, labels in myfmnist_testloader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # 각 분류별로 올바른 예측 수를 모읍니다
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# 각 분류별 정확도(accuracy)를 출력합니다
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(accuracy, end=' ')
#     print("Accuracy for class {:5s} is: {:.1f} %".format(classname, accuracy))

89.8 97.8 69.1 92.5 89.9 94.4 69.2 93.6 92.4 90.3 

In [12]:
# num_of_class_MNIST = [0,0,0,0,0,0,0,0,0,0]
num_of_class_fMNIST = [0,0,0,0,0,0,0,0,0,0]

# for i in mymnist_trainset.targets:
#     num_of_class_MNIST[i] += 1
for i in myfmnist_trainset.targets:
    num_of_class_fMNIST[i] += 1
    
# print(num_of_class_MNIST)
print(num_of_class_fMNIST)

[577, 627, 6000, 606, 6000, 589, 6000, 6000, 6000, 579]


In [None]:
# import matplotlib.pyplot as plt

# plt.figure(figsize=(14,5))
# plt.subplot(1, 2, 1)  
# plt.title("Training and Validation Loss")
# plt.plot(valid_loss_history,label="val")
# plt.plot(train_loss_history,label="train")
# plt.xlabel("Epoch")
# plt.ylabel("Loss")
# plt.legend()

# plt.subplot(1, 2, 2) 
# plt.title("Training and Validation Acc")
# plt.plot(valid_acc_history,label="val")
# plt.plot(train_acc_history,label="train")
# plt.xlabel("Epoch")
# plt.ylabel("Acc")
# plt.legend()
# plt.show()

In [None]:
PATH = './myfmnist_net.pth'
torch.save(net.state_dict(), PATH)

# net = Net()
# net.load_state_dict(torch.load(PATH))
# net.to(device)