In [1]:
import sys

if "google.colab" in sys.modules:
    # Google Colab
    from google.colab import drive
    drive.mount('/content/drive')
    Dataset = "/content/drive/My Drive/Fred"

    !pip install torch_geometric ase mp-api
    import torch
    from torch_geometric.data import Data, DataLoader


else:
    Dataset = "Dataset"

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from pymatgen.core import Element
from pymatgen.analysis.local_env import MinimumDistanceNN, VoronoiNN
mdnn = MinimumDistanceNN()
voronoi_nn = VoronoiNN()
from pymatgen.analysis.molecule_structure_comparator import CovalentRadius

from mp_api.client import MPRester
API_KEY = "z5IlnPVPKLm4IsLW9fLtaLcI53PsXejO"

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()

In [None]:
# Read the data
data_df = pd.read_pickle(f"{Dataset}/absorption_mp_data.pkl")

# Get material ids
df_ids = []
for i in range(len(data_df)):
  df_ids.append(data_df["id"][i])

# Get pymatgen structure
with MPRester(API_KEY) as mpr:
  df_struct = mpr.materials.summary.search(
      material_ids = df_ids,
      fields = ["material_id","structure"])

# Add pymaten structure to dataframe
structure_dict = {item.material_id: item.structure for item in df_struct}
data_df["pymatgen_structure"] = data_df["id"].apply(lambda x: structure_dict.get(x, None))
data_df.head()

## Preprocessing data

In [None]:
new_structure = data_df["pymatgen_structure"][12]
print(new_structure)

In [None]:
a_structure = data_df["structure"][0]
print(a_structure)

