<a href="https://colab.research.google.com/github/Takumi-sekigen/gnn_share/blob/main/dglpytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip uninstall torch torchvision torchaudio -y

Found existing installation: torch 2.1.0+cu121
Uninstalling torch-2.1.0+cu121:
  Successfully uninstalled torch-2.1.0+cu121
Found existing installation: torchvision 0.16.0+cu121
Uninstalling torchvision-0.16.0+cu121:
  Successfully uninstalled torchvision-0.16.0+cu121
Found existing installation: torchaudio 2.1.0+cu121
Uninstalling torchaudio-2.1.0+cu121:
  Successfully uninstalled torchaudio-2.1.0+cu121


In [2]:
!pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch==2.1.0
  Using cached https://download.pytorch.org/whl/cu121/torch-2.1.0%2Bcu121-cp310-cp310-linux_x86_64.whl (2200.6 MB)
Collecting torchvision==0.16.0
  Using cached https://download.pytorch.org/whl/cu121/torchvision-0.16.0%2Bcu121-cp310-cp310-linux_x86_64.whl (7.0 MB)
Collecting torchaudio==2.1.0
  Using cached https://download.pytorch.org/whl/cu121/torchaudio-2.1.0%2Bcu121-cp310-cp310-linux_x86_64.whl (3.3 MB)
Installing collected packages: torch, torchvision, torchaudio
Successfully installed torch-2.1.0+cu121 torchaudio-2.1.0+cu121 torchvision-0.16.0+cu121


In [3]:
!pip install  dgl -f https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html

Looking in links: https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html


In [4]:
import torch
import torchvision
import torchaudio

print("PyTorch version:", torch.__version__)         # PyTorch バージョン
print("Torchvision version:", torchvision.__version__)  # torchvision バージョン
print("Torchaudio version:", torchaudio.__version__)   # torchaudio バージョン
print("CUDA version:", torch.version.cuda)           # CUDA バージョン
print("CUDA available:", torch.cuda.is_available())  # CUDA 利用可能性

PyTorch version: 2.1.0+cu121
Torchvision version: 0.16.0+cu121
Torchaudio version: 2.1.0+cu121
CUDA version: 12.1
CUDA available: False


In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [31]:
import pandas as pd
import numpy as np
from scipy.spatial import Delaunay, Voronoi
import matplotlib.pyplot as plt
import skimage.io as io
import torch
import torch.nn as nn
import torch.optim as optim
import dgl
from dgl.nn import GraphConv

#===========================================================
# 1. データ読み込み
#===========================================================
spots_file = "/content/drive/MyDrive/Colab Notebooks/data_demo/psRatio_demo512_spots.csv"
edges_file = "/content/drive/MyDrive/Colab Notebooks/data_demo/psRatio_demo512_edges.csv"
image_file = "/content/drive/MyDrive/Colab Notebooks/data_demo/psRatio_demo512_3frame.tif"

spots_df = pd.read_csv(spots_file)
edges_df = pd.read_csv(edges_file)

#===========================================================
# 2. 必要なカラムのみ抽出
#===========================================================
spots_df = spots_df[["ID", "POSITION_X", "POSITION_Y", "FRAME", "MAX_INTENSITY_CH1"]]
edges_df = edges_df[["SPOT_SOURCE_ID", "SPOT_TARGET_ID"]]

# 最初の3行スキップ(ヘッダ等無意味行)
spots_df = spots_df.iloc[3:].reset_index(drop=True)
edges_df = edges_df.iloc[3:].reset_index(drop=True)

print("=== 読み込み直後のspotsデータ ===")
print(spots_df.head())
print("=== 読み込み直後のedgesデータ ===")
print(edges_df.head())

#===========================================================
# 3. 型変換
#===========================================================
spots_df["ID"] = spots_df["ID"].astype(int)
spots_df["FRAME"] = spots_df["FRAME"].astype(int)
spots_df["POSITION_X"] = spots_df["POSITION_X"].astype(float)
spots_df["POSITION_Y"] = spots_df["POSITION_Y"].astype(float)
spots_df["MAX_INTENSITY_CH1"] = spots_df["MAX_INTENSITY_CH1"].astype(float)

edges_df["SPOT_SOURCE_ID"] = edges_df["SPOT_SOURCE_ID"].astype(int)
edges_df["SPOT_TARGET_ID"] = edges_df["SPOT_TARGET_ID"].astype(int)

#===========================================================
# 4. new_ID割り当て
#===========================================================
# FRAMEでソートが必要なら行う(ここでは省略)
# spots_df = spots_df.sort_values(by="FRAME").reset_index(drop=True)

spots_df["new_ID"] = spots_df.index
id_map = dict(zip(spots_df["ID"], spots_df["new_ID"]))

# edges_dfをid_mapで変換
edges_df["SPOT_SOURCE_ID"] = edges_df["SPOT_SOURCE_ID"].map(id_map)
edges_df["SPOT_TARGET_ID"] = edges_df["SPOT_TARGET_ID"].map(id_map)

