In [4]:
import os
import sys
import shutil
import glob

import math
import numpy as np

import torch
import torch.utils.data as torch_data
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

In [25]:
class LSTMDataLoader(torch_data.Dataset):
    def __init__(self, data_list ):
        self.data_list = data_list
        self.image_list, self.label_list = [], []
        self.read_lists()
        
    def read_lists(self):
        for item in self.data_list:
            datum = np.load(item).item()
            self.image_list.append(datum["data"])
            label = datum["class"]
            
            if label == "cr":
                self.label_list.append(1)
            else:
                self.label_list.append(0)
                
    def __getitem__(self, index):
        return tuple([self.image_list[index], self.label_list[index]])


    def __len__(self):
        return len(self.image_list)

In [27]:
data_list = glob.glob('../../../data/multilabel/all/*.npy')
train_dataloader = LSTMDataLoader(data_list[:100])

In [28]:
class ConvLSTMCell(nn.Module):
    def __init__(
            self,
            input_size, input_channel, hidden_channel,
            kernel_size, stride=1, padding=0
    ):
        """
        Initializations
        :param input_size: (int, int): height, width tuple of the input
        :param input_channel: int: number of channels of the input
        :param hidden_channel: int: number of channels of the hidden state
        :param kernel_size: int: size of the filter
        :param stride: int: stride
        :param padding: int: width of the 0 padding
        """

        super(ConvLSTMCell, self).__init__()
        self.n_h, self.n_w = input_size
        self.n_c = input_channel
        self.hidden_channel = hidden_channel

        self.conv_xi = nn.Conv2d(
            self.n_c,
            self.hidden_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        self.conv_xf = nn.Conv2d(
            self.n_c,
            self.hidden_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        self.conv_xo = nn.Conv2d(
            self.n_c,
            self.hidden_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        self.conv_xg = nn.Conv2d(
            self.n_c,
            self.hidden_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        self.conv_hi = nn.Conv2d(
            self.hidden_channel,
            self.hidden_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        self.conv_hf = nn.Conv2d(
            self.hidden_channel,
            self.hidden_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        self.conv_ho = nn.Conv2d(
            self.hidden_channel,
            self.hidden_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        self.conv_hg = nn.Conv2d(
            self.hidden_channel,
            self.hidden_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

        self.conv_hi = nn.Conv2d(
            self.hidden_channel,
            self.hidden_channel,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

    def forward(self, x, hidden_states):
        """
        Forward prop.
        reference: https://arxiv.org/pdf/1506.04214.pdf (3.1)
        :param x: input tensor of shape (n_batch, n_c, n_h, n_w)
        :param hidden_states: (tensor, tensor) for hidden and cell states.
                              Each of shape (n_batch, n_hc, n_hh, n_hw)
        :return: (hidden_state, cell_state)
        """

        hidden_state, cell_state = hidden_states

        xi = self.conv_xi(x)
        hi = self.conv_hi(hidden_state)
        xf = self.conv_xf(x)
        hf = self.conv_hf(hidden_state)
        xo = self.conv_xo(x)
        ho = self.conv_ho(hidden_state)
        xg = self.conv_xg(x)
        hg = self.conv_hg(hidden_state)

        i = torch.sigmoid(xi + hi)
        f = torch.sigmoid(xf + hf)
        o = torch.sigmoid(xo + ho)
        g = torch.tanh(xg + hg)

        cell_state = f * cell_state + i * g
        hidden_state = o * torch.tanh(cell_state)

        return hidden_state, cell_state

    def init_hidden(self, batch_size):
        return (
                torch.zeros(batch_size, self.hidden_channel, self.n_h, self.n_w).cuda(),
                torch.zeros(batch_size, self.hidden_channel, self.n_h, self.n_w).cuda()
               )

In [12]:
def train(model, dataset_loader, epoch, device, optimizer, criterion):
    model.train()
    running_loss = 0.0
    
    for i , (data, target) in enumerate(dataset_loader):
        inputs, labels = data, target
        inputs = inputs.float()
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = F.nll_loss(outputs, labels)

        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        """
        print('Train: [Epoch: {}/{}, Batch: {} ({:.0f}%)]'
              .format(
                  epoch,
                  NUM_EPOCHS,
                  i + 1,
                  i*100/len(train_loader)
              ), end='\r')
        """
    return running_loss


def test(model, dataset_loader, device, criterion):
    model.eval()
    correct = 0
    total = 0
    valid_loss = 0

    with torch.no_grad():
        for i, (data, target) in enumerate(dataset_loader):

            images, labels = data, target
            images = images.float()
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            valid_loss += F.nll_loss(outputs, labels).item()

            pred = outputs.max(1, keepdim=True)[1]
            correct += pred.eq(labels.view_as(pred)).sum().item()

    accuracy = 100 * correct / len(test_loader.dataset)
    return valid_loss, accuracy

In [13]:
if __name__ == '__main__':
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = LeNet()
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, amsgrad=False)
    
    epochs = 30
    curr_epoch = 0
    
    model.to('cuda')
    criterion.to('cuda')
    
    loss_history = []
    acc_history = []
    
    best_accuracy = 0.0
    best_model = -1
    
    train_loader = torch_data.DataLoader(
            BinaryWMLoader(
                data_dir='../data/binary_wl/',split='train'
            ), 
            batch_size=1, shuffle=True, num_workers=1
    )

    val_loader = torch_data.DataLoader(
            BinaryWMLoader(
                data_dir='../data/binary_wl/', split='train'
            ), 
        batch_size=1, shuffle=False, num_workers=1
    )
    
    is_best = True
    best_score = 0
    best_epoch = 0
    
    print("Epoch\tTrain Loss\tValidation Loss\tValidation Acc")
    while curr_epoch <= epochs:
        running_loss = train(
            model, train_loader, 
            curr_epoch, device, optimizer, criterion
        )
        valid_loss, accuracy = test(
            model, test_loader, 
            device, criterion
        )
        # record all the models that we have had so far.
        loss_history.append((
            running_loss, valid_loss
        ))
        
        acc_history.append(accuracy)
        # write model to disk.

        state = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        
        'models/model-lenet-epoch'
        torch.save(
            state, 
            MODEL_PATH_PREFIX + '-{}.'.format(epoch) + MODEL_PATH_EXT
        )

        print('{}\t{:.5f}\t{:.5f}\t{:.3f}\t'.format(
            curr_epoch,
            running_loss,
            valid_loss,
            accuracy
        ))
        curr_epoch+=1


Epoch	Train Loss	Validation Loss	Validation Acc


BrokenPipeError: [Errno 32] Broken pipe