In [None]:
dip_rad = {'H': [32.0, 4.50711], 'He': [46.0, 1.38375], 'Li': [133.0, 164.1125],
           'Be': [102.0, 37.74], 'B': [85.0, 20.5], 'C': [75.0, 11.3],
           'N': [71.0, 7.4], 'O': [63.0, 5.3], 'F': [64.0, 3.74],
           'Ne': [67.0, 2.6611], 'Na': [155.0, 162.7], 'Mg': [139.0, 71.2],
           'Al': [126.0, 57.8], 'Si': [115.99999999999999, 37.3],
           'P': [111.00000000000001, 25.0], 'S': [103.0, 19.4], 'Cl': [99.0, 14.6],
           'Ar': [96.0, 11.083], 'K': [196.0, 289.7], 'Ca': [171.0, 160.8],
           'Sc': [148.0, 97.0], 'Ti': [136.0, 100.0], 'V': [134.0, 87.0],
           'Cr': [122.0, 83.0], 'Mn': [119.0, 68.0], 'Fe': [115.99999999999999, 62.0],
           'Co': [111.00000000000001, 55.0], 'Ni': [110.00000000000001, 49.0],
           'Cu': [112.00000000000001, 46.5], 'Zn': [118.0, 38.67], 'Ga': [124.0, 50.0],
           'Ge': [121.0, 40.0], 'As': [121.0, 30.0], 'Se': [115.99999999999999, 28.9],
           'Br': [113.99999999999999, 21.0], 'Kr': [117.0, 16.78], 'Rb': [210.0, 319.8],
           'Sr': [185.0, 197.2], 'Y': [163.0, 162.0], 'Zr': [154.0, 112.0],
           'Nb': [147.0, 98.0], 'Mo': [138.0, 87.0], 'Tc': [128.0, 79.0],
           'Ru': [125.0, 72.0], 'Rh': [125.0, 66.0], 'Pd': [120.0, 26.14],
           'Ag': [128.0, 55.0], 'Cd': [136.0, 46.0], 'In': [142.0, 65.0],
           'Sn': [140.0, 53.0], 'Sb': [140.0, 43.0], 'Te': [136.0, 38.0],
           'I': [133.0, 32.9], 'Xe': [131.0, 27.32], 'Cs': [231.99999999999997, 400.9],
           'Ba': [196.0, 272.0], 'La': [180.0, 215.0], 'Ce': [163.0, 205.0],
           'Pr': [176.0, 216.0], 'Nd': [174.0, 208.0], 'Pm': [173.0, 200.0],
           'Sm': [172.0, 192.0], 'Eu': [168.0, 184.0], 'Gd': [169.0, 158.0],
           'Tb': [168.0, 170.0], 'Dy': [167.0, 163.0], 'Ho': [166.0, 156.0],
           'Er': [165.0, 150.0], 'Tm': [164.0, 144.0], 'Yb': [170.0, 139.0],
           'Lu': [162.0, 137.0], 'Hf': [152.0, 103.0], 'Ta': [146.0, 74.0],
           'W': [137.0, 68.0], 'Re': [131.0, 62.0], 'Os': [129.0, 57.0],
           'Ir': [122.0, 54.0], 'Pt': [123.0, 48.0], 'Au': [124.0, 36.0],
           'Hg': [133.0, 33.91], 'Tl': [144.0, 50.0], 'Pb': [144.0, 47.0],
           'Bi': [151.0, 48.0], 'Po': [145.0, 44.0], 'At': [147.0, 42.0],
           'Rn': [142.0, 35.0], 'Fr': [223.0, 317.8], 'Ra': [200.99999999999997, 246.0],
           'Ac': [186.0, 203.0], 'Th': [175.0, 217.0], 'Pa': [169.0, 154.0],
           'U': [170.0, 129.0], 'Np': [171.0, 151.0], 'Pu': [172.0, 132.0],
           'Am': [166.0, 131.0], 'Cm': [166.0, 144.0], 'Bk': [168.0, 125.0],
           'Cf': [168.0, 122.0], 'Es': [165.0, 118.0], 'Fm': [167.0, 113.0],
           'Md': [173.0, 109.0], 'No': [176.0, 110.0], 'Lr': [161.0, 320.0],
           'Rf': [157.0, 112.0], 'Db': [149.0, 42.0], 'Sg': [143.0, 40.0],
           'Bh': [141.0, 38.0], 'Hs': [134.0, 36.0], 'Mt': [129.0, 34.0],
           'Ds': [128.0, 32.0], 'Rg': [121.0, 32.0], 'Cn': [122.0, 28.0],
           'Nh': [136.0, 29.0], 'Fl': [143.0, 31.0], 'Mc': [162.0, 71.0],
           'Lv': [175.0, 0.0], 'Ts': [165.0, 76.0], 'Og': [157.0, 58.0]}

### Nodes

In [None]:
# From pymatgen structure

the_nodes = []
atomic_numbers = new_structure.atomic_numbers
carts  = new_structure.cart_coords
fracs = new_structure.frac_coords
for i in range(new_structure.num_sites):
    # Nodes
    symb = new_structure[i].specie.symbol
    a_node = [atomic_numbers[i], Element(symb).atomic_mass.real,
           CovalentRadius.radius[symb], dip_rad[symb][0], dip_rad[symb][1],
           Element(symb).atomic_radius.real, Element(symb).van_der_waals_radius.real,
           Element(symb).molar_volume.real,
           carts[i][0], carts[i][1], carts[i][2],]
    the_nodes.append(a_node)
print(the_nodes)

In [None]:
# From ase structure
num_atoms = a_structure.get_global_number_of_atoms()
atomic_numbers = a_structure.get_atomic_numbers()
positions = a_structure.get_positions()
atomic_masses = a_structure.get_masses()


for i in range(num_atoms):
    print([atomic_numbers[i], atomic_masses[i], positions[i][0], positions[i][1], positions[i][2]])

### Edges and edge attributes from pymatgen structure

In [None]:
# Get all nearest neighbour info
all_info = mdnn.get_all_nn_info(new_structure)
nums = len(all_info)

the_lattice = new_structure.lattice.matrix
pos = new_structure.cart_coords

# Convert structure to list of sites
sites_list = new_structure.sites

# Edges
from_atoms = []
to_atoms = []

# Edge attributes
attr = []


