In [None]:
import sys, os
if os.path.abspath('../../musco-pytorch-private') not in sys.path:
    sys.path.append(os.path.abspath('../../musco-pytorch-private'))
if os.path.abspath('../../maxvol_compression_pytorch') not in sys.path:
    sys.path.append(os.path.abspath('../../maxvol_compression_pytorch'))

In [None]:
from functools import partial
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from maxvolpy.maxvol import rect_maxvol

In [None]:
from maxvol_compression.sketch_matrix import RandomSums
from maxvol_compression.vmbf import EVBMF, weaken_rank
from maxvol_compression.layers import LinearMaxvol
from utils.dummy import DummyDatasetCifar10, DummyModelCifar10
from musco.pytorch.compressor.layers.conv1d_toeplitz import Conv1Dtoeplitz

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
BATCH_SIZE = 10

In [None]:
model = DummyModelCifar10().to(device)
model.load_state_dict(torch.load('data/dummy.weights'))
cifar10 = DummyDatasetCifar10(batch_size=BATCH_SIZE, data_root='data')

In [None]:
model.eval()

In [None]:
%%time
correct = 0
all_ = len(cifar10.testloader) * BATCH_SIZE
with torch.no_grad():
    for i, data in tqdm(enumerate(cifar10.testloader, 0)):
        inputs, labels = data
        inputs = inputs.to(device)
        _, predicted = torch.max(model(inputs).cpu(), 1)
        correct += (labels == predicted).sum().detach().numpy()
        
print(f'accuracy: {correct / all_}')

In [None]:
rs = RandomSums(500, 13875, keep_original=True)
rs.load('data/conv1_sketch_matrix.npy')

In [None]:
_, sigma, Vt = np.linalg.svd(rs.sketch_matrix, full_matrices=False)
_, vbmf_s, _, vbmf_post = EVBMF(None, pretrained_svd=(None, sigma, Vt))
V = Vt.T

In [None]:
Vt.shape, vbmf_s.shape

In [None]:
rank = weaken_rank(rank=min(*Vt.shape), extreme_rank =len(vbmf_s), weakenen_factor=1.0)
rank

In [None]:
# idxs, _ = rect_maxvol(V, maxK=int(1.7*min(*V.shape)))
# len(idxs)

In [None]:
idxs = np.arange(0, 13875, 10)
len(idxs)

In [None]:
from torch.nn.parameter import Parameter
from torch.nn import Module
import torch.nn.functional as F

class LinearMaxvol_v2(Module):
    def __init__(self, linear, idxs, device):
        super().__init__()
        self.in_features = linear.in_features
        self.out_features = linear.out_features

        self.idxs = idxs
        with torch.no_grad():
            self.weight = Parameter(linear.weight[idxs].detach())
            if linear.bias is not None:
                self.bias = Parameter(linear.bias[idxs].detach())
            else:
                self.register_parameter('bias', None)
#             self.V = Parameter(torch.Tensor(V))
#             self.invSV = Parameter(torch.Tensor(np.linalg.pinv(V[idxs, :])))
        self.pad = torch.eye(self.out_features, self.out_features)[:, self.idxs].to(device)

    def forward(self, input):
        x = F.linear(input, self.weight, self.bias)
#         return (self.V @ self.invSV @ x.T).T
        return (self.pad @ x.T).T

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}, idxs_len={}'.format(
            self.in_features, self.out_features, self.bias is not None, len(self.idxs)
        )

In [None]:
conv1d_tplz = Conv1Dtoeplitz(model.conv1, (BATCH_SIZE, 3, 1024))
# conv1d_tplz.dense_layer = LinearMaxvol_v2(conv1d_tplz.dense_layer, idxs, device=device)
conv1d_tplz.dense_layer = LinearMaxvol(conv1d_tplz.dense_layer, idxs, V, device=device)
model.conv1 = conv1d_tplz

In [None]:
%%time
correct = 0
all_ = len(cifar10.testloader) * BATCH_SIZE
with torch.no_grad():
    for i, data in tqdm(enumerate(cifar10.testloader, 0)):
        inputs, labels = data
        inputs = inputs.to(device)
        _, predicted = torch.max(model(inputs).cpu(), 1)
        correct += (labels == predicted).sum().detach().numpy()
        
print(f'accuracy: {correct / all_}')

In [None]:
len(idxs)

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
ax.set_title('Conv1d layer decomposition')
ax.plot([850, 1388, 3469, 4625, 6937], [0.2839, 0.4527, 0.53, 0.6, 0.63], 'bs', label='regular without VVt')
ax.plot([850, 1388, 3469, 4625, 6937], [0.62, 0.62, 0.6331, 0.64, 0.63], 'rs', label='regular with VVt')
ax.legend(loc='lower right')
ax.set_xlabel('remained indexes')
ax.set_ylabel('accuracy on 10 classes')
fig.savefig('conv1d_reducedorder.jpg')
plt.show()