In [1]:
import dgl
import torch
import torch.nn as nn
from dgl.nn import GINEConv,GATConv
import numpy as np
import time
from tqdm import tqdm

In [2]:
from torchvision.datasets import STL10
from torchvision import transforms

target_size=(224,224)
transform=transforms.Compose([
    transforms.Resize(target_size),
    transforms.ToTensor()
])

STL10_train = STL10("STL10", split='train', download=True, transform=transform)
 
STL10_test = STL10("STL10", split='test', download=True, transform=transform)


Files already downloaded and verified
Files already downloaded and verified


In [10]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [25]:
g=dgl.graph(([0,1,2,1],
             [1,2,1,0]))
in_feat=2
out_feat=2
graphs=[]
n_feat=torch.tensor([[1],[2],[3]])
e_feat=torch.randn(g.num_edges(),in_feat)
print(n_feat.shape)
print(e_feat.shape)

torch.Size([3, 1])
torch.Size([4, 2])


In [30]:
def get_nearest_neighbors(image, row, col):
    # 画像の形状を取得
    height, width = image.shape[:2]

    # 注目画素の周囲8画素の座標を計算
    neighbors_coords = [(row-1, col-1), (row-1, col), (row-1, col+1),
                        (row, col-1), (row, col+1),
                        (row+1, col-1), (row+1, col), (row+1, col+1)]

    # 注目画素の最近傍画素の値を抜き出す
    nearest_neighbors = []
    for r, c in neighbors_coords:
        # 座標が画像範囲内かチェック
        if 0 <= r < height and 0 <= c < width:
            pixel_value = image[r, c]
            nearest_neighbors.append(pixel_value)
        else:
            # 画像範囲外の場合は0を追加するなど適切な処理を行う
            #nearest_neighbors.append(0)
            pass

    return nearest_neighbors

In [21]:
def image_patch(image,num_patch):
    #画像サイズ
    size=image.shape[1]
    #1パッチ当たりの画素数
    patch_width=int(size/num_patch)
    #パッチ保存用配列
    data=[]

    for i in range(0,size,patch_width):
        for j in range(0,size,patch_width):
            data.append(image[:, i : i + patch_width, j : j + patch_width])
    
    return data

In [33]:
def make_graph(side_length):
    g=dgl.DGLGraph()
    g.add_nodes(side_length**2)
    square_list = np.arange(side_length**2).reshape((side_length, side_length))
    #ノード番号に対応したインデックスを取得
    inds=np.ndindex(square_list.shape)
    inds=[idx for idx in inds]
    #各ノードと最近傍ノード間にエッジを張る
    for i in range(side_length**2):
        x,y=inds[i]
        flatt_nh=get_nearest_neighbors(square_list,x,y)
        for j in flatt_nh:
            if i == j:
                continue
            else:
                g.add_edges(j,i)
    return g

In [22]:
st=time.time()
for image,label in tqdm(STL10_train):
    n_feat=image_patch(image,8)
print(time.time()-st)

100%|██████████| 5000/5000 [00:04<00:00, 1090.97it/s]

4.585525751113892





In [24]:
print(len(n_feat))
print(type(n_feat[0]))

64
<class 'torch.Tensor'>


In [32]:
print(n_feat)
print(e_feat)
conv=GINEConv(nn.Linear(in_feat,out_feat))
res=conv(g,n_feat,e_feat)
print(res)

tensor([[ 0.5888, -0.8675,  0.4436],
        [ 2.4353, -2.4259, -0.3092],
        [ 0.2220, -2.3916,  1.8348]])
tensor([[-0.4855,  0.0899, -1.2415],
        [ 1.0146, -0.0956,  0.1709],
        [ 0.9142,  0.9067, -0.5644],
        [-0.8427,  1.1041,  1.7460]])
tensor([[ 1.0899, -1.3193],
        [ 0.1359, -1.3522],
        [ 0.6151, -1.8083]], grad_fn=<AddmmBackward0>)


In [79]:
a=torch.randn((10,10,3))
b=torch.randn((10,10,3))
c=a.detach()

In [81]:
cos=nn.CosineSimilarity(0)
output=torch.cosine_similarity(a.flatten(),b.flatten(),dim=0)
print(output.item())

-0.09277930855751038


In [None]:
g=dgl.graph(([0,1,2,1],
             [1,2,1,0]))
in_feat=2
out_feat=2
n_feat=torch.randn(g.num_nodes(),in_feat)
e_feat=torch.randn(g.num_edges(),in_feat)

In [67]:
gat1=GATConv(in_feat,10,2)
pred=gat1(g,n_feat,e_feat,get_attention=True)