In [None]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from torch import nn
import numpy as np

In [None]:
BATCH_SIZE = 128
NUM_CLASSES = 9
NUM_EPOCHS = 50
IMAGE_SIZE = 100

In [None]:
preprocess = transforms.Compose([
        transforms.Scale(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
#         transforms.RandomSizedCrop(28),
#         transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [None]:
class Lenet(nn.Module):
    def __init__(self):
        super(Lenet, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5, stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
#             nn.Linear(23*23*16, 1000),
            nn.Linear(11*11*16, 120),
#             nn.Linear(1000, 120),
            nn.Linear(120, 9)
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.log_softmax(x)

In [None]:
class Tnet(nn.Module):
    def __init__(self):
        super(Tnet, self).__init__()
        
#         50*50*4
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 4, 1, stride=1, padding=0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
#         50*50*16
        self.conv3 = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
#         50*50*8
        self.conv5 = nn.Sequential(
            nn.Conv2d(3, 8, 5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
    
        
        
        self.fc = nn.Sequential(
            nn.Linear(25*25*28, 1000),
            nn.Linear(1000, 120),
#             nn.Linear(512, 160),
            nn.Linear(120, 9)
        )
    def inception(self, x):
        a = self.conv1(x)
#         print(a.size())
        b = self.conv3(x)
#         print(b.size())
        c = self.conv5(x)
#         print(c.size())
        x = torch.cat((a, b, c), 1) # 50*50*28
        x = nn.MaxPool2d(2, 2)(x) # 25*25*28
        
        return x
    
    def forward(self, x):
        x = self.inception(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return F.log_softmax(x)


In [None]:
# if __name__ == '__main__':
from torch.autograd import Variable
from torch.optim import Adam
from torchnet.engine import Engine
from torchnet.logger import VisdomPlotLogger, VisdomLogger
from torchvision.utils import make_grid
from torchvision.datasets.mnist import MNIST
from tqdm import tqdm
from tqdm import tqdm_notebook
import torchnet as tnt
from torch.utils.data import DataLoader
from torchvision import datasets

In [None]:
from PIL import Image
def grayimage_loader(path):
    return Image.open(path).convert('L')

In [None]:
# initialize model
model = Tnet().cuda()
print('# parameters: ', sum(param.numel() for param in model.parameters()))

In [None]:
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# criterion = nn.MSELoss(size_average=False)

In [None]:
engine = Engine()
meter_loss = tnt.meter.AverageValueMeter()
meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
confusion_meter = tnt.meter.ConfusionMeter(NUM_CLASSES, normalized=True)


In [None]:
train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
train_error_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'})
test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'})
test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'})
confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion Matrix',
                                                  'columnnames': list(range(NUM_CLASSES)),
                                                  'rownames': list(range(NUM_CLASSES)),
                                                  })
# ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'})
# reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})


In [None]:
def augmentation(x, max_shift=2):
    _, _, height, width = x.size()

    h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2)
    source_height_slice = slice(max(0, h_shift), h_shift + height)
    source_width_slice = slice(max(0, w_shift), w_shift + width)
    target_height_slice = slice(max(0, -h_shift), -h_shift + height)
    target_width_slice = slice(max(0, -w_shift), -w_shift + width)

    shifted_image = torch.zeros(*x.size())
    shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice]
    return shifted_image.float()

In [None]:
def get_iterator(mode):
#     dataset = MNIST(root='../data', download=True, train=mode)
    dataset = {x: datasets.ImageFolder('D:/Anaconda3/Scripts/lwz/category_new_9/' + x, preprocess)
              for x in ['train', 'test']}
#     data = dataset['train' if mode else 'test']
#     labels = getattr(dataset[], 'train' if mode else 'test')
#     tensor_dataset = tnt.dataset.TensorDataset([data, labels])
    dataloader = DataLoader(dataset['train' if mode else 'test'], 
                            num_workers=4, 
                            batch_size=BATCH_SIZE, 
                            shuffle=mode)
#     data, labels = dataloader
#     tensor_dataset = tnt.dataset.TensorDataset([data, labels])
    
#     print(tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode))
#     return tensor_dataset.parallel(batch_size=BATCH_SIZE, num_workers=4, shuffle=mode)
    return dataloader

In [None]:
def processor(sample):
    data, labels, training = sample
    
#     data = augmentation(data.unsqueeze(1).float() / 255.0)
#     data = preprocess(data)
#     data.unsqueeze_(0)
#     data = data.unsqueeze(1).float() / 255.0
#     labels = torch.LongTensor(labels)
    
#     labels = torch.sparse.torch.eye(NUM_CLASSES).index_select(dim=0, index=labels)
    
    data = Variable(data.cuda())
    labels = Variable(labels.cuda())
    
    if training:
        model.train(True)
    else:
        model.train(False)
        
    outputs = model(data)
    loss = criterion(outputs, labels)
    
    return loss, outputs
    

In [None]:
def reset_meters():
    meter_accuracy.reset()
    meter_loss.reset()
    confusion_meter.reset()

def on_sample(state):
    state['sample'].append(state['train'])
    
def on_forward(state):
    meter_accuracy.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    confusion_meter.add(state['output'].data, torch.LongTensor(state['sample'][1]))
    meter_loss.add(state['loss'].data[0])
    
def on_start_epoch(state):
    reset_meters()
    state['iterator'] = tqdm_notebook(state['iterator'])
#     state['iterator'] = state['iterator']

def on_end_epoch(state):
    print('[Epoch {}] Training Loss: {:.4f} (Acc: {:.2f})'.format(
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]
    ))
    
    train_loss_logger.log(state['epoch'], meter_loss.value()[0])
    train_error_logger.log(state['epoch'], meter_accuracy.value()[0])
    
    reset_meters()
    
    engine.test(processor, get_iterator(False))
    test_loss_logger.log(state['epoch'], meter_loss.value()[0])
    test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])
    confusion_logger.log(confusion_meter.value())
    
    print('[Epoch {}] Testing Loss: {:4f} (Acc: {:.2f})'.format(
        state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]
    ))
    
    if state['epoch'] % 10 == 0:
        torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch'])
    
    test_sample = next(iter(get_iterator(False)))
    
#     ground_truth = (test_sample[0].unsqueeze(1).float() / 255.0)
#     _, reconstructions = model(Variable(ground_truth).cuda())
#     reconstruction = reconstructions.cpu().view_as(ground_truth).data

#     ground_truth_logger.log(
#             make_grid(ground_truth, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
#     reconstruction_logger.log(
#             make_grid(reconstruction, nrow=int(BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())


In [None]:
engine.hooks['on_sample'] = on_sample
engine.hooks['on_forward'] = on_forward
engine.hooks['on_start_epoch'] = on_start_epoch
engine.hooks['on_end_epoch'] = on_end_epoch

engine.train(processor, get_iterator(True), maxepoch=NUM_EPOCHS, optimizer=optimizer)
