In [190]:
import pandas as pd
import numpy as np
# import mplhep as hep
from tqdm import tqdm
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

pd.set_option('display.max_columns', 150)

import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
# pd.set_option('display.max_columns', 150)



filecode = 'InfA_RD_DPrmvd'
# filecode = parquetcode


train_hp = {
    "lr":0.01,
    "batch_size":100000,
    "N_epochs":15,
    "seed":0,
}
nodes = [7]


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
set_seed(train_hp['seed'])

def hess_to_tensor(H):
    hess_elements = []
    for i in range(len(H)):
        for j in range(len(H)):
            hess_elements.append(H[i][j].reshape(1))
    return torch.cat(hess_elements).reshape(len(H),len(H))


#Define the Net
class Net(nn.Module):
    def __init__(self, n_features=7, nodes=[7,7], output_nodes=7):
        super(Net, self).__init__()
        # Build network
        n_nodes = [n_features] + nodes + [output_nodes]
        self.layers = nn.ModuleList()
        for i in range(len(n_nodes)-1):
            l = nn.Linear(n_nodes[i], n_nodes[i+1])
            self.layers.append(l)
        self._initialize_weights()
    def _initialize_weights(self):
        with torch.no_grad():
            for l in self.layers:
                l.weight.data.copy_(torch.eye(7))
                l.bias.data.fill_(0)

    def forward(self, x):
        out = torch.tensor(self.layers[0](x),dtype=x.dtype)
        for layer in self.layers[1:]:
            out = nn.ReLU()(layer(out))
            print(out)
        return out
        # return torch.softmax(out, dim=0)



class InfAwareLoss(nn.Module):
    def __init__(self):
        super(InfAwareLoss, self).__init__()

    
    def forward(self,input,target,weight):
    
        # Input = torch.tensor(input)
        # Target = torch.tensor(target,dtype=torch.int8)
        
        label = torch.argmax(target,dim=1)
        pred = torch.argmax(input,dim=1)
        cm = torch.zeros(7,7)
        up = pred.unique()
        ul = label.unique()
        for p in up:
            for l in ul:
                cm[p,l] = weight[pred==p][label[pred==p]==l].sum()
        print(cm)
        cm =cm[1:, :]
        O = cm.sum(dim=1)
        def NLL(mu):
            mu0 =torch.tensor([1.0])
            theta = torch.cat((mu0,mu))
            return -(O@(torch.log(cm@theta))-(cm@theta).sum())
        mu = torch.tensor([1.0,1.0,1.0,1.0,1.0,1.0])
        hess = torch.func.hessian(NLL)(mu)
        I = torch.inverse(hess_to_tensor(hess))
        loss = torch.trace(I)**0.5/1000
        return loss.clone().detach().requires_grad_(True)
        

#Define the trainning function
from NNfunctions import get_batches, get_total_loss,get_total_lossW
def train_network_cross_entropy(model, X_train,X_test,y_train,y_test,w_train,w_test, train_hp={}):
    # optimiser = torch.optim.Adam(model.parameters(), lr=train_hp["lr"])
    optimiser = torch.optim.SGD(model.parameters(), lr=0.01)
    X_train =X_train.to_numpy()
    X_test = X_test.to_numpy()
    y_train = y_train.to_numpy()
    y_test = y_test.to_numpy()
    w_train = w_train.to_numpy()
    w_test = w_test.to_numpy()
    
    
    train_loss, test_loss = [], []
    ia_loss = InfAwareLoss()

    print(">> Training...")
    with tqdm(range(train_hp["N_epochs"])) as t:
        for i_epoch in t:
            model.train()
            # print(i)
            # "get_batches": function defined in statml_tools.py to separate the training data into batches
            batch_gen = get_batches([X_train, y_train, w_train], batch_size=train_hp['batch_size'],
                                    randomise=True, include_remainder=False
                                )
            
            for X_tensor, y_tensor, w_tensor in batch_gen:
                optimiser.zero_grad()
                print(X_tensor)
                output = model(X_tensor)
                print(output)
                # print(output)
                loss = ia_loss(output, y_tensor, w_tensor)
                if torch.isnan(loss):
                    raise ValueError("Loss is NaN, terminating training")

                loss.backward()
                optimiser.step()
                

            model.eval()
            
            Loss = ia_loss
            train_loss.append(get_total_lossW(model, Loss, X_train, y_train,w_train))
            test_loss.append(get_total_lossW(model, Loss, X_test, y_test,w_test))
            
            # "get_total_loss": function defined in statml_tools.py to evaluate the network in batches (useful for large datasets)
            
            t.set_postfix(train_loss=train_loss[-1], test_loss=test_loss[-1])


    print(">> Training finished")
    model.eval()

    return model, train_loss, test_loss
