In [None]:
import subprocess
import sys

print("Installing dependencies for Sustainability AI Model...")
print("="*60)

subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "pip"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "numpy<2.0"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "scipy<1.15.0"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--no-deps", "timm==1.0.12"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "albumentations==1.4.22"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "einops==0.8.0"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "wandb==0.19.1"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--no-deps", "torch-geometric==2.6.1"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "torch-scatter", "torch-sparse", "-f", "https://data.pyg.org/whl/torch-2.5.0+cu121.html"])

print("="*60)
print("âœ… Dependencies installed successfully!")
print("="*60)



In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import sys
import json
import random
import logging
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image

import timm
from timm.data import create_transform, resolve_data_config
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv
from tqdm.notebook import tqdm
import wandb

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    logger.info(f"Random seed set to {seed}")

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        logger.info(f"CUDA available: {torch.cuda.get_device_name(0)}")
        logger.info(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        return device
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def optimize_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Enable expandable segments to reduce fragmentation
        import os
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
        logger.info("Memory optimization enabled: expandable_segments=True")

class EarlyStopping:
    def __init__(self, patience=15, mode="max", delta=0):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.mode = mode
        self.delta = delta

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
        elif self.mode == "max":
            if current_score <= self.best_score + self.delta:
                self.counter += 1
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = current_score
                self.counter = 0
        return self.early_stop

In [None]:
TARGET_CLASSES = [
    'aerosol_cans', 'aluminum_food_cans', 'aluminum_soda_cans', 'cardboard_boxes', 'cardboard_packaging',
    'clothing', 'coffee_grounds', 'disposable_plastic_cutlery', 'egg_shells', 'food_waste',
    'glass_beverage_bottles', 'glass_cosmetic_containers', 'glass_food_jars', 'magazines',
    'newspaper', 'office_paper', 'paper_cups', 'plastic_cup_lids', 'plastic_detergent_bottles',
    'plastic_food_containers', 'plastic_shopping_bags', 'plastic_soda_bottles', 'plastic_straws',
    'plastic_trash_bags', 'plastic_water_bottles', 'shoes', 'steel_food_cans', 'styrofoam_cups',
    'styrofoam_food_containers', 'tea_bags'
]

VISION_CONFIG = {
    "model": {
        "backbone": "eva02_large_patch14_448.mim_m38m_ft_in22k_in1k",
        "pretrained": True,
        "num_classes": 30,
        "drop_rate": 0.3,
        "drop_path_rate": 0.2
    },
    "data": {
        "input_size": 448,
        "num_workers": 2,
        "pin_memory": False,  # Reduce memory pressure
        "sources": [
            {
                "name": "master_30",
                "path": "/kaggle/input/recyclable-and-household-waste-classification/images",
                "type": "master"
            },
            {
                "name": "garbage_12",
                "path": "/kaggle/input/garbage-classification/garbage_classification",
                "type": "mapped_12"
            },
            {
                "name": "waste_22k",
                "path": "/kaggle/input/waste-classification-data/DATASET",
                "type": "mapped_2"
            },
            {
                "name": "garbage_v2_10",
                "path": "/kaggle/input/garbage-classification-v2",
                "type": "mapped_10"
            },
            {
                "name": "garbage_6",
                "path": "/kaggle/input/garbage-classification",
                "type": "mapped_6"
            },
            {
                "name": "garbage_balanced",
                "path": "/kaggle/input/garbage-dataset-classification",
                "type": "mapped_6"
            },
            {
                "name": "warp_industrial",
                "path": "/kaggle/input/warp-waste-recycling-plant-dataset",
                "type": "industrial"
            },
            {
                "name": "multiclass_garbage",
                "path": "/kaggle/input/multi-class-garbage-classification-dataset",
                "type": "multiclass"
            }
        ]
    },
    "training": {
        "batch_size": 4,  # Reduced from 8 to fit in 14.74 GB GPU
        "grad_accum_steps": 16,  # Increased to maintain effective batch size of 64
        "learning_rate": 5e-5,
        "weight_decay": 0.05,
        "num_epochs": 20,
        "patience": 5
    }
}

In [None]:
class UnifiedWasteDataset(Dataset):
    """
    A unified dataset that ingests data from multiple sources and maps them
    to a single 30-class target schema.
    """
    def __init__(self, sources_config, target_classes, transform=None):
        self.transform = transform
        self.target_classes = sorted(target_classes)
        self.class_to_idx = {c: i for i, c in enumerate(self.target_classes)}
        self.samples = []

        self.skipped_count = 0
        self.skipped_labels = {}  # Track what labels are being skipped

        for source in sources_config:
            self._ingest_source(source)

        logger.info(f"Unified Dataset Created: {len(self.samples)} images. Skipped {self.skipped_count} unmappable images.")

        # Log skipped labels for debugging
        if self.skipped_labels:
            logger.warning(f"Skipped labels breakdown:")
            for label, count in sorted(self.skipped_labels.items(), key=lambda x: x[1], reverse=True):
                logger.warning(f"  '{label}': {count} images")

    def _ingest_source(self, source):
        path = Path(source["path"])
        if not path.exists():
            parent = path.parent
            found = False
            if parent.exists():
                for child in parent.iterdir():
                    if child.is_dir():
                        try:
                            if any(child.iterdir()):
                                path = child
                                found = True
                                break
                        except PermissionError:
                            continue

            if not found or not path.exists():
                logger.warning(f"Source {source['name']} not found at {source['path']}. Skipping.")
                return

        logger.info(f"Ingesting {source['name']} from {path}...")

        for root, _, files in os.walk(path):
            folder_name = Path(root).name.lower()

            target_label = self._map_label(folder_name, source['type'])

            if target_label:
                target_idx = self.class_to_idx[target_label]
                for file in files:
                    if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                        self.samples.append((Path(root) / file, target_idx))
            else:
                img_count = sum(1 for f in files if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')))
                if img_count > 0:
                    self.skipped_count += img_count
                    # Track which labels are being skipped
                    if folder_name not in self.skipped_labels:
                        self.skipped_labels[folder_name] = 0
                    self.skipped_labels[folder_name] += img_count

    def _map_label(self, raw_label, source_type):
        raw = raw_label.lower().strip()

        if source_type == 'master':
            if raw in self.target_classes:
                return raw
            # Fallback: try to find closest match
            for target in self.target_classes:
                if raw in target or target in raw:
                    return target
            return None

        if source_type == 'mapped_12':
            mapping = {
                'paper': 'office_paper',
                'cardboard': 'cardboard_boxes',
                'plastic': 'plastic_food_containers',
                'metal': 'aluminum_food_cans',
                'glass': 'glass_food_jars',
                'brown-glass': 'glass_beverage_bottles',
                'green-glass': 'glass_beverage_bottles',
                'white-glass': 'glass_food_jars',
                'clothes': 'clothing',
                'shoes': 'shoes',
                'biological': 'food_waste',
                'trash': 'food_waste'
            }
            return mapping.get(raw)

        if source_type == 'mapped_2':
            # Organic waste
            if raw in ['organic', 'o']:
                return 'food_waste'
            # Recyclable waste (paper, plastic, metal, glass mix)
            if raw in ['recyclable', 'r']:
                return 'plastic_food_containers'  # Generic recyclable
            return None

        if source_type == 'mapped_10':
            mapping = {
                'metal': 'aluminum_food_cans',
                'glass': 'glass_food_jars',
                'biological': 'food_waste',
                'paper': 'office_paper',
                'battery': 'aerosol_cans',
                'trash': 'food_waste',
                'cardboard': 'cardboard_boxes',
                'shoes': 'shoes',
                'clothes': 'clothing',
                'plastic': 'plastic_food_containers'
            }
            return mapping.get(raw)

        if source_type == 'mapped_6':
            mapping = {
                'cardboard': 'cardboard_boxes',
                'glass': 'glass_food_jars',
                'metal': 'aluminum_food_cans',
                'paper': 'office_paper',
                'plastic': 'plastic_food_containers',
                'trash': 'food_waste'
            }
            return mapping.get(raw)

        if source_type == 'industrial':
            mapping = {
                'pet': 'plastic_food_containers',
                'hdpe': 'plastic_food_containers',
                'pvc': 'plastic_food_containers',
                'ldpe': 'plastic_food_containers',
                'pp': 'plastic_food_containers',
                'ps': 'plastic_food_containers',
                'metal': 'aluminum_food_cans',
                'glass': 'glass_food_jars',
                'paper': 'office_paper',
                'cardboard': 'cardboard_boxes',
                'trash': 'food_waste'
            }
            return mapping.get(raw)

        if source_type == 'multiclass':
            mapping = {
                'plastic': 'plastic_food_containers',
                'metal': 'aluminum_food_cans',
                'glass': 'glass_food_jars',
                'paper': 'office_paper',
                'cardboard': 'cardboard_boxes',
                'trash': 'food_waste',
                'organic': 'food_waste',
                'battery': 'aerosol_cans',
                'clothes': 'clothing',
                'shoes': 'shoes'
            }
            return mapping.get(raw)

        # Universal fallback mappings for common waste categories
        # This ensures NO images are skipped
        fallback_mapping = {
            # Recyclables
            'recyclable': 'plastic_food_containers',
            'recycle': 'plastic_food_containers',
            'recycling': 'plastic_food_containers',
            # Waste types
            'waste': 'food_waste',
            'garbage': 'food_waste',
            'rubbish': 'food_waste',
            'refuse': 'food_waste',
            # Organic
            'compost': 'food_waste',
            'food': 'food_waste',
            'kitchen': 'food_waste',
            'biological': 'food_waste',
            # Paper products
            'newspaper': 'newspaper',
            'magazine': 'magazines',
            'book': 'office_paper',
            'document': 'office_paper',
            # Plastic types
            'bottle': 'plastic_water_bottles',
            'bottle-transp': 'plastic_water_bottles',
            'bottle-blue': 'plastic_water_bottles',
            'bottle-dark': 'plastic_water_bottles',
            'bottle-green': 'plastic_water_bottles',
            'bottle-blue5l': 'plastic_water_bottles',
            'bottle-milk': 'plastic_water_bottles',
            'bottle-oil': 'plastic_water_bottles',
            'bottle-yogurt': 'plastic_food_containers',
            'bottle-multicolor': 'plastic_water_bottles',
            'bottle-transp-full': 'plastic_water_bottles',
            'bottle-blue-full': 'plastic_water_bottles',
            'bottle-green-full': 'plastic_water_bottles',
            'bottle-dark-full': 'plastic_water_bottles',
            'bottle-milk-full': 'plastic_water_bottles',
            'bottle-multicolorv-full': 'plastic_water_bottles',
            'bottle-blue5l-full': 'plastic_water_bottles',
            'bottle-oil-full': 'plastic_water_bottles',
            'bag': 'plastic_shopping_bags',
            'container': 'plastic_food_containers',
            'cup': 'paper_cups',
            'straw': 'plastic_straws',
            # Detergents (plastic containers)
            'detergent-white': 'plastic_food_containers',
            'detergent-color': 'plastic_food_containers',
            'detergent-transparent': 'plastic_food_containers',
            'detergent-box': 'cardboard_boxes',
            # Metal
            'can': 'aluminum_soda_cans',
            'cans': 'aluminum_soda_cans',
            'tin': 'steel_food_cans',
            'aluminum': 'aluminum_food_cans',
            'steel': 'steel_food_cans',
            'canister': 'aluminum_food_cans',
            'battery': 'aerosol_cans',  # Hazardous, map to aerosol as closest
            # Glass
            'jar': 'glass_food_jars',
            'glass-transp': 'glass_food_jars',
            'glass-dark': 'glass_beverage_bottles',
            'glass-green': 'glass_beverage_bottles',
            'white-glass': 'glass_food_jars',
            'brown-glass': 'glass_beverage_bottles',
            'green-glass': 'glass_beverage_bottles',
            # Cardboard
            'milk-cardboard': 'cardboard_boxes',
            'juice-cardboard': 'cardboard_boxes',
            # Textiles
            'fabric': 'clothing',
            'textile': 'clothing',
            # Foam
            'foam': 'styrofoam_cups',
            'styrofoam': 'styrofoam_cups',
            'polystyrene': 'styrofoam_cups',
        }

        # Try fallback mapping
        for key, value in fallback_mapping.items():
            if key in raw:
                return value

        return None

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label_idx = self.samples[idx]
        try:
            img = Image.open(path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, label_idx
        except Exception as e:
            logger.error(f"Corrupt image {path}: {e}")
            return torch.zeros((3, 448, 448)), label_idx

    def get_labels(self):
        return [s[1] for s in self.samples]

In [None]:
def get_vision_transforms(config, model, is_train=True):
    try:
        data_config = resolve_data_config(model.default_cfg, model=model)
        if is_train:
            return create_transform(
                input_size=data_config['input_size'],
                is_training=True,
                use_prefetcher=False,
                no_aug=False,
                scale=(0.08, 1.0),
                ratio=(0.75, 1.33),
                hflip=0.5,
                vflip=0.0,
                color_jitter=0.4,
                auto_augment='rand-m9-mstd0.5-inc1',
                interpolation=data_config['interpolation'],
                mean=data_config['mean'],
                std=data_config['std'],
                re_prob=0.25,
                re_mode='pixel',
                re_count=1,
            )
        else:
            return create_transform(
                input_size=data_config['input_size'],
                is_training=False,
                use_prefetcher=False,
                interpolation=data_config['interpolation'],
                mean=data_config['mean'],
                std=data_config['std'],
            )
    except Exception as e:
        logger.error(f"Failed to create transforms: {e}")
        raise

In [None]:
def create_vision_model(config):
    logger.info(f"Creating model: {config['model']['backbone']}")
    model = timm.create_model(
        config["model"]["backbone"],
        pretrained=config["model"]["pretrained"],
        num_classes=config["model"]["num_classes"],
        drop_rate=config["model"]["drop_rate"],
        drop_path_rate=config["model"]["drop_path_rate"]
    )
    return model

In [None]:
def train_vision_model(config):
    set_seed()
    optimize_memory()
    device = get_device()
    logger.info(f"Using device: {device}")

    model = create_vision_model(config).to(device)
    logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

    # Enable gradient checkpointing to save memory
    if hasattr(model, 'set_grad_checkpointing'):
        model.set_grad_checkpointing(enable=True)
        logger.info("Gradient checkpointing enabled")

    train_transform = get_vision_transforms(config, model, is_train=True)
    val_transform = get_vision_transforms(config, model, is_train=False)

    full_dataset = UnifiedWasteDataset(
        sources_config=config["data"]["sources"],
        target_classes=TARGET_CLASSES,
        transform=None
    )

    if len(full_dataset) == 0:
        logger.error("Dataset is empty. Check paths.")
        return None

    train_size = int(0.85 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

    train_dataset.dataset.transform = train_transform
    val_dataset.dataset.transform = val_transform

    train_loader = DataLoader(
        train_dataset,
        batch_size=config["training"]["batch_size"],
        shuffle=True,
        num_workers=config["data"]["num_workers"],
        pin_memory=config["data"]["pin_memory"],
        persistent_workers=True if config["data"]["num_workers"] > 0 else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["training"]["batch_size"] * 2,
        shuffle=False,
        num_workers=config["data"]["num_workers"],
        persistent_workers=True if config["data"]["num_workers"] > 0 else False
    )

    optimizer = optim.AdamW(
        model.parameters(),
        lr=config["training"]["learning_rate"],
        weight_decay=config["training"]["weight_decay"]
    )
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["training"]["num_epochs"])
    early_stopping = EarlyStopping(patience=config["training"]["patience"])

    use_amp = (device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp) if use_amp else None

    accumulation_steps = config["training"]["grad_accum_steps"]

    try:
        wandb.init(project="sustainability-vision-lake", config=config, mode="online")
    except Exception as e:
        logger.warning(f"W&B initialization failed: {e}. Continuing without logging.")
        wandb.init(mode="disabled")

    for epoch in range(config["training"]["num_epochs"]):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['training']['num_epochs']}")
        optimizer.zero_grad()

        for i, (images, labels) in enumerate(pbar):
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            if use_amp:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels) / accumulation_steps
                scaler.scale(loss).backward()

                if (i + 1) % accumulation_steps == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                outputs = model(images)
                loss = criterion(outputs, labels) / accumulation_steps
                loss.backward()
                if (i + 1) % accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()
                    optimizer.zero_grad()

            running_loss += loss.item() * accumulation_steps
            with torch.no_grad():
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            current_loss = running_loss / (i + 1)
            pbar.set_postfix({'loss': f"{current_loss:.4f}", 'acc': f"{100*correct/total:.2f}%"})

        scheduler.step()
        train_acc = 100 * correct / total
        
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc="Validation", leave=False):
                images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

                if use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = model(images)
                        loss = criterion(outputs, labels)
                else:
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = 100 * val_correct / val_total
        val_loss /= len(val_loader)

        logger.info(f"Epoch {epoch+1}/{config['training']['num_epochs']}: Train Acc {train_acc:.2f}%, Val Loss {val_loss:.4f}, Val Acc {val_acc:.2f}%")

        try:
            wandb.log({
                "epoch": epoch + 1,
                "train_acc": train_acc,
                "val_acc": val_acc,
                "val_loss": val_loss,
                "learning_rate": optimizer.param_groups[0]['lr']
            })
        except:
            pass

        if early_stopping(val_acc):
            logger.info("Early stopping triggered")
            break

        # Clear GPU cache after each epoch to prevent memory fragmentation
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    try:
        wandb.finish()
    except:
        pass

    return model

