In [None]:
import torch
from torch import nn
import random
import os
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Subset
from MatSciBERT.normalize_text import normalize
from transformers import AutoModel, AutoTokenizer, AutoConfig


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
## 设置随机数种子
setup_seed(42)

config = AutoConfig.from_pretrained('./MatSciBERT')
config.max_position_embeddings = 900
bert_model = AutoModel.from_pretrained('./MatSciBERT', config=config, ignore_mismatched_sizes=True)


class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super(BertClassifier, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 3)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):
        outputs = self.bert(input_ids=input_id, attention_mask=mask,return_dict=True, output_attentions=True)
        pooled_output = outputs.pooler_output
        attentions = outputs.attentions
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)
        return final_layer, attentions


## 数据获取
tokenizer = AutoTokenizer.from_pretrained('./MatSciBERT')
def find_text(composition):
    file_path = os.path.join('../description/', composition + '.txt')
    with open(file_path, 'r') as file:
        text = file.read()
    return text
df = pd.read_csv('../unique_compositions.csv')

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# device = torch.device("cpu")
## 模型读取
from torch.serialization import load
model_path = 'MgBERT.pth'
model_data = torch.load(model_path, map_location=device)
model = BertClassifier()
model.to(device)
model.load_state_dict(model_data)
model.eval()

In [None]:
bmg_text = find_text('Cu55Zr42.5Ga2.5')
ribbon_text = find_text('Ag20Al25La55')
nr_text = find_text('Al40Mn25Si35')

In [None]:
print(bmg_text)

In [None]:
## 分析输入输出关系
from lime import lime_text
import math
def get_prob(input_text):
    with torch.no_grad():
        if isinstance(input_text, str):
            input_text = [input_text]
        len_input = len(input_text)
        batch_size = 20
        batch_num = math.ceil(len_input/batch_size)
        whole_output = []
        for i in range(batch_num):
            segment_start = i * batch_size
            segment_end = (i + 1) * batch_size if (i + 1) * batch_size < len_input else len_input
            segment_input_text = input_text[segment_start:segment_end]
            inputs = tokenizer([normalize(text) for text in segment_input_text],
                                padding='max_length', 
                                max_length=900, 
                                truncation=True,
                                return_tensors="pt").to(device)
            output, attention = model(inputs['input_ids'], inputs['attention_mask'])
            whole_output.append(output)
        op = torch.softmax(torch.cat(whole_output, dim=0), dim=1).cpu().numpy()
        return op

In [None]:
print(get_prob(bmg_text))

In [None]:
print(get_prob(ribbon_text))

In [None]:
print(get_prob(nr_text))

In [None]:
from lime.lime_text import LimeTextExplainer
class_names = ['BMG', 'Ribbon', 'NR']
explainer = LimeTextExplainer(class_names=class_names, random_state=42)

In [None]:
store_list = []
for i in range(3):
    if i == 0 :
        input_text = bmg_text
        exp = explainer.explain_instance(input_text, get_prob, num_features=50, num_samples=1000, labels=[0])
        store_list.append(exp.as_list(label=0))
        print ('Explanation for class %s' % class_names[0])
        print ('\n'.join(map(str, exp.as_list(label=0))))
        print ()
    elif i == 1 :
        input_text = ribbon_text
        exp = explainer.explain_instance(input_text, get_prob, num_features=50, num_samples=1000, labels=[1])
        store_list.append(exp.as_list(label=1))
        print ('Explanation for class %s' % class_names[1])
        print ('\n'.join(map(str, exp.as_list(label=1))))
        print()
    else:
        input_text = nr_text
        exp = explainer.explain_instance(input_text, get_prob, num_features=50, num_samples=1000, labels=[2])
        store_list.append(exp.as_list(label=2))
        print ('Explanation for class %s' % class_names[2])
        print ('\n'.join(map(str, exp.as_list(label=2))))
        print()

import pandas as pd
df = pd.DataFrame(store_list)
df.to_excel("lime_output_list.xlsx")


In [None]:
import pandas as pd
df = pd.read_excel('lime_output_list.xlsx').T

In [None]:
bmg_data = df[0].tolist()
bmg_data = [eval(i) for i in bmg_data]
ribbon_data = df[1].tolist()
ribbon_data = [eval(i) for i in ribbon_data]
nr_data = df[2].tolist()
nr_data = [eval(i) for i in nr_data]

