In [None]:
import os
import torch
from torch import nn
from torch.nn import functional as F
import torch.utils.data as Data
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import torchmetrics
import numpy as np
import gift_64 as gift

In [None]:
class LSTM(pl.LightningModule):
    def __init__(self, 
                 n_features, 
                 hidden_size, 
                 num_layers, 
                 dropout,
                 bidirectional,
                 learning_rate,
                 criterion):
        super(LSTM, self).__init__()
        self.n_features = n_features
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.ce = criterion
        self.learning_rate = learning_rate

        self.lstm = nn.LSTM(input_size=n_features, 
                            hidden_size=hidden_size,
                            num_layers=num_layers, 
                            dropout=dropout,
                            bidirectional=bidirectional,
                            batch_first=True)
        if bidirectional:
            self.linear = nn.Sequential(
                nn.Linear(hidden_size*2, 1),
                nn.Sigmoid()
            )
        else:
            self.linear = nn.Sequential(
                nn.Linear(hidden_size, 1),
                nn.Sigmoid()
            )

    def forward(self, x):
        x = x.view(len(x), 1, -1)
        lstm_out, _ = self.lstm(x.float())
        y_pred = self.linear(lstm_out[:,-1])
        return y_pred.float()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        y_hat = self(x)
        acc_y_hat = y_hat.clone()
        loss = self.ce(acc_y_hat.clone().float(), y.clone().float())
        acc = torchmetrics.functional.accuracy(acc_y_hat.long(), y.clone().long())
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        acc_y_hat = y_hat.clone()
        loss = self.ce(acc_y_hat.clone().float(), y.clone().float())
        acc = torchmetrics.functional.accuracy(acc_y_hat.clone().long(), y.clone().long())        
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('val_acc', acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x)
        loss = self.ce(y_hat.view_as(y), y.float())
        acc_y_hat = y_hat.clone()
        acc = torchmetrics.functional.accuracy(acc_y_hat.long(), y.clone().long())        
        self.log('test_loss', loss)
        self.log('test_acc', acc)


In [None]:
%%time
start_round=1
cipher='GIFT_64'
num_rounds=4
data_train= 2**25 #2**25
data_val= 2**22 #2**22
difference=(0x0044,0x0000,0x0011,0x0000)
pre_trained_model='fresh'
if (cipher == "GIFT_64"):
    wdir = './gift_64_nets/'
    # print(difference)
    if not os.path.exists(wdir):
        os.makedirs(wdir)
    

In [None]:
# %%time
# X, Y = gift.make_train_data(data_train,
#                             num_rounds,
#                             diff=difference,
#                             r_start=start_round)
# X_eval, Y_eval = gift.make_train_data(data_val,
#                                         num_rounds,
#                                         diff=difference,
#                                         r_start=start_round)

In [None]:
%%time
import numpy
ver = '2'
print(os.getcwd())
X = numpy.load(ver + '_X.npy')
Y = numpy.load(ver + '_Y.npy')
X_eval = numpy.load(ver + '_Xv.npy')
Y_eval = numpy.load(ver + '_Yv.npy')

X = numpy.reshape(X, (X.shape[0], -1))
X_eval = numpy.reshape(X_eval, (X_eval.shape[0], -1))

# print(Y.shape)
Y = numpy.reshape(Y, (Y.shape[0], 1))
# print(Y.shape)
Y_eval = numpy.reshape(Y_eval, (Y_eval.shape[0], 1))


In [None]:
%%time
# print(type(X))
print(X.shape, X.dtype)
print(Y)

p = dict(
    criterion = nn.MSELoss(),
#     max_epochs = 10,
    n_features = 256*1,
    hidden_size = 128*8, # 128,
    num_layers = 2,
    dropout = 0.2,
    learning_rate = 0.001,
    bidirectional = True,
    # bidirectional = False,
)

net = LSTM(
    n_features = p['n_features'],
    hidden_size = p['hidden_size'],
    criterion = p['criterion'],
    num_layers = p['num_layers'],
    dropout = p['dropout'],
    bidirectional = p['bidirectional'],
    learning_rate = p['learning_rate']
)

# net = PU(prior1, prior2)

train_loader = Data.TensorDataset(*(torch.tensor(X.astype('float32')), torch.tensor(Y.astype('float32'))))
#     train_loader = Data.TensorDataset(torch.Tensor(X), torch.Tensor(Y))
val_loader = Data.TensorDataset(*(torch.tensor(X_eval.astype('float32')), torch.tensor(Y_eval.astype('float32'))))
#     test_loader = Data.TensorDataset(torch.Tensor(X_eval), torch.Tensor(Y_eval))
train_loader = DataLoader(train_loader, num_workers=2, batch_size=2**10, pin_memory=True) #, shuffle=True)
val_loader = DataLoader(val_loader, num_workers=1, batch_size=2**7, pin_memory=True) #, shuffle=True)
del(X, Y, X_eval, Y_eval)

In [None]:
from pl_bolts.callbacks import PrintTableMetricsCallback
from pytorch_lightning.utilities.model_summary import ModelSummary
callback = PrintTableMetricsCallback()
ModelSummary(net, max_depth=10)

In [None]:
train_loader.dataset.tensors[0].shape

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

pl.seed_everything(42, workers=True)
bar = pl.callbacks.progress.TQDMProgressBar(refresh_rate=64)

checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    filename='gift64-{epoch:02d}-{val_acc:.3f}',
    save_top_k=1,
    mode='max',
    save_last=True,
#     every_n_train_steps= 0, every_n_epochs= 1, train_time_interval= None, save_on_train_epoch_end= None
)

early_stop_callback = EarlyStopping(monitor="val_acc",
                                    min_delta=0.0000,
                                    patience=4, verbose=False, mode="max")

trainer = pl.Trainer(
                        auto_lr_find=True,
                        callbacks=[bar, checkpoint_callback],
                        precision=16,
                        gpus=[0],
                        deterministic=True,
                        max_epochs=201)
lr_finder = trainer.tuner.lr_find(net, train_loader, val_loader, max_lr = 0.1 , num_training = 233)


In [None]:
import matplotlib
# Results can be found in
# print(lr_finder.results)

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()
print(new_lr)

In [None]:
net.learning_rate = new_lr

trainer.fit(net, train_loader, val_loader)

In [None]:
# %%time
# xt, yt = gift.make_train_data(data_val,
#                                         num_rounds,
#                                         diff=difference,
#                                         r_start=start_round)
# print(yt)

xt = np.load(ver+'_Xt.npy')
yt = np.load(ver+'_Yt.npy')

xt = numpy.reshape(xt, (xt.shape[0], -1))
yt = numpy.reshape(yt, (yt.shape[0], 1))

test_loader = Data.TensorDataset(*(torch.tensor(xt.astype('float32')), torch.tensor(yt.astype('float32'))))
test_loader = DataLoader(test_loader, num_workers=1, batch_size=2**7, pin_memory=True) #, shuffle=True)

# del(xt, yt)
ret = trainer.test(net, test_loader)

In [None]:
print(ret[0]['test_acc'])