In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from ultralytics import YOLO
from ultralytics.nn.modules import Conv, C2f, SPPF, Detect
import networkx as nx
from nltk.corpus import wordnet as wn
import json
import os
from PIL import Image
import numpy as np
from tqdm import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class HierarchicalDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, ann_file, hierarchy_file, img_size=640):
        self.img_dir = img_dir
        self.img_size = img_size
        with open(hierarchy_file) as f:
            self.hierarchy = json.load(f)
        with open(ann_file) as f:
            self.annotations = json.load(f)
        
        # Create class mappings
        self.class_to_idx = {}
        self.idx_to_class = {}
        idx = 0
        queue = [("root", self.hierarchy)]
        while queue:
            path, node = queue.pop(0)
            for cls, children in node.items():
                full_path = f"{path}/{cls}" if path != "root" else cls
                self.class_to_idx[full_path] = idx
                self.idx_to_class[idx] = full_path
                idx += 1
                if children:
                    queue.append((full_path, children))
    
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        img_path = os.path.join(self.img_dir, ann["image_file"])
        img = Image.open(img_path).convert("RGB")
        
        # Resize image and adjust boxes
        orig_w, orig_h = img.size
        img = img.resize((self.img_size, self.img_size))
        scale_x = self.img_size / orig_w
        scale_y = self.img_size / orig_h
        
        boxes = []
        labels = []
        for obj in ann["objects"]:
            # Find hierarchical path for this object
            cls_path = self.find_class_path(obj["class_name"])
            if cls_path is None:
                continue
                
            # Scale bounding boxes
            box = obj["bbox"]
            box[0] *= scale_x  # xmin
            box[1] *= scale_y  # ymin
            box[2] *= scale_x  # xmax
            box[3] *= scale_y  # ymax
            boxes.append(box)
            labels.append(self.class_to_idx[cls_path])
        
        img_tensor = torch.tensor(np.array(img), dtype=torch.float32).permute(2, 0, 1) / 255.0
        return {
            "image": img_tensor,
            "boxes": torch.tensor(boxes, dtype=torch.float32),
            "labels": torch.tensor(labels, dtype=torch.long),
            "original_size": torch.tensor([orig_h, orig_w])
        }
    
    def find_class_path(self, class_name):
        queue = [("root", self.hierarchy)]
        while queue:
            path, node = queue.pop(0)
            for cls, children in node.items():
                if cls.lower() == class_name.lower():
                    return f"{path}/{cls}" if path != "root" else cls
                if children:
                    queue.append((f"{path}/{cls}" if path != "root" else cls, children))
        return None

class HierarchyGNN(nn.Module):
    def __init__(self, num_classes, input_dim=256, hidden_dim=128):
        super().__init__()
        self.gat1 = GATConv(input_dim, hidden_dim, heads=3)
        self.gat2 = GATConv(hidden_dim*3, hidden_dim)
        self.class_embed = nn.Embedding(num_classes, hidden_dim)
        
    def forward(self, x, edge_index):
        x = self.class_embed(x)
        x = F.relu(self.gat1(x, edge_index))
        x = self.gat2(x, edge_index)
        return x

class HierarchicalYOLOv8(nn.Module):
    def __init__(self, num_classes, hierarchy):
        super().__init__()
        # Load pretrained YOLOv8
        self.yolo = YOLO("yolov8n.yaml").model
        self.yolo.to(device)
        
        # Freeze backbone layers
        for param in self.yolo.parameters():
            param.requires_grad = False
            
        # GNN for hierarchy
        self.gnn = HierarchyGNN(num_classes).to(device)
        
        # Build hierarchy graph
        self.graph = self.build_hierarchy_graph(hierarchy)
        self.edge_index = self.graph.edge_index.to(device)
        
        # Modified detection head
        self.detect = HierarchicalDetect(num_classes).to(device)
    
    def build_hierarchy_graph(self, hierarchy):
        G = nx.DiGraph()
        
        def add_nodes(parent, children):
            for child, subchildren in children.items():
                G.add_edge(parent, child)
                if subchildren:
                    add_nodes(child, subchildren)
        
        add_nodes("root", hierarchy)
        return from_networkx(G)
    
    def forward(self, x):
        # YOLOv8 backbone
        x = self.yolo(x)
        
        # GNN processing
        node_features = torch.arange(len(self.graph.x)).to(device)
        gnn_out = self.gnn(node_features, self.edge_index)
        
        # Combine features
        combined = torch.cat([x, gnn_out.unsqueeze(0).repeat(x.size(0), 1, 1)], dim=1)
        
        # Detection
        return self.detect(combined)

class HierarchicalDetect(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes
        self.conv = nn.Conv2d(256, num_classes, kernel_size=1)
        
    def forward(self, x):
        return self.conv(x)

def train(model, dataloader, epochs=50, lr=0.001):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            images = batch["image"].to(device)
            labels = batch["labels"].to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

#  usage
if __name__ == "__main__":
    #  hierarchy
    hierarchy = {
        "Animal": {
            "Mammal": {
                "Dog": {"Labrador": {}, "Poodle": {}},
                "Cat": {"Persian": {}, "Siamese": {}}
            },
            "Bird": {"Eagle": {}, "Sparrow": {}}
        },
        "Vehicle": {
            "Car": {"Sedan": {}, "SUV": {}},
            "Truck": {"Pickup": {}, "Semi": {}}
        }
    }
    
    # Create dataset
    dataset = HierarchicalDataset(
        img_dir="data/images",
        ann_file="data/annotations.json",
        hierarchy_file="data/hierarchy.json"
    )
    
    # Create model
    model = HierarchicalYOLOv8(len(dataset.class_to_idx), hierarchy).to(device)
    
    # Train
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)
    train(model, dataloader)
    
    # Save model
    torch.save(model.state_dict(), "hierarchical_yolov8.pth")

ModuleNotFoundError: No module named 'torch'