### SOM (Self-Organizing Map)

In [1]:
import os
import time
import torch
import argparse
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
from som import SOM

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
parser = argparse.ArgumentParser(description='Self Organizing Map')
parser.add_argument('--color', dest='dataset', action='store_const',
                    const='color', default=None,
                    help='use color')
parser.add_argument('--mnist', dest='dataset', action='store_const',
                    const='mnist', default=None,
                    help='use mnist dataset')
parser.add_argument('--fashion_mnist', dest='dataset', action='store_const',
                    const='fashion_mnist', default=None,
                    help='use mnist dataset')
parser.add_argument('--train', action='store_const',
                    const=True, default=False,
                    help='train network')
parser.add_argument('--dataset', type=str, default='mnist', help='dataset name')
parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
parser.add_argument('--lr', type=float, default=0.3, help='input learning rate')
parser.add_argument('--epoch', type=int, default=100, help='input total epoch')
parser.add_argument('--data_dir', type=str, default='./datasets', help='set a data directory')
parser.add_argument('--res_dir', type=str, default='./results', help='set a result directory')
parser.add_argument('--model_dir', type=str, default='./model', help='set a model directory')
parser.add_argument('--row', type=int, default=20, help='set SOM row length')
parser.add_argument('--col', type=int, default=20, help='set SOM col length')
args = parser.parse_args([])

In [4]:
import torch
import torch.nn as nn
from torchvision.utils import save_image

class SOM(nn.Module):
    def __init__(self, input_size, out_size=(10, 10), lr=0.3, sigma=None):
        '''
        parameter들의 input size, output size, learning rate, sigma 이용
        
        '''
        super(SOM, self).__init__()
        self.input_size = input_size
        self.out_size = out_size

        self.lr = lr
        if sigma is None:
            self.sigma = max(out_size) / 2
        else:
            self.sigma = float(sigma)

        self.weight = nn.Parameter(torch.randn(input_size, out_size[0] * out_size[1]), requires_grad=False)
        self.locations = nn.Parameter(torch.Tensor(list(self.get_map_index())), requires_grad=False)
        self.pdist_fn = nn.PairwiseDistance(p=2)

    def get_map_index(self):
        '''
        2차원 매핑 함수의 이용
        '''
        for x in range(self.out_size[0]):
            for y in range(self.out_size[1]):
                yield (x, y)

    def _neighborhood_fn(self, input, current_sigma):
        '''
        e^(-(input / sigma^2))
        '''
        input.div_(current_sigma ** 2)
        input.neg_()
        input.exp_()

        return input

    def forward(self, input):
        '''
        best matching unit(bmu) 위치 탐색
        : parameter 입력값 : 데이터로 지정
        :return: location of best matching unit, loss
        '''
        batch_size = input.size()[0]
        input = input.view(batch_size, -1, 1)
        batch_weight = self.weight.expand(batch_size, -1, -1)

        dists = self.pdist_fn(input, batch_weight)
        # bmu 탐색
        losses, bmu_indexes = dists.min(dim=1, keepdim=True)
        bmu_locations = self.locations[bmu_indexes]

        return bmu_locations, losses.sum().div_(batch_size).item()

    def self_organizing(self, input, current_iter, max_iter):
        '''
        Self Oranizing Map(SOM)을 이용하여 학습 진행
        :param input: 학습 데이터
        :param current_iter: 전체 epoch 중 현재 epoch
        :param max_iter: 전체 epoch
        :return: loss (최소 거리 반환)
        '''
        batch_size = input.size()[0]
        # learning rate 설정
        iter_correction = 1.0 - current_iter / max_iter
        lr = self.lr * iter_correction
        sigma = self.sigma * iter_correction

        # best matching unit 탐색
        bmu_locations, loss = self.forward(input)

        # 마할라노비스 거리에 기반하여 계산
        distance_squares = self.locations.float() - bmu_locations.float()
        distance_squares.pow_(2)
        distance_squares = torch.sum(distance_squares, dim=2)
        
        # learning rate에 기반하여 각 노드의 위치 계산
        lr_locations = self._neighborhood_fn(distance_squares, sigma)
        lr_locations.mul_(lr).unsqueeze_(1)

        # 델타 계산
        delta = lr_locations * (input.unsqueeze(2) - self.weight)
        delta = delta.sum(dim=0)
        delta.div_(batch_size)
        self.weight.data.add_(delta)

        return loss

    def save_result(self, dir, im_size=(0, 0, 0)):
        
        '''
        Self Organizing Map(SOM)의 결과 시각화
        : parameter directory : 저장할 경로 지정
        : image size : 채널, x,y 크기 지정
        :return:
        
        '''
        # 이미지의 weight 구하기
        images = self.weight.view(im_size[0], im_size[1], im_size[2], self.out_size[0] * self.out_size[1])

        images = images.permute(3, 0, 1, 2)
        save_image(images, dir, normalize=True, padding=1, nrow=self.out_size[0])

In [5]:
args