for site in range(nums):
    for neighbor in range(len(all_info[site])):
        # Edges
        from_atoms.append(site)
        to_index = all_info[site][neighbor]["site_index"]
        to_atoms.append(to_index)

        # shift vector
        shift = np.dot(all_info[site][neighbor]["image"], the_lattice)
        direct_vec = pos[to_index] - pos[site]
        shift_vector = direct_vec + shift

        # Edge attributes
        attr.append([all_info[site][neighbor]["weight"],
                    sites_list[site].distance(all_info[site][neighbor]["site"]),
                    shift[0], shift[1], shift[2],
                    shift_vector[0], shift_vector[1], shift_vector[2]
                    ])


print(from_atoms)
print(to_atoms)
print(attr)

In [None]:
total_sites = new_structure.num_sites
for i in range(total_sites):
    # Get every image of a neighbouring site
    the_image = mdnn.get_nn_images(new_structure, i)
    print(the_image)

In [None]:
# From ase structure
num_atoms = a_structure.get_global_number_of_atoms()
all_distances = a_structure.get_all_distances()

# Arrange the distances to edge features
edge_features = []
for i in range(num_atoms):
    for j in range(num_atoms):
        edge_features.append(all_distances[i][j])

### Global attributes

In [None]:
global_attr = [new_structure.volume,
            new_structure.lattice.abc[0],
            new_structure.lattice.abc[1],
            new_structure.lattice.abc[2],
            new_structure.lattice.angles[0],
            new_structure.lattice.angles[1],
            new_structure.lattice.angles[2],
            new_structure.density.real,
            new_structure.num_sites,
            new_structure.get_space_group_info()[1],
            new_structure.n_elems
]

if new_structure.is_3d_periodic == True:
    global_attr.append(1)
else:
    global_attr.append(0)

if new_structure.is_ordered == True:
    global_attr.append(1)
else:
    global_attr.append(0)

print(global_attr)

In [None]:
# From ase structure

print([a_structure.get_volume()])

### Target values

In [None]:
# Get the target values
targ = data_df["absorption_coefficient"][100]

# plot
plt.plot(targ, '-')
plt.ylabel("Absorption coefficient")
plt.show()

## Normalizing Target Value

The target values were so dispersed that the model could not identify any patterns. To solve this, you **normalize** the target values. Once the model has learned the patterns and made predictions, you **denormalize** the predicted values to get the actual predicted values

We normalize by finding: $\log(1 + x)$. We use this ethod because it handles $\log(0)$ safely.
Once you find the log of every value, you fit them between values -1 and 1 using `StandardScaler.fit_transform(<log results>)`

In [None]:
def log_1p_norm(targ):
    y_log = np.log1p(new_targ)
    y_scaled = scaler.fit_transform(y_log.reshape(-1,1)).flatten()
    return y_scaled

# Normalize the target values
y_scaled = log_1p_norm(targ)

# Plot the normalized target values.
plt.plot(y_scaled, '-')
plt.ylabel("Absorption coefficient")
plt.show()

## Denormalizing

We then denormalize by doing the reverse of the normalization process i.e. Finding the inverse transform of the predicted values using `SatandardScaler.inverse_transform(<predicted values>)`
Then use `np.expem1(<array_after_inverse_transform>)` to calculate ``exp(x) - 1`` for all elements in the array.

In [None]:
# Denormalize the normalized target values
def log_1p_denorm(the_scaled):
    y_pred_log = scaler.inverse_transform(the_scaled.reshape(-1, 1)).flatten()
    y_pred_original = np.expm1(y_pred_log)  # inverse of log1p
    return y_pred_original

y_pred_original = log_1p_denorm(y_scaled)

plt.plot(y_pred_original, '-')
plt.ylabel("Absorption coefficient")
plt.show()


## Now as a function that will iterate through the whole dataset

In [None]:
# These values are obtain through mendeleev
# {Element: [dipole polarizability, radius]}

