In [9]:
import torch
from petsc4py import PETSc

def torch_coo_to_petsc(tensor: torch.Tensor) -> PETSc.Mat:
    """
    PyTorchのCOO形式の疎行列をPETScの疎行列に変換する関数。

    :param tensor: PyTorchのCOO形式の疎行列（torch.sparse_coo_tensor）
    :return: PETScの疎行列（PETSc.Mat）
    """
    tensor = tensor.coalesce()
    indices = tensor.indices().numpy().astype("int32")
    values = tensor.values().numpy()
    shape = tensor.shape

    A = PETSc.Mat().createAIJ(size=shape)
    A.setValues(indices[0], indices[1], values)
    A.assemble()

    return A

# 使用例
if __name__ == "__main__":
    indices = torch.tensor([[0, 1, 1], [2, 0, 2]])
    values = torch.tensor([3.0, 4.0, 5.0])
    size = [2, 3]

    coo_tensor = torch.sparse_coo_tensor(indices, values, size)
    petsc_mat = torch_coo_to_petsc(coo_tensor)
    petsc_mat.view()

ValueError: incompatible array sizes: ni=3, nj=3, nv=3

In [10]:
import torch
import petsc4py
import numpy as np

# PETScの初期化
petsc4py.init()
from petsc4py import PETSc

# PyTorch Tensorを作成
tensor = torch.rand(5, 5)  # 5x5のランダムなTensor

# TensorをNumPy配列に変換
numpy_array = tensor.numpy()

# PETSc Matを作成
A = PETSc.Mat().create()
A.setSizes(numpy_array.shape)
A.setType('dense')  # デフォルトの形式（密行列）
A.setUp()

# NumPy配列をPETSc Matにコピー
A[:, :] = numpy_array

# PETSc Matを使用
A.assemble()

# 確認のために出力
print(A)

<petsc4py.PETSc.Mat object at 0x7f04ec405120>


In [11]:
A.view()

Mat Object: 1 MPI processes
  type: seqdense
6.7152512073516846e-01 2.1587842702865601e-01 3.9931350946426392e-01 7.5507098436355591e-01 4.4482547044754028e-01 
8.6070072650909424e-01 1.9142699241638184e-01 7.2929805517196655e-01 5.2156823873519897e-01 2.1504819393157959e-02 
6.4386016130447388e-01 5.8503919839859009e-01 7.6476496458053589e-01 9.4866359233856201e-01 1.4169079065322876e-01 
3.1605482101440430e-02 9.5142477750778198e-01 9.3237513303756714e-01 2.1590763330459595e-01 2.7173995971679688e-01 
9.1019421815872192e-01 2.9157775640487671e-01 7.4407005310058594e-01 6.0704129934310913e-01 7.2583508491516113e-01 


In [12]:
torch.Tensor(1.0)


TypeError: new(): data must be a sequence (got float)