In [None]:
!pip install -q timm torch-geometric torch-scatter torch-sparse albumentations wandb einops

In [None]:
import warnings
# Suppress Pydantic warnings (must be done before other imports)
warnings.filterwarnings("ignore", message=".*The 'repr' attribute with value False.*")
warnings.filterwarnings("ignore", message=".*The 'frozen' attribute with value True.*")
warnings.filterwarnings("ignore", module="pydantic")

import os
import sys
import json
import random
import logging
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, ConcatDataset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from PIL import Image

# PEAK STANDARD Libraries
import timm
from timm.data import create_transform, resolve_data_config
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv
from tqdm.notebook import tqdm
import wandb

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def set_seed(seed: int = 42):
    """
    Set entire environment seed for reproducibility including CUDA benchmarks.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    logger.info(f"Random seed set to {seed}")

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

class EarlyStopping:
    def __init__(self, patience=15, mode="max", delta=0):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.mode = mode
        self.delta = delta

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
        elif self.mode == "max":
            if current_score <= self.best_score + self.delta:
                self.counter += 1
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = current_score
                self.counter = 0
        return self.early_stop

In [None]:
# PEAK STANDARD CONFIGURATION: DATA LAKE STRATEGY
# Combining 3 verified massive datasets

# 1. Define the Master Schema (30 Classes from 'Recyclable and Household Waste')
TARGET_CLASSES = [
    'aerosol_cans', 'aluminum_food_cans', 'aluminum_soda_cans', 'cardboard_boxes', 'cardboard_packaging',
    'clothing', 'coffee_grounds', 'disposable_plastic_cutlery', 'egg_shells', 'food_waste',
    'glass_beverage_bottles', 'glass_cosmetic_containers', 'glass_food_jars', 'magazines',
    'newspaper', 'office_paper', 'paper_cups', 'plastic_cup_lids', 'plastic_detergent_bottles',
    'plastic_food_containers', 'plastic_shopping_bags', 'plastic_soda_bottles', 'plastic_straws',
    'plastic_trash_bags', 'plastic_water_bottles', 'shoes', 'steel_food_cans', 'styrofoam_cups',
    'styrofoam_food_containers', 'tea_bags'
]

VISION_CONFIG = {
    "model": {
        # SOTA Backbone: EVA-02 Large
        "backbone": "eva02_large_patch14_448.mim_m38m_ft_in22k_in1k",
        "pretrained": True,
        "num_classes": 30,
        "drop_rate": 0.3,
        "drop_path_rate": 0.2
    },
    "data": {
        "input_size": 448,
        "num_workers": 2,
        "pin_memory": True,
        # Data Lake Configuration - VERIFIED PATHS
        "sources": [
            {
                "name": "master_30 (Alistair King)",
                "path": "/kaggle/input/recyclable-and-household-waste-classification/images",
                "type": "master"
            },
            {
                "name": "garbage_12 (Mostafa Abla)",
                "path": "/kaggle/input/garbage-classification/garbage_classification",
                "type": "mapped_12"
            },
            {
                "name": "waste_22k (TechSash)",
                "path": "/kaggle/input/waste-classification-data/DATASET",
                "type": "mapped_2"
            }
        ]
    },
    "training": {
        "batch_size": 8,
        "grad_accum_steps": 8,
        "learning_rate": 5e-5,
        "weight_decay": 0.05,
        "num_epochs": 20,
        "patience": 5
    }
}

In [None]:
class UnifiedWasteDataset(Dataset):
    """
    A unified dataset that ingests data from multiple sources and maps them
    to a single 30-class target schema.
    """
    def __init__(self, sources_config, target_classes, transform=None):
        self.transform = transform
        self.target_classes = sorted(target_classes)
        self.class_to_idx = {c: i for i, c in enumerate(self.target_classes)}
        self.samples = []
        
        self.skipped_count = 0
        
        for source in sources_config:
            self._ingest_source(source)
            
        logger.info(f"Unified Dataset Created: {len(self.samples)} images. Skipped {self.skipped_count} unmappable images.")

    def _ingest_source(self, source):
        path = Path(source["path"])
        if not path.exists():
            # Fallback search if exact path fails (common in Kaggle)
            parent = path.parent
            found = False
            if parent.exists():
                for child in parent.iterdir():
                    if child.is_dir() and any(child.iterdir()):
                         # Very basic heuristic: check if it looks like the dataset
                         path = child
                         found = True
                         break
            
            if not found or not path.exists():
                logger.warning(f"Source {source['name']} not found at {source['path']}. Skipping.")
                return
        
        logger.info(f"Ingesting {source['name']} from {path}...")

        # Iterate through class folders
        # Note: 'waste_22k' has TRAIN/TEST split folders usually, we should handle recursion or assume flat if simple.
        # To be robust, we walk the tree.
        for root, _, files in os.walk(path):
            folder_name = Path(root).name.lower()
            
            # Determine target label based on source type logic
            target_label = self._map_label(folder_name, source['type'])
            
            if target_label:
                target_idx = self.class_to_idx[target_label]
                for file in files:
                    if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                         self.samples.append((Path(root) / file, target_idx))
            else:
                # Only count skipped files if they are images
                img_count = sum(1 for f in files if f.lower().endswith(('.jpg', '.png')))
                if img_count > 0:
                    self.skipped_count += img_count

    def _map_label(self, raw_label, source_type):
        """
        Intelligent mapping logic to unify taxonomies.
        """
        raw = raw_label.lower().strip()
        
        # Strategy 1: Master Schema (Identity)
        if source_type == 'master':
            # Try exact match first
            if raw in self.target_classes:
                return raw
            # Try heuristic match (e.g. 'bio' -> 'biological')
            # But master should largely match. 
            return None
            
        # Strategy 2: 12-Class Garbage Classification
        # Classes: paper, cardboard, battery, metal, plastic, glass, [clothes, shoes -> clothing], trash, biological
        if source_type == 'mapped_12':
            mapping = {
                'paper': 'office_paper', # Approx
                'cardboard': 'cardboard_boxes',
                'plastic': 'plastic_food_containers', # Generalize to most common
                'metal': 'aluminum_food_cans', # Generalize
                'glass': 'glass_food_jars', # Generalize
                'brown-glass': 'glass_beverage_bottles', # Beer bottles
                'green-glass': 'glass_beverage_bottles',
                'white-glass': 'glass_food_jars',
                'clothes': 'clothing',
                'shoes': 'shoes',
                'biological': 'food_waste',
                'trash': 'food_waste' # Often mixed/dirty
            }
            return mapping.get(raw)

        # Strategy 3: 2-Class (Organic vs Recyclable)
        # This is tricky. We map 'organic' to 'food_waste' and 'recyclable' to... a mix?
        # Actually, 'recyclable' is too broad. We might skip it or map to a generic class if we had one.
        # DECISION: Only use the Organic part for 'food_waste' augmentation, as Recyclable is too noisy.
        if source_type == 'mapped_2':
            if raw == 'organic' or raw == 'o':
                return 'food_waste'
            # 'recyclable' / 'r' is skipped to preserve data quality of specific classes
            return None
            
        return None

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label_idx = self.samples[idx]
        try:
            img = Image.open(path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            return img, label_idx
        except Exception as e:
            logger.error(f"Corrupt image {path}: {e}")
            return torch.zeros((3, 448, 448)), label_idx

    def get_labels(self):
        return [s[1] for s in self.samples]

In [None]:
def get_vision_transforms(config, model, is_train=True):
    try:
        data_config = resolve_data_config(model.default_cfg, model=model)
        if is_train:
            return create_transform(
                input_size=data_config['input_size'],
                is_training=True,
                use_prefetcher=False,
                no_aug=False,
                scale=(0.08, 1.0),
                ratio=(0.75, 1.33),
                hflip=0.5,
                vflip=0.0,
                color_jitter=0.4,
                auto_augment='rand-m9-mstd0.5-inc1',
                interpolation=data_config['interpolation'],
                mean=data_config['mean'],
                std=data_config['std'],
                re_prob=0.25,
                re_mode='pixel',
                re_count=1,
            )
        else:
            return create_transform(
                input_size=data_config['input_size'],
                is_training=False,
                use_prefetcher=False,
                interpolation=data_config['interpolation'],
                mean=data_config['mean'],
                std=data_config['std'],
            )
    except Exception as e:
        logger.error(f"Failed to create transforms: {e}")
        raise

In [None]:
def create_vision_model(config):
    logger.info(f"Creating model: {config['model']['backbone']}")
    model = timm.create_model(
        config["model"]["backbone"],
        pretrained=config["model"]["pretrained"],
        num_classes=config["model"]["num_classes"],
        drop_rate=config["model"]["drop_rate"],
        drop_path_rate=config["model"]["drop_path_rate"]
    )
    return model

In [None]:
def train_vision_model(config):
    set_seed()
    device = get_device()
    logger.info(f"Using device: {device}")
    
    # 1. Initialize Model
    model = create_vision_model(config).to(device)
    
    # 2. Setup Data Pipeline
    train_transform = get_vision_transforms(config, model, is_train=True)
    val_transform = get_vision_transforms(config, model, is_train=False)
    
    # Unified Data Ingestion
    full_dataset = UnifiedWasteDataset(
        sources_config=config["data"]["sources"],
        target_classes=TARGET_CLASSES,
        transform=train_transform
    )

    if len(full_dataset) == 0:
         logger.error("Dataset is empty. Check paths.")
         return None

    # Split
    train_size = int(0.85 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
    
    # Handle Imbalance with Weighted Random Sampler
    # We need to extract labels from the subset (subset doesn't expose them directly easily)
    # So we iterate (a bit slow but robust) or assume stats.
    # For Kaggle speed, we will skip weighted sampler unless crucial, 
    # as this is a fine-tuning job on robust pretrained weights.
    
    val_dataset.dataset.transform = val_transform

    train_loader = DataLoader(
        train_dataset,
        batch_size=config["training"]["batch_size"],
        shuffle=True,
        num_workers=config["data"]["num_workers"],
        pin_memory=config["data"]["pin_memory"]
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["training"]["batch_size"] * 2,
        shuffle=False,
        num_workers=config["data"]["num_workers"]
    )

    # 3. Optimization
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=config["training"]["learning_rate"],
        weight_decay=config["training"]["weight_decay"]
    )
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["training"]["num_epochs"])
    early_stopping = EarlyStopping(patience=config["training"]["patience"])
    
    # 4. Mixed Precision
    use_amp = (device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler() if use_amp else None
    
    # 5. Training Loop
    accumulation_steps = config["training"]["grad_accum_steps"]
    
    try:
        wandb.init(project="sustainability-vision-lake", config=config)
    except:
        pass

    for epoch in range(config["training"]["num_epochs"]):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['training']['num_epochs']}")
        optimizer.zero_grad()

        for i, (images, labels) in enumerate(pbar):
            images, labels = images.to(device), labels.to(device)
            
            if use_amp:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels) / accumulation_steps
                scaler.scale(loss).backward()
                
                if (i + 1) % accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                outputs = model(images)
                loss = criterion(outputs, labels) / accumulation_steps
                loss.backward()
                if (i + 1) % accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()

            running_loss += loss.item() * accumulation_steps
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            current_loss = running_loss / (i + 1)
            pbar.set_postfix({'loss': f"{current_loss:.4f}"})

        scheduler.step()
        train_acc = 100 * correct / total
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = 100 * val_correct / val_total
        val_loss /= len(val_loader)

        print(f"Ep {epoch+1}: Train Acc {train_acc:.2f}%, Val Loss {val_loss:.4f}, Val Acc {val_acc:.2f}%")
        
        try: wandb.log({"train_acc": train_acc, "val_acc": val_acc, "val_loss": val_loss}); except: pass

        if early_stopping(val_acc):
            print("Early stopping triggered")
            break
            
    return model

In [None]:
# PEAK STANDARD GNN
# Using Graph Attention Networks v2 (GATv2) for superior expressive power

def generate_structured_knowledge_graph(num_classes=30, feat_dim=128):
    """
    Generates a realistic Knowledge Graph structure for waste classification.
    Simulates the schema: Item -> Material -> Bin
    """
    logger.info("Generating structured Knowledge Graph...")
    
    total_nodes = num_classes + 8 + 4
    x = torch.randn(total_nodes, feat_dim) # Node features (embeddings)
    
    edge_sources = []
    edge_targets = []
    
    # Node Indices for Materials
    mat_base = num_classes
    mat_plastic = mat_base + 0
    mat_paper = mat_base + 1
    mat_glass = mat_base + 2
    mat_metal = mat_base + 3
    mat_organic = mat_base + 4
    mat_fabric = mat_base + 5
    mat_ewaste = mat_base + 6
    mat_misc = mat_base + 7
    
    # Node Indices for Bins
    bin_base = mat_base + 8
    bin_recycle = bin_base + 0
    bin_compost = bin_base + 1
    bin_haz = bin_base + 2
    bin_landfill = bin_base + 3
    
    # 1. Edges: Material -> Bin (Knowledge Rules)
    mat_bin_map = [
        (mat_plastic, bin_recycle),
        (mat_paper, bin_recycle),
        (mat_glass, bin_recycle),
        (mat_metal, bin_recycle),
        (mat_organic, bin_compost),
        (mat_fabric, bin_landfill), 
        (mat_ewaste, bin_haz),
        (mat_misc, bin_landfill)
    ]
    
    for m, b in mat_bin_map:
        edge_sources.append(m); edge_targets.append(b)
        edge_sources.append(b); edge_targets.append(m)
        
    # 2. Edges: Item -> Material (Simulate Classification Knowledge)
    for i in range(num_classes):
        mat_idx = mat_base + (i % 8) 
        edge_sources.append(i); edge_targets.append(mat_idx)
        edge_sources.append(mat_idx); edge_targets.append(i)
        
    # 3. Edges: Item -> Item (Similarity)
    for i in range(num_classes):
        neighbor = (i + 8) % num_classes
        edge_sources.append(i); edge_targets.append(neighbor)
        edge_sources.append(neighbor); edge_targets.append(i)

    edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long)
    
    logger.info(f"Graph generated: {total_nodes} nodes, {len(edge_sources)} edges.")
    
    return Data(x=x, edge_index=edge_index, num_nodes=total_nodes)

class GATv2Model(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=4, heads=8, dropout=0.3):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GATv2Conv(in_channels, hidden_channels, heads=heads, concat=True, dropout=dropout))
        for _ in range(num_layers - 2):
            self.convs.append(GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads, concat=True, dropout=dropout))
        self.convs.append(GATv2Conv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout))
        self.dropout = dropout
        self.norm = nn.ModuleList([nn.LayerNorm(hidden_channels * heads) for _ in range(num_layers - 1)])

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.norm[i](x)
            x = F.gelu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return self.convs[-1](x, edge_index)

In [None]:
def train_gnn_model():
    set_seed()
    device = get_device()
    logger.info(f"Using device: {device}")
    
    in_dim = 128
    hidden_dim = 512
    out_dim = 256
    lr = 0.001
    epochs = 50
    
    data = generate_structured_knowledge_graph(num_classes=30, feat_dim=128).to(device)
    
    model = GATv2Model(in_dim, hidden_dim, out_dim).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    
    logger.info("Starting PEAK GNN Training...")
    model.train()
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        z = model(data.x, data.edge_index)
        
        pos_src, pos_dst = data.edge_index
        pos_loss = -torch.log(torch.sigmoid((z[pos_src] * z[pos_dst]).sum(dim=1)) + 1e-15).mean()
        
        neg_src = torch.randint(0, data.num_nodes, (pos_src.size(0),), device=device)
        neg_dst = torch.randint(0, data.num_nodes, (pos_src.size(0),), device=device)
        neg_loss = -torch.log(1 - torch.sigmoid((z[neg_src] * z[neg_dst]).sum(dim=1)) + 1e-15).mean()
        
        loss = pos_loss + neg_loss
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step(loss)
        
        if (epoch + 1) % 5 == 0:
            print(f"Ep {epoch+1}: Loss {loss.item():.4f}")
            
    return model

In [None]:
if __name__ == "__main__":
    logger.info("Phase 1: Multi-Source Data Lake Training")
    vision_model = train_vision_model(VISION_CONFIG)
    if vision_model:
        torch.save(vision_model.state_dict(), "best_vision_eva02_lake.pth")
        print("Vision model artifacts saved successfully.")

    logger.info("Phase 2: GNN Model Training")
    gnn_model = train_gnn_model()
    torch.save(gnn_model.state_dict(), "best_gnn_gatv2.pth")
    print("GNN model artifacts saved successfully.")