# mi_series = pd.read_csv('/vols/cms/hw423/Week6/MI_balanced.csv')
# MIcol = mi_series.head(140)['Features']


# oc = np.load(f'/vols/cms/hw423/Data/Week14/octest_{filecode}.npy')
# df = pd.DataFrame(oc)
# dfx = df
# # mi_series = pd.read_csv('/vols/cms/hw423/Week6/MI_balanced.csv')
# # df = pd.read_parquet('/vols/cms/hw423/Data/Week14/df_InfA_RD_DPrmvd.parquet')

# # dfx=df[MIcol]
# label = pd.read_pickle('/vols/cms/hw423/Data/Week14/Label.pkl')
# dfy = pd.get_dummies(label)
# dfw = pd.read_pickle('/vols/cms/hw423/Data/Week14/weight.pkl')

# model_ia = Net(n_features=140, nodes=nodes, output_nodes=7)
# model_ia.load_state_dict(torch.load(f'/vols/cms/hw423/Data/Week14/model_b_u_x30x20_DPrmvd_100.pth'))



In [192]:
model_ia = Net(n_features=7, nodes=[7,7,7,7,7], output_nodes=7)

In [191]:
a = torch.tensor(np.load('/vols/cms/hw423/Data/Week14/octest_InfAwar_test.npy'))
a[0]

tensor([9.2121e-08, 9.3106e-01, 9.2121e-08, 5.8412e-03, 6.3094e-02, 9.2121e-08,
        9.2121e-08])

In [193]:
torch.softmax(model_ia(a[1]),dim=0)

tensor([0.0013, 0.3522, 0.1162, 0.0351, 0.2609, 0.0424, 0.1918],
       grad_fn=<ReluBackward0>)
tensor([0.0013, 0.3522, 0.1162, 0.0351, 0.2609, 0.0424, 0.1918],
       grad_fn=<ReluBackward0>)
tensor([0.0013, 0.3522, 0.1162, 0.0351, 0.2609, 0.0424, 0.1918],
       grad_fn=<ReluBackward0>)
tensor([0.0013, 0.3522, 0.1162, 0.0351, 0.2609, 0.0424, 0.1918],
       grad_fn=<ReluBackward0>)
tensor([0.0013, 0.3522, 0.1162, 0.0351, 0.2609, 0.0424, 0.1918],
       grad_fn=<ReluBackward0>)


  out = torch.tensor(self.layers[0](x),dtype=x.dtype)


tensor([0.1231, 0.1748, 0.1381, 0.1273, 0.1596, 0.1282, 0.1489],
       grad_fn=<SoftmaxBackward0>)

In [194]:
for layer in model_ia.layers:
    print(layer.weight.data)

tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.]])
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [

In [209]:
torch.sum(a,dim=1)

tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000])

In [210]:
a

tensor([[9.2121e-08, 9.3106e-01, 9.2121e-08,  ..., 6.3094e-02, 9.2121e-08,
         9.2121e-08],
        [1.3017e-03, 3.5224e-01, 1.1622e-01,  ..., 2.6093e-01, 4.2365e-02,
         1.9184e-01],
        [1.2141e-04, 6.4055e-01, 2.4761e-01,  ..., 7.4634e-02, 1.2141e-04,
         3.9260e-04],
        ...,
        [6.4752e-05, 5.7427e-01, 1.9477e-01,  ..., 1.7192e-01, 6.4752e-05,
         9.8250e-04],
        [1.0982e-08, 1.4988e-07, 1.0982e-08,  ..., 1.4480e-03, 6.3886e-08,
         1.1960e-06],
        [1.4982e-05, 2.4007e-02, 9.5210e-01,  ..., 1.4659e-02, 1.4982e-05,
         2.5381e-04]])