In [1]:
import torch
import torch_dct as dct
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.init as init
from torch.optim.lr_scheduler import StepLR

import matplotlib.pyplot as plt
import time
import pkbar
import math

import sys
sys.path.append('../')
from common import *
from transform_based_network import *

In [14]:
def t_product_in_network(A, B):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1])
    dct_C = torch.zeros(A.shape[0], A.shape[1], B.shape[2])
    dct_A = torch_apply(dct.dct, A)
    for k in range(A.shape[0]):
        dct_C[k, ...] = torch.mm(dct_A[k, ...], B[k, ...])
    return dct_C #.to(device)

In [15]:
class tNN(nn.Module):
    def __init__(self):
        super(tNN, self).__init__()
        W, B = [], []
        self.num_layers = 10
        for i in range(self.num_layers):
            W.append(nn.Parameter(torch.Tensor(28, 28, 28)))
            B.append(nn.Parameter(torch.Tensor(28, 28, 1)))
        self.W = nn.ParameterList(W)
        self.B = nn.ParameterList(B)
        self.reset_parameters()

    def forward(self, x):
        for i in range(self.num_layers):
            x = torch.add(t_product(self.W[i], x), self.B[i])
            x = F.relu(x)
        return x

    def reset_parameters(self):
        for i in range(self.num_layers):
            init.kaiming_uniform_(self.W[i], a=math.sqrt(5))
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W[i])
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.B[i], -bound, bound)

In [22]:
class new_tNN(nn.Module):
    def __init__(self):
        super(new_tNN, self).__init__()
        W, B = [], []
        self.num_layers = 4
        for i in range(self.num_layers):
            W.append(nn.Parameter(torch.Tensor(28, 28, 28)))
            B.append(nn.Parameter(torch.Tensor(28, 28, 1)))
        self.W = nn.ParameterList(W)
        self.B = nn.ParameterList(B)
        self.reset_parameters()

    def forward(self, x):
        x = torch_apply(dct.dct, x)
        for i in range(self.num_layers):
            x = torch.add(t_product_in_network(self.W[i], x), self.B[i])
            x = F.relu(x)
        x = torch_apply(dct.idct, x)
        return x

    def reset_parameters(self):
        for i in range(self.num_layers):
            init.kaiming_uniform_(self.W[i], a=math.sqrt(5))
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W[i])
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.B[i], -bound, bound)

In [28]:
lr_rate = 0.001
epochs_num = 20
device = 'cpu' # 'cuda:0' if torch.cuda.is_available() else 'cpu'
batch_size = 100
train_loader, test_loader = load_mnist_multiprocess(batch_size)

module = new_tNN()
module = module.to(device)

Loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(module.parameters(), lr=lr_rate)

test_loss_epoch = []
test_acc_epoch = []
train_loss_epoch = []
train_acc_epoch = []
time_list = []

# begain train
for epoch in range(epochs_num):
    since = time.time()
    running_loss = 0.0
    running_acc = 0.0
    module.train()

    pbar_train = pkbar.Pbar(name='Epoch '+str(epoch+1)+' training:', target=60000/batch_size)
    for i, data in enumerate(train_loader):
        img, label = data
        img = raw_img(img, batch_size, n=28)
        img = img.to(device)
        label = label.to(device)

        # forward
        out = module(img)

        # softmax function
        out = torch.transpose(scalar_tubal_func(out), 0, 1)
        loss = Loss_function(out, label)
        running_loss += loss.item()
        _, pred = torch.max(out, 1)
        running_acc += (pred == label).float().mean()

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        pbar_train.update(i)

    print('[{Epoch}/{Epochs_num}] Loss:{Running_loss} Acc:{Running_acc}'
          .format(Epoch=epoch + 1, Epochs_num=epochs_num, Running_loss=(running_loss / i),
                  Running_acc=running_acc / i))
    train_loss_epoch.append(running_loss / i)
    train_acc_epoch.append((running_acc / i) * 100)

    module.eval()
    eval_loss = 0.0
    eval_acc = 0.0

    pbar_test = pkbar.Pbar(name='Epoch '+str(epoch+1)+' test', target=10000/batch_size)
    for i, data in enumerate(test_loader):
        img, label = data
        img = cifar_img_process(img)
        img = img.to(device)
        label = label.to(device)

        with torch.no_grad():
            out = module(img)
            out = torch.transpose(scalar_tubal_func(out), 0, 1)
            loss = Loss_function(out, label)
        eval_loss += loss.item()
        _, pred = torch.max(out, 1)
        eval_acc += (pred == label).float().mean()

        pbar_test.update(i)

    print('Test Loss: {Eval_loss}, Acc: {Eval_acc}'
          .format(Eval_loss=eval_loss / len(test_loader), 
                  Eval_acc=eval_acc / len(test_loader)))
    test_loss_epoch.append(eval_loss / len(test_loader))
    test_acc_epoch.append((eval_acc / len(test_loader)) * 100)
    time_list.append(time.time() - since)

    if np.isnan(eval_loss):
        print('invalid loss')
        break

==> Loading data..
Epoch 1 training:
 23/600  [>.............................] - 8.4s

KeyboardInterrupt: 