In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import os
import time, datetime

import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
from torchvision import datasets, transforms

class ConvNet(nn.Module):
    def __init__(self, n_ch, n_cls):
        super().__init__()
        # RGB세개 1채널, 20개 특징 추출, filter 크기, stride 1
        self.conv1_1 = nn.Conv2d(n_ch, 64, 3, 1, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, 3, 1, padding=1)
        self.conv1_bn = nn.BatchNorm2d(64)
        self.maxp1 = nn.MaxPool2d(2, 2)
        
        self.conv2_1 = nn.Conv2d(64, 64, 3, 1, padding=1)
        self.conv2_2 = nn.Conv2d(64, 64, 3, 1, padding=1)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.maxp2 = nn.MaxPool2d(2, 2)
        
        self.conv3_1 = nn.Conv2d(64, 64, 3, 1, padding=1)
        self.conv3_2 = nn.Conv2d(64, 64, 3, 1, padding=1)
        self.conv3_bn = nn.BatchNorm2d(64)
        self.maxp3 = nn.MaxPool2d(2, 2)
        
        self.conv4_1 = nn.Conv2d(64, 64, 3, 1, padding=1)
        self.conv4_2 = nn.Conv2d(64, 64, 3, 1, padding=1)
        self.conv4_bn = nn.BatchNorm2d(64)
        self.maxp4 = nn.MaxPool2d(2, 2)
        
        self.dense1 = nn.Linear(2*2*64, 128)
        self.dropout1 = nn.Dropout(0.5)
        self.dense2 = nn.Linear(128, n_cls)  
        
    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_bn(self.conv1_2(x)))
        x = self.maxp1(x)
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_bn(self.conv2_2(x)))
        x = self.maxp2(x)
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_bn(self.conv3_2(x)))
        x = self.maxp3(x)
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_bn(self.conv4_2(x)))
        x = self.maxp4(x)
        # flatten
        x = x.view(-1, 2*2*64)
        feature = F.relu(self.dense1(x))
        x = self.dropout1(feature)
        x = self.dense2(x)
        return x

