In [None]:
import torch
import torch_dct as dct
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import StepLR

In [None]:
def bcirc(A):
    l, m, n = A.shape
    bcirc_A = []
    for i in range(l):
        bcirc_A.append(torch.roll(A, shifts=i, dims=0))
    return torch.cat(bcirc_A, dim=2).reshape(l*m, l*n)

def t_product(A, B):
    assert(A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1])
    prod = torch.mm(bcirc(A), bcirc(B)[:, 0:B.shape[2]])
    return prod.reshape(A.shape[0], A.shape[1], B.shape[2])

def h_func_dct(lateral_slice):
    l, m, n = lateral_slice.shape
    dct_slice = dct.dct(lateral_slice)
    tubes = [dct_slice[i, :, 0] for i in range(l)]
    h_tubes = []
    for tube in tubes:
        tube_sum = torch.sum(torch.exp(tube))
        h_tubes.append(torch.exp(tube) / tube_sum)
    res_slice = torch.stack(h_tubes, dim=0).reshape(l, m, n)
    idct_a = dct.idct(res_slice)
    return torch.sum(idct_a, dim=0)

def scalar_tubal_func(output_tensor):
    l, m, n = output_tensor.shape
    lateral_slices = [output_tensor[:, :, i].reshape(l, m, 1) for i in range(n)]
    h_slice = []
    for slice in lateral_slices:
        h_slice.append(h_func_dct(slice))
    pro_matrix = torch.stack(h_slice, dim=2)
    return pro_matrix.reshape(m, n)

def raw_img(img, batch_size, n):
    img_raw = img.reshape(batch_size, n * n)
    single_img = torch.split(img_raw, split_size_or_sections=1, dim=0)
    single_img_T = [torch.transpose(i.reshape(n, n, 1), 0, 1) for i in single_img]
    ultra_img = torch.cat(single_img_T, dim=2)
    return ultra_img

In [None]:
def train_step_transform(epoch, train_acc, model, trainloader, optimizer):  
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    
    print('\nEpoch: ', epoch)
    print('|', end='')
    for batch_idx, (inputs, labels) in enumerate(trainloader):   
        inputs = raw_img(inputs, inputs.size(0), 28)
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs) / 5e2
        outputs = torch.transpose(scalar_tubal_func(outputs), 0, 1)
        
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        if np.isnan(loss.item()):
            print('Training terminated due to instability')
            break
        loss.backward()
        optimizer.step()
        train_loss += loss.item()  
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        if batch_idx % 10 == 0:
            print('=', end='')
    print('|', 'Accuracy:', correct / total, 'Loss:', train_loss / total)
    train_acc.append(correct / total)
    return train_acc

def test(test_acc, model, testloader):
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    s = time.time()
    with torch.no_grad():
        print('|', end='')
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs = raw_img(inputs, inputs.size(0), 28)
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs) / 1e3
            outputs = torch.transpose(scalar_tubal_func(outputs), 0, 1)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            if batch_idx % 10 == 0:
                print('=', end='')
    e = time.time() 
    print('|', ' Test accuracy:', correct / total, 'Test loss:', test_loss / total)
    print('The inference time is', e - s, 'seconds')
    test_acc.append(correct / total)
    return test_acc, e - s
    
def train_transform(i, model, trainloader, testloader, optimizer):
    train_acc, test_acc = [], []
    scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
    
    for epoch in range(i):
        s = time.time()
        train_acc = train_step_transform(epoch, train_acc, model, trainloader, optimizer)
        test_acc, _ = test(test_acc, model, testloader)
        scheduler.step()
        e = time.time()
        print('This epoch took', e - s, 'seconds to train')
        print('Current learning rate: ', scheduler.get_last_lr()[0])
    print('Best training accuracy overall: ', max(test_acc))
    return train_acc, test_acc

In [None]:
class Transform_Layer(nn.Module):
    def __init__(self, n, size_in, m, size_out):
        super().__init__()
        self.size_in = size_in
        self.size_out = size_out
        weights = torch.randn(n, size_out, size_in)
        bias = torch.randn(1, size_out, m)
        self.weights = nn.Parameter(weights, requires_grad=True)
        self.bias = nn.Parameter(bias, requires_grad=True)
        
    def forward(self, x):
        Wx = t_product(self.weights, x)
        return Wx + self.bias

In [None]:
class Transform_Net(nn.Module):
    def __init__(self, batch_size):
        super(Transform_Net, self).__init__()
        self.features = nn.Sequential(
            Transform_Layer(28, 28, batch_size, 28),
            nn.ReLU(inplace=True),
            Transform_Layer(28, 28, batch_size, 28),
            nn.ReLU(inplace=True),
            Transform_Layer(28, 28, batch_size, 10),
        )

    def forward(self, x):
        x.requires_grad = True
        x = self.features(x)        
        return x

In [None]:
model = Transform_Net(100)
print(model)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
trainloader, testloader = load_mnist()

In [None]:
train_acc, test_acc = train_transform(25, model, trainloader, testloader, optimizer)

In [None]:
class T_Net(nn.Module):
    def __init__(self, batch_size):
        super(T_Net, self).__init__()
        self.first = nn.Sequential(
            Transform_Layer(28, 28, batch_size, 28),
            nn.ReLU(inplace=True),
        ) 
        self.intermediate = nn.Sequential(
            nn.Conv2d(28, 28, kernel_size=3, padding=1),
            nn.Conv2d(28, 28, kernel_size=3, padding=1),
            nn.Conv2d(28, 28, kernel_size=1, padding=0),
        )
        self.last = Transform_Layer(28, 28, batch_size, 10)

    def forward(self, x):
        x = self.first(x)
        
        x = torch.transpose(x, 0, 2)
        x = torch.transpose(x, 1, 2)
        x = x.reshape(100, 28, 4, 7)
        
        x = self.intermediate(x)
        
        x = x.reshape(100, 28, 28)
        x = torch.transpose(x, 0, 2)
        x = torch.transpose(x, 0, 1)
        x = self.last(x)
        
        return x

In [None]:
model = T_Net(100)
print(model)
optimizer = optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)

In [None]:
train_acc, test_acc = train_transform(1, model, trainloader, testloader, optimizer)

In [None]:
inputs = torch.randn(28, 28, 100)

model = T_Net(100)
model(inputs).shape

In [None]:
trans_layer = Transform_Layer(28, 28, 100, 28)
x = trans_layer(inputs)

x = torch.transpose(x, 0, 2)
x = torch.transpose(x, 1, 2)
x = x.reshape(100, 28, 4, 7)

conv1 = nn.Conv2d(28, 28, kernel_size=3, padding=1)
conv2 = nn.Conv2d(28, 28, kernel_size=3, padding=1)
conv3 = nn.Conv2d(28, 28, kernel_size=1, padding=0)
x = conv1(x)
x = conv2(x)
x = conv3(x)
x = x.reshape(100, 28, 28)
x = torch.transpose(x, 0, 2)
x = torch.transpose(x, 0, 1)

last_layer = Transform_Layer(28, 28, 100, 10)
x = last_layer(x)