In [None]:
# PEAK STANDARD GNN
# Using Graph Attention Networks v2 (GATv2) for superior expressive power

def generate_structured_knowledge_graph(num_classes=30, feat_dim=128):
    """
    Generates a realistic Knowledge Graph structure for waste classification.
    Simulates the schema: Item -> Material -> Bin
    """
    logger.info("Generating structured Knowledge Graph...")
    
    total_nodes = num_classes + 8 + 4
    x = torch.randn(total_nodes, feat_dim) # Node features (embeddings)
    
    edge_sources = []
    edge_targets = []
    
    # Node Indices for Materials
    mat_base = num_classes
    mat_plastic = mat_base + 0
    mat_paper = mat_base + 1
    mat_glass = mat_base + 2
    mat_metal = mat_base + 3
    mat_organic = mat_base + 4
    mat_fabric = mat_base + 5
    mat_ewaste = mat_base + 6
    mat_misc = mat_base + 7
    
    # Node Indices for Bins
    bin_base = mat_base + 8
    bin_recycle = bin_base + 0
    bin_compost = bin_base + 1
    bin_haz = bin_base + 2
    bin_landfill = bin_base + 3
    
    # 1. Edges: Material -> Bin (Knowledge Rules)
    mat_bin_map = [
        (mat_plastic, bin_recycle),
        (mat_paper, bin_recycle),
        (mat_glass, bin_recycle),
        (mat_metal, bin_recycle),
        (mat_organic, bin_compost),
        (mat_fabric, bin_landfill), 
        (mat_ewaste, bin_haz),
        (mat_misc, bin_landfill)
    ]
    
    for m, b in mat_bin_map:
        edge_sources.append(m); edge_targets.append(b)
        edge_sources.append(b); edge_targets.append(m)
        
    # 2. Edges: Item -> Material (Simulate Classification Knowledge)
    for i in range(num_classes):
        mat_idx = mat_base + (i % 8) 
        edge_sources.append(i); edge_targets.append(mat_idx)
        edge_sources.append(mat_idx); edge_targets.append(i)
        
    # 3. Edges: Item -> Item (Similarity)
    for i in range(num_classes):
        neighbor = (i + 8) % num_classes
        edge_sources.append(i); edge_targets.append(neighbor)
        edge_sources.append(neighbor); edge_targets.append(i)

    edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long)
    
    logger.info(f"Graph generated: {total_nodes} nodes, {len(edge_sources)} edges.")
    
    return Data(x=x, edge_index=edge_index, num_nodes=total_nodes)

