# Test nowych funkcjonalno≈õci

Szybki test nowo dodanych element√≥w:

## Nowe modele
- **GIN** - Graph Isomorphism Network
- **GCNVirtualNode** - GCN z Virtual Node
- **GINVirtualNode** - GIN z Virtual Node

## Nowe datasety
- **ogbg-molpcba** - Multi-label molecular classification (128 tasks)
- **ogbg-ppa** - Protein-protein association

---

Ten notebook tylko sprawdza czy wszystko dzia≈Ça - nie trenuje pe≈Çnych modeli.

In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

import torch
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Test import√≥w nowych modeli

In [None]:
# Test import√≥w
print("Importujƒô modele...")

from models import (
    GOAT, Exphormer,           # Graph Transformers
    GCN, GAT, GIN, GraphMLP,   # Baselines
    GCNVirtualNode, GINVirtualNode  # Hybrid
)

print("‚úì Wszystkie modele zaimportowane!")

# Lista modeli
all_models = [
    'GOAT', 'Exphormer',
    'GCN', 'GAT', 'GIN', 'GraphMLP',
    'GCNVirtualNode', 'GINVirtualNode'
]
print(f"\nDostƒôpne modele ({len(all_models)}): {', '.join(all_models)}")

## 2. Test nowych modeli - forward pass

In [None]:
from torch_geometric.data import Data, Batch

def create_dummy_batch(num_graphs=4, num_nodes=20, num_edges=40, in_channels=9):
    """Stw√≥rz sztuczny batch graf√≥w do test√≥w."""
    graphs = []
    for _ in range(num_graphs):
        x = torch.randn(num_nodes, in_channels)
        edge_index = torch.randint(0, num_nodes, (2, num_edges))
        y = torch.randn(1)  # regression target
        pe = torch.randn(num_nodes, 8)  # positional encoding
        graphs.append(Data(x=x, edge_index=edge_index, y=y, pe=pe))
    return Batch.from_data_list(graphs)

# Stw√≥rz dummy batch
batch = create_dummy_batch()
print(f"Dummy batch: {batch}")
print(f"  - Wƒôz≈Çy: {batch.num_nodes}")
print(f"  - Krawƒôdzie: {batch.num_edges}")
print(f"  - Grafy: {batch.num_graphs}")

In [None]:
# Test GIN
print("\n" + "="*50)
print("TEST: GIN (Graph Isomorphism Network)")
print("="*50)

gin = GIN(
    in_channels=9,
    hidden_channels=64,
    out_channels=1,
    num_layers=3,
    dropout=0.1,
    train_eps=True,
)

with torch.no_grad():
    out = gin(batch)

print(f"‚úì Forward pass OK!")
print(f"  Input: {batch.x.shape}")
print(f"  Output: {out.shape}")
print(f"  Parametry: {sum(p.numel() for p in gin.parameters()):,}")

In [None]:
# Test GCNVirtualNode
print("\n" + "="*50)
print("TEST: GCNVirtualNode (Hybrid)")
print("="*50)

gcn_vn = GCNVirtualNode(
    in_channels=9,
    hidden_channels=64,
    out_channels=1,
    num_layers=4,
    dropout=0.1,
)

with torch.no_grad():
    out = gcn_vn(batch)

print(f"‚úì Forward pass OK!")
print(f"  Input: {batch.x.shape}")
print(f"  Output: {out.shape}")
print(f"  Parametry: {sum(p.numel() for p in gcn_vn.parameters()):,}")

In [None]:
# Test GINVirtualNode
print("\n" + "="*50)
print("TEST: GINVirtualNode (Hybrid)")
print("="*50)

gin_vn = GINVirtualNode(
    in_channels=9,
    hidden_channels=64,
    out_channels=1,
    num_layers=4,
    dropout=0.1,
)

with torch.no_grad():
    out = gin_vn(batch)

print(f"‚úì Forward pass OK!")
print(f"  Input: {batch.x.shape}")
print(f"  Output: {out.shape}")
print(f"  Parametry: {sum(p.numel() for p in gin_vn.parameters()):,}")

## 3. Test nowych dataset√≥w

In [None]:
from src.utils.data import (
    load_molhiv_dataset,
    load_molpcba_dataset,
    load_ppa_dataset,
    load_zinc_dataset,
)

print("‚úì Funkcje ≈Çadowania dataset√≥w zaimportowane!")

In [None]:
# Test ogbg-molpcba
print("\n" + "="*50)
print("TEST: ogbg-molpcba (Multi-label classification)")
print("="*50)

