Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/link_prediction/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _setup_dataloaders(
main_loader = DistABLPLoader(
dataset=dataset,
num_neighbors=num_neighbors,
input_nodes=(query_node_type, main_input_nodes[query_node_type]),
input_nodes=(query_node_type, main_input_nodes[query_node_type]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
supervision_edge_type=supervision_edge_type,
num_workers=sampling_workers_per_process,
batch_size=main_batch_size,
Expand All @@ -156,7 +156,7 @@ def _setup_dataloaders(
random_negative_loader = DistNeighborLoader(
dataset=dataset,
num_neighbors=num_neighbors,
input_nodes=(labeled_node_type, dataset.node_ids[labeled_node_type]),
input_nodes=(labeled_node_type, dataset.node_ids[labeled_node_type]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
num_workers=sampling_workers_per_process,
batch_size=random_batch_size,
pin_memory_device=device,
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorial/KDD_2025/heterogeneous_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def inference(
),
nprocs=int(args.process_count),
join=True,
)
) # ty: ignore[call-non-callable] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions.

# Now let's load the embeddings to a dataframe
# Note in a "production" setting we have `gigl.common.data.export.load_embeddings_to_bigquery`
Expand Down
14 changes: 7 additions & 7 deletions examples/tutorial/KDD_2025/heterogeneous_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ def get_data_loader(
node_type = QUERY_NODE_TYPE
if split == "train":
assert isinstance(dataset.train_node_ids, Mapping)
input_nodes = (node_type, dataset.train_node_ids[node_type])
input_nodes = (node_type, dataset.train_node_ids[node_type]) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
elif split == "val":
assert isinstance(dataset.val_node_ids, Mapping)
input_nodes = (node_type, dataset.val_node_ids[node_type])
input_nodes = (node_type, dataset.val_node_ids[node_type]) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
elif split == "test":
assert isinstance(dataset.test_node_ids, Mapping)
input_nodes = (node_type, dataset.test_node_ids[node_type])
input_nodes = (node_type, dataset.test_node_ids[node_type]) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
else:
raise ValueError(f"Unknown split: {split}")

Expand Down Expand Up @@ -253,16 +253,16 @@ def train(
assert isinstance(dataset.train_node_ids, Mapping)
process_count = int(args.process_count)
for node_type, node_ids in dataset.train_node_ids.items():
logger.info(f"Training node type {node_type} has {node_ids.size(0)} nodes.")
max_training_batches = node_ids.size(0) // (
logger.info(f"Training node type {node_type} has {node_ids.size(0)} nodes.") # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
max_training_batches = node_ids.size(0) // ( # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
int(args.batch_size) * torch.distributed.get_world_size() * process_count
)
assert isinstance(dataset.val_node_ids, Mapping)
for node_type, node_ids in dataset.val_node_ids.items():
logger.info(f"Validation node type {node_type} has {node_ids.size(0)} nodes.")
logger.info(f"Validation node type {node_type} has {node_ids.size(0)} nodes.") # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
assert isinstance(dataset.test_node_ids, Mapping)
for node_type, node_ids in dataset.test_node_ids.items():
logger.info(f"Test node type {node_type} has {node_ids.size(0)} nodes.")
logger.info(f"Test node type {node_type} has {node_ids.size(0)} nodes.") # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
training_process_port = get_free_port()
logger.info(f"Will train for {max_training_batches} batches.")
if strtobool(args.use_local_saved_model):
Expand Down
2 changes: 1 addition & 1 deletion gigl/common/data/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _tf_tensor_to_torch_tensor(tf_tensor: tf.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: The converted PyTorch tensor.
"""
return torch.utils.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_tensor))
return torch.utils.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_tensor)) # ty: ignore[possibly-missing-submodule] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface.


def _build_example_parser(
Expand Down
2 changes: 1 addition & 1 deletion gigl/distributed/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _load_and_build_partitioned_dataset(
for supervision_edge_type in splitter._supervision_edge_types:
positive_label_edges[supervision_edge_type] = (
select_ssl_positive_label_edges(
edge_index=loaded_graph_tensors.edge_index[
edge_index=loaded_graph_tensors.edge_index[ # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
supervision_edge_type
],
positive_label_percentage=_ssl_positive_label_percentage,
Expand Down
2 changes: 1 addition & 1 deletion gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __init__(
dataset_schema,
backend_key,
) = self._setup_for_graph_store(
input_nodes=input_nodes,
input_nodes=input_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
dataset=dataset,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
Expand Down
13 changes: 8 additions & 5 deletions gigl/distributed/dist_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,10 @@ def _initialize_node_ids(
self._num_val = val_nodes.numel() # ty: ignore[unresolved-attribute]
self._num_test = test_nodes.numel() # ty: ignore[unresolved-attribute]
self._node_ids = _append_non_split_node_ids(
train_nodes, val_nodes, test_nodes, node_ids_on_machine
train_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
val_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
test_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
node_ids_on_machine,
)
else:
logger.info(
Expand Down Expand Up @@ -622,8 +625,8 @@ def _initialize_node_features(
# if it is not an edge type, since it must be one of the two.
assert not isinstance(node_type, EdgeType)
self._node_feature_info[node_type] = FeatureInfo(
dim=node_features_per_node_type.size(1),
dtype=node_features_per_node_type.dtype,
dim=node_features_per_node_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
dtype=node_features_per_node_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
)
logger.info(
f"Initialized node features for heterogeneous graph to dataset with node types: {node_features.keys()}"
Expand Down Expand Up @@ -705,8 +708,8 @@ def _initialize_edge_features(
for edge_type, edge_features_per_edge_type in edge_features.items():
assert isinstance(edge_type, EdgeType)
self._edge_feature_info[edge_type] = FeatureInfo(
dim=edge_features_per_edge_type.size(1),
dtype=edge_features_per_edge_type.dtype,
dim=edge_features_per_edge_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
dtype=edge_features_per_edge_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
)
logger.info(
f"Initialized edge features for heterogeneous graph to dataset with edge types: {edge_features.keys()}"
Expand Down
10 changes: 5 additions & 5 deletions gigl/distributed/dist_ppr_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,27 +696,27 @@ async def _sample_from_nodes(
ntype_to_flat_ids,
ntype_to_flat_weights,
ntype_to_valid_counts,
) = await self._compute_ppr_scores(seed_nodes, seed_type)
) = await self._compute_ppr_scores(seed_nodes, seed_type) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
assert isinstance(ntype_to_flat_ids, dict)
assert isinstance(ntype_to_flat_weights, dict)
assert isinstance(ntype_to_valid_counts, dict)

for ntype, flat_ids in ntype_to_flat_ids.items():
ppr_edge_type: EdgeType = (seed_type, "ppr", ntype)
valid_counts = ntype_to_valid_counts[ntype]
valid_counts = ntype_to_valid_counts[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
ppr_edge_type_to_flat_weights[ppr_edge_type] = (
ntype_to_flat_weights[ntype]
ntype_to_flat_weights[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
)

# Skip empty pairs; induce_next handles deduplication across
# seed types so a neighbor reachable from multiple seed types
# gets one consistent local index in node_dict[ntype].
if flat_ids.numel() > 0:
if flat_ids.numel() > 0: # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
nbr_dict[ppr_edge_type] = [
src_dict[seed_type],
flat_ids,
valid_counts,
]
] # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes.

# induce_next processes all PPR edge types in nbr_dict in one
# pass, assigning local indices to neighbors not yet registered and
Expand Down
2 changes: 1 addition & 1 deletion gigl/distributed/distributed_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def _setup_for_colocated(
)

curr_process_nodes = shard_nodes_by_process(
input_nodes=node_ids,
input_nodes=node_ids, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
local_process_rank=local_rank,
local_process_world_size=local_world_size,
)
Expand Down
2 changes: 1 addition & 1 deletion gigl/distributed/graph_store/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _get_node_ids(
f"node_type was provided as {node_type}, so node ids must be a dict[NodeType, torch.Tensor] "
f"(e.g. a heterogeneous dataset), got {type(nodes)}"
)
nodes = nodes[node_type]
nodes = nodes[node_type] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
elif not isinstance(nodes, torch.Tensor):
raise ValueError(
f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def apply_sparse_optimizer(
if not optimizer_cls and optimizer_kwargs:
optimizer_cls = RowWiseAdagrad
optimizer_kwargs = {"lr": 0.01}
apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs)
apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs) # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface.


def apply_dense_optimizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def _assert_sampling_config_is_valid(self):
@property
def active_sampling_config(self) -> SamplingConfig:
if self.phase == ModelPhase.TRAIN:
return self.training_sampling_config
return self.training_sampling_config # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes.
elif self.phase == ModelPhase.VAL:
return self.validation_sampling_config
return self.validation_sampling_config # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes.
elif self.phase == ModelPhase.TEST:
return self.testing_sampling_config
return self.testing_sampling_config # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes.
elif (
self.phase == ModelPhase.INFERENCE_SRC
or self.phase == ModelPhase.INFERENCE_DST
Expand Down
6 changes: 3 additions & 3 deletions gigl/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,18 +398,18 @@ def _weighted_layer_sum(
torch.Tensor: Weighted sum of all layer embeddings, shape [N, D].
"""
if len(all_layer_embeddings) != len(
self._layer_weights
self._layer_weights # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions.
): # https://github.com/Snapchat/GiGL/issues/408
raise ValueError(
f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." # https://github.com/Snapchat/GiGL/issues/408
f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." # https://github.com/Snapchat/GiGL/issues/408 # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions.
)

# Stack all layer embeddings and compute weighted sum
# _layer_weights is already a tensor buffer registered in __init__
stacked = torch.stack(all_layer_embeddings, dim=0) # shape [K+1, N, D]
w = self._layer_weights.to(stacked.device) # shape [K+1], ensure on same device
out = (
stacked * w.view(-1, 1, 1)
stacked * w.view(-1, 1, 1) # ty: ignore[call-non-callable] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions.
).sum( # https://github.com/Snapchat/GiGL/issues/408
dim=0
) # shape [N, D], w_0*X_0 + w_1*X_1 + ...
Expand Down
2 changes: 1 addition & 1 deletion gigl/src/common/graph_builder/pyg_graph_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def get_global_edge_features_dict(self) -> FrozenDict[Edge, torch.Tensor]:
edge_feature = (
edge_attr[edge_number] if edge_attr is not None else None
)
global_edge_to_features_map[edge] = edge_feature
global_edge_to_features_map[edge] = edge_feature # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes.

return FrozenDict(global_edge_to_features_map)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def model(self) -> torch.nn.Module:
@model.setter
def model(self, model: torch.nn.Module) -> None:
self.__model = model
self.__model.graph_backend = GraphBackend.PYG
self.__model.graph_backend = GraphBackend.PYG # ty: ignore[unresolved-attribute] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions.

def init_model(
self,
Expand Down Expand Up @@ -339,7 +339,7 @@ def _process_batch(

pos_nodes: torch.LongTensor = main_batch.pos_supervision_edge_data[
CondensedEdgeType(0)
].root_node_to_target_node_id[root_node.item()]
].root_node_to_target_node_id[root_node.item()] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.

if pos_nodes.numel():
pos_scores = torch.mm(
Expand All @@ -349,7 +349,7 @@ def _process_batch(
hard_neg_nodes: torch.LongTensor = (
main_batch.hard_neg_supervision_edge_data[
CondensedEdgeType(0)
].root_node_to_target_node_id[root_node.item()]
].root_node_to_target_node_id[root_node.item()] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access.
) # shape=[num_hard_neg_nodes]

if hard_neg_nodes.numel():
Expand Down Expand Up @@ -517,9 +517,9 @@ def _compute_metrics(
mrr = 1.0 / pos_rank.float()

hit_rates = hit_rate_at_k(
pos_scores=pos_scores,
neg_scores=neg_scores,
ks=torch.tensor(ks, device=device, dtype=torch.long),
pos_scores=pos_scores, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization.
neg_scores=neg_scores, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization.
ks=torch.tensor(ks, device=device, dtype=torch.long), # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization.
)

total_mrr += mrr.mean().item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def validate(
hr_result = hit_rate_at_k(
pos_scores=batch_scores.pos_scores,
neg_scores=batch_scores.random_neg_scores,
ks=ks_for_evaluation,
ks=ks_for_evaluation, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization.
)
mrr_result = mean_reciprocal_rank(
pos_scores=batch_scores.pos_scores,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def train(
for epoch in range(self.__num_epochs):
logger.info(f"Batch training... for epoch {epoch}/{self.__num_epochs}")
train_loss = self._train(
data_loader=data_loaders.train_main, # type: ignore[arg-type]
data_loader=data_loaders.train_main, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface.
device=device,
)
train_loss_str = (
Expand Down
Loading