class GATv2Model(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=4, heads=8, dropout=0.3):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GATv2Conv(in_channels, hidden_channels, heads=heads, concat=True, dropout=dropout))
        for _ in range(num_layers - 2):
            self.convs.append(GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads, concat=True, dropout=dropout))
        self.convs.append(GATv2Conv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout))
        self.dropout = dropout
        self.norm = nn.ModuleList([nn.LayerNorm(hidden_channels * heads) for _ in range(num_layers - 1)])

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.norm[i](x)
            x = F.gelu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return self.convs[-1](x, edge_index)

In [None]:
def train_gnn_model():
    set_seed()
    optimize_memory()
    device = get_device()
    logger.info(f"Using device: {device}")

    in_dim = 128
    hidden_dim = 512
    out_dim = 256
    lr = 0.001
    epochs = 50

    data = generate_structured_knowledge_graph(num_classes=30, feat_dim=128).to(device)

    model = GATv2Model(in_dim, hidden_dim, out_dim).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)

    logger.info("Starting GNN Training...")
    best_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        z = model(data.x, data.edge_index)

        pos_src, pos_dst = data.edge_index
        pos_loss = -torch.log(torch.sigmoid((z[pos_src] * z[pos_dst]).sum(dim=1)) + 1e-15).mean()

        neg_src = torch.randint(0, data.num_nodes, (pos_src.size(0),), device=device)
        neg_dst = torch.randint(0, data.num_nodes, (pos_src.size(0),), device=device)
        neg_loss = -torch.log(1 - torch.sigmoid((z[neg_src] * z[neg_dst]).sum(dim=1)) + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step(loss)

        if loss.item() < best_loss:
            best_loss = loss.item()

        if (epoch + 1) % 5 == 0:
            logger.info(f"Epoch {epoch+1}/{epochs}: Loss {loss.item():.4f}, Best Loss {best_loss:.4f}")

    return model

In [None]:
if __name__ == "__main__":
    try:
        logger.info("="*80)
        logger.info("Phase 1: Multi-Source Data Lake Vision Training")
        logger.info("="*80)

        vision_model = train_vision_model(VISION_CONFIG)

        if vision_model is not None:
            save_path = "best_vision_eva02_lake.pth"
            torch.save(vision_model.state_dict(), save_path)
            logger.info(f"Vision model saved to {save_path}")

            del vision_model
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        else:
            logger.error("Vision model training failed")

        logger.info("="*80)
        logger.info("Phase 2: GNN Knowledge Graph Training")
        logger.info("="*80)

        gnn_model = train_gnn_model()

        if gnn_model is not None:
            save_path = "best_gnn_gatv2.pth"
            torch.save(gnn_model.state_dict(), save_path)
            logger.info(f"GNN model saved to {save_path}")

            del gnn_model
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        logger.info("="*80)
        logger.info("Training completed successfully!")
        logger.info("="*80)

    except Exception as e:
        logger.error(f"Training failed with error: {e}")
        import traceback
        traceback.print_exc()
        raise