print("≈Åadujƒô dataset... (mo≈ºe potrwaƒá przy pierwszym uruchomieniu)")
try:
    dataset_pcba, split_pcba = load_molpcba_dataset()
    print(f"‚úì Dataset za≈Çadowany!")
    print(f"  Graf√≥w: {len(dataset_pcba):,}")
    print(f"  Train: {len(split_pcba['train']):,}")
    print(f"  Valid: {len(split_pcba['valid']):,}")
    print(f"  Test: {len(split_pcba['test']):,}")
    
    # Sprawd≈∫ przyk≈Çadowy graf
    sample = dataset_pcba[0]
    print(f"\n  Przyk≈Çadowy graf:")
    print(f"    - Wƒôz≈Çy: {sample.num_nodes}")
    print(f"    - Krawƒôdzie: {sample.num_edges}")
    print(f"    - Features: {sample.x.shape}")
    print(f"    - Labels: {sample.y.shape} (128 binary tasks)")
except Exception as e:
    print(f"‚úó B≈ÇƒÖd: {e}")
    print("  (Dataset mo≈ºe wymagaƒá pobrania - to normalne przy pierwszym uruchomieniu)")

In [None]:
# Test ogbg-ppa
print("\n" + "="*50)
print("TEST: ogbg-ppa (Protein-protein association)")
print("="*50)

print("≈Åadujƒô dataset... (mo≈ºe potrwaƒá przy pierwszym uruchomieniu)")
try:
    dataset_ppa, split_ppa = load_ppa_dataset()
    print(f"‚úì Dataset za≈Çadowany!")
    print(f"  Graf√≥w: {len(dataset_ppa):,}")
    print(f"  Train: {len(split_ppa['train']):,}")
    print(f"  Valid: {len(split_ppa['valid']):,}")
    print(f"  Test: {len(split_ppa['test']):,}")
    
    # Sprawd≈∫ przyk≈Çadowy graf
    sample = dataset_ppa[0]
    print(f"\n  Przyk≈Çadowy graf:")
    print(f"    - Wƒôz≈Çy: {sample.num_nodes}")
    print(f"    - Krawƒôdzie: {sample.num_edges}")
    if hasattr(sample, 'x') and sample.x is not None:
        print(f"    - Features: {sample.x.shape}")
    print(f"    - Label: {sample.y} (37 classes)")
except Exception as e:
    print(f"‚úó B≈ÇƒÖd: {e}")
    print("  (Dataset mo≈ºe wymagaƒá pobrania - to normalne przy pierwszym uruchomieniu)")

## 4. Mini trening - sprawdzenie ≈ºe wszystko dzia≈Ça

In [None]:
from torch_geometric.loader import DataLoader
import torch.nn.functional as F

print("\n" + "="*50)
print("MINI TRENING: GIN na ma≈Çym podzbiorze ZINC")
print("="*50)

# Za≈Çaduj ZINC
print("≈Åadujƒô ZINC...")
dataset_zinc, split_zinc = load_zinc_dataset()

# U≈ºyj ma≈Çego podzbioru
small_train = dataset_zinc[split_zinc['train'][:100]]
small_val = dataset_zinc[split_zinc['valid'][:50]]

train_loader = DataLoader(small_train, batch_size=32, shuffle=True)
val_loader = DataLoader(small_val, batch_size=32)

print(f"Train: {len(small_train)} graf√≥w, Val: {len(small_val)} graf√≥w")

In [None]:
# Stw√≥rz model
in_channels = dataset_zinc[0].x.shape[1] if dataset_zinc[0].x.dim() > 1 else 1

model = GIN(
    in_channels=in_channels,
    hidden_channels=64,
    out_channels=1,
    num_layers=3,
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(f"Model: GIN")
print(f"Parametry: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Mini trening
num_epochs = 5

print(f"\nTrenujƒô przez {num_epochs} epok...")

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        y = batch.y.float().view(-1, 1)
        loss = F.mse_loss(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            out = model(batch)
            y = batch.y.float().view(-1, 1)
            val_loss += F.l1_loss(out, y).item()
    
    print(f"  Epoch {epoch+1}/{num_epochs}: Train Loss={total_loss/len(train_loader):.4f}, Val MAE={val_loss/len(val_loader):.4f}")

print("\n‚úì Mini trening zako≈Ñczony pomy≈õlnie!")

## 5. Podsumowanie

In [None]:
print("\n" + "="*60)
print("PODSUMOWANIE TEST√ìW")
print("="*60)

print("\n‚úì Nowe modele:")
print("  - GIN (Graph Isomorphism Network) - dzia≈Ça")
print("  - GCNVirtualNode (Hybrid) - dzia≈Ça")
print("  - GINVirtualNode (Hybrid) - dzia≈Ça")

print("\n‚úì Nowe datasety:")
print("  - ogbg-molpcba (Multi-label, 438K graf√≥w) - dostƒôpny")
print("  - ogbg-ppa (Multi-class, 158K graf√≥w) - dostƒôpny")

print("\n‚úì Mini trening: GIN na ZINC - sukces")

print("\n" + "="*60)
print("WSZYSTKO DZIA≈ÅA! üéâ")
print("="*60)
print("\nNastƒôpne kroki:")
print("1. Uruchom experiments/compare_all_models.ipynb dla pe≈Çnego por√≥wnania")
print("2. Zmie≈Ñ EXPERIMENT_MODE na 'gpu' dla pe≈Çnych eksperyment√≥w")