In [1]:
import torch.nn.functional as F
from sklearn.decomposition import PCA, IncrementalPCA
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import torch
from torchvision import transforms, models


class PCANet(torch.nn.Module):
    def __init__(self, num_filters: list, filters_sizes: list, batch_size=256):
        self.params = {
            'num_filters': num_filters,
            'filters_sizes': filters_sizes,
        }
        self.W_1 = None
        self.W_2 = None
        self.batch_size = batch_size

    def forward(self, x):
        x = F.conv2d(x, self.W_1)
        N, C, H, W = x.shape
        x = x.view(-1, 1, H, W)
        
        x = F.conv2d(x, self.W_2)
        N, C, H, W = x.shape
        x = x.view(N*C, H, W)
        
        x_flat = x.view(N*C, H*W)
        print("N = {}, C = {}, H = {}, W = {}".format(N, C, H, W), x_flat.shape)
        x_flat = torch.nn.Linear(H*W, 2, bias=True)(x_flat)
        return x_flat
            
    @staticmethod        
    def _extract_image_patches(imgs: torch.Tensor, filter_size, stride=1, remove_mean=True):
        # imgs.shape = (N, C, H, W) -> (N, 1, H, W) 
        # так должно быть, но сюда могут прийти не grayscale изображения первого шага, а со второго
        # на котором применено L1 фильтров -> L1 каналов
        N, n_channels, H, W = imgs.shape
        
        if n_channels > 1:
            # изображение вида (N, C, H, W) - N C-канальных изображений
            # приводим к виду (N*C, 1, H, W) - N*C одно-канальных изображений
            imgs = imgs.view(-1, 1, H, W)
        print('images shape', imgs.shape)
            
        k = filter_size
        patches = torch.nn.functional.unfold(imgs, k, padding=k//2) # (N, k^2, H*W)
        print('patches_shape, ', patches.shape)
        print('should be patches shape, ', (imgs.shape[0], k**2, H*W))
        
        if remove_mean:
            patches -= patches.mean(dim=1, keepdim=True) # последнее измерение - количество патчей
        
        print('filter_size', k)
        X = patches.view(k**2, -1) # (k^2, N*H*W)

        return X
    
    def _convolve(self, imgs: torch.Tensor, filter_bank: torch.Tensor) -> torch.Tensor:
        weight = filter_bank
        output = F.conv2d(imgs, weight) #, padding=padding)
        return output
    
    def _first_stage(self, imgs: torch.Tensor, train: bool) -> torch.Tensor:
        # (N, C, H, W) image
        # (train_size, 1, H, W) - grayscale
        assert imgs.dim() == 4 and imgs.nelement() > 0

        print('PCANet first stage...')

        if train:
            # достаем все патчи из всех N изображений
            filter_size1 = self.params['filters_sizes'][0]
            X = self._extract_image_patches(
                imgs, filter_size1)
            
            n_filters = self.params['num_filters'][0]
            
            eigenvectors = self.get_pca_eigenvectors(X, n_components=n_filters, batch_size=self.batch_size)
            self.W_1 = torch.FloatTensor(eigenvectors).view(n_filters, 1, filter_size1, filter_size1)
         
        I = self._convolve(imgs, self.W_1)  # (N, 1, H, W) * (L1, k1, k1) -> (N, L1, H', W')
        return I
    
    @staticmethod
    def conv_output_size(w, filter_size, padding=0, stride=1):
        return int((w - filter_size + 2 * padding) / stride + 1)
    
    @staticmethod
    def get_pca_eigenvectors(X, n_components, batch_size=100):
        ipca = IncrementalPCA(n_components=n_components, batch_size=batch_size)
        print('pca fitting ...')
        ipca.fit(X @ X.t())
        eigenvectors = ipca.components_
        print('eigenvectors shape:', eigenvectors.shape)
        return eigenvectors
        
    def _second_stage(self, I: torch.Tensor, train):
        print('PCANet second stage...')
        # I: (N, L1, H, W)
        if train:
            N, L1, H, W = I.shape
            I = I.view(-1, 1, H, W)
            filter_size2 = self.params['filters_sizes'][1]
            n_filters = self.params['num_filters'][1]
            
            H_new = self.conv_output_size(I.shape[2], filter_size2)
            W_new = self.conv_output_size(I.shape[3], filter_size2)
            
            X = self._extract_image_patches(I, filter_size2)
            print('X_SHAPE ', X.shape)
            eigenvectors = self.get_pca_eigenvectors(X, n_components=n_filters, batch_size=self.batch_size)
            self.W_2 = torch.FloatTensor(eigenvectors).view(n_filters, 1, filter_size2, filter_size2)
        return self._convolve(I, self.W_2)
    
    def run(self, images):
        # Создаем фильтры
        # images: (N, 1, H, W)
        I = self._first_stage(images, train=True)
        II = self._second_stage(I, train=True)
        
        

In [2]:
net = PCANet([8, 8],[5, 3])
net.params

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net = net.to(device)

{'num_filters': [8, 8], 'filters_sizes': [5, 3]}

In [None]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

num_epochs = 100
batch_size = 50
lr = 1e-3

loss = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
optimizer = torch.optim.SGD(net.parameters(), lr=lr)

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = torchvision.datasets.ImageFolder(train_dir, train_transforms)
val_dataset = torchvision.datasets.ImageFolder(val_dir, val_transforms)

all_train = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_size, shuffle=False, num_workers=batch_size)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size)