dip_rad = {'H': [32.0, 4.50711], 'He': [46.0, 1.38375], 'Li': [133.0, 164.1125],
           'Be': [102.0, 37.74], 'B': [85.0, 20.5], 'C': [75.0, 11.3],
           'N': [71.0, 7.4], 'O': [63.0, 5.3], 'F': [64.0, 3.74],
           'Ne': [67.0, 2.6611], 'Na': [155.0, 162.7], 'Mg': [139.0, 71.2],
           'Al': [126.0, 57.8], 'Si': [115.99999999999999, 37.3],
           'P': [111.00000000000001, 25.0], 'S': [103.0, 19.4], 'Cl': [99.0, 14.6],
           'Ar': [96.0, 11.083], 'K': [196.0, 289.7], 'Ca': [171.0, 160.8],
           'Sc': [148.0, 97.0], 'Ti': [136.0, 100.0], 'V': [134.0, 87.0],
           'Cr': [122.0, 83.0], 'Mn': [119.0, 68.0], 'Fe': [115.99999999999999, 62.0],
           'Co': [111.00000000000001, 55.0], 'Ni': [110.00000000000001, 49.0],
           'Cu': [112.00000000000001, 46.5], 'Zn': [118.0, 38.67], 'Ga': [124.0, 50.0],
           'Ge': [121.0, 40.0], 'As': [121.0, 30.0], 'Se': [115.99999999999999, 28.9],
           'Br': [113.99999999999999, 21.0], 'Kr': [117.0, 16.78], 'Rb': [210.0, 319.8],
           'Sr': [185.0, 197.2], 'Y': [163.0, 162.0], 'Zr': [154.0, 112.0],
           'Nb': [147.0, 98.0], 'Mo': [138.0, 87.0], 'Tc': [128.0, 79.0],
           'Ru': [125.0, 72.0], 'Rh': [125.0, 66.0], 'Pd': [120.0, 26.14],
           'Ag': [128.0, 55.0], 'Cd': [136.0, 46.0], 'In': [142.0, 65.0],
           'Sn': [140.0, 53.0], 'Sb': [140.0, 43.0], 'Te': [136.0, 38.0],
           'I': [133.0, 32.9], 'Xe': [131.0, 27.32], 'Cs': [231.99999999999997, 400.9],
           'Ba': [196.0, 272.0], 'La': [180.0, 215.0], 'Ce': [163.0, 205.0],
           'Pr': [176.0, 216.0], 'Nd': [174.0, 208.0], 'Pm': [173.0, 200.0],
           'Sm': [172.0, 192.0], 'Eu': [168.0, 184.0], 'Gd': [169.0, 158.0],
           'Tb': [168.0, 170.0], 'Dy': [167.0, 163.0], 'Ho': [166.0, 156.0],
           'Er': [165.0, 150.0], 'Tm': [164.0, 144.0], 'Yb': [170.0, 139.0],
           'Lu': [162.0, 137.0], 'Hf': [152.0, 103.0], 'Ta': [146.0, 74.0],
           'W': [137.0, 68.0], 'Re': [131.0, 62.0], 'Os': [129.0, 57.0],
           'Ir': [122.0, 54.0], 'Pt': [123.0, 48.0], 'Au': [124.0, 36.0],
           'Hg': [133.0, 33.91], 'Tl': [144.0, 50.0], 'Pb': [144.0, 47.0],
           'Bi': [151.0, 48.0], 'Po': [145.0, 44.0], 'At': [147.0, 42.0],
           'Rn': [142.0, 35.0], 'Fr': [223.0, 317.8], 'Ra': [200.99999999999997, 246.0],
           'Ac': [186.0, 203.0], 'Th': [175.0, 217.0], 'Pa': [169.0, 154.0],
           'U': [170.0, 129.0], 'Np': [171.0, 151.0], 'Pu': [172.0, 132.0],
           'Am': [166.0, 131.0], 'Cm': [166.0, 144.0], 'Bk': [168.0, 125.0],
           'Cf': [168.0, 122.0], 'Es': [165.0, 118.0], 'Fm': [167.0, 113.0],
           'Md': [173.0, 109.0], 'No': [176.0, 110.0], 'Lr': [161.0, 320.0],
           'Rf': [157.0, 112.0], 'Db': [149.0, 42.0], 'Sg': [143.0, 40.0],
           'Bh': [141.0, 38.0], 'Hs': [134.0, 36.0], 'Mt': [129.0, 34.0],
           'Ds': [128.0, 32.0], 'Rg': [121.0, 32.0], 'Cn': [122.0, 28.0],
           'Nh': [136.0, 29.0], 'Fl': [143.0, 31.0], 'Mc': [162.0, 71.0],
           'Lv': [175.0, 0.0], 'Ts': [165.0, 76.0], 'Og': [157.0, 58.0]}

