In [1]:
import torch
import numpy as np
from matplotlib import pyplot as plt
from dataset import Dataset
from tqdm.notebook import tqdm

In [2]:
def to_param(x):
    return torch.nn.parameter.Parameter(
        torch.tensor(x.astype(np.float32)), requires_grad=True)

In [28]:
class MFA(torch.nn.Module):

    def __init__(self, d=32, scale=1.0, dims=(10, 10)):

        super().__init__()

        x1 = np.random.uniform(size=(dims[0], d)) * np.sqrt(scale / d)
        x2 = np.random.uniform(size=(dims[1], d)) * np.sqrt(scale / d)
        self.x1 = to_param(x1)
        self.x2 = to_param(x2)
    
    def forward(self, idx):
        return torch.logsumexp(self.x1[idx[:, 0]] + self.x2[idx[:, 1]], 1)

In [3]:
class MF(torch.nn.Module):

    def __init__(self, d=32, scale=1.0, dims=(10, 10)):

        super().__init__()

        x1 = np.random.uniform(size=(dims[0], d)) * np.sqrt(scale / d)
        x2 = np.random.uniform(size=(dims[1], d)) * np.sqrt(scale / d)
        self.x1 = to_param(x1)
        self.x2 = to_param(x2)
    
    def forward(self, idx):
        return torch.sum(self.x1[idx[:, 0]] * self.x2[idx[:, 1]], axis=1)


In [4]:
class MFNN(torch.nn.Module):

    def __init__(self, d=32, scale=1.0, dims=(10, 10)):

        super().__init__()

        x1 = np.random.uniform(size=(dims[0], d)) * np.sqrt(scale / d)
        x2 = np.random.uniform(size=(dims[1], d)) * np.sqrt(scale / d)
        self.x1 = to_param(x1)
        self.x2 = to_param(x2)

        self.fc1 = torch.nn.Linear(d * 2, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.out = torch.nn.Linear(32, 1)

    def forward(self, idx):
        
        x = torch.concat([self.x1[idx[:, 0]], self.x2[idx[:, 1]]], axis=1)
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return self.out(x).reshape(-1)


In [5]:
class MFNNPlus(torch.nn.Module):

    def __init__(self, features, d=32, scale=1.0, dims=(10, 10)):

        super().__init__()

        x2 = np.random.uniform(size=(dims[1], d)) * np.sqrt(scale / d)
        self.x2 = to_param(x2)

        self.features = torch.tensor(features)

        self.fc1 = torch.nn.Linear(d + features.shape[1], 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.out = torch.nn.Linear(32, 1)

    def forward(self, idx):

        x = torch.concat(
            [self.x2[idx[:, 1]], self.features[idx[:, 0]]], axis=1)
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return self.out(x).reshape(-1)


In [26]:
class MFNNBoth(torch.nn.Module):

    def __init__(self, features, d=32, scale=1.0, dims=(10, 10)):

        super().__init__()

        x1 = np.random.uniform(size=(dims[0], d)) * np.sqrt(scale / d)
        x2 = np.random.uniform(size=(dims[1], d)) * np.sqrt(scale / d)
        self.x1 = to_param(x1)
        self.x2 = to_param(x2)

        self.features = torch.tensor(features)

        self.fc1 = torch.nn.Linear(d * 2 + features.shape[1], 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.out = torch.nn.Linear(32, 1)

    def forward(self, idx):

        inputs = [
            self.x1[idx[:, 0]], self.x2[idx[:, 1]], self.features[idx[:, 0]]]
        x = torch.concat(inputs, axis=1)
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return self.out(x).reshape(-1)

In [21]:
def train(ds, model, initlr=0.001, steps=100, batch=64, disable_pbar=False):
    train_loss = []
    val_loss = []

    opt = torch.optim.Adam(model.parameters(), lr=initlr)
    sch = torch.optim.lr_scheduler.ReduceLROnPlateau(opt)

    for _ in tqdm(range(100), disable=disable_pbar):
        epoch_losses = []
        for _ in range(steps):

            opt.zero_grad()

            idx = ds.sample(batch, split='train')
            pred = model.forward(idx)
            loss = ds.loss(pred, idx=idx)

            loss.backward()

            opt.step()

            epoch_losses.append(loss.detach().cpu().numpy())

        train_loss.append(np.mean(epoch_losses))
        sch.step(train_loss[-1])
        
        with torch.no_grad():
            pred = model.forward(ds.splits['val'])
            loss = ds.loss(pred, split='val')
            val_loss.append(loss.detach().cpu().numpy())
    
    return np.array(train_loss), np.array(val_loss)

In [22]:
def evaluate(data, model_constructor, sparsity=0.5, repeat=10):
    print("Sparsity={}".format(sparsity))
    device = torch.device('cpu')
    losses = []
    for _ in tqdm(range(repeat)):
        ds = Dataset(data=data, val=1 - sparsity, device=device)
        model = model_constructor().to(device)
        losses.append(train(ds, model, batch=64, steps=100, disable_pbar=True))
    return np.array([x[0] for x in losses]), np.array([x[1] for x in losses])

In [29]:
from sklearn.decomposition import PCA

ds = Dataset(data="../data/polybench/20-40.npz")
opcodes_pca = PCA().fit_transform(ds.opcodes)[:, :8]

def _mf():
    return MF(d=1, scale=ds.rms, dims=ds.matrix.shape)
def _mfa():
    return MFA(d=1, scale=ds.rms, dims=ds.matrix.shape)
def _mfnn():
    return MFNN(d=8, scale=ds.rms, dims=ds.matrix.shape)
def _mfnnp():
    return MFNNPlus(
        features=ds.opcodes, d=8, scale=ds.rms, dims=ds.matrix.shape)
def _mfnnpca():
    return MFNNPlus(
        features=opcodes_pca, d=8, scale=ds.rms, dims=ds.matrix.shape)
def _mfnnboth():
    return MFNNBoth(
        features=opcodes_pca, d=8, scale=ds.rms, dims=ds.matrix.shape)

In [30]:
# [0] - sparsity (7)
# [1] - replicate (10)
# [2] - checkpoint (100)

methods = {
    # "mf": _mf,
    "mfa": _mfa,
    # "mfnn": _mfnn,
    # "mfnn_opcodes": _mfnnp,
    # "mfnn_pca": _mfnnpca
    # "mfnn_both": _mfnnboth
}
for method, func in methods.items():
    print(method)
    res = [
        evaluate(ds.data, func, sparsity=s, repeat=10)
        for s in [0.95, 0.9, 0.75, 0.5, 0.25, 0.1, 0.05]]
    train_loss = np.array([r[0] for r in res])
    val_loss = np.array([r[1] for r in res])

    np.savez("{}.npz".format(method), train=train_loss, val=val_loss)

mfa
Sparsity=0.95


  0%|          | 0/10 [00:00<?, ?it/s]

Sparsity=0.9


  0%|          | 0/10 [00:00<?, ?it/s]

Sparsity=0.75


  0%|          | 0/10 [00:00<?, ?it/s]

Sparsity=0.5


  0%|          | 0/10 [00:00<?, ?it/s]

Sparsity=0.25


  0%|          | 0/10 [00:00<?, ?it/s]

Sparsity=0.1


  0%|          | 0/10 [00:00<?, ?it/s]

Sparsity=0.05


  0%|          | 0/10 [00:00<?, ?it/s]