# NaN除去
edges_df = edges_df.dropna(subset=["SPOT_SOURCE_ID","SPOT_TARGET_ID"]).astype(int)

max_id = spots_df["new_ID"].max()
edges_df = edges_df[(edges_df["SPOT_SOURCE_ID"] <= max_id) & (edges_df["SPOT_TARGET_ID"] <= max_id)]

#===========================================================
# 5. Voronoi分割でarea, perimeter付与
#===========================================================
all_coords = spots_df[["POSITION_X","POSITION_Y"]].values
vor = Voronoi(all_coords)

def polygon_area_and_perimeter(coords):
    x = coords[:,0]
    y = coords[:,1]
    area = 0.5*np.abs(np.dot(x, np.roll(y,1))-np.dot(y, np.roll(x,1)))
    perimeter = np.sum(np.sqrt(np.sum(np.diff(coords, axis=0)**2, axis=1)))
    perimeter += np.sqrt(np.sum((coords[0]-coords[-1])**2))
    return area, perimeter

regions = []
for i, region_index in enumerate(vor.point_region):
    vertices = vor.regions[region_index]
    if -1 in vertices:
        area, perimeter = np.nan, np.nan
    else:
        poly_coords = vor.vertices[vertices]
        area, perimeter = polygon_area_and_perimeter(poly_coords)
    regions.append((area, perimeter))

spots_df["area"] = [r[0] for r in regions]
spots_df["perimeter"] = [r[1] for r in regions]

print("\n=== 面積・周長付与後のspots_df ===")
print(spots_df.head())

#===========================================================
# 6. 時間エッジから速度計算
#===========================================================
merged = pd.merge(edges_df, spots_df[["new_ID","POSITION_X","POSITION_Y","FRAME"]],
                  left_on="SPOT_SOURCE_ID", right_on="new_ID", how="left", suffixes=("", "_source"))
merged = pd.merge(merged, spots_df[["new_ID","POSITION_X","POSITION_Y","FRAME"]],
                  left_on="SPOT_TARGET_ID", right_on="new_ID", how="left", suffixes=("", "_target"))

merged["vel_x"] = merged["POSITION_X_target"] - merged["POSITION_X"]
merged["vel_y"] = merged["POSITION_Y_target"] - merged["POSITION_Y"]

print("\n=== 時間エッジ対応テーブル ===")
print(merged.head())

velocity_df = merged[["SPOT_SOURCE_ID","vel_x","vel_y"]].drop_duplicates("SPOT_SOURCE_ID")
velocity_map_x = dict(zip(velocity_df["SPOT_SOURCE_ID"], velocity_df["vel_x"]))
velocity_map_y = dict(zip(velocity_df["SPOT_SOURCE_ID"], velocity_df["vel_y"]))

spots_df["vel_x"] = spots_df["new_ID"].map(velocity_map_x)
spots_df["vel_y"] = spots_df["new_ID"].map(velocity_map_y)

print("\n=== 速度ベクトル付与後のspots_df ===")
print(spots_df.head())

# 最終フレームはvelocityがNaNになるはずなのでdropna
# さらにarea, perimeterにもNaNが混入する可能性があるため、全特徴列でNaN除去
spots_df = spots_df.dropna(subset=["MAX_INTENSITY_CH1", "area", "perimeter", "vel_x", "vel_y"]).reset_index(drop=True)

print("\n=== 全特徴量NaN除去後のspots_df ===")
print(spots_df.head())
print("NaNの有無確認:\n", spots_df[["MAX_INTENSITY_CH1","area","perimeter","vel_x","vel_y"]].isna().sum())

#===========================================================
# 7. 再度new_ID割り当て (NaN除去後ノード減少のため)
#===========================================================
spots_df["new_ID"] = np.arange(len(spots_df))
id_map = dict(zip(spots_df["ID"], spots_df["new_ID"]))

edges_df["SPOT_SOURCE_ID"] = edges_df["SPOT_SOURCE_ID"].map(id_map)
edges_df["SPOT_TARGET_ID"] = edges_df["SPOT_TARGET_ID"].map(id_map)
edges_df = edges_df.dropna(subset=["SPOT_SOURCE_ID","SPOT_TARGET_ID"]).astype(int)

max_id = spots_df["new_ID"].max()
edges_df = edges_df[(edges_df["SPOT_SOURCE_ID"] <= max_id) & (edges_df["SPOT_TARGET_ID"] <= max_id)]

