In [80]:
import torch
import torch.nn.functional as F
from sophia import SophiaG
import torch.nn as nn
from utils import read_data, default_dataset_parameters
from torch.utils.data import TensorDataset, DataLoader


In [141]:
data_name = 'a9a'
dataset_path = './datasets/{}.txt'.format(data_name)


# regularization parameter
lmb = 1e-3

# number of nodes, size of local data, and dimension of the problem
# according to the paper
N = default_dataset_parameters[data_name]['N']# size of the whole data set
n = default_dataset_parameters[data_name]['n']# number of nodes
m = default_dataset_parameters[data_name]['m']# size of local data set
d = default_dataset_parameters[data_name]['d']# dimension of the problem

batch_size = 512

In [142]:
A, b = read_data(dataset_path=dataset_path, 
                 N=N, n=n, m=m, d=d, lmb=lmb,
                labels=['+1', '-1', '0'])

In [143]:
b[b==-1.]

array([-1., -1., -1., ..., -1., -1., -1.])

In [144]:
tensor_x = torch.Tensor(A)
tensor_y = torch.Tensor(b)


my_dataset = TensorDataset(tensor_x,tensor_y) # create your datset
my_dataloader = DataLoader(my_dataset, batch_size=batch_size,shuffle=False) # create your dataloader


In [145]:
for i, (x, y) in zip(range(1), my_dataloader):
    print(x.shape[0])

512


In [146]:
def loss_fn(outputs, labels):
    labels = labels.view(-1)
    return 1/labels.shape[0] * torch.sum(torch.log(1 + torch.exp(-labels * outputs)))

In [147]:


class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.c_fc    = nn.Linear(d, 1, bias=False)
        #self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        #self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        return x

In [148]:
model = MLP()

In [149]:
loss_fn(tensor_y, tensor_y * 5)

tensor(0.0067)

In [150]:

optim = torch.optim.SGD(model.parameters(), lr=1e-5)

In [151]:
#Training 
for epoch in range(10):
    for X, Y in my_dataloader:
        # standard training code
        logits = model(X)
        loss = loss_fn(Y, logits)
        loss.backward()
        optim.step()
        
        print(loss, ((model(X) > 0) == (Y >0)).float().mean())