val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, num_workers=batch_size)

In [None]:
val_loss = np.zeros(num_epochs)
train_loss = np.zeros(num_epochs)

val_acc = np.zeros(num_epochs)
train_acc = np.zeros(num_epochs)


def train_model(model, loss, optimizer, scheduler, num_epochs):
    best_val_loss = np.inf
    best_train_loss = np.inf
    best_val_model = {}
    best_train_model = {}
    
    model.run(train_all_dataloader)

    for epoch in range(num_epochs):

        print('Epoch {}/{}:'.format(epoch, num_epochs - 1), flush=True)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                dataloader = train_dataloader
                scheduler.step()
                model.train()  # training mode
                history_acc = train_acc
                history_loss = train_loss
            else:
                dataloader = val_dataloader
                model.eval()   # evaluate mode (dropout + bn)
                history_acc = val_acc
                history_loss = val_loss

            running_loss = 0.
            running_acc = 0.

            # Iterate over data.
            for inputs, labels in dataloader:
                #print(inputs, labels)
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # forward and backward
                with torch.set_grad_enabled(phase=='train'):
                    preds = model(inputs)
                    loss_value = loss(preds, labels)
                    preds_class = preds.argmax(dim=1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss_value.backward()
                        optimizer.step()

                # statistics
                running_loss += loss_value.cpu().item()
                running_acc += (preds_class.cpu() == labels.cpu().data).float().mean()
            
            epoch_loss = running_loss / len(dataloader)
            epoch_acc = running_acc / len(dataloader)
            history_acc[epoch] = epoch_acc
            history_loss[epoch] = epoch_loss

            # запоминаем модель по лоссу
            if phase == 'val' and best_val_loss > epoch_loss:
                best_val_model = model.state_dict()
                best_val_loss = epoch_loss
            if phase == 'train' and best_train_loss > epoch_loss:
                best_train_model = model.state_dict()
                best_train_loss = epoch_loss

            print('{} Loss: {:.4f} Acc: {:.4f}\n best: {}'.format(phase, epoch_loss, epoch_acc, best_val_loss), flush=True)
    result_dict = {'best_val_model': best_val_model,
                   'best_train_model': best_train_model}
    return result_dict

In [None]:
result_dict = train_model(model, loss, optimizer, scheduler, num_epochs=num_epochs)

In [None]:
# Loss
plt.plot(num_epochs, train_loss, val_loss)
plt.legend(['train', 'val'])
plt.xlabel('epoch')
plt.ylabel('cross entropy loss')
plt.title('loss')

In [None]:
# accuracy
plt.plot(num_epochs, train_acc, val_acc)
plt.legend(['train', 'val'])
plt.xlabel('epoch')
plt.ylabel('cross entropy loss')
plt.title('accuracy')

In [3]:
imgs = torch.randn(10, 1, 10, 10)

In [4]:
I = net._first_stage(imgs=imgs, train=True)
I.shape

PCANet first stage...
images shape torch.Size([10, 1, 10, 10])
patches_shape,  torch.Size([10, 25, 100])
should be patches shape,  (10, 25, 100)
filter_size 5
pca fitting ...
eigenvectors shape: (8, 25)


torch.Size([10, 8, 6, 6])

In [5]:
net._second_stage(I, train=True).shape

PCANet second stage...
images shape torch.Size([80, 1, 6, 6])
patches_shape,  torch.Size([80, 9, 36])
should be patches shape,  (80, 9, 36)
filter_size 3
X_SHAPE  torch.Size([9, 2880])
pca fitting ...
eigenvectors shape: (8, 9)


torch.Size([80, 8, 4, 4])

In [6]:
net.forward(imgs)

N = 80, C = 8, H = 4, W = 4 torch.Size([640, 16])


tensor([[-0.7040,  0.3876],
        [ 0.5826,  1.0300],
        [ 0.2970,  0.7326],
        ...,
        [ 0.5815, -0.1781],
        [-0.7312,  1.2732],
        [-1.1729,  0.5596]], grad_fn=<AddmmBackward>)