In [None]:
镜像：flowmm:0612

机型：c4_m15_1 * NVIDIA T4

本文介绍使用baseline模型进行比赛的提交流程. Baseline模型是DiffCSP模型，该模型没有使用PXRD信息，只是使用晶胞内原子成分预测晶体结构，模型技术细节参考论文：https://arxiv.org/abs/2309.04475

模型代码参见https://github.com/jiaor17/DiffCSP

这里的baseline 模型直接使用了论文作者提供的mpts_52_csp checkpoint.

比赛的A榜数据在数据集PXRD-test-expo(v2)中

参赛注意事项
后台打分时使用的环境
网络：后台打分时不会联网的，因此，notebook中关于下载软件/包的命令都无法执行。请先创建好你需要的环境。
环境： 选手提交notebook的时候，右上角可以选择镜像。可以选择baseline notebook 的镜像，也可以自己定义镜像
数据: 关于AB榜的数据（晶胞内原子成分，和谱）不需要额外挂载数据集，按照baseline notebook的示例即可。如果你的notebook会用到额外的数据集，也可以自己定义数据集挂载上来。

In [None]:
# 定义路径
import os
MODEL_PATH = "/bohr/test-3h1m/v5/DiffCSP"
if os.environ.get('DATA_PATH'):
    DATA_PATH = os.environ.get("DATA_PATH")
else:
    print("Baseline运行时，因为无法读取测试集，所以会有此条报错，属于正常现象")  #Baseline运行时，因为无法读取测试集，所以会有此条报错，属于正常现象
    print("When the baseline is running, this error message will appear because the test set cannot be read, which is a normal phenomenon.") #When the baseline is running, this error message will appear because the test set cannot be read, which is a normal phenomenon.
# 检查关键路径
print("MODEL_PATH:", MODEL_PATH)

In [None]:
import os
os.makedirs("/baseline", exist_ok=True)

****使用baseline模型进行预测. 

In [None]:
import os
import sys
project_root = MODEL_PATH
script_name = "scripts/generate_cif_from_model.py"
ckpt_path = f"{MODEL_PATH}/mpts_52_ckpt"
command = f"""
cd {project_root} && \
PYTHONPATH={project_root} python scripts/evaluate.py \
--model_path {ckpt_path} \
--dataset mpts_52 \
--output_dir /baseline \
--composition_file_path {DATA_PATH}/composition.json
"""
!{command}


上面的单元格运行结束之后，在/baseline路径中会产生一个eval_diff.pt 文件。这个文件存储了解析之后的2000个晶体结构. 

接下来演示如何准备submission.csv文件

In [None]:
import torch
from pymatgen.core import Structure, Lattice, Element
from tqdm import tqdm
import pandas as pd

data = torch.load("/baseline/eval_diff.pt", map_location='cpu')
frac_coords = data['frac_coords']  # Tensor形状[1, 27820, 3]
num_atoms = data['num_atoms']      # Tensor形状[1, 2000]
atom_types = data['atom_types']    # Tensor形状[1, 27820]
lattices = data['lattices']        # Tensor形状[1, 2000, 3, 3]

# 假设输入数据为以下变量
# frac_coords: Tensor形状[1, 27820, 3]
# num_atoms: Tensor形状[1, 2000]
# atom_types: Tensor形状[1, 27820]
# lattices: Tensor形状[1, 2000, 3, 3]

# 去除批次维度
frac_coords = frac_coords.squeeze(0)  # [27820, 3]
num_atoms = num_atoms.squeeze(0)      # [2000]
atom_types = atom_types.squeeze(0)     # [27820]
lattices = lattices.squeeze(0)        # [2000, 3, 3]

# 确保张量在CPU上以便转换为numpy
frac_coords = frac_coords.cpu()
num_atoms = num_atoms.cpu().long()    # 转换为整型
atom_types = atom_types.cpu()
lattices = lattices.cpu()

# 验证原子总数是否匹配
total_atoms = num_atoms.sum().item()
assert total_atoms == frac_coords.shape[0], "原子总数不匹配"
assert len(atom_types) == total_atoms, "原子类型数量不匹配"

# 拆分每个晶体的数据
structures = []
current_idx = 0
for i in range(num_atoms.shape[0]):
    n_i = num_atoms[i].item()
    # 提取第i个晶体的数据
    coords = frac_coords[current_idx:current_idx + n_i].numpy()
    types = atom_types[current_idx:current_idx + n_i].numpy().astype(int)
    lattice_matrix = lattices[i].numpy()
    
    # 转换为元素对象列表
    species = [Element.from_Z(z) for z in types]
    
    # 创建Lattice和Structure
    lattice = Lattice(lattice_matrix)
    structure = Structure(lattice, species, coords)
    structures.append(structure)
    
    current_idx += n_i

# 最终的结构列表包含2000个Structure对象
print(f"成功生成{len(structures)}个晶体结构")


# 生成submission文件
header = ["ID", "cif"]
rows = []

"""
!!!!!!!重要!!!!!!!!!!:
比赛区分A,B榜，因此结构的ID的前缀不同，notebook 需要可以判断出结构ID前缀是A还是B, 否则会导致没有分数!!!
"""
import json
with open(f"{DATA_PATH}/composition.json", 'r') as f:
    composition_dict = json.load(f)
prefix = next(iter(composition_dict))[0]
print(prefix)

# store cif files and create a list of dictionary for spgroup number and crystal system
for i in tqdm(range(len(structures)), desc="Generating CIF files"):
    # check if id is a str or numpy.int
    ID = f"{prefix}-{i+1}"
    cif = structures[i].to(fmt="cif")
    rows.append([ID, cif])
# save header and rows to a csv file
df = pd.DataFrame(rows, columns=header)

"""
!!!!!!!!!重要!!!!!!!
产生的submission.csv 必须在. 目录下，否则没有得分
"""
df.to_csv(f"submission.csv", index=False)