In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import precision_recall_curve, auc
import matplotlib.pyplot as plt
from utils.simulators import Steel_APT_Dataset

device = torch.device('cpu')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
def preprocess_inputs(data: pd.DataFrame, k: int = 10) -> np.ndarray:
    knn = NearestNeighbors(n_neighbors=k + 1)
    knn.fit(data[["X", "Y", "Z"]])
    distances, indexes = knn.kneighbors(data[["X", "Y", "Z"]])
    return distances[:, 1:]  # 1st nieghbour is the point itself

class Training_Dataset(Dataset):

    def __init__(self, preprocessed_inputs: np.ndarray, cluster_labels: np.ndarray):
        self.X = preprocessed_inputs
        self.Y = cluster_labels
        self.n_examples = self.Y.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.Y[index]

    def __len__(self):
        return self.n_examples

def get_auprc(model: torch.nn.Sequential, X: torch.Tensor, Y: torch.Tensor) -> float:
    Y_pred = model(X).detach().cpu().numpy()
    Y = Y.cpu().numpy()
    precision, recall, _ = precision_recall_curve(Y, Y_pred)
    return auc(x=recall, y=precision)

In [5]:
class LUNOT(nn.Module):
    def __init__(self, k: int, layer_size: int = 128, layer_count: int = 8, dropout: float = 0.0):
        super(LUNOT, self).__init__()

        layers = [nn.Linear(k, layer_size), nn.Dropout(dropout), nn.GELU()]
        for i in range(layer_count - 2):
            layers += [nn.Linear(layer_size, layer_size), nn.Dropout(dropout), nn.GELU()]
            if i % 2 == 0:
                layers += [nn.BatchNorm1d(layer_size, affine=False)]
        layers += [nn.Linear(layer_size, 1), nn.Sigmoid()]
        
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        out = self.network(x)
        return torch.squeeze(out, 1)

In [6]:
train_dataset = Steel_APT_Dataset(
    unit_cells_per_side=150,
    cluster_relative_density=200,
    cluster_atom_counts=np.random.randint(low=10, high=50, size=400),
)
train_data = train_dataset.data[train_dataset.data['Element']!='Fe']

val_dataset = Steel_APT_Dataset(
    unit_cells_per_side=100,
    cluster_relative_density=200,
    cluster_atom_counts=np.random.randint(low=10, high=50, size=120),
)
val_data = val_dataset.data[val_dataset.data['Element']!='Fe']

In [7]:
train_data['is cluster'].value_counts()

is cluster
0    68079
1     6590
Name: count, dtype: int64

In [8]:
val_data['is cluster'].value_counts()

is cluster
0    20053
1     1949
Name: count, dtype: int64

In [None]:
def get_model(train_data: pd.DataFrame, val_data: pd.DataFrame, lr=0.001, n_neighbours=10, layer_size=32, batch_size=64, layer_count=3, lr_schedule_k=1.0) -> nn.Module:
    X_val = preprocess_inputs(val_data, k=n_neighbours)
    Y_val = val_data['is cluster'].to_numpy()
    X_train = preprocess_inputs(train_data, k=n_neighbours)
    Y_train = train_data['is cluster'].to_numpy()
    
    X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
    Y_val = torch.tensor(Y_val, dtype=torch.float32).to(device)
    X_train = torch.tensor(X_train, dtype=torch.float32).to(device)
    Y_train = torch.tensor(Y_train, dtype=torch.float32).to(device)

    training_dataset = Training_Dataset(X_train, Y_train)
    train_data_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)

    model = LUNOT(k=n_neighbours, layer_size=layer_size, layer_count=layer_count).to(device, dtype=torch.float32)
    criterion = torch.nn.BCELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: lr_schedule_k**epoch)

    losses, train_scores, val_scores = [], [], []

    for epoch in range(1, 21):

        model.train()
        for X, Y in train_data_loader:
            optimizer.zero_grad()
            Y_pred = model(X)
            loss = criterion(Y_pred, Y)
            loss.backward()
            optimizer.step()
        
        losses.append(loss.item())
        scheduler.step()
        model.eval()
        with torch.no_grad():
            train_scores.append(get_auprc(model, X_train, Y_train))
            val_scores.append(get_auprc(model, X_val, Y_val))

        if epoch >= 3 and np.std(val_scores[-3:]) < 0.01:
            # print(f'stopping at epoch {epoch} due to low variance in validation scores')
            break

    print(f'losses: {losses}')
    print(f'train_scores: {train_scores}')
    print(f'val_scores: {val_scores}')

    return model

In [None]:
model = get_model(train_data, val_data, lr=0.02, n_neighbours=50, layer_size=150, batch_size=300,
                   layer_count=4, lr_schedule_k=0.2) #values from hyperparameter search

losses: [0.2537890672683716, 0.13283410668373108, 0.0869079977273941]
train_scores: [0.7381278863311806, 0.740625348432118, 0.7418195275067734]
val_scores: [0.7567422019333573, 0.7623879586608083, 0.7649024199899588]


In [None]:
scripted_model = torch.jit.script(model)
scripted_model.save('utils/LUNOT_finetuned')