In [1]:
import torch
import torch_dct as dct
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import StepLR

from pytorch_wavelets import DWTForward, DWTInverse
import matplotlib.pyplot as plt

from multiprocessing import Pool, Queue, Process, set_start_method
import multiprocessing as mp_

import time
import pkbar
import sys
sys.path.append('../')
from common import *
from transform_based_network import *
from joblib import Parallel, delayed

Using numpy backend.


In [2]:
class T_Layer(nn.Module):
    def __init__(self, dct_w, dct_b):
        super(T_Layer, self).__init__()
        self.weights = nn.Parameter(dct_w)
        self.bias = nn.Parameter(dct_b)
        
    def forward(self, dct_x):
        x = torch.mm(self.weights, dct_x) + self.bias
        return x

    
class Frontal_Slice(nn.Module):
    def __init__(self, dct_w, dct_b):
        super(Frontal_Slice, self).__init__()
        self.device = dct_w.device
        self.dct_linear = nn.Sequential(
            T_Layer(dct_w, dct_b),
        )
        #nn.ReLU(inplace=True),
        #self.linear1 = nn.Linear(28, 28)
        #nn.ReLU(inplace=True),
        #self.classifier = nn.Linear(28, 10)
        
    def forward(self, x):
        #x = torch.transpose(x, 0, 1).to(self.device)
        x = self.dct_linear(x)
        #x = self.linear1(x)
        #x = self.classifier(x)
        #x = torch.transpose(x, 0, 1)
        return x

In [19]:
def train_slice(i, model, x_i, y, outputs, optimizer):
    s = time.time()
    
    criterion = nn.CrossEntropyLoss()
    o = torch.stack(outputs)
    o[i, ...] = outputs_grad[i]
    o = torch_apply(dct.idct, o)
    o = scalar_tubal_func(o)
    o = torch.transpose(o, 0, 1)
    
    optimizer.zero_grad()
    loss = criterion(o, y) 
    
    loss.backward()
    e = time.time()
    
    print(e - s)
    optimizer.step()

In [20]:
def rough_loader(loader, mode='bior1.3'):
    for (x, y) in loader:
        func = DWTForward(J=1, wave=mode, mode='zero')
        l = func(x)[0]
        yield (x, l, y)
        
def generate_model_list(shape):
    models, ops = [], []
    dct_w, dct_b = make_weights(shape, device=device)
    for i in range(shape[0]):
        w_i = dct_w[i, ...].clone()
        b_i = dct_b[i, ...].clone()

        w_i.requires_grad = True
        b_i.requires_grad = True

        model = Frontal_Slice(w_i, b_i)
        model.train()
        models.append(model.to(device))

        op = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
        ops.append(op)
        
    return models, ops

In [None]:
device = 'cpu'
batch_size = 100
trainloader, testloader = load_mnist_multiprocess(batch_size)
rough_train, rough_test = rough_loader(trainloader), rough_loader(testloader)

shape = (28, 28, batch_size)
rough_shape = (16, 16, batch_size)

models, ops = generate_model_list(shape)
rough_models, rough_ops = generate_model_list(rough_shape)

epochs = 10
acc_list = []
loss_list = []

global outputs_grad
for e in range(epochs):
    correct = 0
    total = 0
    losses = 0
    pbar = pkbar.Pbar(name='Epoch '+str(e), target=60000/batch_size)
    for batch_idx, (x, l, y) in enumerate(rough_train):  
        
        dct_x = torch_shift(x)
        dct_x = torch_apply(dct.dct, dct_x)
        dct_x = dct_x.to(device)
        
        dct_l = torch_shift(l)
        dct_l = torch_apply(dct.dct, dct_l)
        dct_l = dct_l.to(device)
        
        y = y.to(device)            

        outputs_grad = []
        outputs = []
        for i in range(len(models)):
            out = models[i](dct_x[i, ...])        
            outputs_grad.append(out)
            outputs.append(out.detach())

        Parallel(n_jobs=8, prefer="threads", verbose=0)(
            delayed(train_slice)(i, models[i], dct_x[i, ...], y, outputs, ops[i]) \
            for i in range(len(models)))

        outputs_grad = []
        rough_outputs = []
        for i in range(len(rough_models)):
            out = rough_models[i](dct_l[i, ...])
            outputs_grad.append(out)
            rough_outputs.append(out.detach())
            
        Parallel(n_jobs=8, prefer="threads", verbose=0)(
            delayed(train_slice)(i, rough_models[i], dct_l[i, ...], y, rough_outputs, rough_ops[i]) \
            for i in range(len(rough_models)))

        res = torch.empty(shape[0], 10, shape[2])
        for i in range(len(rough_models)):
            res[i, ...] = rough_models[i](dct_l[i, ...])

        res = torch_apply(dct.idct, res).to(device)
        res = scalar_tubal_func(res)
        res = torch.transpose(res, 0, 1)
        criterion = nn.CrossEntropyLoss()
        total_loss = criterion(res, y)

        _, predicted = torch.max(res, 1)
        total += y.size(0)
        correct += predicted.eq(y).sum().item()
        losses += total_loss
        
        print(total_loss)
        pbar.update(batch_idx)

    loss_list.append(losses / total)
    acc_list.append(correct / total)

