In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.utils.data as data
import torch.backends.cudnn as cudnn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
import os
import torchvision.transforms as transforms

## Dataset

In [2]:
class UCSDADData(data.Dataset):
    def __init__(self, root_dir, seq_len=10, time_stride=1, transform=None):
        super(UCSDADData, self).__init__()
        self.root_dir=root_dir
        vids=[d for d in os.listdir(self.root_dir) if os.path.isdir(os.path.join(self.root_dir, d))]
        self.samples=[]
        #videos number
        for d in vids:
            #max possible time stride used for data augmentation
            for t in range(1, time_stride+1):
                #image index 001 ~ 200
                for i in range(1,200):
                    if i+(seq_len-1)*t>200:
                        break
                    self.samples.append((os.path.join(self.root_dir, d), range(i,i+(seq_len-1)*t+1, t)))
        
        self.pil_transform = transforms.Compose([
            transforms.Resize((227,227)),
            transforms.Grayscale(),
            transforms.ToTensor()])
        self.tensor_transform = transforms.Compose([
            transforms.Normalize(mean=(0.3750352255196134,), std=(0.20129592430286292,))])
        
    def __getitem__(self, index):
        sample = []
        pref = self.samples[index][0]
        for fr in self.samples[index][1]:
            with open(os.path.join(pref, '{0:03d}.tif'.format(fr)), 'rb') as fin:
                frame = Image.open(fin).convert('RGB')
                frame = self.pil_transform(frame)
                frame = self.tensor_transform(frame)
                sample.append(frame)
        sample = torch.stack(sample, dim=0)
        return sample
    def __len__(self):
        return len(self.samples)
        

## Model

In [3]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()
        
        assert hidden_channels%2==0
        
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_features = 4
        
        self.padding = int((kernel_size-1)/2)
        
        self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        
        self.Wci = None
        self.Wcf = None
        self.Wco = None
    
    def forward(self, x, h, c):
        ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c*self.Wci)
        cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c*self.Wcf)
        cc = cf*c + ci*torch.tanh(self.Wxc(x) + self.Whc(h))
        co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc*self.Wco)
        ch = co*torch.tanh(cc)
        return ch, cc
        
    def init_hidden(self, batch_size, hidden, shape, use_cuda=False):
        if self.Wci is None:
            self.Wci = Variable(torch.zeros(1, hidden, shape[0], shape[1]))
            self.Wcf = Variable(torch.zeros(1, hidden, shape[0], shape[1]))
            self.Wco = Variable(torch.zeros(1, hidden, shape[0], shape[1]))
        else:
            assert shape[0]==self.Wci.size()[2], 'Input Height Mismatch!'
            assert shape[1]==self.Wci.size()[3], 'Input Width Mismatch!' 
        if use_cuda:
            self.Wci = self.Wci.cuda()
            self.Wcf = self.Wcf.cuda()
            self.Wco = self.Wco.cuda()
        h=Variable(torch.zeros(batch_size, hidden, shape[0], shape[1]))
        c=Variable(torch.zeros(batch_size, hidden, shape[0], shape[1]))
        if use_cuda:
            h,c = h.cuda(), c.cuda()
        return (h,c)

In [4]:
class ConvLSTM(nn.Module):
    def __init__(self, input_channels,
                hidden_channels,
                kernel_size,
                batch_first=False):
        super(ConvLSTM, self).__init__()
        self.input_channels=[input_channels]+hidden_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        self.num_layers = len(hidden_channels)
        self.batch_first = batch_first
        self._all_layers = nn.ModuleList()
        
        for i in range(self.num_layers):
            name = 'cell{}'.format(i)
            cell = ConvLSTMCell(self.input_channels[i],self.hidden_channels[i],self.kernel_size)
            setattr(self, name, cell)
            self._all_layers.append(cell)

    def forward(self, input):
        if not self.batch_first:
            input = input.permute(1,0,2,3,4) #batch first
        internal_state=[]
        outputs=[]
        n_steps = input.size(1) #seq length
        
        for t in range(n_steps):
            x = input[:,t,:,:,:]
            for i in range(self.num_layers):
                name = 'cell{}'.format(i)
                if t==0:
                    bsize,_,height,width=x.size()
                    (h,c)=getattr(self, name).init_hidden(batch_size=bsize,
                                                         hidden=self.hidden_channels[i],
                                                         shape=(height, width), use_cuda=input.is_cuda)
                    internal_state.append((h,c))
                
                (h,c)=internal_state[i]
                x, new_c = getattr(self, name)(x, h, c)
                internal_state[i] = (x, new_c)
            
            outputs.append(x)
        outputs=torch.stack(outputs, dim=1)
        
        return outputs, (x, new_c)