#===========================================================
# 8. Delaunayエッジ再計算 (確定したspots_dfに対して)
#===========================================================
all_frames = spots_df["FRAME"].unique()
all_frames.sort()
delaunay_edges = []
for fr in all_frames:
    frame_spots = spots_df[spots_df["FRAME"] == fr]
    coords = frame_spots[["POSITION_X", "POSITION_Y"]].values
    if len(coords) > 2:
        tri = Delaunay(coords)
        frame_new_ids = frame_spots["new_ID"].values
        for simplex in tri.simplices:
            edges = [(frame_new_ids[simplex[i]], frame_new_ids[simplex[j]]) for i in range(3) for j in range(i+1, 3)]
            delaunay_edges.extend(edges)

delaunay_edges = list(set(delaunay_edges))
print("\n=== 再計算後のDelaunayエッジ例 ===")
print(delaunay_edges[:10])

# ID範囲チェック
all_edge_ids = [e[0] for e in delaunay_edges] + [e[1] for e in delaunay_edges] \
               + edges_df["SPOT_SOURCE_ID"].tolist() + edges_df["SPOT_TARGET_ID"].tolist()
assert max(all_edge_ids) <= max_id

num_nodes = max_id + 1

#===========================================================
# 9. GNN用グラフ構築
#===========================================================
space_edges_src = [e[0] for e in delaunay_edges]
space_edges_dst = [e[1] for e in delaunay_edges]
time_edges_src = edges_df["SPOT_SOURCE_ID"].values
time_edges_dst = edges_df["SPOT_TARGET_ID"].values

all_src = torch.tensor(list(space_edges_src) + list(time_edges_src), dtype=torch.int64)
all_dst = torch.tensor(list(space_edges_dst) + list(time_edges_dst), dtype=torch.int64)

assert all_src.max().item() <= max_id
assert all_dst.max().item() <= max_id

g = dgl.graph((all_src, all_dst), num_nodes=num_nodes)
# DGLバージョンによってはto_bidirectedがない場合はadd_reverse_edges()
g = dgl.to_bidirected(g)  # または g = dgl.add_reverse_edges(g)

node_features = torch.tensor(
    spots_df[["MAX_INTENSITY_CH1", "area", "perimeter", "vel_x", "vel_y"]].values,
    dtype=torch.float32
)
targets = node_features[:, -2:]  # vel_x, vel_y

# 再チェック
print("Check NaN in node_features:", torch.isnan(node_features).any().item())
print("Check NaN in targets:", torch.isnan(targets).any().item())

g.ndata['feat'] = node_features

class SimpleGNN(nn.Module):
    def __init__(self, in_feats, hidden_feats, out_feats):
        super(SimpleGNN, self).__init__()
        self.conv1 = GraphConv(in_feats, hidden_feats)
        self.conv2 = GraphConv(hidden_feats, out_feats)
        self.relu = nn.ReLU()

    def forward(self, graph, feat):
        h = self.conv1(graph, feat)
        h = self.relu(h)
        h = self.conv2(graph, h)
        return h

all_frames = spots_df["FRAME"].unique()
all_frames.sort()
second_last_frame = all_frames[-2]
frame_nodes = spots_df["FRAME"].values
frame_nodes_tensor = torch.tensor(frame_nodes, dtype=torch.int64)
mask = (frame_nodes_tensor == second_last_frame)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
in_feats = node_features.shape[1]
hidden_feats = 32
out_feats = 2

model = SimpleGNN(in_feats, hidden_feats, out_feats).to(device)
g = g.to(device)
node_features = node_features.to(device)
targets = targets.to(device)
mask = mask.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

n_epochs = 1000
for epoch in range(n_epochs):
    model.train()
    optimizer.zero_grad()
    pred = model(g, node_features)
    # maskに対応するノードだけで損失計算
    if mask.sum().item() == 0:
        print("No nodes in second_last_frame, cannot train on this frame.")
        break
    loss = criterion(pred[mask], targets[mask])
    loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs}, Loss: {loss.item()}")

model.eval()
with torch.no_grad():
    pred = model(g, node_features)
    print("Predicted velocities for second last frame:", pred[mask][:10])

print("処理完了。")

=== 読み込み直後のspotsデータ ===
     ID          POSITION_X          POSITION_Y FRAME MAX_INTENSITY_CH1
0  4096  430.75051124744374   308.3865030674847     0             221.0
1  3976   431.1983606557377   308.7120218579235     1             236.0
2  4185   431.6177707676131   309.8470031545741     2             234.0
3  2048  326.16029723991505  3.3322717622080678     1             179.0
4  2050  325.92579075425795    2.66301703163017     0             182.0
=== 読み込み直後のedgesデータ ===
  SPOT_SOURCE_ID SPOT_TARGET_ID
0           3976           4185
1           4096           3976
2           2048           2031
3           2050           2048
4           3978           4186

=== 面積・周長付与後のspots_df ===
     ID  POSITION_X  POSITION_Y  FRAME  MAX_INTENSITY_CH1  new_ID  \
0  4096  430.750511  308.386503      0              221.0       0   
1  3976  431.198361  308.712022      1              236.0       1   
2  4185  431.617771  309.847003      2              234.0       2   
3  2048  326.160297    3.