In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

cpu


Creation des masques 2D et 1D pour les convolutions


In [6]:
class MaskedConv2D(nn.Conv2d):
	def __init__(self,mask_type, *args, **kwargs):
		super(MaskedConv2D, self).__init__()
		self.mask_type = mask_type
		assert mask_type in ['A', 'B'], "Unknown Mask Type"
		self.register_buffer('mask', self.weight.data.clone())
		_, depth, height, width = self.weight.size()
		self.mask.fill_(1)
		if mask_type =='A':
			self.mask[:,:,height//2,width//2:] = 0
			self.mask[:,:,height//2+1:,:] = 0
		else:
			self.mask[:,:,height//2,width//2+1:] = 0
			self.mask[:,:,height//2+1:,:] = 0

	def forward(self, x):
		self.weight.data*=self.mask
		return super(MaskedConv2D, self).forward(x)

In [None]:
class MaskedConv1d(nn.Conv1d):
    def __init__(self, *args, mask='B', **kargs):
        super(MaskedConv1d, self).__init__(*args, **kargs)
        assert mask in {'A', 'B'}
        self.mask_type = mask
        self.register_buffer('mask', self.weight.data.clone())
        self.mask.fill_(1)
    
        _, _, W = self.mask.size()
    
        self.mask[:, :, W//2 + (self.mask_type == 'B'):] = 0
    
    def forward(self, x):
        self.weight.data *= self.mask
        return super(MaskedConv1d, self).forward(x)

Definition de la classe PixelCNN 

In [9]:
class PixelCNN(nn.Module):
	def __init__(self, nb_layer_block=8, channels=64, device=None):
		super(PixelCNN, self).__init__()
		self.nb_block = nb_layer_block
		self.channels = channels
		self.layers = {}
		self.device = device
		#first convolution (cf TABLE 1)
		self.ConvA = nn.Sequential(
			MaskedConv2D('A',3, 2*channels, kernel_size=7, bias=False),
			nn.ReLU(True)
		)
		#Residual blocks for PixelCNN (figure 5)
		self.multiple_blocks = nn.Sequential(
			nn.Conv2d(2*channels,channels, kernel_size = 1, bias=False),
			nn.ReLU(True),
			MaskedConv2D('B',channels,channels, kernel_size = 3, bias=False),
			nn.ReLU(True),
			nn.Conv2d(channels,2*channels, kernel_size = 1, bias=False),
			nn.ReLU(True)
		)
		#finalisation
		self.end = nn.Sequential(
			nn.ReLU(True),
			MaskedConv2D('B',2*channels, 2 * channels, kernel_size = 1, bias=False),
			nn.ReLU(True),
			MaskedConv2D('B',2* channels,2 * channels, kernel_size = 1, bias=False),
			nn.Conv2d( 2 * channels, 256, 1)
		)

	def residual_block(self, x):
		return (x + self.multiple_blocks(x))
	
	def forward(self,x):
		x = self.ConvA(x)
		for i in range(self.nb_block):
			x = residual_block(x)
		x = self.end(x)
		return x

CODAGE DE PIXELRNN 

In [None]:
class RowLSTMCell(nn.Module):
    def __init__(self, h, image_size, channel_in, *args, **kargs):
        super(RowLSTMCell, self).__init__(*args, **kargs)

        self._h = h
        self._image_size = image_size
        self._channel_in = channel_in
        self._num_units = self._h * self._image_size
        self._output_size = self._num_units
        self._state_size = self._num_units * 2

        self.conv_i_s = MaskedConv1d(self._h, 4 * self._h, 3, mask='B', padding='same')
        self.conv_s_s = nn.Conv1d(channel_in, 4 * self._h, 3, padding='same')
   
    def forward(self, inputs, states):
        c_previous, h_previous = states



        h_previous = h_previous.view(-1, self._h,  self._image_size)
        inputs = inputs.view(-1, self._channel_in, self._image_size)

        s_s = self.conv_s_s(h_previous) #(batch, 4*h, width)
        i_s = self.conv_i_s(inputs) #(batch, 4*h, width)



        s_s = s_s.view(-1, 4 * self._num_units) #(batch, 4*h*width)
        i_s = i_s.view(-1, 4 * self._num_units) #(batch, 4*h*width)

        #print(s_s.size(), i_s.size())

        lstm = s_s + i_s

        lstm = torch.sigmoid(lstm)

        i, g, f, o = torch.split(lstm, (4 * self._num_units)//4, dim=1)

        c = f * c_previous + i * g
        h = o * torch.tanh(c)

        new_state = (c, h)
        return h, new_state

SAVING FITTED MODEL

In [None]:
PATH = './model_1_Alex.pth'
torch.save(net.state_dict(), PATH)