In [5]:
from collections import OrderedDict
class VideoAELSTM(nn.Module):
    def __init__(self, in_channels=1):
        super(VideoAELSTM, self).__init__()
        self.in_channels=in_channels
        self.conv_encoder=nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=self.in_channels, out_channels=128, kernel_size=11, stride=4, padding=0)),
            ('nonl1', nn.Tanh()),
            ('conv2', nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=2,padding=0)),
            ('nonl2', nn.Tanh())
        ]))
        self.rnn_enc_dec = ConvLSTM(input_channels=64,
                                   hidden_channels=[64,32,64],
                                   kernel_size=3,
                                   batch_first=True)
        self.conv_decoder = nn.Sequential(OrderedDict([
            ('deconv1', nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=0)),
            ('nonl1', nn.Tanh()),
            ('deconv2', nn.ConvTranspose2d(in_channels=128, out_channels=self.in_channels, kernel_size=11, stride=4, padding=0)),
            ('nonl2', nn.Tanh())
        ]))
        
    def forward(self, x):
        b, t, c, h, w = x.size()
        x = x.view(b*t, c, h, w)
        x = self.conv_encoder(x)
        x = x.view(b, t, x.size(1), x.size(2), x.size(3))
        x, _ = self.rnn_enc_dec(x)
        x = x.view(b*t, x.size(2), x.size(3), x.size(4))
        x = self.conv_decoder(x)
        x = x.view(b, t, x.size(1), x.size(2), x.size(3))
        return x 
    def set_cuda(self):
        self.conv_encoder.cuda()
        self.rnn_enc_dec.cuda()
        self.conv_decoder.cuda()

## Train

In [6]:
root_dir='./UCSD_Anomaly_Dataset.v1p2/UCSDped1/Train'
train_ds = UCSDADData(root_dir, time_stride=1)
train_dl = data.DataLoader(train_ds, batch_size=32, shuffle=True)

In [7]:
model = VideoAELSTM()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e04, eps=1e-6, weight_decay=1e-5)

use_cuda=torch.cuda.is_available()
if use_cuda:
    cudnn.benchmark = True
    model.set_cuda()
    criterion.cuda()

In [None]:
model.train()
for epoch in range(5):
    for batch_idx, x in enumerate(train_dl):
        optimizer.zero_grad()
        if use_cuda:
            x = x.cuda()
        y=model(x)
        loss = criterion(y, x)
        loss.backward()
        optimizer.step()
        if batch_idx%20==0:
            print('Epoch {}, iter {}: Loss = {}'.format(epoch, batch_idx, loss.item()))
    torch.save({
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict()},
        './snapshot/checkpoint.epoch{}.pth.tar'.format(epoch))

Epoch 0, iter 0: Loss = 1.0079936981201172
Epoch 0, iter 20: Loss = 2.1018834114074707
Epoch 0, iter 40: Loss = 2.068948745727539
Epoch 0, iter 60: Loss = 1.8993468284606934
Epoch 0, iter 80: Loss = 2.0192394256591797
Epoch 0, iter 100: Loss = 2.0169758796691895


## Inference

In [None]:
model = VideoAELSTM()
model.load_state_dict(torch.load('./snapshot/checkpoint.epoch4.pth.tar')['state_dict'])
#model.set_cuda()
model.eval()

In [None]:
test_dir='./UCSD_Anomaly_Dataset.v1p2/UCSDped1/Inference'
test_ds = UCSDADData(test_dir)
test_dl = data.DataLoader(test_ds, batch_size=32, shuffle=False)

In [None]:
frames = []
errors = []
for batch_idx, x in enumerate(test_dl):
    y = model(x)
    mse = torch.norm(x.cpu().data.view(x.size(0), -1)-y.cpu().data.view(y.size(0),-1), dim=1)
    errors.append(mse)
errors=torch.cat(errors).numpy()

In [None]:
errors= errors.reshape(-1,191)
s=np.zeros(2,191)
s[0, :]=1-(errors[0, :]-np.min(errors[0,:]))/(np.max(errors[0,:] - np.min(errors[0,:])))
s[1, :]=1-(errors[1, :]-np.min(errors[0,:]))/(np.max(errors[1,:] - np.min(errors[1,:])))


In [None]:
#Test001
plt.plot(s[0,:])
plt.show()

In [None]:
#Test032
plt.plot(s[1,:])
plt.show()