<a href="https://colab.research.google.com/github/Shihori/AI/blob/main/GNN_corona2_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [4]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/GNN-corona2/
%ls -a

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/GNN-corona2
[0m[01;34mdata[0m/  GNN-corona2.py


In [3]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from PIL import Image
import numpy as np
import os
from sklearn.model_selection import train_test_split
from collections import defaultdict

def image_to_graph_data(image_path, label):

    try:
        img = Image.open(image_path).convert('RGB').resize((32, 32))
        img_array = np.array(img) / 255.0
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None

    height, width, channels = img_array.shape
    num_nodes = height * width

    x = torch.tensor(img_array.reshape(num_nodes, channels), dtype=torch.float)

    edge_index = []
    for r in range(height):
        for c in range(width):
            node_idx = r * width + c

            if c + 1 < width:
                edge_index.append([node_idx, node_idx + 1])
                edge_index.append([node_idx + 1, node_idx])

            if r + 1 < height:
                edge_index.append([node_idx, (r + 1) * width + c])
                edge_index.append([(r + 1) * width + c, node_idx])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    y = torch.tensor([label], dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, y=y)
    return data

class GCNForImageClassification(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

        self.classifier = torch.nn.Linear(hidden_channels, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)

        graph_level_features = global_mean_pool(x, batch)


        out = self.classifier(graph_level_features)
        return out

if __name__ == "__main__":
    base_image_dir = 'data'

    category_to_label = {}
    label_counter = 0
    all_image_paths = []
    all_labels = []

    if not os.path.exists(base_image_dir):
        print(f"Error: Directory '{base_image_dir}' not found.")
        print("Please create subdirectories like 'data/category_A', 'data/category_B' and place images inside.")
        exit()

    for category_name in sorted(os.listdir(base_image_dir)):
        category_dir = os.path.join(base_image_dir, category_name)
        if os.path.isdir(category_dir):
            category_to_label[category_name] = label_counter
            print(f"Mapping '{category_name}' to label {label_counter}")
            for img_file in os.listdir(category_dir):
                if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    all_image_paths.append(os.path.join(category_dir, img_file))
                    all_labels.append(label_counter)
            label_counter += 1

    if not all_image_paths:
        print("No image files found in the specified categories. Please check your data directory.")
        exit()

    if len(all_image_paths) > 100:
        indices = np.random.choice(len(all_image_paths), 100, replace=False)
        all_image_paths = [all_image_paths[i] for i in indices]
        all_labels = [all_labels[i] for i in indices]
        print(f"Using a random subset of 100 images from {len(all_image_paths)} available.")
    else:
        print(f"Using all {len(all_image_paths)} images found.")

    graph_datasets = []
    for i, img_path in enumerate(all_image_paths):
        graph_data = image_to_graph_data(img_path, all_labels[i])
        if graph_data:
            graph_datasets.append(graph_data)

    if not graph_datasets:
        print("No valid graph data could be created from images. Exiting.")
        exit()

    print(f"Successfully converted {len(graph_datasets)} images to graph data.")

    train_data, test_data = train_test_split(graph_datasets, test_size=0.2, random_state=42)
    print(f"Train data size: {len(train_data)}, Test data size: {len(test_data)}")

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

    in_channels = graph_datasets[0].x.shape[1]
    hidden_channels = 64
    num_classes = len(category_to_label)

    model = GCNForImageClassification(in_channels, hidden_channels, num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    def train(model, train_loader, optimizer, criterion, device):
        model.train()
        total_loss = 0
        correct = 0
        total_samples = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, data.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total_samples += data.y.size(0)

        avg_loss = total_loss / len(train_loader)
        accuracy = correct / total_samples
        return avg_loss, accuracy

    def test(model, loader, device):
        model.eval()
        correct = 0
        total_samples = 0
        with torch.no_grad():
            for data in loader:
                data = data.to(device)
                out = model(data)
                pred = out.argmax(dim=1)
                correct += (pred == data.y).sum().item()
                total_samples += data.y.size(0)
        accuracy = correct / total_samples
        return accuracy

    epochs = 100
    print(f"\n--- Starting training on {device} ---")
    for epoch in range(1, epochs + 1):
        train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
        test_acc = test(model, test_loader, device)
        print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

    print("\n--- Training complete ---")
    print(f"Final Test Accuracy: {test(model, test_loader, device):.4f}")

    print("\nCategory to Label Mapping:")
    for category, label in category_to_label.items():
        print(f"  {category}: {label}")

Mapping 'category_A' to label 0
Mapping 'category_B' to label 1
Using all 100 images found.
Successfully converted 100 images to graph data.
Train data size: 80, Test data size: 20

--- Starting training on cpu ---




Epoch: 001, Train Loss: 0.7031, Train Acc: 0.4750, Test Acc: 0.4000
Epoch: 002, Train Loss: 0.6934, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 003, Train Loss: 0.6960, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 004, Train Loss: 0.6904, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 005, Train Loss: 0.6907, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 006, Train Loss: 0.6904, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 007, Train Loss: 0.6925, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 008, Train Loss: 0.6914, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 009, Train Loss: 0.6876, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 010, Train Loss: 0.6897, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 011, Train Loss: 0.6914, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 012, Train Loss: 0.6934, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 013, Train Loss: 0.6886, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 014, Train Loss: 0.6899, Train Acc: 0.5250, Test Acc: 0.4000
Epoch: 015, Train Loss: 0.6914, Train Acc: 0.525