class SemblexTraining:
    def __init__(self, Xs, lr, n_iter, model_name, GPU_idx):
        self.Xs = Xs
        self.lr = lr
        self.n_batch = n_batch
        self.n_iter = n_iter
        self.model_name = model_name
        self.GPU_idx = GPU_idx

    def Reshape4torch(self, img):
        """
        (sample #, height, width, channel) -> (sample #, channel, height, width)
        """
        img = np.transpose(img, (0, 3, 1, 2))
        return img
    
    def GenerateLabel(self, data, cls):
        label = cls*np.ones([data.shape[0]])
        return label
    
    def Split2TV(self, data, label, rate_t_v = 0.9):
        data_num = len(data)
        train_idx = np.random.choice(data_num, int(rate_t_v*data_num), replace = False)
        valid_idx = np.setdiff1d(np.arange(data_num), train_idx)
        return data[train_idx], label[train_idx], data[valid_idx], label[valid_idx]
    
    def Get_device(self, GPU_idx = 3):
        self.device = torch.device("cuda:{}".format(GPU_idx) if torch.cuda.is_available() else "cpu")
        if cuda:
            current_device = torch.cuda.current_device()
            print("Device:", torch.cuda.get_device_name(current_device))
        else:
            print("Device: CPU")
    
    def Define_model_opt(self, n_ch, n_cls, lr = 0.00001):
        self.model = ConvNet(n_ch, n_cls)
        # model = model.cuda()
        self.model = self.model.to(device)
        # if device == 'cuda':
        #     net = torch.nn.DataParallel(net)
        #     cudnn.benchmark = True
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(model.parameters(), lr = lr)
        summary(self.model, (n_ch, 40, 40), device = self.device)
         
    def RandomMinibatch(self, data, label, n_batch = 100):
        batch_idx = np.random.choice(len(data), n_batch, replace = False)
        return data[batch_idx], label[batch_idx]
    
    def Shuffle(x1, x2):
        """
        random shuffle of two paired data -> x, y = shuffle(x, y)
        but, available of one data -> x = shuffle(x, None)
        """
        idx = np.arange(len(x1))
        np.random.shuffle(idx)
        if type(x1) == type(x2):
            return x1[idx], x2[idx] 
        else:
            return x1[idx]
        
    def Torch_Minibatch_Load(self, Xs, Ys, batch_size = 100, shuffle = False):
        x, y = [], []
        for X, Y in zip(Xs, Ys):
            x_i, y_i = RandomMinibatch(X, Y, batch_size)
            x.append(x_i), y.append(y_i)
        x, y = np.concatenate(x), np.concatenate(y)
        if shuffle != False:
            x, y = shuffle(x, y)
        x, y = torch.tensor(x, device=device).float(), torch.tensor(y, device=device).long()
        return x, y
        
    def Training_Process(self, n_iter, batch_size, model_name, save_path = './model/'):
    
        self.loss_hist, self.accr_hist = [], []
        self.val_loss_hist, self.val_accr_hist = [], []

        iter_i = 0

        while True:
            iter_i += 1

            train_x, train_y = Torch_Minibatch_Load(self.train_Xs, self.train_Ys, batch_size, shuffle = True)

            output = model(train_x)
            loss = criterion(output, train_y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if iter_i % 10 == 0:

                _, pred = torch.max(output, 1)

                self.loss_hist.append(loss.tolist())
                self.accr_hist.append((torch.sum(pred == train_y.data).tolist() / len(train_y)))

                with torch.no_grad():
                    valid_x, valid_y = Torch_Minibatch_Load(self.valid_Xs, self.valid_Ys, batch_size)

                    valid_output = model(valid_x)
                    valid_loss = criterion(valid_output, valid_y)

                    _, valid_pred = torch.max(valid_output, 1)

                    self.val_loss_hist.append(valid_loss.tolist())
                    self.val_accr_hist.append((torch.sum(valid_pred == valid_y.data).tolist() / len(valid_y)))


                print("{:05d} | train_loss: {:.5f}, train_accr: {:.3f} | val_loss: {:.5f}, val_accr: {:.3f}".format(iter_i, 
                                                                                                                    self.loss_hist[-1], 
                                                                                                                    self.self.accr_hist[-1], 
                                                                                                                    self.val_loss_hist[-1], 
                                                                                                                    self.val_accr_hist[-1]))

                if iter_i >= 100:  
        #             if np.mean(val_loss_hist[-5 :]) >= np.mean(val_loss_hist[-10:]):
        #                 print('')
        #                 print('Early stopping!!')
        #                 print('..val_loss (avg. within 5): {:.5f} >= val_loss (avg. within 10) {:.5f}'.format(np.mean(val_loss_hist[-5 :]),
        #                                                                                                   np.mean(val_loss_hist[-10:])))

                    if val_loss_hist[-1] == np.min(val_loss_hist):
                        now = datetime.datetime.now()
                        nowDatetime = now.strftime('%y%m%d%H%M')
                        model_full_name = '{}_AB_{}_{:05d}_loss_{:.6f}_val_loss_{:.6f}.pt'.format(model_name, nowDatetime, iter_i, 
                                                                                             np.mean(loss_hist[-3:]),
                                                                                             np.mean(val_loss_hist[-3:]))
                        
                        
                        Generate_folder(save_path)
                        torch.save(model.state_dict(), save_path + model_full_name)
        #                 break  
                if iter_i == n_iter:
                    break
                    
    def PlotHIST(self, model_name, save_path = './results/npy/'):
        
        fig = plt.figure(figsize = (15,30))
        plt.suptitle('Training hist', y = 0.92, fontsize = 20)

        # x_axis = range(1, 10*len(accr_hist)+1, 10)
        x_axis = np.arange(10, 10*len(self.accr_hist)+1, 10)

        plt.subplot(2, 1, 1)
        plt.plot(x_axis, self.accr_hist, 'b-', label = 'Training Accuracy')
        plt.plot(x_axis, self.val_accr_hist, 'r-', label = 'Validation Accuracy')
        plt.xlabel('Iteration', fontsize = 15)
        plt.ylabel('Accuracy', fontsize = 15)
        plt.legend(fontsize = 10)
        plt.grid('on')
        plt.subplot(2, 1, 2)
        plt.plot(x_axis, self.loss_hist, 'b-', label = 'Training Loss')
        plt.plot(x_axis, self.val_loss_hist, 'r-', label = 'Validation Loss')
        plt.xlabel('Iteration', fontsize = 15)
        plt.ylabel('Loss', fontsize = 15)
        # plt.yticks(np.arange(0, 0.25, step=0.025))
        plt.legend(fontsize = 12)
        plt.grid('on')
        plt.show()
        
        Generate_folder(save_path)
        
        np.save(save_path + '{}_accr_hist'.format(model_name), accr_hist)
        np.save(save_path + '{}_val_accr_hist'.format(model_name), val_accr_hist)
        np.save(save_path + '{}_loss_hist'.format(model_name), loss_hist)
        np.save(save_path + '{}_val_loss_hist'.format(model_name), val_loss_hist)
        
    def Run(self, model_dir, hist_dir):
        Xs = []
        for i, X in zip(range(self.n_cls), self.Xs):
            X = self.Reshape4torch(X)
            Y = self.CreateLabel(X, i)
            train_X, train_Y, valid_X, valid_Y = self.Split2TV(X, Y, rate_t_v = 0.9)
            train_Xs.append(train_X), train_Ys.append(train_Y)
            valid_Xs.append(valid_X), valid_Ys.append(valid_Y)
        self.train_Xs, self.train_Ys = np.concatenate(train_Xs), np.concatenate(train_Ys)
        self.valid_Xs, self.valid_Ys = np.concatenate(valid_Xs), np.concatenate(valid_Ys)
        self.Get_device(self.GPU_idx)
        self.Define_model_opt(self.n_ch, self.n_cls, self.lr)
        self.Training_Process(self.n_iter, self.batch_size, self.model_name, save_path = './model/')
        self.PlotHIST(self.model_name, save_path = './results/npy/')    