In [11]:
import os
import numpy as np
import argparse
import time
import utils
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from PIL import Image
from glob import glob
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torch.autograd import Variable
from MobileNetV2 import MobileNetV2
np.set_printoptions(threshold=np.nan)
torch.cuda.manual_seed_all(50)

class MobileNet(nn.Module):
    
    def __init__(self):
        
        super(MobileNet, self).__init__()
        
        self.net = MobileNetV2()
        
        state_dict = torch.load('./mobilenet_v2.pth.tar', map_location="cuda:0")
        self.net.load_state_dict(state_dict)

        self.net.classifier[-1] = nn.Linear(1280,10)
        
    def forward(self, x):
        x = self.net(x)
        return x
    
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True)
            )

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),
    
                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.model = nn.Sequential(
            conv_bn(  3,  32, 2), 
            conv_dw( 32,  64, 1),
            conv_dw( 64, 128, 2),
            conv_dw(128, 128, 1),
            conv_dw(128, 256, 2),
            conv_dw(256, 256, 1),
            conv_dw(256, 512, 2),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 512, 1),
            conv_dw(512, 1024, 2),
            conv_dw(1024, 1024, 1),
            nn.AvgPool2d(7),
        )
        self.fc = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x
        

def main():
    
    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )
    
    data_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    normalize])
    
    dataset = torchvision.datasets.FashionMNIST(root="/home/kn15263s/workspace/FashionMNIST",
                                                transform=data_transforms)
    
    num_total = len(dataset)
    shuffle = np.random.permutation(num_total)
    split_val = int(num_total * 0.2)

    train_idx, valid_idx = shuffle[split_val:], shuffle[:split_val]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    trainset_ld = DataLoader(dataset, batch_size=32, sampler=train_sampler, num_workers=4)
    validset_ld = DataLoader(dataset, batch_size=32, sampler=valid_sampler, num_workers=4)
    
    modelname = './{}.pth.tar'.format("MobileNet")
    loggername = modelname.replace("pth.tar", "log")
    logger = utils.buildLogger(loggername)
    
    # ---- hyperparameters ----
    lr = 0.01
    momentum = 0.5
    weight_decay = 1e-4
    factor = 0.1
    Epoch = 50
    # -------------------- SETTINGS: NETWORK ARCHITECTURE
    #model = MobileNet().cuda()
    model = Net().cuda()

    model = torch.nn.DataParallel(model).cuda()

    logger.info("Build Model Done")

    # -------------------- SETTINGS: OPTIMIZER & SCHEDULER --------------------
    optimizer = optim.SGD(filter(lambda x: x.requires_grad, model.parameters()),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay,
                          nesterov=False)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     factor=factor,
                                                     patience=10, mode='min')

    logger.info("Build Optimizer Done")

    # -------------------- SETTINGS: LOSS -------------------------
    loss = nn.CrossEntropyLoss()

    lossMIN = 999999999999.0

    # ---- TRAIN THE NETWORK
    for epochID in range(0, Epoch):
        start_time = time.time()

        model.train()

        for batchID, (input, target) in enumerate(trainset_ld):
            varInput = Variable(input).cuda(async=True)
            varTarget = Variable(target).cuda(async=True)
            varOutput = model(varInput)

            lossvalue = loss(varOutput, varTarget)
            optimizer.zero_grad()
            lossvalue.backward()
            optimizer.step()

        model.eval()

        lossVal = 0
        lossValNorm = 0

        for batchID, (input, target) in enumerate(validset_ld):
            with torch.no_grad():
                varInput = Variable(input).cuda(async=True)
                varTarget = Variable(target).cuda(async=True)
                varOutput = model(varInput)

                losstensor = loss(varOutput, varTarget)

                lossVal += losstensor.item()
                lossValNorm += 1

        outLoss = lossVal / lossValNorm

        scheduler.step(outLoss, epoch=epochID)

        if outLoss < lossMIN:

            lossMIN = outLoss

            logger.info('Epoch [' + str(epochID + 1) + '] [save] loss= {:.5f}'.format(outLoss) +
                        ' ---- model: {}'.format(modelname) +
                        ' ---- time: {:.1f} s'.format((time.time() - start_time)))

            torch.save({'epoch': epochID + 1,
                        'state_dict': model.state_dict(),
                        'best_loss': lossMIN,
                        'optimizer': optimizer.state_dict(),
                        # 'scheduler_best': scheduler.best,
                        # 'scheduler_cooldown_counter': scheduler.cooldown_counter,
                        # 'scheduler_num_bad_epochs': scheduler.num_bad_epochs,
                        }, modelname)

        else:
            logger.info('Epoch [' + str(epochID + 1) + '] [----] loss= {:.5f}'.format(outLoss) +
                        ' ---- loss_min= {:.5f}'.format(lossMIN) +
                        ' ---- time: {:.1f} s'.format((time.time() - start_time)))

if __name__ == '__main__':
    main()



2019-03-05 13:28:52,658 - root - INFO - Build Model Done
2019-03-05 13:28:52,658 - root - INFO - Build Model Done
2019-03-05 13:28:52,660 - root - INFO - Build Optimizer Done
2019-03-05 13:28:52,660 - root - INFO - Build Optimizer Done
2019-03-05 13:30:02,668 - root - INFO - Epoch [1] [save] loss= 0.44280 ---- model: ./MobileNet.pth.tar ---- time: 70.0 s
2019-03-05 13:30:02,668 - root - INFO - Epoch [1] [save] loss= 0.44280 ---- model: ./MobileNet.pth.tar ---- time: 70.0 s
2019-03-05 13:31:12,729 - root - INFO - Epoch [2] [save] loss= 0.37773 ---- model: ./MobileNet.pth.tar ---- time: 70.0 s
2019-03-05 13:31:12,729 - root - INFO - Epoch [2] [save] loss= 0.37773 ---- model: ./MobileNet.pth.tar ---- time: 70.0 s
2019-03-05 13:32:22,722 - root - INFO - Epoch [3] [----] loss= 0.38345 ---- loss_min= 0.37773 ---- time: 69.9 s
2019-03-05 13:32:22,722 - root - INFO - Epoch [3] [----] loss= 0.38345 ---- loss_min= 0.37773 ---- time: 69.9 s
2019-03-05 13:33:32,874 - root - INFO - Epoch [4] [save]

KeyboardInterrupt: 