tensor(0.7501, grad_fn=<MulBackward0>) tensor(0.3233)
tensor(0.7426, grad_fn=<MulBackward0>) tensor(0.3445)
tensor(0.7377, grad_fn=<MulBackward0>) tensor(0.3561)
tensor(0.7454, grad_fn=<MulBackward0>) tensor(0.3321)
tensor(0.7409, grad_fn=<MulBackward0>) tensor(0.3408)
tensor(0.7378, grad_fn=<MulBackward0>) tensor(0.3265)
tensor(0.7345, grad_fn=<MulBackward0>) tensor(0.3425)
tensor(0.7476, grad_fn=<MulBackward0>) tensor(0.3071)
tensor(0.7452, grad_fn=<MulBackward0>) tensor(0.3178)
tensor(0.7430, grad_fn=<MulBackward0>) tensor(0.3312)
tensor(0.7463, grad_fn=<MulBackward0>) tensor(0.3511)
tensor(0.7414, grad_fn=<MulBackward0>) tensor(0.3367)
tensor(0.7467, grad_fn=<MulBackward0>) tensor(0.3346)
tensor(0.7450, grad_fn=<MulBackward0>) tensor(0.3257)
tensor(0.7442, grad_fn=<MulBackward0>) tensor(0.3257)
tensor(0.7420, grad_fn=<MulBackward0>) tensor(0.3165)
tensor(0.7505, grad_fn=<MulBackward0>) tensor(0.3123)
tensor(0.7321, grad_fn=<MulBackward0>) tensor(0.3522)
tensor(0.7530, grad_fn=<MulB

tensor(0.6664, grad_fn=<MulBackward0>) tensor(0.6334)
tensor(0.6698, grad_fn=<MulBackward0>) tensor(0.6321)
tensor(0.6640, grad_fn=<MulBackward0>) tensor(0.6545)
tensor(0.6590, grad_fn=<MulBackward0>) tensor(0.6604)
tensor(0.6666, grad_fn=<MulBackward0>) tensor(0.6363)
tensor(0.6628, grad_fn=<MulBackward0>) tensor(0.6581)
tensor(0.6651, grad_fn=<MulBackward0>) tensor(0.6386)
tensor(0.6548, grad_fn=<MulBackward0>) tensor(0.6821)
tensor(0.6627, grad_fn=<MulBackward0>) tensor(0.6702)
tensor(0.6600, grad_fn=<MulBackward0>) tensor(0.6442)
tensor(0.6682, grad_fn=<MulBackward0>) tensor(0.6418)
tensor(0.6589, grad_fn=<MulBackward0>) tensor(0.6775)
tensor(0.6556, grad_fn=<MulBackward0>) tensor(0.6751)
tensor(0.6642, grad_fn=<MulBackward0>) tensor(0.6548)
tensor(0.6605, grad_fn=<MulBackward0>) tensor(0.6610)
tensor(0.6573, grad_fn=<MulBackward0>) tensor(0.6547)
tensor(0.6539, grad_fn=<MulBackward0>) tensor(0.6665)
tensor(0.6469, grad_fn=<MulBackward0>) tensor(0.6991)
tensor(0.6510, grad_fn=<MulB

tensor(0.5438, grad_fn=<MulBackward0>) tensor(0.7617)
tensor(0.5522, grad_fn=<MulBackward0>) tensor(0.7480)
tensor(0.5353, grad_fn=<MulBackward0>) tensor(0.7754)
tensor(0.5374, grad_fn=<MulBackward0>) tensor(0.7715)
tensor(0.5554, grad_fn=<MulBackward0>) tensor(0.7500)
tensor(0.5445, grad_fn=<MulBackward0>) tensor(0.7656)
tensor(0.5690, grad_fn=<MulBackward0>) tensor(0.7285)
tensor(0.5285, grad_fn=<MulBackward0>) tensor(0.7910)
tensor(0.5329, grad_fn=<MulBackward0>) tensor(0.7715)
tensor(0.5463, grad_fn=<MulBackward0>) tensor(0.7500)
tensor(0.5412, grad_fn=<MulBackward0>) tensor(0.7578)
tensor(0.5397, grad_fn=<MulBackward0>) tensor(0.7617)
tensor(0.5455, grad_fn=<MulBackward0>) tensor(0.7578)
tensor(0.5260, grad_fn=<MulBackward0>) tensor(0.7812)
tensor(0.5266, grad_fn=<MulBackward0>) tensor(0.7773)
tensor(0.5488, grad_fn=<MulBackward0>) tensor(0.7480)
tensor(0.5298, grad_fn=<MulBackward0>) tensor(0.7715)
tensor(0.5623, grad_fn=<MulBackward0>) tensor(0.7383)
tensor(0.5160, grad_fn=<MulB

tensor(0.4907, grad_fn=<MulBackward0>) tensor(0.7832)
tensor(0.5778, grad_fn=<MulBackward0>) tensor(0.7168)
tensor(0.5238, grad_fn=<MulBackward0>) tensor(0.7578)
tensor(0.5402, grad_fn=<MulBackward0>) tensor(0.7480)
tensor(0.5067, grad_fn=<MulBackward0>) tensor(0.7695)
tensor(0.4970, grad_fn=<MulBackward0>) tensor(0.7793)
tensor(0.5073, grad_fn=<MulBackward0>) tensor(0.7637)
tensor(0.5862, grad_fn=<MulBackward0>) tensor(0.7188)
tensor(0.5010, grad_fn=<MulBackward0>) tensor(0.7754)
tensor(0.5154, grad_fn=<MulBackward0>) tensor(0.7637)
tensor(0.5095, grad_fn=<MulBackward0>) tensor(0.7715)
tensor(0.5469, grad_fn=<MulBackward0>) tensor(0.7383)
tensor(0.5801, grad_fn=<MulBackward0>) tensor(0.7129)
tensor(0.5417, grad_fn=<MulBackward0>) tensor(0.7422)
tensor(0.4840, grad_fn=<MulBackward0>) tensor(0.7829)
tensor(0.4937, grad_fn=<MulBackward0>) tensor(0.7793)
tensor(0.5351, grad_fn=<MulBackward0>) tensor(0.7520)
tensor(0.5704, grad_fn=<MulBackward0>) tensor(0.7246)
tensor(0.5258, grad_fn=<MulB

In [None]:
# init the optimizer
optimizer = SophiaG(model.parameters(), lr=2e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=1e-1)

total_bs = len(data_loader)
bs = total_bs * block_size
k = 10
iter_num = -1

# training loop
for epoch in range(epochs):
    for X, Y in mydata_loader:
        # standard training code
        logits, loss = model(X, Y)
        loss.backward()
        optimizer.step(bs=bs)
        optimizer.zero_grad(set_to_none=True)
        iter_num += 1

        if iter_num % k != k - 1:
            continue
        else:
            # update hessian EMA
            logits, _ = model(X, None)
            samp_dist = torch.distributions.Categorical(logits=logits)
            y_sample = samp_dist.sample()
            loss_sampled = F.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1), ignore_index=-1)
            loss_sampled.backward()
            optimizer.update_hessian()
            optimizer.zero_grad(set_to_none=True)
            model.zero_grad()