In [None]:
!pip install -q timm torch-geometric albumentations pycocotools wandb einops

In [None]:
import os
import sys
import json
import random
import logging
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from PIL import Image

# PEAK STANDARD Libraries
import timm
from timm.data import create_transform, resolve_data_config
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv
from tqdm.notebook import tqdm
import wandb

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

class EarlyStopping:
    def __init__(self, patience=15, mode="max", delta=0):
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.mode = mode
        self.delta = delta

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
        elif self.mode == "max":
            if current_score <= self.best_score + self.delta:
                self.counter += 1
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = current_score
                self.counter = 0
        return self.early_stop

In [None]:
# PEAK STANDARD CONFIGURATION
VISION_CONFIG = {
    "model": {
        # SOTA Backbone: EVA-02 Large
        "backbone": "eva02_large_patch14_448.mim_m38m_ft_in22k_in1k",
        "pretrained": True,
        "num_classes": 12,
        "drop_rate": 0.2,
        "drop_path_rate": 0.2
    },
    "data": {
        "train_dir": "/kaggle/input/garbage-classification/garbage_classification/",
        "input_size": 448,  # High Resolution
        "num_workers": 4,
        "pin_memory": True
    },
    "training": {
        "batch_size": 16,  # Smaller batch for 448px
        "grad_accum_steps": 4, 
        "learning_rate": 1e-4,
        "weight_decay": 0.05,
        "num_epochs": 30,
        "patience": 10
    }
}

In [None]:
class VisionDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.classes = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.images = self._load_images()

    def _load_images(self):
        images = []
        for class_name in self.classes:
            class_dir = self.root_dir / class_name
            for img_path in class_dir.glob("*"):
                if img_path.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]:
                    images.append((img_path, self.class_to_idx[class_name]))
        return images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path, label = self.images[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            return torch.zeros((3, 448, 448)), label

In [None]:
def get_vision_transforms(config, model, is_train=True):
    # Use timm's data config to get optimal normalization and resolution
    data_config = resolve_data_config(model.default_cfg, model=model)
    
    if is_train:
        # State-of-the-Art Training Transforms (AutoAugment, RandAugment, etc.)
        return create_transform(
            input_size=data_config['input_size'],
            is_training=True,
            use_prefetcher=False,
            no_aug=False,
            scale=(0.08, 1.0),
            ratio=(0.75, 1.33),
            hflip=0.5,
            vflip=0.0,
            color_jitter=0.4,
            auto_augment='rand-m9-mstd0.5-inc1',  # Strongest augmentations
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
        )
    else:
        # Validation Transforms (Center Crop / Resize)
        return create_transform(
            input_size=data_config['input_size'],
            is_training=False,
            use_prefetcher=False,
            interpolation=data_config['interpolation'],
            mean=data_config['mean'],
            std=data_config['std'],
        )

def create_vision_model(config):
    model = timm.create_model(
        config["model"]["backbone"],
        pretrained=config["model"]["pretrained"],
        num_classes=config["model"]["num_classes"],
        drop_rate=config["model"]["drop_rate"],
        drop_path_rate=config["model"]["drop_path_rate"]
    )
    return model

In [None]:
def train_vision_model(config):
    set_seed()
    device = get_device()
    logger.info(f"Using device: {device}")
    
    model = create_vision_model(config).to(device)
    
    # Get transforms specifically tuned for this model
    train_transform = get_vision_transforms(config, model, is_train=True)
    val_transform = get_vision_transforms(config, model, is_train=False)

    full_dataset = VisionDataset(
        config["data"]["train_dir"],
        transform=train_transform
    )

    train_size = int(0.85 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    if train_size == 0:
        print("No data found.")
        return

    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
    val_dataset.dataset.transform = val_transform

    train_loader = DataLoader(
        train_dataset,
        batch_size=config["training"]["batch_size"],
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config["training"]["batch_size"] * 2,
        shuffle=False,
        num_workers=4
    )

    optimizer = optim.AdamW(
        model.parameters(), 
        lr=config["training"]["learning_rate"],
        weight_decay=config["training"]["weight_decay"]
    )
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["training"]["num_epochs"])
    early_stopping = EarlyStopping(patience=config["training"]["patience"])
    
    # Modern AMP
    scaler = torch.cuda.amp.GradScaler() if device.type == "cuda" else None

    logger.info(f"Starting PEAK Vision Training with {config['model']['backbone']}...")
    accumulation_steps = config["training"]["grad_accum_steps"]

    for epoch in range(config["training"]["num_epochs"]):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        optimizer.zero_grad()

        for i, (images, labels) in enumerate(pbar):
            images, labels = images.to(device), labels.to(device)
            
            # Mixed Precision Training
            if scaler:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels) / accumulation_steps
                scaler.scale(loss).backward()
                
                if (i + 1) % accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                outputs = model(images)
                loss = criterion(outputs, labels) / accumulation_steps
                loss.backward()
                if (i + 1) % accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()

            running_loss += loss.item() * accumulation_steps
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({'loss': running_loss / (i+1)})

        scheduler.step()
        train_acc = 100 * correct / total
        
        # Validation
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = 100 * val_correct / val_total
        val_loss /= len(val_loader)

        print(f"Ep {epoch+1}: Train Acc {train_acc:.2f}%, Val Loss {val_loss:.4f}, Val Acc {val_acc:.2f}%")

        if early_stopping(val_acc):
            print("Early stopping triggered")
            break
            
    return model

