In [None]:

!pip install -q timm torch-geometric torch-scatter torch-sparse albumentations wandb einops

In [None]:
import os
import sys
import json
import random
import logging
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from PIL import Image

# PEAK STANDARD Libraries
import timm
from timm.data import create_transform, resolve_data_config
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv
from tqdm.notebook import tqdm
import wandb

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def set_seed(seed: int = 42):
    """
    Set entire environment seed for reproducibility including CUDA benchmarks.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    logger.info(f"Random seed set to {seed}")

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def find_dataset_path(preferred_path: str) -> Path:
    """
    Robustly find the dataset path in Kaggle's input directory.
    If the preferred path doesn't exist, it searches for a directory containing images.
    """
    preferred = Path(preferred_path)
    if preferred.exists() and any(preferred.iterdir()):
        logger.info(f"Dataset found at: {preferred}")
        return preferred
    
    logger.warning(f"Preferred path {preferred} not found. Searching /kaggle/input...")
    base = Path("/kaggle/input")
    if not base.exists():
        logger.warning("Run is not on Kaggle or input directory missing.")
        return preferred # Return default and let it fail later if not on Kaggle

    # BFS Search for a folder that looks like a dataset (has subfolders with images)
    for root, dirs, files in os.walk(base):
        if len(dirs) > 1: # Candidate for class folders
            # Check if first subfolder contains images
            first_sub = Path(root) / dirs[0]
            if any(f.lower().endswith(('.jpg', '.png', '.jpeg')) for f in os.listdir(first_sub)):
                found_path = Path(root)
                logger.info(f"Dataset discovered at: {found_path}")
                return found_path
    
    logger.error("Could not automatically locate a valid dataset.")
    return preferred

class EarlyStopping:
    def __init__(self, patience=15, mode="max", delta=0):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.mode = mode
        self.delta = delta

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
        elif self.mode == "max":
            if current_score <= self.best_score + self.delta:
                self.counter += 1
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = current_score
                self.counter = 0
        return self.early_stop

In [None]:
# PEAK STANDARD CONFIGURATION
VISION_CONFIG = {
    "model": {
        # SOTA Backbone: EVA-02 Large
        # Optimized for fine-grained classification
        "backbone": "eva02_large_patch14_448.mim_m38m_ft_in22k_in1k",
        "pretrained": True,
        "num_classes": 12,
        "drop_rate": 0.3,      # Increased regularization for large model
        "drop_path_rate": 0.2
    },
    "data": {
        "train_dir": "/kaggle/input/garbage-classification/garbage_classification/",
        "input_size": 448,  # High Resolution
        "num_workers": 2,   # Kaggle often limits shared mem, 2 is safe
        "pin_memory": True
    },
    "training": {
        "batch_size": 8,     # Reduced for memory safety with Large model @ 448px
        "grad_accum_steps": 8, # Effective batch size = 64
        "learning_rate": 5e-5, # Lower LR for finetuning large model
        "weight_decay": 0.05,
        "num_epochs": 20,    # Large models converge faster
        "patience": 5
    }
}

In [None]:
class VisionDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        if not self.root_dir.exists():
            # Empty initialization if path is wrong, handled in main
            self.classes = []
            self.images = []
            return
            
        self.classes = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.images = self._load_images()

    def _load_images(self):
        images = []
        for class_name in self.classes:
            class_dir = self.root_dir / class_name
            for img_path in class_dir.glob("*"):
                if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]:
                    images.append((img_path, self.class_to_idx[class_name]))
        return images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path, label = self.images[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            # Return a blank image to avoid crashing training
            logger.error(f"Error loading {img_path}: {e}")
            return torch.zeros((3, 448, 448)), label

In [None]:
def get_vision_transforms(config, model, is_train=True):
    """
    Creates SOTA transforms using timm's factory methods.
    Guarantees compatibility with the model's pretraining (normalization, etc.)
    """
    try:
        # Use timm's data config to get optimal normalization and resolution
        data_config = resolve_data_config(model.default_cfg, model=model)
        
        if is_train:
            # State-of-the-Art Training Transforms
            # Includes: AutoAugment/RandAugment, RandomErasing, ColorJitter
            return create_transform(
                input_size=data_config['input_size'],
                is_training=True,
                use_prefetcher=False, # We use PyTorch DataLoader
                no_aug=False,
                scale=(0.08, 1.0),
                ratio=(0.75, 1.33),
                hflip=0.5,
                vflip=0.0,
                color_jitter=0.4,
                auto_augment='rand-m9-mstd0.5-inc1',  # Robust RandAugment policy
                interpolation=data_config['interpolation'],
                mean=data_config['mean'],
                std=data_config['std'],
                re_prob=0.25, # Random Erasing Probability
                re_mode='pixel',
                re_count=1,
            )
        else:
            # Exact validation transforms used during pretraining
            return create_transform(
                input_size=data_config['input_size'],
                is_training=False,
                use_prefetcher=False,
                interpolation=data_config['interpolation'],
                mean=data_config['mean'],
                std=data_config['std'],
            )
    except Exception as e:
        logger.error(f"Failed to create transforms: {e}")
        raise

In [None]:
def create_vision_model(config):
    """
    Creates the vision model with specific attention to SOTA architecture support.
    """
    logger.info(f"Creating model: {config['model']['backbone']}")
    model = timm.create_model(
        config["model"]["backbone"],
        pretrained=config["model"]["pretrained"],
        num_classes=config["model"]["num_classes"],
        drop_rate=config["model"]["drop_rate"],
        drop_path_rate=config["model"]["drop_path_rate"]
    )
    return model

In [None]:
def train_vision_model(config):
    set_seed()
    device = get_device()
    logger.info(f"Using device: {device}")
    
    # 1. Initialize Model
    model = create_vision_model(config).to(device)
    
    # 2. Setup Data Pipeline
    train_transform = get_vision_transforms(config, model, is_train=True)
    val_transform = get_vision_transforms(config, model, is_train=False)

    # Robust Path Discovery
    data_path = find_dataset_path(config["data"]["train_dir"])
    full_dataset = VisionDataset(data_path, transform=train_transform)

    train_size = int(0.85 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    if train_size == 0:
        logger.error("Dataset is empty or path is incorrect. Aborting training.")
        return None

    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
    # Important: Apply correct transforms to validation set
    val_dataset.dataset.transform = val_transform

    train_loader = DataLoader(
        train_dataset,
        batch_size=config["training"]["batch_size"],
        shuffle=True,
        num_workers=config["data"]["num_workers"],
        pin_memory=config["data"]["pin_memory"]
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["training"]["batch_size"] * 2,
        shuffle=False,
        num_workers=config["data"]["num_workers"]
    )

    # 3. Optimization Setup
    optimizer = optim.AdamW(
        model.parameters(), 
        lr=config["training"]["learning_rate"],
        weight_decay=config["training"]["weight_decay"]
    )
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["training"]["num_epochs"])
    early_stopping = EarlyStopping(patience=config["training"]["patience"])
    
    # 4. Mixed Precision Setup (Robust)
    use_amp = (device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler() if use_amp else None
    logger.info(f"Mixed Precision Enabled: {use_amp}")

    # 5. Training Loop
    accumulation_steps = config["training"]["grad_accum_steps"]
    
    # Try to init WandB silently (fails safely if no login)
    try: 
        wandb.init(project="sustainability-vision", config=config) 
    except: 
        pass

    for epoch in range(config["training"]["num_epochs"]):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['training']['num_epochs']}")
        optimizer.zero_grad()

        for i, (images, labels) in enumerate(pbar):
            images, labels = images.to(device), labels.to(device)
            
            # AMP Forward Pass
            if use_amp:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels) / accumulation_steps
                scaler.scale(loss).backward()
                
                if (i + 1) % accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                outputs = model(images)
                loss = criterion(outputs, labels) / accumulation_steps
                loss.backward()
                if (i + 1) % accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()

            running_loss += loss.item() * accumulation_steps
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # Update ProgressBar
            current_loss = running_loss / (i + 1)
            pbar.set_postfix({'loss': f"{current_loss:.4f}"})

        scheduler.step()
        train_acc = 100 * correct / total
        
        # Validation Loop
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = 100 * val_correct / val_total
        val_loss /= len(val_loader)

        print(f"Ep {epoch+1}: Train Acc {train_acc:.2f}%, Val Loss {val_loss:.4f}, Val Acc {val_acc:.2f}%")
        
        # Safe Log
        try: wandb.log({"train_acc": train_acc, "val_acc": val_acc, "val_loss": val_loss}) 
        except: pass

        if early_stopping(val_acc):
            print("Early stopping triggered")
            break
            
    return model

In [None]:
# PEAK STANDARD GNN
# Using Graph Attention Networks v2 (GATv2) for superior expressive power

def generate_synthetic_graph_data(num_nodes=5000, feat_dim=128):
    """
    Generates a synthetic graph to demonstrate GNN training logic.
    In production, this would load the Knowledge Graph export.
    """
    x = torch.randn(num_nodes, feat_dim)
    
    # Generate random edges with valid indices
    src = torch.randint(0, num_nodes, (num_nodes * 10,), dtype=torch.long)
    dst = torch.randint(0, num_nodes, (num_nodes * 10,), dtype=torch.long)
    edge_index = torch.stack([src, dst], dim=0)
    
    return Data(x=x, edge_index=edge_index)

class GATv2Model(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=4, heads=4, dropout=0.3):
        super().__init__()
        self.convs = nn.ModuleList()
        
        # First layer
        self.convs.append(GATv2Conv(in_channels, hidden_channels, heads=heads, concat=True, dropout=dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads, concat=True, dropout=dropout))
        
        # Output layer (Embedding size = out_channels)
        self.convs.append(GATv2Conv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout))
        
        self.dropout = dropout
        self.norm = nn.ModuleList([nn.LayerNorm(hidden_channels * heads) for _ in range(num_layers - 1)])

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.norm[i](x)
            x = F.gelu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return self.convs[-1](x, edge_index)

In [None]:
def train_gnn_model():
    set_seed()
    device = get_device()
    logger.info(f"Using device: {device}")
    
    # Config
    in_dim = 128
    hidden_dim = 256    # Tuned for T4
    out_dim = 256
    lr = 0.001
    epochs = 50
    
    data = generate_synthetic_graph_data(num_nodes=5000, feat_dim=128).to(device)
    
    model = GATv2Model(in_dim, hidden_dim, out_dim).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    
    logger.info("Starting PEAK GNN Training...")
    model.train()
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        z = model(data.x, data.edge_index)
        
        # Unsupervised Link Prediction Loss
        # Positive edges: Existing edges in graph
        pos_src, pos_dst = data.edge_index
        pos_loss = -torch.log(torch.sigmoid((z[pos_src] * z[pos_dst]).sum(dim=1)) + 1e-15).mean()
        
        # Negative edges: Random pairs
        neg_src = torch.randint(0, data.num_nodes, (pos_src.size(0),), device=device)
        neg_dst = torch.randint(0, data.num_nodes, (pos_src.size(0),), device=device)
        neg_loss = -torch.log(1 - torch.sigmoid((z[neg_src] * z[neg_dst]).sum(dim=1)) + 1e-15).mean()
        
        loss = pos_loss + neg_loss
        loss.backward()
        
        # Gradient Clipping for Stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step(loss)
        
        if (epoch + 1) % 5 == 0:
            print(f"Ep {epoch+1}: Loss {loss.item():.4f}")
            
    return model

In [None]:
if __name__ == "__main__":
    # 1. Train Vision Model
    # We search for the dataset robustly
    vision_model = train_vision_model(VISION_CONFIG)
    if vision_model:
        torch.save(vision_model.state_dict(), "best_vision_eva02.pth")
        print("Vision model artifacts saved successfully.")

    # 2. Train GNN Model
    gnn_model = train_gnn_model()
    torch.save(gnn_model.state_dict(), "best_gnn_gatv2.pth")
    print("GNN model artifacts saved successfully.")