In [None]:
bmg_label = [i[0] for i in bmg_data]
bmg_value = [i[1] for i in bmg_data]
ribbon_label = [i[0] for i in ribbon_data]
ribbon_value = [i[1] for i in ribbon_data]
nr_label = [i[0] for i in nr_data]
nr_value = [i[1] for i in nr_data]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
# Assuming you have a list of 50 labels and 50 float values
labels = bmg_label[1:25]
values = bmg_value[1:25]

fontsize = 18

# Set up the angle for each bar
num_vars = len(labels)
angles = np.linspace(0, 2*np.pi, num_vars, endpoint=False).tolist()
print(angles)
# The plot is made circular by appending the start value to the end.
values = np.concatenate((values,[values[0]]))
angles += angles[:1]

# Create a polar subplot
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))

# bar width
bar_width = 2 * np.pi / num_vars * 0.8

# Draw the bars
ax.bar(angles, values, color='#a3bded', linewidth=2, width=bar_width)

# Set the direction of the zero angle
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
# Set the labels for each bar
ax.set_xticks(angles[:-1])
ax.set_xticklabels(labels, fontsize=fontsize, fontweight='bold')
ax.tick_params(axis='y', labelsize=fontsize, rotation=-18)


# To make the labels readable, set the alignment and rotation
for label, angle in zip(ax.get_xticklabels(), angles):
    if angle in (0, np.pi):
        label.set_horizontalalignment('center')
    elif 0 < angle < np.pi:
        label.set_horizontalalignment('left')
    else:
        label.set_horizontalalignment('right')

# Fine-tune the grid and other elements if needed
ax.xaxis.grid(True, color='grey', linestyle='--', linewidth=1)
ax.yaxis.grid(True, color='grey', linestyle='--', linewidth=1)
# Show the plot
plt.savefig('bmg_lime.svg', dpi=600, bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
# Assuming you have a list of 50 labels and 50 float values
labels = ribbon_label[:25]
values = ribbon_value[:25]

fontsize = 18

# Set up the angle for each bar
num_vars = len(labels)
angles = np.linspace(0, 2*np.pi, num_vars, endpoint=False).tolist()
print(angles)
# The plot is made circular by appending the start value to the end.
values = np.concatenate((values,[values[0]]))
angles += angles[:1]

# Create a polar subplot
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))

# bar width
bar_width = 2 * np.pi / num_vars * 0.8

# Draw the bars
ax.bar(angles, values, color='#fcb69f', linewidth=2, width=bar_width)

# Set the direction of the zero angle
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
# Set the labels for each bar
ax.set_xticks(angles[:-1])
ax.set_xticklabels(labels, fontsize=fontsize, fontweight='bold')
ax.tick_params(axis='y', labelsize=fontsize, rotation=-18)


# To make the labels readable, set the alignment and rotation
for label, angle in zip(ax.get_xticklabels(), angles):
    if angle in (0, np.pi):
        label.set_horizontalalignment('center')
    elif 0 < angle < np.pi:
        label.set_horizontalalignment('left')
    else:
        label.set_horizontalalignment('right')


# Fine-tune the grid and other elements if needed
ax.xaxis.grid(True, color='grey', linestyle='--', linewidth=1)
ax.yaxis.grid(True, color='grey', linestyle='--', linewidth=1)

# Show the plot
plt.savefig('ribbon_lime.svg', dpi=600, bbox_inches='tight')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
# Assuming you have a list of 50 labels and 50 float values
labels = nr_label[:25]
values = nr_value[:25]

fontsize = 18

# Set up the angle for each bar
num_vars = len(labels)
angles = np.linspace(0, 2*np.pi, num_vars, endpoint=False).tolist()
print(angles)
# The plot is made circular by appending the start value to the end.
values = np.concatenate((values,[values[0]]))
angles += angles[:1]

# Create a polar subplot
fig, ax = plt.subplots(figsize=(12, 12), subplot_kw=dict(polar=True))

# bar width
bar_width = 2 * np.pi / num_vars * 0.8

# Draw the bars
ax.bar(angles, values, color='#43e97b', linewidth=2, width=bar_width)

# Set the direction of the zero angle
ax.set_theta_offset(np.pi / 2)
ax.set_theta_direction(-1)
# Set the labels for each bar
ax.set_xticks(angles[:-1])
ax.set_xticklabels(labels, fontsize=fontsize, fontweight='bold')
ax.tick_params(axis='y', labelsize=fontsize, rotation=-18)