==> Loading data..
Epoch 0
1.684023141860962
1.71051168441772461.71104693412780761.71038722991943361.71056795120239261.71124196052551271.7109260559082031
1.7118091583251953





1.7997148036956787
1.8081641197204591.6390078067779541.67082619667053221.67075395584106451.7597539424896241.67117381095886231.6714129447937012






1.6977579593658447
1.8544828891754151.68806004524230961.80538177490234381.68821716308593751.88077282905578611.75513505935668951.8058419227600098






0.9735040664672852
0.99721026420593261.09637975692749020.9976699352264404


1.0632870197296143
1.0639243125915527
1.0809361934661865
1.0815539360046387
1.097269058227539
1.0976550579071045
1.0974080562591553
1.0983638763427734
1.7078590393066406
1.77417302131652831.7741191387176514
1.7447121143341064
1.82173705101013181.759605884552002

1.80572700500488281.745422124862671


tensor(2.2607, grad_fn=<NllLossBackward>)
  1/600  [..............................] - 9.9s1.6159241199493408
1.64740085601806641.6472790241241455

  9/600  [..............................] - 90.6s1.731632947921753
1.76683282852172851.76737999916076661.7673599720001221.76776981353759771.76801323890686041.76801896095275881.7674171924591064






1.7171237468719482
1.74891591072082521.78206992149353031.85200572013854981.7828371524810791.74859118461608891.748736858367921.9840803146362305






1.920647144317627
1.79389309883117681.86342501640319821.86327815055847172.018425703048706
1.959639072418213


1.864192008972168

1.7953250408172607
1.1089317798614502
1.1362912654876711.22712707519531251.1360559463500977


1.109666109085083
1.1386642456054688
1.13861083984375
1.1393980979919434
1.13898277282714841.139901876449585

1.1401422023773193
1.1968209743499756
1.1373910903930664
1.16171002388000491.132425069808961.16175103187561041.28333592414855961.16165804862976071.22571706771850591.1848950386047363






tensor(1.8700, grad_fn=<NllLossBackward>)
 10/600  [..............................] - 100.5s1.72482204437255861.725172996520996

1.

tensor(1.7693, grad_fn=<NllLossBackward>)
 18/600  [..............................] - 182.2s1.7058379650115967
1.7476220130920411.74797582626342771.74816513061523441.74775385856628421.74759984016418461.74810695648193361.7483818531036377






1.6921710968017578
1.86313605308532711.69309806823730471.7245600223541261.86344385147094731.92465305328369141.72452306747436521.795738935470581






1.7146649360656738
1.87796187400817871.77746391296386721.96192002296447751.87798404693603521.7479059696197511.74663305282592771.877912998199463






1.0971591472625732
1.12269282341003421.17859292030334470.9956259727478027


1.11343789100646971.114104986190796

1.1357629299163818
1.1364860534667969
1.1371080875396729
1.1371939182281494
1.137887954711914
1.1911919116973877
1.1551640033721924
1.1797392368316651.099898099899292
1.1216809749603271
1.14206695556640621.1797637939453125
1.2361340522766113
1.1233069896697998