Namespace(batch_size=32, col=20, data_dir='./datasets', dataset=None, epoch=100, lr=0.3, model_dir='./model', res_dir='./results', row=20, train=False)

In [6]:
dataset = args.dataset
batch_size = args.batch_size
total_epoch = args.epoch
row = args.row
col = args.col
train = args.train

#### 1. Dataset : MNIST

In [7]:
args.dataset = 'mnist'
# Hyper-parameters
DATA_DIR = args.data_dir + '/' + args.dataset
RES_DIR = args.res_dir + '/' + args.dataset
MODEL_DIR = args.model_dir + '/' + args.dataset

In [8]:
# Create results dir
if not os.path.exists(args.res_dir):
    os.makedirs(args.res_dir)

# Create results/datasetname dir
if not os.path.exists(RES_DIR):
    os.makedirs(RES_DIR)

# Create model dir
if not os.path.exists(args.model_dir):
    os.makedirs(args.model_dir)

# Create model/datasetname dir
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

In [9]:
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

In [10]:
# train_data.train_data = train_data.train_data[:5000]
# train_data.train_labels = train_data.train_labels[:5000]

print('Building Model...')
som = SOM(input_size=28 * 28 * 1, out_size=(row, col))
if os.path.exists('%s/som.pth' % MODEL_DIR):
    som.load_state_dict(torch.load('%s/som.pth' % MODEL_DIR))
    print('Model Loaded!')
else:
    print('Create Model!')
som = som.to(device)

Building Model...
Model Loaded!


In [11]:
if train == True:
    losses = list()
    for epoch in range(total_epoch):
        running_loss = 0
        start_time = time.time()
        for idx, (X, Y) in enumerate(train_loader):
            X = X.view(-1, 28 * 28 * 1).to(device)    # flatten
            loss = som.self_organizing(X, epoch, total_epoch)    # train som
            running_loss += loss

        losses.append(running_loss)
        print('epoch = %d, loss = %.2f, time = %.2fs' % (epoch + 1, running_loss, time.time() - start_time))

        if epoch % 5 == 0:
            # model save
            som.save_result('%s/som_epoch_%d.png' % (RES_DIR, epoch), (1, 28, 28))
            torch.save(som.state_dict(), '%s/som.pth' % MODEL_DIR)

    torch.save(som.state_dict(), '%s/som.pth' % MODEL_DIR)
    plt.title('SOM loss')
    plt.plot(losses)
    plt.show()

som.save_result('%s/som_result.png' % (RES_DIR), (1, 28, 28))
torch.save(som.state_dict(), '%s/som.pth' % MODEL_DIR)

In [12]:
from IPython.display import Image
Image(url='./results/mnist/som_animation.gif')  

#### 2. Dataset : Fashion_MNIST

In [13]:
args.dataset = 'fashion_mnist'
# Hyper-parameters
DATA_DIR = args.data_dir + '/' + args.dataset
RES_DIR = args.res_dir + '/' + args.dataset
MODEL_DIR = args.model_dir + '/' + args.dataset

In [14]:
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

In [15]:
# from som import SOM

# train data:5000까지 cutting 진행
# model load

# train_data.train_data = train_data.train_data[:5000]
# train_data.train_labels = train_data.train_labels[:5000]

print('Building Model...')
som = SOM(input_size=28 * 28 * 1, out_size=(row, col))
if os.path.exists('%s/som.pth' % MODEL_DIR):
    som.load_state_dict(torch.load('%s/som.pth' % MODEL_DIR))
    print('Model Loaded!')
else:
    print('Create Model!')
som = som.to(device)

Building Model...
Model Loaded!


In [16]:
if train == True:
    losses = list()
    for epoch in range(total_epoch):
        running_loss = 0
        start_time = time.time()
        for idx, (X, Y) in enumerate(train_loader):
            X = X.view(-1, 28 * 28 * 1).to(device)    # flatten
            loss = som.self_organizing(X, epoch, total_epoch)    # train som
            running_loss += loss

        losses.append(running_loss)
        print('epoch = %d, loss = %.2f, time = %.2fs' % (epoch + 1, running_loss, time.time() - start_time))

        if epoch % 5 == 0:
            # model save
            som.save_result('%s/som_epoch_%d.png' % (RES_DIR, epoch), (1, 28, 28))
            torch.save(som.state_dict(), '%s/som.pth' % MODEL_DIR)

    torch.save(som.state_dict(), '%s/som.pth' % MODEL_DIR)
    plt.title('SOM loss')
    plt.plot(losses)
    plt.show()

som.save_result('%s/som_result.png' % (RES_DIR), (1, 28, 28))
torch.save(som.state_dict(), '%s/som.pth' % MODEL_DIR)

In [17]:
from IPython.display import Image
Image(url='./results/fashion_mnist/som_animation.gif')  

##### Reference List
- https://ratsgo.github.io/machine%20learning/2017/05/01/SOM/
- https://github.com/FlorisHoogenboom/som-anomaly-detector