# To make the labels readable, set the alignment and rotation
for label, angle in zip(ax.get_xticklabels(), angles):
    if angle in (0, np.pi):
        label.set_horizontalalignment('center')
    elif 0 < angle < np.pi:
        label.set_horizontalalignment('left')
    else:
        label.set_horizontalalignment('right')

# Fine-tune the grid and other elements if needed
ax.xaxis.grid(True, color='grey', linestyle='--', linewidth=1)
ax.yaxis.grid(True, color='grey', linestyle='--', linewidth=1)

# Show the plot
plt.savefig('nr_lime.svg', dpi=600, bbox_inches='tight')
plt.show()

In [None]:
# with torch.no_grad():
#     for test_input, test_label in data_loader:
#         # 如果有GPU，则使用GPU，接下来的操作同训练
#         # test_label = test_label.to(device)
#         # token_type_ids = test_input['token_type_ids'].to(device)
#         # attention_mask = test_input['attention_mask'].to(device)
#         # input_ids = test_input['input_ids'].squeeze(1).to(device)
#         # Cu20Hf65Ni15.txt
#         input_text = "Composition Information: [Cu20Hf65Ni15 consists of 20% Copper, 65% Hafnium, and 15% Nickel]."
#         inputs = tokenizer(normalize(input_text),
#                                 padding='max_length', 
#                                 max_length = 900, 
#                                 truncation=True,
#                                 return_tensors="pt").to(device)
#         output, attention = model(inputs['input_ids'], inputs['attention_mask'])
#         # output, attention = model(input_ids, attention_mask)
#         break

In [None]:
# attention全局可视化
# should combine with model loading
# from bertviz import head_view, model_view
# tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist())
# # tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
# a = model_view(attention, tokens, html_action='return')
# b = head_view(attention, tokens, html_action='return')
# with open('short_model_view.html', 'w') as file:
#     file.write(a.data)

# with open('short_head_view.html', 'w') as file:
#     file.write(b.data)

In [None]:
## last layer attention可视化
inputs = tokenizer(normalize(nr_text),
                                padding='max_length', 
                                max_length = 900, 
                                truncation=True,
                                return_tensors="pt").to(device)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].tolist())
output, attention = model(inputs['input_ids'], inputs['attention_mask'])


In [None]:
print(tokens)
print(len(tokens))
p1 = [i for i, token in enumerate(tokens) if token == 'element']
p2 = [i for i, token in enumerate(tokens) if token == 'alloy']
p3 = [i for i, token in enumerate(tokens) if token == '[SEP]']
print(p1)
print(p2)
print(p3)
print(tokens[31:33])
print(tokens[316:318])

In [None]:
last_layer_attention = attention[-1]  # Extracting the last layer attention score

In [None]:
mean_last_layer_attention = last_layer_attention.mean(dim=1).squeeze(0)
print(mean_last_layer_attention.shape)
print(mean_last_layer_attention)
cls_attention = mean_last_layer_attention[0]
print(cls_attention.shape)
print(cls_attention.sum())


In [None]:
comp_layer = cls_attention[:31].sum()
elem_layer = cls_attention[31:316].sum()
alloy_layer = cls_attention[316:370].sum()
print(comp_layer, elem_layer, alloy_layer)
c_e = comp_layer / (31/370)
e_e = elem_layer / ((316-31)/370)
a_e = alloy_layer / ((370-316)/370)
print(c_e, e_e, a_e)
ce_percentage = c_e / (c_e + e_e + a_e)
ee_percentage = e_e / (c_e + e_e + a_e)
ae_percentage = a_e / (c_e + e_e + a_e)
print(ce_percentage, ee_percentage, ae_percentage)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
df = pd.read_excel('confusion_matrix——5577.xlsx', usecols=lambda x: x != 'Unnamed: 0')
row_names = ['r_BMG', 'r_Ribbon', 'r_NR']
col_names = ['p_BMG', 'p_Ribbon', 'p_NR']

import seaborn as sns
plt.figure(figsize=(8, 6))
heatmap = sns.heatmap(df, annot=True, cmap='Blues', fmt='g', xticklabels=col_names, yticklabels=row_names,annot_kws={"fontsize":14})
cbar = heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=12)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.xlabel('Predicted label', fontsize=12)
plt.ylabel('True label', fontsize=12)
plt.title('Confusion Matrix', fontsize=12)
plt.savefig('confusion_matrix_5577.svg', dpi=600)
plt.show()