In [None]:
# Combining
def get_graph(row):

    # Basic things
    py_struct = row["pymatgen_structure"]

    num_atoms = py_struct.num_sites
    the_lattice = py_struct.lattice.matrix
    sites_list = py_struct.sites

    # Nodes
    all_nodes = []

    # Nodes from pymatgen structure
    atomic_nums = py_struct.atomic_numbers
    carts  = py_struct.cart_coords
    fracs = py_struct.frac_coords

    # Edges and edge attributes
    all_info = mdnn.get_all_nn_info(py_struct)
    # all_info = voronoi_nn.get_all_nn_info(py_struct)

    # Edges
    from_atoms = []
    to_atoms = []

    # Edge attributes
    edge_attrs = []

    for i in range(num_atoms):
        symb = py_struct[i].specie.symbol
        the_nodes = [atomic_nums[i],
                     Element(symb).atomic_mass.real,
                     # CovalentRadius.radius[symb],
                     dip_rad[symb][0],
                     dip_rad[symb][1],
                     # Element(symb).van_der_waals_radius.real,
                     # Element(symb).molar_volume.real,
                     carts[i][0], carts[i][1], carts[i][2],
                     # fracs[i][0], fracs[i][1], fracs[i][2]
                     ]
        all_nodes.append(the_nodes)

        for neighbor in range(len(all_info[i])):
            # Edges
            from_atoms.append(i)
            to_index = all_info[i][neighbor]["site_index"]
            to_atoms.append(to_index)

            # Shift vector
            shift = np.dot(all_info[i][neighbor]["image"], the_lattice)
            direct_vec = carts[to_index] - carts[i]
            shift_vector = direct_vec + shift

            # Edge attributes
            edge_attrs.append([all_info[i][neighbor]["weight"],
                            sites_list[i].distance(all_info[i][neighbor]["site"]),
                            shift[0], shift[1], shift[2],
                            shift_vector[0], shift_vector[1], shift_vector[2]
                ])


    global_attr = [py_struct.volume,
                py_struct.density.real,
                py_struct.num_sites,
                py_struct.get_space_group_info()[1],
                py_struct.n_elems
    ]

    if py_struct.is_3d_periodic == True:
        global_attr.append(1)
    else:
        global_attr.append(0)

    if py_struct.is_ordered == True:
        global_attr.append(1)
    else:
        global_attr.append(0)

    # global_attr.append(row['bandgap'])

    edges = [from_atoms, to_atoms]

    all_nodes = scaler.fit_transform(all_nodes)
    edge_attrs = scaler.fit_transform(edge_attrs)
    global_attr = scaler.fit_transform([global_attr])

    targ = row["absorption_coefficient"]

    normalized = log_1p_norm(targ)

    graph_data = Data(x=torch.tensor(all_nodes, dtype=torch.float),
                    edge_index=torch.tensor(edges, dtype=torch.long),
                    edge_attr=torch.tensor(edge_attrs, dtype=torch.float),
                    u=torch.tensor(global_attr[0], dtype=torch.float).unsqueeze(0),
                    y=torch.tensor(normalized).unsqueeze(0))

    return graph_data

the_graphs = data_df.apply(lambda row: get_graph(row), axis=1).to_list()
the_graphs[0]

## Splitting the data