tensor(1.6373, grad_fn=<NllLossBackward>)
 19/600  [.............................

tensor(1.5719, grad_fn=<NllLossBackward>)
 27/600  [>.............................] - 269.1s1.83834910392761231.8387069702148438

1.87155413627624511.87166690826416021.8718919754028321.8716409206390381.8722112178802491.8719627857208252





2.0049798488616943
1.95434308052062992.01727604866027832.072201967239381.8884627819061282.1002669334411621.8876810073852541.8888661861419678






2.0010550022125244
2.03579592704772951.90046501159667971.96862387657165532.0360927581787111.9008271694183351.93206405639648442.1263790130615234






1.1579310894012451
0.99493098258972171.15252399444580080.9948809146881104


1.084136962890625
1.1083967685699463
1.1087596416473389
1.1091041564941406
1.1094231605529785
1.1094200611114502
1.125859260559082
1.126194953918457
1.0993869304656982
1.21783781051635741.2351830005645752

1.186765193939209
1.1412527561187744
1.14502501487731931.19069123268127441.1918830871582031


tensor(1.6115, grad_fn=<NllLossBackward>)
 28/600  [>.............................] - 

tensor(1.3088, grad_fn=<NllLossBackward>)
 36/600  [>.............................] - 356.6s1.6662580966949463
1.7033011913299561.70359086990356451.70330715179443361.70386505126953121.7037649154663086


1.7040867805480957

1.7048208713531494

1.7079308032989502
1.73955988883972171.92126297950744631.86515617370605471.80378794670104981.80408596992492681.95167708396911621.7686340808868408






1.7660601139068604
1.89422798156738281.74570584297180181.808499813079834
1.71154522895812991.68363070487976071.7464277744293213



1.6841630935668945

1.0231781005859375
1.07701992988586431.1343007087707521.0488829612731934


1.0873289108276367
1.11229896545410161.11257600784301761.11232995986938481.11185479164123541.1126980781555176
1.1131508350372314

1.1137280464172363



1.0814688205718994
1.10275411605834961.11737108230590821.13356065750122071.2019910812377931.22104191780090331.15058207511901861.1025898456573486






tensor(1.3445, grad_fn=<NllLossBackward>)
 37/600  [>.......................

tensor(1.1682, grad_fn=<NllLossBackward>)
 45/600  [=>............................] - 444.4s1.7935452461242676
1.82198095321655271.82227396965026861.82284712791442871.82301330566406251.8229839801788331.8228571414947511.8230361938476562






1.984267234802246
1.8503959178924561.85079884529113772.04445385932922361.878823995590211.87867903709411621.85083818435668951.8511061668395996






1.9108569622039795
2.0843040943145751.88768267631530762.05601000785827641.88813591003417971.94776892662048341.9177851676940918





2.0893988609313965
1.1076600551605225
1.21880912780761721.13288402557373051.0481669902801514


1.1531238555908203
1.17434406280517581.1745910644531251.17428708076477051.17523503303527831.17464375495910641.17505979537963871.1754190921783447






1.1521971225738525
1.19223618507385251.25428295135498051.27184176445007321.23867392539978031.19227290153503421.1631729602813721.177232265472412






tensor(1.2994, grad_fn=<NllLossBackward>)
 46/600  [=>............................

tensor(1.2674, grad_fn=<NllLossBackward>)
 54/600  [=>............................] - 533.2s1.71443414688110351.7149670124053955

1.7485330104827881.74896407127380371.74977302551269531.74965524673461911.750108003616333


1.7500190734863281


1.7017028331756592
1.90588903427124021.87636923789978031.79104781150817871.82008409500122071.70600700378417971.76142406463623051.7906701564788818






1.9522409439086914
1.95691704750061041.82128405570983891.79465603828430181.821249246597291.87673187255859381.82121706008911131.876626968383789






1.1044859886169434
1.051396131515503
1.0246262550354004
1.1589000225067139
1.120452880859375
1.1418480873107911.14263582229614261.14230179786682131.1430218219757081.14319014549255371.14305305480957031.14280104637146






1.2237210273742676
1.1697709560394287
1.13703489303588871.1224529743194581.12288093566894531.18948101997375491.22500896453857421.1371262073516846





tensor(1.2937, grad_fn=<NllLossBackward>)
 55/600  [=>............................] 

tensor(0.9941, grad_fn=<NllLossBackward>)
 63/600  [==>...........................] - 619.5s1.7176871299743652
1.7559039592742921.75684213638305661.7568731307983398
1.7565140724182131.75705099105834961.7576200962066651.757396936416626





1.709484577178955
1.88352680206298831.71198821067810061.74070286750793461.797276258468628
1.7406389713287354

1.6867659091949463

1.8549299240112305

1.82340407371521
1.7556412220001221.90775179862976071.85787701606750491.9346899986267091.83298301696777341.80817270278930661.7829291820526123






0.9709508419036865
1.1577010154724121.1294558048248291.1022918224334717


1.1431009769439697
1.16512680053710941.16582202911376951.16613817214965821.16635513305664061.16622090339660641.16662597656251.16550612449646






1.0919301509857178
1.22640490531921391.1375460624694824

1.1062650680541992
1.1220669746398926
1.17311477661132811.208305835723877

1.1744911670684814
tensor(1.0773, grad_fn=<NllLossBackward>)
 64/600  [==>...........................] - 629.

In [13]:
if __name__ == '__main__':
    run()

==> Loading data..
Epoch 0
  1/600  [..............................] - 0.5sEpoch 1
  1/600  [..............................] - 0.5sEpoch 2
  1/600  [..............................] - 0.5sEpoch 3
  1/600  [..............................] - 0.5sEpoch 4
  1/600  [..............................] - 0.5sEpoch 5
  1/600  [..............................] - 0.5sEpoch 6
  1/600  [..............................] - 0.5sEpoch 7
  1/600  [..............................] - 0.5sEpoch 8
  1/600  [..............................] - 0.5sEpoch 9
  1/600  [..............................] - 0.5s

In [None]:
'''tmp = torch_mp.get_context('spawn')
            for model in models:
                model.share_memory()
            processes = []

            for i in range(len(models)):
                p = tmp.Process(target=train_slice, 
                    args=(i, models[i], dct_x[i, ...], y, outputs, ops[i]))
                p.start()
                processes.append(p)
            for p in processes: 
                p.join()'''

In [None]:
#plt.imshow(l.squeeze(), interpolation="nearest", cmap=plt.cm.gray)