In [1]:
import torch
import torch_dct as dct
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import StepLR

import matplotlib.pyplot as plt
plt.style.use(['science','no-latex', 'notebook'])

from multiprocessing import Pool, Queue, Process, set_start_method
import multiprocessing as mp_

import time
import pkbar
import sys
sys.path.append('../')
from common import *
from transform_based_network import *

In [45]:
class T_Layer(nn.Module):
    def __init__(self, dct_w, dct_b):
        super(T_Layer, self).__init__()
        w = torch.randn(dct_w.shape)
        b = torch.randn(dct_b.shape)
        self.weights = nn.Parameter(dct_w)
        self.bias = nn.Parameter(dct_b)
        
    def forward(self, dct_x):
        x = torch.mm(self.weights, dct_x)# + self.bias
        return x

    
class Frontal_Slice(nn.Module):
    def __init__(self, dct_w, dct_b):
        super(Frontal_Slice, self).__init__()
        self.device = dct_w.device
        self.dct_linear = nn.Sequential(
            T_Layer(dct_w, dct_b),
        )
        #nn.ReLU(inplace=True),
        #self.linear1 = nn.Linear(28, 28)
        #nn.ReLU(inplace=True),
        #self.linear2 = nn.Linear(28, 28)
        #nn.ReLU(inplace=True),
        #self.classifier = nn.Linear(28, 10)
        
    def forward(self, x):
        #x = torch.transpose(x, 0, 1).to(self.device)
        x = self.dct_linear(x)
        #x = self.linear1(x)
        #x = self.linear2(x)
        #x = self.classifier(x)
        #x = torch.transpose(x, 0, 1)
        return x
    
    
class Ensemble(nn.Module):
    def __init__(self, shape, device='cpu'):
        super(Ensemble, self).__init__()
        self.device = device    
        self.models = nn.ModuleList([])
        dct_w, dct_b = make_weights(shape, device, scale=0.001)
        self.weights = nn.Parameter(dct_w)
        self.bias = nn.Parameter(dct_b)
        for i in range(shape[0]):
            model = Frontal_Slice(self.weights[i, ...], self.bias[i, ...])
            self.models.append(model.to(device))
        
    def forward(self, x):
        self.res = torch.empty(x.shape[0], 10, x.shape[2])
        dct_x = torch_apply(dct.dct, x).to(self.device)
        self.tmp = []
        for i in range(len(self.models)):
            self.tmp.append(self.models[i](dct_x[i, ...]))
            self.res[i, ...] = self.tmp[i]
        self.result = torch_apply(dct.idct, self.res)
        self.softmax = scalar_tubal_func(self.result)
        return torch.transpose(self.softmax, 0, 1)

