In [31]:
pip install torch torchvision torch_geometric scikit-image plotly




In [32]:
import torch
from skimage.segmentation import slic
from skimage.color import rgb2lab
from skimage import graph
from torch_geometric.data import Data

def image_to_graph(image):
    segments = slic(image, n_segments=75, compactness=10)
    rag = graph.rag_mean_color(image, segments)

    edge_index = []
    for edge in rag.edges():
        edge_index.append([edge[0], edge[1]])
        edge_index.append([edge[1], edge[0]])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = torch.tensor([rag.nodes[n]['mean color'] for n in rag.nodes], dtype=torch.float)

    return Data(x=x, edge_index=edge_index)

In [33]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class FashionGCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 32)
        self.conv2 = GCNConv(32, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


In [34]:
import plotly.graph_objects as go
import numpy as np

def plot_3d_graph(data):
    coords = np.random.rand(data.x.size(0), 3)  # You can replace this with real node positions
    edge_list = data.edge_index.cpu().numpy()

    edge_x = []
    edge_y = []
    edge_z = []

    for i in range(edge_list.shape[1]):
        src, dst = edge_list[0][i], edge_list[1][i]
        edge_x.extend([coords[src][0], coords[dst][0], None])
        edge_y.extend([coords[src][1], coords[dst][1], None])
        edge_z.extend([coords[src][2], coords[dst][2], None])

    edge_trace = go.Scatter3d(
        x=edge_x, y=edge_y, z=edge_z,
        line=dict(width=2, color='black'),
        hoverinfo='none',
        mode='lines')

    node_trace = go.Scatter3d(
        x=coords[:, 0], y=coords[:, 1], z=coords[:, 2],
        mode='markers',
        marker=dict(size=6, color='red'),
        hoverinfo='text')

    fig = go.Figure(data=[edge_trace, node_trace])
    fig.update_layout(showlegend=False)
    fig.show()


In [35]:
# Load an image (replace 'path/to/your/image.jpg' with the actual path)
from skimage import io

image = io.imread('/content/maxresdefault.jpg')

# Convert the image to a PyG Data object
graph_data = image_to_graph(image)

# Display the graph data
print(graph_data)

Data(x=[57, 3], edge_index=[2, 280])


  x = torch.tensor([rag.nodes[n]['mean color'] for n in rag.nodes], dtype=torch.float)