In [None]:
def generate_synthetic_graph_data(num_nodes=1000, feat_dim=128):
    x = torch.randn(num_nodes, feat_dim)
    
    # Generate random edges
    src = torch.randint(0, num_nodes, (num_nodes * 10,))
    dst = torch.randint(0, num_nodes, (num_nodes * 10,))
    edge_index = torch.stack([src, dst], dim=0)
    
    # Generate dummy labels (0: not upcyclable, 1: upcyclable)
    # In Link Prediction, negative sampling is used, so valid edges are positives
    return Data(x=x, edge_index=edge_index)

# PEAK STANDARD GNN: GATv2
class GATv2Model(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=4, heads=8, dropout=0.3):
        super().__init__()
        self.convs = nn.ModuleList()
        
        # Multi-head attention layers
        self.convs.append(GATv2Conv(in_channels, hidden_channels, heads=heads, concat=True, dropout=dropout))
        for _ in range(num_layers - 2):
            self.convs.append(GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads, concat=True, dropout=dropout))
        
        # Output layer
        self.convs.append(GATv2Conv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout))
        
        self.dropout = dropout
        self.norm = nn.ModuleList([nn.LayerNorm(hidden_channels * heads) for _ in range(num_layers - 1)])

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = self.norm[i](x)
            x = F.gelu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        return self.convs[-1](x, edge_index)

In [None]:
def train_gnn_model():
    set_seed()
    device = get_device()
    logger.info(f"Using device: {device}")
    
    # Config
    in_dim = 128
    hidden_dim = 512    # Peak Capacity
    out_dim = 256
    lr = 0.0005
    epochs = 100
    
    # Synthetic Data with higher complexity
    data = generate_synthetic_graph_data(num_nodes=5000, feat_dim=128).to(device)
    
    model = GATv2Model(in_dim, hidden_dim, out_dim).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5)
    
    logger.info("Starting PEAK GNN Training...")
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        z = model(data.x, data.edge_index)
        
        # Link Prediction
        pos_loss = -torch.log(torch.sigmoid((z[data.edge_index[0]] * z[data.edge_index[1]]).sum(dim=1)) + 1e-15).mean()
        
        # Hard negative mining (random for demo)
        neg_src = torch.randint(0, data.num_nodes, (data.edge_index.size(1),), device=device)
        neg_dst = torch.randint(0, data.num_nodes, (data.edge_index.size(1),), device=device)
        neg_loss = -torch.log(1 - torch.sigmoid((z[neg_src] * z[neg_dst]).sum(dim=1)) + 1e-15).mean()
        
        loss = pos_loss + neg_loss
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step(loss)
        
        if (epoch + 1) % 10 == 0:
            print(f"Ep {epoch+1}: Loss {loss.item():.4f}")
            
    return model

In [None]:
if __name__ == "__main__":
    # 1. Train Vision Model
    if os.path.exists(VISION_CONFIG["data"]["train_dir"]):
        vision_model = train_vision_model(VISION_CONFIG)
        torch.save(vision_model.state_dict(), "best_vision_eva02.pth")
        print("Vision model saved.")
    else:
        print("Dataset not found. Please attach 'Garbage Classification' dataset.")

    # 2. Train GNN Model
    gnn_model = train_gnn_model()
    torch.save(gnn_model.state_dict(), "best_gnn_gatv2.pth")
    print("GNN model saved.")