In [46]:
def train_ensemble(x, y, i=50, device='cuda:0'):
    x = torch_shift(x).to(device)
    y = y.to(device)
    ensemble = Ensemble(x.shape, device).to(device)
    optimizer = optim.SGD(ensemble.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    pbar = pkbar.Pbar(name='progress', target=i)
    for j in range(i):
        outputs = ensemble(x)
        print(outputs.shape, y.shape)
        optimizer.zero_grad()
        loss = criterion(outputs.to(device), y)
        loss.backward()
        optimizer.step()
        pbar.update(j)
        
    print(loss.item())
    return ensemble

## 16, 10, 10, 100 iterations
# cpu, for loop: 4.1s
# gpu, for loop: 5.5s

In [47]:
x0 = []
y0 = []
for i in range(100):
    x0.append(torch.randn(16, 29, 28))
    y0.append(torch.randint(10, (16,)))

for i in range(1):
    train_ensemble(x0[i], y0[i], i=2, device='cpu')

progress
torch.Size([16, 10]) torch.Size([16])
2.179849624633789


In [48]:
batch_size = 10
trainloader, testloader = load_mnist_multiprocess(batch_size)

==> Loading data..


In [49]:
device = 'cpu'
for epoch in range(10):
    pbar = pkbar.Pbar(name='Epoch'+str(epoch), target=60000/batch_size)
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        '''
        dct_x = torch_apply(dct.dct, x.squeeze())
        y_cat = to_categorical(y, 10) 

        dct_y_cat = torch.randn(y_cat.shape[0], dct_x.shape[1], 10)
        for i in range(10):
            dct_y_cat[:, i, :] = y_cat
        dct_y_cat = torch_apply(dct.dct, dct_y_cat)
        dct_x.to(device)
        dct_y_cat.to(device)
        '''
        correct = 0
        train_loss = 0
        total = 0
        inputs = torch_shift(inputs).to(device)
        ensemble = Ensemble(inputs.shape, device).to(device)
        optimizer = optim.SGD(ensemble.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
        criterion = nn.CrossEntropyLoss()

        outputs = ensemble(inputs) 
        optimizer.zero_grad()
        loss = criterion(outputs.to(device), targets.to(device))
        loss.backward()
        optimizer.step()
            
        _, predicted = torch.max(outputs, 1)
        correct += predicted.eq(targets).sum().item()
        train_loss += loss.item()
        total += batch_size
        print(loss)
        
        pbar.update(batch_idx)
    print(correct/total, train_loss/total)
    

'''
    models = []
    for i in range(16):
        dct_w, dct_b = make_weights(dct_x.shape, device=device)
        model = Frontal_Slice(dct_w[i, ...], dct_b[i, ...])
        models.append(model.to(device))

    for i in range(len(models)):
        train_slice(models[i], dct_x[i, ...], dct_y_cat[i, ...])
    print()
    pbar.update(batch_idx)
    
    tmp = torch_mp.get_context('spawn')
    for model in models:
        model.share_memory()
    processes = []

    for i in range(len(models)):
        p = tmp.Process(target=train_slice, 
                        args=(models[i], dct_x[i, ...], dct_y_cat[i, ...]))
        p.start()
        processes.append(p)
    for p in processes: 
        p.join()
    '''

Epoch0
tensor(2.2876, grad_fn=<NllLossBackward>)
   1/6000  [..............................] - 0.0stensor(2.2945, grad_fn=<NllLossBackward>)
   2/6000  [..............................] - 0.1stensor(2.3091, grad_fn=<NllLossBackward>)
   3/6000  [..............................] - 0.1stensor(2.3509, grad_fn=<NllLossBackward>)
   4/6000  [..............................] - 0.2stensor(2.3103, grad_fn=<NllLossBackward>)
   5/6000  [..............................] - 0.3stensor(2.3188, grad_fn=<NllLossBackward>)
   6/6000  [..............................] - 0.3stensor(2.2883, grad_fn=<NllLossBackward>)
   7/6000  [..............................] - 0.4stensor(2.3286, grad_fn=<NllLossBackward>)
   8/6000  [..............................] - 0.4stensor(2.2519, grad_fn=<NllLossBackward>)
   9/6000  [..............................] - 0.5stensor(2.3762, grad_fn=<NllLossBackward>)
  10/6000  [..............................] - 0.5stensor(2.2961, grad_fn=<NllLossBackward>)
  11/6000  [...................

  89/6000  [..............................] - 4.4stensor(2.2559, grad_fn=<NllLossBackward>)
  90/6000  [..............................] - 4.4stensor(2.3373, grad_fn=<NllLossBackward>)
  91/6000  [..............................] - 4.5stensor(2.2490, grad_fn=<NllLossBackward>)
  92/6000  [..............................] - 4.5stensor(2.2938, grad_fn=<NllLossBackward>)
  93/6000  [..............................] - 4.6stensor(2.2821, grad_fn=<NllLossBackward>)
  94/6000  [..............................] - 4.6stensor(2.2545, grad_fn=<NllLossBackward>)
  95/6000  [..............................] - 4.7stensor(2.3990, grad_fn=<NllLossBackward>)
  96/6000  [..............................] - 4.7stensor(2.2276, grad_fn=<NllLossBackward>)
  97/6000  [..............................] - 4.7stensor(2.4132, grad_fn=<NllLossBackward>)
  98/6000  [..............................] - 4.8s

KeyboardInterrupt: 

In [25]:
def train_slice(model, x_i, y_i):
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.9, weight_decay=5e-4)
    outputs = model(x_i)
    # print(outputs.shape, y_i.shape)
    optimizer.zero_grad()
    loss = criterion(outputs, y_i)
    loss.backward()
    optimizer.step()

In [56]:
for batch_idx, (x, y) in enumerate(trainloader):  
    device = 'cpu'
    x = torch_shift(x)
    dct_x = torch_apply(dct.dct, x.squeeze())
    y_cat = to_categorical(y, 10) 

    dct_y_cat = torch.randn(28, dct_x.shape[2], 10) #y_cat.shape[0]
    for i in range(28):
        dct_y_cat[i, :, :] = y_cat
    dct_y_cat = torch_apply(dct.dct, dct_y_cat)
    dct_x.to(device)
    dct_y_cat.to(device)
        
    models = []
    dct_w, dct_b = make_weights(dct_x.shape, device=device)
    for i in range(28):
        model = Frontal_Slice(dct_w[i, ...], dct_b[i, ...])
        models.append(model.to(device))

    for i in range(len(models)):
        train_slice(models[i], dct_x[i, ...], dct_y_cat[i, ...])

In [61]:
y = torch.eye(10)

In [74]:
dct_yy = torch.empty(28, 10, 10)
for i in range(28):
    dct_yy[i, ...] = y * 1
dct_yy = torch_apply(dct.dct, dct_yy)

In [75]:
result = torch_apply(dct.idct, dct_yy)
softmax = scalar_tubal_func(result)
torch.transpose(softmax, 0, 1)

tensor([[12.6239,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,
          1.7085,  1.7085],
        [ 1.7085, 12.6239,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,
          1.7085,  1.7085],
        [ 1.7085,  1.7085, 12.6239,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,
          1.7085,  1.7085],
        [ 1.7085,  1.7085,  1.7085, 12.6239,  1.7085,  1.7085,  1.7085,  1.7085,
          1.7085,  1.7085],
        [ 1.7085,  1.7085,  1.7085,  1.7085, 12.6239,  1.7085,  1.7085,  1.7085,
          1.7085,  1.7085],
        [ 1.7085,  1.7085,  1.7085,  1.7085,  1.7085, 12.6239,  1.7085,  1.7085,
          1.7085,  1.7085],
        [ 1.7085,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085, 12.6239,  1.7085,
          1.7085,  1.7085],
        [ 1.7085,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085, 12.6239,
          1.7085,  1.7085],
        [ 1.7085,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,  1.7085,
         12.6239,  1.7085],
        [ 1.7085,  