In [None]:
# Split the data
from sklearn.model_selection import train_test_split
train_graphs, test_graphs = train_test_split(the_graphs, test_size=0.1, random_state=42) # For original
train_graphs, val_graphs = train_test_split(train_graphs, test_size=0.1, random_state=42)


## Building Dataloaders

In [None]:
# Create data loaders
train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=32, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

In [None]:
for data in train_loader:
    print(data)
    break

## Libraries to create GNN model

In [None]:
# Import dependencies for the model
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import NNConv, global_mean_pool

## Model architecture

In [None]:
# The model
class GNNabsorption(nn.Module):
    def __init__(self, node_dim, edge_dim, global_dim, hidden_dim):
        super().__init__()
        # Edge NN maps edge_attr to weight matrix
        self.edge_nn = nn.Sequential(
            nn.Linear(edge_dim, 64),
            nn.ReLU(),
            nn.Linear(64, node_dim * hidden_dim)
        )

        self.conv1 = NNConv(node_dim, hidden_dim, self.edge_nn, aggr='mean')
        self.fc1 = nn.Linear(hidden_dim + global_dim, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 2001)

    def forward(self, x, edge_index, edge_attr,batch, global_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = global_mean_pool(x, batch)
        x = torch.cat([x, global_attr], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# run the model in the gpu if the device has one
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instance of the model
model = GNNabsorption(node_dim=7, edge_dim=8, global_dim=7, hidden_dim=256).to(device) # Without global features

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Loss function
loss_fn = nn.L1Loss()

## Training, Evaluation and Training Loss

In [None]:
def train():
    model.train()
    total_loss = 0
    for data in train_loader: # For each data point in the train loader
        data = data.to(device) # Add data to gpu
        optimizer.zero_grad()

        out = model(data.x, data.edge_index, data.edge_attr, data.batch, data.u)# For original
        loss = loss_fn(out, data.y) # Compare predicted value with actual value

        # back propogation
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def evaluate(loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.edge_attr, data.batch, data.u) # For original
            loss = loss_fn(out, data.y)
            total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

# Run training and testing
train_losses = []
val_losses = []
for epoch in range(1, 20):
    train_loss = train()
    val_loss = evaluate(val_loader)
    print(f'Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
    train_losses.append(train_loss)
    val_losses.append(val_loss)

test_loss = evaluate(test_loader)
print(f'Test Loss: {test_loss:.4f}')

## Visualizing Train-Valid Loss

In [None]:
import matplotlib.pyplot as plt
# Plotting the training and validation loss
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

## Model deployment

In [None]:
# Let's get the predicted values from the model
actual_values = []
predicted_values = []

with torch.no_grad():
    for data in test_loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch, data.u)
        actual_values.append(data.y)
        predicted_values.append(out)

In [None]:
# Focus on one material predicted
attempt_act = actual_values[0][10]
attempt_pred = predicted_values[0][10]

In [None]:
# Did the model learn?
from scipy.stats import pearsonr
corr, _ = pearsonr(attempt_act, attempt_pred)
print(f"Correlation: {corr}")

# If corr is close to 0, the model did not learn
# If corr is close to 1, the model learned

In [None]:
# Plot for actual normalized values
plt.subplot(1,2,1)
plt.plot(attempt_act, label="actual")
plt.title("Actual normalized values")

# Plot for predicted normalized values
plt.subplot(1,2,2)
plt.plot(attempt_pred, label="predicted")
plt.title("Predicted normalized values")

plt.show()

In [None]:
# Denormalize the values
new_attempt_act = log_1p_denorm(attempt_act)
new_attempt_pred = log_1p_denorm(attempt_pred)

In [None]:
# plot for actual denormalized values
plt.subplot(1,2,1)
plt.plot(new_attempt_act, label="actual")
plt.title("Actual denormalized values")

# plot for predicted denormalized values
plt.subplot(1,2,2)
plt.plot(new_attempt_pred, label="predicted")
plt.title("predicted denormalized values")

plt.show()

Next step would be to try to refine the model to increase its accuracy score.

This would be by choosing the best features, adding, or removing features.