diff --git a/examples/link_prediction/heterogeneous_training.py b/examples/link_prediction/heterogeneous_training.py index f0d58ca5e..8ed672b7c 100644 --- a/examples/link_prediction/heterogeneous_training.py +++ b/examples/link_prediction/heterogeneous_training.py @@ -133,7 +133,7 @@ def _setup_dataloaders( main_loader = DistABLPLoader( dataset=dataset, num_neighbors=num_neighbors, - input_nodes=(query_node_type, main_input_nodes[query_node_type]), + input_nodes=(query_node_type, main_input_nodes[query_node_type]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. supervision_edge_type=supervision_edge_type, num_workers=sampling_workers_per_process, batch_size=main_batch_size, @@ -156,7 +156,7 @@ def _setup_dataloaders( random_negative_loader = DistNeighborLoader( dataset=dataset, num_neighbors=num_neighbors, - input_nodes=(labeled_node_type, dataset.node_ids[labeled_node_type]), + input_nodes=(labeled_node_type, dataset.node_ids[labeled_node_type]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. num_workers=sampling_workers_per_process, batch_size=random_batch_size, pin_memory_device=device, diff --git a/examples/tutorial/KDD_2025/heterogeneous_inference.py b/examples/tutorial/KDD_2025/heterogeneous_inference.py index 5c2cc751f..f4d0616c6 100644 --- a/examples/tutorial/KDD_2025/heterogeneous_inference.py +++ b/examples/tutorial/KDD_2025/heterogeneous_inference.py @@ -175,7 +175,7 @@ def inference( ), nprocs=int(args.process_count), join=True, - ) + ) # ty: ignore[call-non-callable] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. # Now let's load the embeddings to a dataframe # Note in a "production" setting we have `gigl.common.data.export.load_embeddings_to_bigquery` diff --git a/examples/tutorial/KDD_2025/heterogeneous_training.py b/examples/tutorial/KDD_2025/heterogeneous_training.py index 80c41db94..08454ac51 100644 --- a/examples/tutorial/KDD_2025/heterogeneous_training.py +++ b/examples/tutorial/KDD_2025/heterogeneous_training.py @@ -118,13 +118,13 @@ def get_data_loader( node_type = QUERY_NODE_TYPE if split == "train": assert isinstance(dataset.train_node_ids, Mapping) - input_nodes = (node_type, dataset.train_node_ids[node_type]) + input_nodes = (node_type, dataset.train_node_ids[node_type]) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. elif split == "val": assert isinstance(dataset.val_node_ids, Mapping) - input_nodes = (node_type, dataset.val_node_ids[node_type]) + input_nodes = (node_type, dataset.val_node_ids[node_type]) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. elif split == "test": assert isinstance(dataset.test_node_ids, Mapping) - input_nodes = (node_type, dataset.test_node_ids[node_type]) + input_nodes = (node_type, dataset.test_node_ids[node_type]) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. else: raise ValueError(f"Unknown split: {split}") @@ -253,16 +253,16 @@ def train( assert isinstance(dataset.train_node_ids, Mapping) process_count = int(args.process_count) for node_type, node_ids in dataset.train_node_ids.items(): - logger.info(f"Training node type {node_type} has {node_ids.size(0)} nodes.") - max_training_batches = node_ids.size(0) // ( + logger.info(f"Training node type {node_type} has {node_ids.size(0)} nodes.") # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + max_training_batches = node_ids.size(0) // ( # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. int(args.batch_size) * torch.distributed.get_world_size() * process_count ) assert isinstance(dataset.val_node_ids, Mapping) for node_type, node_ids in dataset.val_node_ids.items(): - logger.info(f"Validation node type {node_type} has {node_ids.size(0)} nodes.") + logger.info(f"Validation node type {node_type} has {node_ids.size(0)} nodes.") # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. assert isinstance(dataset.test_node_ids, Mapping) for node_type, node_ids in dataset.test_node_ids.items(): - logger.info(f"Test node type {node_type} has {node_ids.size(0)} nodes.") + logger.info(f"Test node type {node_type} has {node_ids.size(0)} nodes.") # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. training_process_port = get_free_port() logger.info(f"Will train for {max_training_batches} batches.") if strtobool(args.use_local_saved_model): diff --git a/gigl/common/data/dataloaders.py b/gigl/common/data/dataloaders.py index 346014e5d..0ac83f0e7 100644 --- a/gigl/common/data/dataloaders.py +++ b/gigl/common/data/dataloaders.py @@ -197,7 +197,7 @@ def _tf_tensor_to_torch_tensor(tf_tensor: tf.Tensor) -> torch.Tensor: Returns: torch.Tensor: The converted PyTorch tensor. """ - return torch.utils.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_tensor)) + return torch.utils.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_tensor)) # ty: ignore[possibly-missing-submodule] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. def _build_example_parser( diff --git a/gigl/distributed/dataset_factory.py b/gigl/distributed/dataset_factory.py index 230aed66d..776ae2463 100644 --- a/gigl/distributed/dataset_factory.py +++ b/gigl/distributed/dataset_factory.py @@ -125,7 +125,7 @@ def _load_and_build_partitioned_dataset( for supervision_edge_type in splitter._supervision_edge_types: positive_label_edges[supervision_edge_type] = ( select_ssl_positive_label_edges( - edge_index=loaded_graph_tensors.edge_index[ + edge_index=loaded_graph_tensors.edge_index[ # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. supervision_edge_type ], positive_label_percentage=_ssl_positive_label_percentage, diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 2c72c9ceb..3e27e8098 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -324,7 +324,7 @@ def __init__( dataset_schema, backend_key, ) = self._setup_for_graph_store( - input_nodes=input_nodes, + input_nodes=input_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. dataset=dataset, num_workers=num_workers, worker_concurrency=worker_concurrency, diff --git a/gigl/distributed/dist_dataset.py b/gigl/distributed/dist_dataset.py index 2323c3fe1..12c361b87 100644 --- a/gigl/distributed/dist_dataset.py +++ b/gigl/distributed/dist_dataset.py @@ -488,7 +488,10 @@ def _initialize_node_ids( self._num_val = val_nodes.numel() # ty: ignore[unresolved-attribute] self._num_test = test_nodes.numel() # ty: ignore[unresolved-attribute] self._node_ids = _append_non_split_node_ids( - train_nodes, val_nodes, test_nodes, node_ids_on_machine + train_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + val_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + test_nodes, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + node_ids_on_machine, ) else: logger.info( @@ -622,8 +625,8 @@ def _initialize_node_features( # if it is not an edge type, since it must be one of the two. assert not isinstance(node_type, EdgeType) self._node_feature_info[node_type] = FeatureInfo( - dim=node_features_per_node_type.size(1), - dtype=node_features_per_node_type.dtype, + dim=node_features_per_node_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + dtype=node_features_per_node_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. ) logger.info( f"Initialized node features for heterogeneous graph to dataset with node types: {node_features.keys()}" @@ -705,8 +708,8 @@ def _initialize_edge_features( for edge_type, edge_features_per_edge_type in edge_features.items(): assert isinstance(edge_type, EdgeType) self._edge_feature_info[edge_type] = FeatureInfo( - dim=edge_features_per_edge_type.size(1), - dtype=edge_features_per_edge_type.dtype, + dim=edge_features_per_edge_type.size(1), # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + dtype=edge_features_per_edge_type.dtype, # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. ) logger.info( f"Initialized edge features for heterogeneous graph to dataset with edge types: {edge_features.keys()}" diff --git a/gigl/distributed/dist_ppr_sampler.py b/gigl/distributed/dist_ppr_sampler.py index e09a8f3ff..70aa4a743 100644 --- a/gigl/distributed/dist_ppr_sampler.py +++ b/gigl/distributed/dist_ppr_sampler.py @@ -696,27 +696,27 @@ async def _sample_from_nodes( ntype_to_flat_ids, ntype_to_flat_weights, ntype_to_valid_counts, - ) = await self._compute_ppr_scores(seed_nodes, seed_type) + ) = await self._compute_ppr_scores(seed_nodes, seed_type) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. assert isinstance(ntype_to_flat_ids, dict) assert isinstance(ntype_to_flat_weights, dict) assert isinstance(ntype_to_valid_counts, dict) for ntype, flat_ids in ntype_to_flat_ids.items(): ppr_edge_type: EdgeType = (seed_type, "ppr", ntype) - valid_counts = ntype_to_valid_counts[ntype] + valid_counts = ntype_to_valid_counts[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. ppr_edge_type_to_flat_weights[ppr_edge_type] = ( - ntype_to_flat_weights[ntype] + ntype_to_flat_weights[ntype] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. ) # Skip empty pairs; induce_next handles deduplication across # seed types so a neighbor reachable from multiple seed types # gets one consistent local index in node_dict[ntype]. - if flat_ids.numel() > 0: + if flat_ids.numel() > 0: # ty: ignore[unresolved-attribute] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. nbr_dict[ppr_edge_type] = [ src_dict[seed_type], flat_ids, valid_counts, - ] + ] # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. # induce_next processes all PPR edge types in nbr_dict in one # pass, assigning local indices to neighbors not yet registered and diff --git a/gigl/distributed/distributed_neighborloader.py b/gigl/distributed/distributed_neighborloader.py index 3d6d5a34b..db08d2328 100644 --- a/gigl/distributed/distributed_neighborloader.py +++ b/gigl/distributed/distributed_neighborloader.py @@ -478,7 +478,7 @@ def _setup_for_colocated( ) curr_process_nodes = shard_nodes_by_process( - input_nodes=node_ids, + input_nodes=node_ids, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. local_process_rank=local_rank, local_process_world_size=local_world_size, ) diff --git a/gigl/distributed/graph_store/dist_server.py b/gigl/distributed/graph_store/dist_server.py index 4566e6021..d42099abe 100644 --- a/gigl/distributed/graph_store/dist_server.py +++ b/gigl/distributed/graph_store/dist_server.py @@ -397,7 +397,7 @@ def _get_node_ids( f"node_type was provided as {node_type}, so node ids must be a dict[NodeType, torch.Tensor] " f"(e.g. a heterogeneous dataset), got {type(nodes)}" ) - nodes = nodes[node_type] + nodes = nodes[node_type] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. elif not isinstance(nodes, torch.Tensor): raise ValueError( f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}." diff --git a/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py b/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py index 5b39af6ac..f8726dcb1 100644 --- a/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py +++ b/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py @@ -137,7 +137,7 @@ def apply_sparse_optimizer( if not optimizer_cls and optimizer_kwargs: optimizer_cls = RowWiseAdagrad optimizer_kwargs = {"lr": 0.01} - apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs) + apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs) # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. def apply_dense_optimizer( diff --git a/gigl/experimental/knowledge_graph_embedding/lib/model/heterogeneous_graph_model.py b/gigl/experimental/knowledge_graph_embedding/lib/model/heterogeneous_graph_model.py index bdfd9f179..1c1e4d67e 100644 --- a/gigl/experimental/knowledge_graph_embedding/lib/model/heterogeneous_graph_model.py +++ b/gigl/experimental/knowledge_graph_embedding/lib/model/heterogeneous_graph_model.py @@ -95,11 +95,11 @@ def _assert_sampling_config_is_valid(self): @property def active_sampling_config(self) -> SamplingConfig: if self.phase == ModelPhase.TRAIN: - return self.training_sampling_config + return self.training_sampling_config # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. elif self.phase == ModelPhase.VAL: - return self.validation_sampling_config + return self.validation_sampling_config # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. elif self.phase == ModelPhase.TEST: - return self.testing_sampling_config + return self.testing_sampling_config # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. elif ( self.phase == ModelPhase.INFERENCE_SRC or self.phase == ModelPhase.INFERENCE_DST diff --git a/gigl/nn/models.py b/gigl/nn/models.py index d86920bdc..6016561c3 100644 --- a/gigl/nn/models.py +++ b/gigl/nn/models.py @@ -398,10 +398,10 @@ def _weighted_layer_sum( torch.Tensor: Weighted sum of all layer embeddings, shape [N, D]. """ if len(all_layer_embeddings) != len( - self._layer_weights + self._layer_weights # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. ): # https://github.com/Snapchat/GiGL/issues/408 raise ValueError( - f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." # https://github.com/Snapchat/GiGL/issues/408 + f"Got {len(all_layer_embeddings)} layer tensors but {len(self._layer_weights)} weights." # https://github.com/Snapchat/GiGL/issues/408 # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. ) # Stack all layer embeddings and compute weighted sum @@ -409,7 +409,7 @@ def _weighted_layer_sum( stacked = torch.stack(all_layer_embeddings, dim=0) # shape [K+1, N, D] w = self._layer_weights.to(stacked.device) # shape [K+1], ensure on same device out = ( - stacked * w.view(-1, 1, 1) + stacked * w.view(-1, 1, 1) # ty: ignore[call-non-callable] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. ).sum( # https://github.com/Snapchat/GiGL/issues/408 dim=0 ) # shape [N, D], w_0*X_0 + w_1*X_1 + ... diff --git a/gigl/src/common/graph_builder/pyg_graph_data.py b/gigl/src/common/graph_builder/pyg_graph_data.py index 9f025a0ce..53204f5cb 100644 --- a/gigl/src/common/graph_builder/pyg_graph_data.py +++ b/gigl/src/common/graph_builder/pyg_graph_data.py @@ -193,7 +193,7 @@ def get_global_edge_features_dict(self) -> FrozenDict[Edge, torch.Tensor]: edge_feature = ( edge_attr[edge_number] if edge_attr is not None else None ) - global_edge_to_features_map[edge] = edge_feature + global_edge_to_features_map[edge] = edge_feature # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. return FrozenDict(global_edge_to_features_map) diff --git a/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py b/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py index 443a3d6a8..1ef84655c 100644 --- a/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py +++ b/gigl/src/common/modeling_task_specs/graphsage_template_modeling_spec.py @@ -129,7 +129,7 @@ def model(self) -> torch.nn.Module: @model.setter def model(self, model: torch.nn.Module) -> None: self.__model = model - self.__model.graph_backend = GraphBackend.PYG + self.__model.graph_backend = GraphBackend.PYG # ty: ignore[unresolved-attribute] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. def init_model( self, @@ -339,7 +339,7 @@ def _process_batch( pos_nodes: torch.LongTensor = main_batch.pos_supervision_edge_data[ CondensedEdgeType(0) - ].root_node_to_target_node_id[root_node.item()] + ].root_node_to_target_node_id[root_node.item()] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. if pos_nodes.numel(): pos_scores = torch.mm( @@ -349,7 +349,7 @@ def _process_batch( hard_neg_nodes: torch.LongTensor = ( main_batch.hard_neg_supervision_edge_data[ CondensedEdgeType(0) - ].root_node_to_target_node_id[root_node.item()] + ].root_node_to_target_node_id[root_node.item()] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. ) # shape=[num_hard_neg_nodes] if hard_neg_nodes.numel(): @@ -517,9 +517,9 @@ def _compute_metrics( mrr = 1.0 / pos_rank.float() hit_rates = hit_rate_at_k( - pos_scores=pos_scores, - neg_scores=neg_scores, - ks=torch.tensor(ks, device=device, dtype=torch.long), + pos_scores=pos_scores, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. + neg_scores=neg_scores, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. + ks=torch.tensor(ks, device=device, dtype=torch.long), # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) total_mrr += mrr.mean().item() diff --git a/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py b/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py index e61024642..c8fd60e21 100644 --- a/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py +++ b/gigl/src/common/modeling_task_specs/node_anchor_based_link_prediction_modeling_task_spec.py @@ -547,7 +547,7 @@ def validate( hr_result = hit_rate_at_k( pos_scores=batch_scores.pos_scores, neg_scores=batch_scores.random_neg_scores, - ks=ks_for_evaluation, + ks=ks_for_evaluation, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) mrr_result = mean_reciprocal_rank( pos_scores=batch_scores.pos_scores, diff --git a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py index 12e5a4e38..bfaba1fb0 100644 --- a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py +++ b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py @@ -252,7 +252,7 @@ def train( for epoch in range(self.__num_epochs): logger.info(f"Batch training... for epoch {epoch}/{self.__num_epochs}") train_loss = self._train( - data_loader=data_loaders.train_main, # type: ignore[arg-type] + data_loader=data_loaders.train_main, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. device=device, ) train_loss_str = ( diff --git a/gigl/src/common/modeling_task_specs/utils/infer.py b/gigl/src/common/modeling_task_specs/utils/infer.py index 17804a3ed..a85b3ce39 100644 --- a/gigl/src/common/modeling_task_specs/utils/infer.py +++ b/gigl/src/common/modeling_task_specs/utils/infer.py @@ -137,10 +137,10 @@ def infer_task_inputs( decoder = model.module.decode batch_result_types = model.module.tasks.result_types else: - decoder = model.decode # https://github.com/Snapchat/GiGL/issues/408 + decoder = model.decode # https://github.com/Snapchat/GiGL/issues/408 # ty: ignore[invalid-assignment] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. batch_result_types = ( - model.tasks.result_types - ) # https://github.com/Snapchat/GiGL/issues/408 + model.tasks.result_types # ty: ignore[unresolved-attribute] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. + ) # https://github.com/Snapchat/GiGL/issues/408 # ty: ignore[invalid-assignment] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. # If we only have losses which only require the input batch, don't forward here and return the # input batch immediately to minimize computation we don't need, such as encoding and decoding. @@ -220,7 +220,7 @@ def infer_task_inputs( random_neg_root_embeddings[condensed_node_type] = ( random_neg_embeddings[condensed_node_type][random_neg_root_node_indices] if random_neg_root_node_indices.numel() - else torch.FloatTensor([]).to(device=device) + else torch.FloatTensor([]).to(device=device) # ty: ignore[invalid-assignment] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) if ModelResultType.batch_scores in batch_result_types or should_eval: random_neg_scores[condensed_node_type] = ( @@ -228,7 +228,7 @@ def infer_task_inputs( query_embeddings, random_neg_root_embeddings[condensed_node_type] ) if random_neg_root_embeddings[condensed_node_type].numel() - else torch.FloatTensor([]).to(device=device) + else torch.FloatTensor([]).to(device=device) # ty: ignore[invalid-assignment] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) # Loop through all root nodes and populate ids, embeddings, and scores per condensed edge type @@ -247,14 +247,16 @@ def infer_task_inputs( ) = gbml_config_pb_wrapper.graph_metadata_pb_wrapper.condensed_edge_type_to_condensed_node_types[ condensed_supervision_edge_type ] - pos_nodes: torch.LongTensor = main_batch.pos_supervision_edge_data[ + pos_nodes: torch.LongTensor = main_batch.pos_supervision_edge_data[ # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. condensed_supervision_edge_type - ].root_node_to_target_node_id[root_node.item()] # shape=[num_pos_nodes] + ].root_node_to_target_node_id[ + root_node.item() + ] # shape=[num_pos_nodes] hard_neg_nodes: torch.LongTensor = ( main_batch.hard_neg_supervision_edge_data[ condensed_supervision_edge_type - ].root_node_to_target_node_id[root_node.item()] + ].root_node_to_target_node_id[root_node.item()] # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. ) # shape=[num_hard_neg_nodes] repeated_anchor_count[condensed_supervision_edge_type].append( @@ -263,7 +265,7 @@ def infer_task_inputs( if pos_nodes.numel(): _pos_embeddings[condensed_supervision_edge_type].append( - main_embeddings[condensed_supervision_target_node_type][pos_nodes] # type: ignore[arg-type] + main_embeddings[condensed_supervision_target_node_type][pos_nodes] # type: ignore[arg-type] # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) _positive_ids[condensed_supervision_edge_type].append(pos_nodes) @@ -271,7 +273,7 @@ def infer_task_inputs( _hard_neg_embeddings[condensed_supervision_edge_type].append( main_embeddings[condensed_supervision_target_node_type][ hard_neg_nodes - ] # type: ignore[arg-type] + ] # type: ignore[arg-type] # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) _hard_neg_ids[condensed_supervision_edge_type].append(hard_neg_nodes) @@ -301,9 +303,9 @@ def infer_task_inputs( condensed_supervision_target_node_type ][[root_node_idx], :].to(device=device) _batch_scores[condensed_supervision_edge_type] = BatchScores( - pos_scores=pos_scores, - hard_neg_scores=hard_neg_scores, - random_neg_scores=random_neg_scores_root, + pos_scores=pos_scores, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. + hard_neg_scores=hard_neg_scores, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. + random_neg_scores=random_neg_scores_root, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) if ModelResultType.batch_scores in batch_result_types or should_eval: @@ -325,12 +327,12 @@ def infer_task_inputs( pos_embeddings[condensed_supervision_edge_type] = ( torch.cat(tuple(_pos_embeddings[condensed_supervision_edge_type])) if len(_pos_embeddings[condensed_supervision_edge_type]) - else torch.tensor([]) + else torch.tensor([]) # ty: ignore[invalid-assignment] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) hard_neg_embeddings[condensed_supervision_edge_type] = ( torch.cat(tuple(_hard_neg_embeddings[condensed_supervision_edge_type])) if len(_hard_neg_embeddings[condensed_supervision_edge_type]) - else torch.tensor([]) + else torch.tensor([]) # ty: ignore[invalid-assignment] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) repeated_anchor_embeddings[condensed_supervision_edge_type] = ( @@ -339,7 +341,7 @@ def infer_task_inputs( device=device ), dim=0, - ) + ) # ty: ignore[invalid-assignment] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. ) # If needed, calculate task inputs for retrieval loss per condensed edge type @@ -435,18 +437,18 @@ def infer_task_inputs( batch_combined_scores[condensed_supervision_edge_type] = ( BatchCombinedScores( - repeated_candidate_scores=repeated_candidate_scores, - positive_ids=global_positive_ids, - hard_neg_ids=global_hard_neg_ids, - random_neg_ids=global_random_neg_ids, - repeated_query_ids=repeated_global_query_ids, + repeated_candidate_scores=repeated_candidate_scores, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. + positive_ids=global_positive_ids, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. + hard_neg_ids=global_hard_neg_ids, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. + random_neg_ids=global_random_neg_ids, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. + repeated_query_ids=repeated_global_query_ids, # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. num_unique_query_ids=main_batch_root_node_indices.shape[0], ) ) # Populate all computed embeddings for task input batch_embeddings = BatchEmbeddings( - query_embeddings=query_embeddings, + query_embeddings=query_embeddings, # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. repeated_query_embeddings=repeated_anchor_embeddings, pos_embeddings=pos_embeddings, hard_neg_embeddings=hard_neg_embeddings, diff --git a/gigl/src/common/modeling_task_specs/utils/profiler_wrapper.py b/gigl/src/common/modeling_task_specs/utils/profiler_wrapper.py index afc017f86..9e634fce4 100644 --- a/gigl/src/common/modeling_task_specs/utils/profiler_wrapper.py +++ b/gigl/src/common/modeling_task_specs/utils/profiler_wrapper.py @@ -20,7 +20,7 @@ class TorchProfiler: def __init__(self, **kwargs) -> None: self.trace_handler = tensorboard_trace_handler( - dir_name=TMP_PROFILER_LOG_DIR_NAME, # type: ignore[arg-type] + dir_name=TMP_PROFILER_LOG_DIR_NAME, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. use_gzip=True, ) self.wait = int(kwargs.get("wait", 5)) diff --git a/gigl/src/common/models/layers/count_min_sketch.py b/gigl/src/common/models/layers/count_min_sketch.py index fb54657ff..2fffe129f 100644 --- a/gigl/src/common/models/layers/count_min_sketch.py +++ b/gigl/src/common/models/layers/count_min_sketch.py @@ -87,7 +87,7 @@ def estimate_torch_long_tensor(self, tensor: torch.LongTensor) -> torch.LongTens return torch.tensor( [self.estimate(item) for item in tensor_cpu], dtype=torch.long, - ) + ) # ty: ignore[invalid-return-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. def get_table(self) -> np.ndarray: """ @@ -116,5 +116,5 @@ def calculate_in_batch_candidate_sampling_probability( """ estimated_prob: torch.FloatTensor = ( batch_size * frequency_tensor.float() / total_cnt - ) + ) # ty: ignore[invalid-assignment] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. return estimated_prob.clamp(max=1.0) diff --git a/gigl/src/common/models/layers/feature_interaction.py b/gigl/src/common/models/layers/feature_interaction.py index afa025365..f0ccf2126 100644 --- a/gigl/src/common/models/layers/feature_interaction.py +++ b/gigl/src/common/models/layers/feature_interaction.py @@ -150,7 +150,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def reset_parameters(self): for layer in self._layers: if hasattr(layer, "reset_parameters") and callable(layer.reset_parameters): - layer.reset_parameters() + layer.reset_parameters() # ty: ignore[call-top-callable] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. def __repr__(self) -> str: return f"{self.__class__.__name__}(in_dim={self._in_dim}, num_layers={self._num_layers}, projection_dim={self._projection_dim}, diag_scale={self._diag_scale}, use_bias={self._use_bias})" diff --git a/gigl/src/common/models/layers/loss.py b/gigl/src/common/models/layers/loss.py index 4aab10dfc..08ef93c7a 100644 --- a/gigl/src/common/models/layers/loss.py +++ b/gigl/src/common/models/layers/loss.py @@ -65,7 +65,7 @@ def _calculate_margin_loss( input1=pos_scores_repeated, input2=all_neg_scores_repeated, target=ys, - margin=self.margin, + margin=self.margin, # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. reduction="sum", ) sample_size = pos_scores_repeated.numel() @@ -143,7 +143,7 @@ def _calculate_softmax_loss( loss = F.cross_entropy( input=all_scores - / self.softmax_temperature, # https://github.com/Snapchat/GiGL/issues/408 + / self.softmax_temperature, # https://github.com/Snapchat/GiGL/issues/408 # ty: ignore[unsupported-operator] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. target=ys, reduction="sum", ) diff --git a/gigl/src/common/models/layers/normalization.py b/gigl/src/common/models/layers/normalization.py index dc6f87812..9f0e93d1e 100644 --- a/gigl/src/common/models/layers/normalization.py +++ b/gigl/src/common/models/layers/normalization.py @@ -12,8 +12,10 @@ def l2_normalize_embeddings( if isinstance(node_typed_embeddings, dict): for node_type in node_typed_embeddings: node_typed_embeddings[node_type] = F.normalize( - node_typed_embeddings[node_type], p=2, dim=-1 - ) + node_typed_embeddings[node_type], # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + p=2, + dim=-1, + ) # ty: ignore[invalid-assignment] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. elif isinstance(node_typed_embeddings, torch.Tensor): node_typed_embeddings = F.normalize(node_typed_embeddings, p=2, dim=-1) else: diff --git a/gigl/src/common/models/layers/task.py b/gigl/src/common/models/layers/task.py index 5a5bb8e54..88515675a 100644 --- a/gigl/src/common/models/layers/task.py +++ b/gigl/src/common/models/layers/task.py @@ -228,9 +228,9 @@ def __init__( hid_dim = self.encoder.hid_dim out_dim = self.encoder.out_dim self.head = torch.nn.Sequential( - torch.nn.Linear(out_dim, hid_dim), + torch.nn.Linear(out_dim, hid_dim), # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. torch.nn.ReLU(), - torch.nn.Linear(hid_dim, out_dim), + torch.nn.Linear(hid_dim, out_dim), # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. ) self.loss = GRACELoss(temperature=temperature) self.feat_drop_1 = feat_drop_1 @@ -296,8 +296,8 @@ def __init__( out_dim = self.encoder.out_dim self.loss = FeatureReconstructionLoss(alpha=alpha) self.reconstruction_decoder = GraphConv(out_dim, in_dim) - self.reconstruction_mask = torch.nn.Parameter(torch.zeros(1, in_dim)) - self.reconstruction_enc_dec = torch.nn.Linear(out_dim, out_dim, bias=False) + self.reconstruction_mask = torch.nn.Parameter(torch.zeros(1, in_dim)) # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. + self.reconstruction_enc_dec = torch.nn.Linear(out_dim, out_dim, bias=False) # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. self.edge_drop = edge_drop def forward( @@ -368,9 +368,9 @@ def __init__( out_dim = self.encoder.out_dim self.loss = WhiteningDecorrelationLoss(lambd=lambd) self.head = torch.nn.Sequential( - torch.nn.Linear(out_dim, hid_dim), + torch.nn.Linear(out_dim, hid_dim), # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. torch.nn.ReLU(), - torch.nn.Linear(hid_dim, out_dim), + torch.nn.Linear(hid_dim, out_dim), # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. ) self.feat_drop_1 = feat_drop_1 self.edge_drop_1 = edge_drop_1 @@ -502,9 +502,9 @@ def __init__( param.requires_grad = False self.loss = BGRLLoss() self.head = torch.nn.Sequential( - torch.nn.Linear(out_dim, hid_dim), + torch.nn.Linear(out_dim, hid_dim), # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. torch.nn.ReLU(), - torch.nn.Linear(hid_dim, out_dim), + torch.nn.Linear(hid_dim, out_dim), # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. ) self.feat_drop_1 = feat_drop_1 self.edge_drop_1 = edge_drop_1 @@ -585,9 +585,9 @@ def __init__( param.requires_grad = False self.loss = TBGRLLoss(neg_lambda=neg_lambda) self.head = torch.nn.Sequential( - torch.nn.Linear(out_dim, hid_dim), + torch.nn.Linear(out_dim, hid_dim), # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. torch.nn.ReLU(), - torch.nn.Linear(hid_dim, out_dim), + torch.nn.Linear(hid_dim, out_dim), # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. ) self.feat_drop_1 = feat_drop_1 @@ -710,7 +710,7 @@ def _get_all_tasks( fn = self._task_to_fn_map[task] weight = self._task_to_weights_map[task] tasks_list.append( - (fn, weight) + (fn, weight) # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. ) # https://github.com/Snapchat/GiGL/issues/408 return tasks_list diff --git a/gigl/src/common/models/pyg/heterogeneous.py b/gigl/src/common/models/pyg/heterogeneous.py index 8fb5cc18a..fa4e9b651 100644 --- a/gigl/src/common/models/pyg/heterogeneous.py +++ b/gigl/src/common/models/pyg/heterogeneous.py @@ -129,7 +129,7 @@ def forward( if self.should_l2_normalize_embedding_layer_output: node_typed_embeddings = l2_normalize_embeddings( node_typed_embeddings=node_typed_embeddings - ) + ) # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. return node_typed_embeddings @@ -285,4 +285,4 @@ def forward( node_typed_embeddings = l2_normalize_embeddings( node_typed_embeddings=node_typed_embeddings ) - return node_typed_embeddings + return node_typed_embeddings # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. diff --git a/gigl/src/common/models/pyg/homogeneous.py b/gigl/src/common/models/pyg/homogeneous.py index 5af61c82b..4cd46552a 100644 --- a/gigl/src/common/models/pyg/homogeneous.py +++ b/gigl/src/common/models/pyg/homogeneous.py @@ -146,11 +146,11 @@ def forward( if self.should_l2_normalize_embedding_layer_output: x = l2_normalize_embeddings(node_typed_embeddings=x) if self.return_emb: - return x + return x # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. if self.linear_layer: x = self.linear(x) - return x + return x # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. def init_conv_layers( self, diff --git a/gigl/src/common/models/pyg/link_prediction.py b/gigl/src/common/models/pyg/link_prediction.py index 274b8e436..bc266b41b 100644 --- a/gigl/src/common/models/pyg/link_prediction.py +++ b/gigl/src/common/models/pyg/link_prediction.py @@ -69,7 +69,7 @@ def decode( @property def tasks(self) -> NodeAnchorBasedLinkPredictionTasks: - return self.__tasks + return self.__tasks # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. @property def graph_backend(self) -> GraphBackend: diff --git a/gigl/src/common/models/pyg/nn/conv/hgt_conv.py b/gigl/src/common/models/pyg/nn/conv/hgt_conv.py index d651dae1c..5947ac390 100644 --- a/gigl/src/common/models/pyg/nn/conv/hgt_conv.py +++ b/gigl/src/common/models/pyg/nn/conv/hgt_conv.py @@ -150,8 +150,8 @@ def _construct_src_node_feat( ks.append(k_dict[src]) vs.append(v_dict[src]) - ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D) - vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D) + ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D) # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. + vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D) # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. type_vec = torch.cat(type_list, dim=1).flatten() k = self.k_rel(ks, type_vec).view(H, -1, D).transpose(0, 1) diff --git a/gigl/src/common/translators/training_samples_protos_translator.py b/gigl/src/common/translators/training_samples_protos_translator.py index 56c7f7a49..57633535e 100644 --- a/gigl/src/common/translators/training_samples_protos_translator.py +++ b/gigl/src/common/translators/training_samples_protos_translator.py @@ -141,7 +141,7 @@ def training_samples_from_NodeAnchorBasedLinkPredictionSamplePb( ): condensed_supervision_edge_type_to_pos_edge_feats[ condensed_edge_type - ].append(pos_edge[1]) + ].append(pos_edge[1]) # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. for hard_neg_edge_pb in sample.hard_neg_edges: hard_neg_edge: Tuple[Edge, Optional[torch.Tensor]] = ( @@ -165,7 +165,7 @@ def training_samples_from_NodeAnchorBasedLinkPredictionSamplePb( ): condensed_supervision_edge_type_to_hard_neg_edge_feats[ condensed_edge_type - ].append(hard_neg_edge[1]) + ].append(hard_neg_edge[1]) # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. for condensed_edge_type in graph_metadata_pb_wrapper.condensed_edge_types: condensed_edge_type_to_supervision_edge_data[condensed_edge_type] = ( @@ -180,7 +180,7 @@ def training_samples_from_NodeAnchorBasedLinkPredictionSamplePb( torch.stack( condensed_supervision_edge_type_to_pos_edge_feats[ condensed_edge_type - ] + ] # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. ) if len( condensed_supervision_edge_type_to_pos_edge_feats[ @@ -189,12 +189,12 @@ def training_samples_from_NodeAnchorBasedLinkPredictionSamplePb( ) > 0 else None - ), + ), # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. hard_neg_edge_features=( torch.stack( condensed_supervision_edge_type_to_hard_neg_edge_feats[ condensed_edge_type - ] + ] # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. ) if len( condensed_supervision_edge_type_to_hard_neg_edge_feats[ @@ -203,7 +203,7 @@ def training_samples_from_NodeAnchorBasedLinkPredictionSamplePb( ) > 0 else None - ), + ), # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. ) ) diff --git a/gigl/src/common/utils/eval_metrics.py b/gigl/src/common/utils/eval_metrics.py index 147482fff..27f537f3d 100644 --- a/gigl/src/common/utils/eval_metrics.py +++ b/gigl/src/common/utils/eval_metrics.py @@ -43,7 +43,7 @@ def hit_rate_at_k( ) ks_adjusted = ks - 1 # subtract 1 since indices are 0-indexed hits_at_ks = torch.gather(input=hit_rates_padded, dim=0, index=ks_adjusted) - return hits_at_ks + return hits_at_ks # ty: ignore[invalid-return-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. def mean_reciprocal_rank( @@ -68,4 +68,4 @@ def mean_reciprocal_rank( adjusted_ranks = unadjusted_ranks + 1 # +1 since ranks are 0-indexed here reciprocal_ranks = 1.0 / adjusted_ranks # compute reciprocal mrr = torch.mean(reciprocal_ranks) - return mrr + return mrr # ty: ignore[invalid-return-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. diff --git a/gigl/src/common/utils/file_loader.py b/gigl/src/common/utils/file_loader.py index 701c286a6..e67a736bd 100644 --- a/gigl/src/common/utils/file_loader.py +++ b/gigl/src/common/utils/file_loader.py @@ -214,14 +214,14 @@ def load_from_filelike(self, uri: Uri, filelike: IO[AnyStr]) -> None: if isinstance(first, bytes): with open(uri.uri, "wb") as dest: shutil.copyfileobj( - filelike, - dest, # ty: ignore[invalid-argument-type] + filelike, # ty: ignore[invalid-argument-type] TODO(ty-io-anystr-copyfileobj): reconcile AnyStr IO with shutil.copyfileobj overloads. + dest, ) else: with open(uri.uri, "w", encoding="utf-8") as dest: shutil.copyfileobj( - filelike, - dest, # ty: ignore[invalid-argument-type] + filelike, # ty: ignore[invalid-argument-type] TODO(ty-io-anystr-copyfileobj): reconcile AnyStr IO with shutil.copyfileobj overloads. + dest, ) else: diff --git a/gigl/src/inference/v1/gnn_inferencer.py b/gigl/src/inference/v1/gnn_inferencer.py index 759aa6707..8e0b6901c 100644 --- a/gigl/src/inference/v1/gnn_inferencer.py +++ b/gigl/src/inference/v1/gnn_inferencer.py @@ -248,7 +248,7 @@ def __run( inferencer_instance: BaseInferencer = self.generate_inferencer_instance() graph_builder = GraphBuilderFactory.get_graph_builder( - backend_name=inferencer_instance.model.graph_backend + backend_name=inferencer_instance.model.graph_backend # ty: ignore[invalid-argument-type] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. ) inference_blueprint: BaseInferenceBlueprint = ( diff --git a/gigl/src/mocking/lib/pyg_to_training_samples.py b/gigl/src/mocking/lib/pyg_to_training_samples.py index 61c3ad85b..4bb92b52a 100644 --- a/gigl/src/mocking/lib/pyg_to_training_samples.py +++ b/gigl/src/mocking/lib/pyg_to_training_samples.py @@ -203,7 +203,7 @@ def _get_random_negative_samples_for_pos_edges( pos_node_ids = edge_index[0].repeat(num_negative_samples_per_pos_edge) neg_node_ids = torch.randint(low=0, high=num_nodes, size=[pos_node_ids.numel()]) - return torch.vstack((pos_node_ids, neg_node_ids)) + return torch.vstack((pos_node_ids, neg_node_ids)) # ty: ignore[invalid-return-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. def _build_rooted_node_neighborhood_samples_from_subgraphs( diff --git a/gigl/src/training/v1/lib/data_loaders/node_anchor_based_link_prediction_data_loader.py b/gigl/src/training/v1/lib/data_loaders/node_anchor_based_link_prediction_data_loader.py index b0dace301..19d548a7f 100644 --- a/gigl/src/training/v1/lib/data_loaders/node_anchor_based_link_prediction_data_loader.py +++ b/gigl/src/training/v1/lib/data_loaders/node_anchor_based_link_prediction_data_loader.py @@ -250,7 +250,7 @@ def get_default_data_loader( iterable_dataset=_iterable_training_dataset ) else: - iterable_training_dataset = _iterable_training_dataset + iterable_training_dataset = _iterable_training_dataset # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. collate_fn = partial( NodeAnchorBasedLinkPredictionBatch.collate_pyg_node_anchor_based_link_prediction_minibatch, diff --git a/gigl/src/training/v1/lib/data_loaders/supervised_node_classification_data_loader.py b/gigl/src/training/v1/lib/data_loaders/supervised_node_classification_data_loader.py index c03265d75..2d1a9ae31 100644 --- a/gigl/src/training/v1/lib/data_loaders/supervised_node_classification_data_loader.py +++ b/gigl/src/training/v1/lib/data_loaders/supervised_node_classification_data_loader.py @@ -196,7 +196,7 @@ def get_default_data_loader( iterable_dataset=_iterable_training_dataset ) else: - iterable_training_dataset = _iterable_training_dataset + iterable_training_dataset = _iterable_training_dataset # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. collate_fn = partial( SupervisedNodeClassificationBatch.collate_pyg_node_classification_minibatch, diff --git a/gigl/transforms/utils.py b/gigl/transforms/utils.py index 4904abe26..f8b244d36 100644 --- a/gigl/transforms/utils.py +++ b/gigl/transforms/utils.py @@ -36,7 +36,7 @@ def add_node_attr( for node_type, value in values.items(): if node_type not in data.node_types: continue - _set_node_attr_for_type(data, node_type, value, attr_name) + _set_node_attr_for_type(data, node_type, value, attr_name) # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. return data # Otherwise, values is a tensor in homogeneous order - split by node type @@ -113,7 +113,7 @@ def add_edge_attr( for edge_type, value in values.items(): if edge_type not in data.edge_types: continue - _set_edge_attr_for_type(data, edge_type, value, attr_name) + _set_edge_attr_for_type(data, edge_type, value, attr_name) # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. return data # Otherwise, values is a tensor in homogeneous order - split by edge type diff --git a/gigl/types/graph.py b/gigl/types/graph.py index 1e3219240..b96d6d6b9 100644 --- a/gigl/types/graph.py +++ b/gigl/types/graph.py @@ -503,7 +503,7 @@ def to_homogeneous( f"Expected a single value in the dictionary, but got multiple keys: {x.keys()}" ) n = next(iter(x.values())) - return n + return n # ty: ignore[invalid-return-type] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. return x diff --git a/gigl/utils/data_splitters.py b/gigl/utils/data_splitters.py index 4aa416f2e..1b88db00b 100644 --- a/gigl/utils/data_splitters.py +++ b/gigl/utils/data_splitters.py @@ -468,16 +468,21 @@ def __call__( splits: dict[NodeType, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} for node_type, nodes_to_split in node_ids_dict.items(): - _check_node_ids(nodes_to_split) + _check_node_ids(nodes_to_split) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. - hash_values = self._hash_function(nodes_to_split) # 1 x M + hash_values = self._hash_function( + nodes_to_split # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + ) # 1 x M # Create train, val, test splits using distributed coordination train, val, test = _create_distributed_splits_from_hash( - nodes_to_split, hash_values, self._num_val, self._num_test + nodes_to_split, # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + hash_values, + self._num_val, + self._num_test, ) - splits[node_type] = (train, val, test) + splits[node_type] = (train, val, test) # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. # Clean up memory del hash_values diff --git a/pyproject.toml b/pyproject.toml index 434feb022..258fee2d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -289,9 +289,10 @@ replace-imports-with-any = [ "kfp.**", "tensorflow.**", "tensorflow_metadata.**", - # TODO: (svij Torch has type information but we are doing some incorrect casting and using old api types thus type check fails - # Although, the code still runs fine. - "torch.**", + # Keep torch type checking enabled by not listing torch here. Temporary ty + # suppressions are tracked via TODO-tagged inline ignores documented in + # docs/ty_ignore_buckets.md. + # "torch.**", "torch_geometric.**", "torchrec.**", ] diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 722d7a1cb..2485d71c4 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -111,13 +111,13 @@ def _assert_global_seed_coverage( Gathers seen and expected seeds from all ranks and verifies full coverage. """ # Gather seen seeds from all ranks - all_seen: list[torch.Tensor] = [None] * cluster_info.compute_cluster_world_size # type: ignore[list-item] + all_seen: list[torch.Tensor] = [None] * cluster_info.compute_cluster_world_size # type: ignore[list-item] # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. torch.distributed.all_gather_object(all_seen, _to_long_cpu(local_seen)) globally_seen = _sorted_seed_tensor(_concat_seed_tensors(all_seen)) # Gather expected seeds from all ranks. In graph-store mode, input sharding is # per compute process, not per compute node. - all_expected: list[torch.Tensor] = [None] * cluster_info.compute_cluster_world_size # type: ignore[list-item] + all_expected: list[torch.Tensor] = [None] * cluster_info.compute_cluster_world_size # type: ignore[list-item] # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. torch.distributed.all_gather_object(all_expected, _to_long_cpu(local_expected)) globally_expected = _sorted_seed_tensor(_concat_seed_tensors(all_expected)) diff --git a/tests/unit/distributed/dist_ablp_neighborloader_test.py b/tests/unit/distributed/dist_ablp_neighborloader_test.py index 4575b7ad4..d32f020c8 100644 --- a/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -214,7 +214,7 @@ def _run_dblp_supervised( loader = DistABLPLoader( dataset=dataset, num_neighbors=num_neighbors, - input_nodes=(anchor_node_type, dataset.train_node_ids[anchor_node_type]), + input_nodes=(anchor_node_type, dataset.train_node_ids[anchor_node_type]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. supervision_edge_type=supervision_edge_type, pin_memory_device=torch.device("cpu"), ) @@ -230,7 +230,7 @@ def _run_dblp_supervised( local_positive_nodes < len(datum[supervision_node_type].node) ) count += 1 - assert count == dataset.train_node_ids[anchor_node_type].size(0) + assert count == dataset.train_node_ids[anchor_node_type].size(0) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. shutdown_rpc() @@ -259,7 +259,7 @@ def _run_toy_heterogeneous_ablp( loader = DistABLPLoader( dataset=dataset, num_neighbors=fanout, - input_nodes=(anchor_node_type, dataset.train_node_ids[anchor_node_type]), + input_nodes=(anchor_node_type, dataset.train_node_ids[anchor_node_type]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. supervision_edge_type=supervision_edge_type, # We set the batch size to the number of "user" nodes in the heterogeneous toy graph to guarantee that the dataloader completes an epoch in 1 batch batch_size=15, @@ -275,10 +275,11 @@ def _run_toy_heterogeneous_ablp( # Ensure that the node ids we should be fanout from are all found in the batch assert_tensor_equality( - dataset.train_node_ids[anchor_node_type], datum[anchor_node_type].batch + dataset.train_node_ids[anchor_node_type], # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + datum[anchor_node_type].batch, ) assert ( - dataset.train_node_ids[anchor_node_type].size(0) + dataset.train_node_ids[anchor_node_type].size(0) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. == datum[anchor_node_type].batch_size ) diff --git a/tests/unit/distributed/dist_ppr_sampler_test.py b/tests/unit/distributed/dist_ppr_sampler_test.py index b0879c306..d3a456d70 100644 --- a/tests/unit/distributed/dist_ppr_sampler_test.py +++ b/tests/unit/distributed/dist_ppr_sampler_test.py @@ -407,7 +407,7 @@ def _run_ppr_hetero_loader_correctness_check( loader = DistNeighborLoader( dataset=dataset, - input_nodes=(USER, node_ids[USER]), + input_nodes=(USER, node_ids[USER]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. num_neighbors=[], # Unused by PPR sampler; required by interface sampler_options=PPRSamplerOptions( alpha=alpha, @@ -488,7 +488,7 @@ def _run_ppr_ablp_loader_correctness_check( loader = DistABLPLoader( dataset=dataset, num_neighbors=[], # Unused by PPR sampler; required by interface - input_nodes=(USER, train_node_ids[USER]), + input_nodes=(USER, train_node_ids[USER]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. supervision_edge_type=USER_TO_STORY, sampler_options=PPRSamplerOptions( alpha=alpha, @@ -640,7 +640,7 @@ def _run_ppr_ablp_label_edges_do_not_affect_anchor_ppr(_: int) -> None: loader = DistABLPLoader( dataset=dataset, num_neighbors=[], - input_nodes=(USER, train_node_ids[USER]), + input_nodes=(USER, train_node_ids[USER]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. supervision_edge_type=USER_TO_STORY, sampler_options=PPRSamplerOptions( alpha=_TEST_ALPHA, diff --git a/tests/unit/distributed/distributed_dataset_test.py b/tests/unit/distributed/distributed_dataset_test.py index bbf421389..692f7816b 100644 --- a/tests/unit/distributed/distributed_dataset_test.py +++ b/tests/unit/distributed/distributed_dataset_test.py @@ -90,7 +90,7 @@ def assert_tensor_equal( if isinstance(actual, dict) and isinstance(expected, dict): self.assertEqual(actual.keys(), expected.keys()) for key in actual.keys(): - assert_close(actual[key], expected[key], atol=0, rtol=0) + assert_close(actual[key], expected[key], atol=0, rtol=0) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. elif isinstance(actual, torch.Tensor) and isinstance(expected, torch.Tensor): assert_close(actual, expected, atol=0, rtol=0) diff --git a/tests/unit/distributed/distributed_neighborloader_test.py b/tests/unit/distributed/distributed_neighborloader_test.py index a1a9cca55..2205ca8bd 100644 --- a/tests/unit/distributed/distributed_neighborloader_test.py +++ b/tests/unit/distributed/distributed_neighborloader_test.py @@ -145,7 +145,7 @@ def _run_distributed_heterogeneous_neighbor_loader( assert isinstance(dataset.node_ids, Mapping) loader = DistNeighborLoader( dataset=dataset, - input_nodes=(NodeType("author"), dataset.node_ids[NodeType("author")]), + input_nodes=(NodeType("author"), dataset.node_ids[NodeType("author")]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. num_neighbors=[2, 2], pin_memory_device=torch.device("cpu"), ) @@ -240,7 +240,7 @@ def _run_distributed_neighbor_loader_with_node_labels_heterogeneous( user_loader = DistNeighborLoader( dataset=dataset, - input_nodes=(_USER, dataset.node_ids[_USER]), + input_nodes=(_USER, dataset.node_ids[_USER]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. num_neighbors=[2, 2], pin_memory_device=torch.device("cpu"), batch_size=batch_size, @@ -248,7 +248,7 @@ def _run_distributed_neighbor_loader_with_node_labels_heterogeneous( story_loader = DistNeighborLoader( dataset=dataset, - input_nodes=(_STORY, dataset.node_ids[_STORY]), + input_nodes=(_STORY, dataset.node_ids[_STORY]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. num_neighbors=[2, 2], pin_memory_device=torch.device("cpu"), batch_size=batch_size, diff --git a/tests/unit/distributed/distributed_partitioner_test.py b/tests/unit/distributed/distributed_partitioner_test.py index 89cecc6d3..0b02b8e2b 100644 --- a/tests/unit/distributed/distributed_partitioner_test.py +++ b/tests/unit/distributed/distributed_partitioner_test.py @@ -561,7 +561,7 @@ def _assert_label_outputs( assert isinstance(output_labeled_edge_index, abc.Mapping), ( "Homogeneous output detected from labels for heterogeneous input" ) - entity_iterable = list(output_labeled_edge_index.items()) + entity_iterable = list(output_labeled_edge_index.items()) # ty: ignore[invalid-assignment] TODO(ty-torch-container-shapes): fix ty false positives for torch container and return shapes. else: assert isinstance(output_labeled_edge_index, torch.Tensor), ( "Heterogeneous output detected from labels for homogeneous input" diff --git a/tests/unit/distributed/sampler_test.py b/tests/unit/distributed/sampler_test.py index a9dde44a4..2e62dafff 100644 --- a/tests/unit/distributed/sampler_test.py +++ b/tests/unit/distributed/sampler_test.py @@ -184,7 +184,7 @@ def test_prepare_ablp_inputs_dedupes_same_type_seeds_and_keeps_anchors_first( assert isinstance(nodes_to_sample, Mapping) self.assertEqual(set(nodes_to_sample.keys()), {_USER}) self.assert_tensor_equality( - nodes_to_sample[_USER], + nodes_to_sample[_USER], # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. torch.tensor([10, 11, 12, 13, 14]), ) self.assert_tensor_equality( @@ -223,11 +223,11 @@ def test_prepare_ablp_inputs_dedupes_cross_type_supervision_nodes(self) -> None: assert isinstance(nodes_to_sample, Mapping) self.assertEqual(set(nodes_to_sample.keys()), {_USER, _ITEM}) self.assert_tensor_equality( - nodes_to_sample[_USER], + nodes_to_sample[_USER], # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. torch.tensor([4, 5]), ) self.assert_tensor_equality( - nodes_to_sample[_ITEM], + nodes_to_sample[_ITEM], # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. torch.tensor([20, 21, 22]), ) @@ -259,5 +259,5 @@ def test_prepare_sample_loop_inputs_heterogeneous(self) -> None: self.assertIsInstance(result, SampleLoopInputs) assert isinstance(result.nodes_to_sample, Mapping) self.assertEqual(set(result.nodes_to_sample.keys()), {_USER}) - self.assert_tensor_equality(result.nodes_to_sample[_USER], torch.tensor([1, 2])) + self.assert_tensor_equality(result.nodes_to_sample[_USER], torch.tensor([1, 2])) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. self.assertEqual(result.metadata, {}) diff --git a/tests/unit/distributed/utils/degree_test.py b/tests/unit/distributed/utils/degree_test.py index 6836e4581..ffcb6e5a4 100644 --- a/tests/unit/distributed/utils/degree_test.py +++ b/tests/unit/distributed/utils/degree_test.py @@ -81,7 +81,7 @@ def test_heterogeneous_graph(self): for edge_type, edge_index in edge_indices.items(): num_nodes = int(edge_index[0].max().item() + 1) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result[edge_type], expected) + self.assert_tensor_equality(result[edge_type], expected) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. def test_heterogeneous_graph_with_missing_topology(self): """Test that edge types with missing topology get empty tensors. @@ -268,7 +268,7 @@ def test_degree_tensor_heterogeneous(self): for edge_type, edge_index in edge_indices.items(): num_nodes = int(edge_index[0].max().item() + 1) expected = _compute_expected_degrees_from_edge_index(edge_index, num_nodes) - self.assert_tensor_equality(result[edge_type], expected) + self.assert_tensor_equality(result[edge_type], expected) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. class TestHelperFunctions(TestCase): diff --git a/tests/unit/nn/models_test.py b/tests/unit/nn/models_test.py index 2e1ff65cd..0d62bf104 100644 --- a/tests/unit/nn/models_test.py +++ b/tests/unit/nn/models_test.py @@ -85,7 +85,7 @@ def test_forward_heterogeneous_with_node_types(self): assert isinstance(result, dict) self.assertEqual(set(result.keys()), set(output_node_types)) for node_type in output_node_types: - self.assert_tensor_equality(result[node_type], torch.tensor([1.0, 2.0])) + self.assert_tensor_equality(result[node_type], torch.tensor([1.0, 2.0])) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. def test_forward_heterogeneous_missing_node_types(self): encoder = DummyEncoder() diff --git a/tests/unit/src/common/modeling_task_spec_utils/early_stop_test.py b/tests/unit/src/common/modeling_task_spec_utils/early_stop_test.py index 083625ffe..77567b146 100644 --- a/tests/unit/src/common/modeling_task_spec_utils/early_stop_test.py +++ b/tests/unit/src/common/modeling_task_spec_utils/early_stop_test.py @@ -93,7 +93,7 @@ def test_early_stopping( for step_num, value in enumerate(mocked_criteria_values): has_metric_improved, should_early_stop = early_stopper.step(value=value) if model is not None: - model.foo += 1 # https://github.com/Snapchat/GiGL/issues/408 + model.foo += 1 # https://github.com/Snapchat/GiGL/issues/408 # ty: ignore[unsupported-operator] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. if step_num in improvement_steps: self.assertTrue(has_metric_improved) else: diff --git a/tests/unit/src/common/models/layers/count_min_sketch_test.py b/tests/unit/src/common/models/layers/count_min_sketch_test.py index de837d135..244e98719 100644 --- a/tests/unit/src/common/models/layers/count_min_sketch_test.py +++ b/tests/unit/src/common/models/layers/count_min_sketch_test.py @@ -13,7 +13,7 @@ def test_count(self): # Initialize the CountMinSketch object cms = CountMinSketch(width=20, depth=5) candidate_ids = torch.tensor([1, 2, 2, 3, 3, 3, 4, 4, 4, 4], dtype=torch.long) - cms.add_torch_long_tensor(candidate_ids) + cms.add_torch_long_tensor(candidate_ids) # ty: ignore[invalid-argument-type] TODO(ty-torch-tensor-specialization): fix ty Tensor vs FloatTensor/LongTensor specialization. # Check the total count self.assertEqual(cms.total(), 10) # Check the estimated count diff --git a/tests/unit/src/training/lib/data_loaders/combined_iterable_dataset_test.py b/tests/unit/src/training/lib/data_loaders/combined_iterable_dataset_test.py index 9892ee745..0443e7508 100644 --- a/tests/unit/src/training/lib/data_loaders/combined_iterable_dataset_test.py +++ b/tests/unit/src/training/lib/data_loaders/combined_iterable_dataset_test.py @@ -78,13 +78,14 @@ def preprocess_raw_sample_fn( loopy_dataset = LoopyIterableDataset(iterable_dataset=tf_dataset) loopy_datasets_map[condensed_node_type_str] = loopy_dataset - dataset = CombinedIterableDatasets(iterable_dataset_map=loopy_datasets_map) + dataset = CombinedIterableDatasets(iterable_dataset_map=loopy_datasets_map) # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. dataset_iter = iter(dataset) for _ in range(15): dataset_sample = next(dataset_iter) self.assertEqual(self._node_types, list(dataset_sample.keys())) self.assertEqual( - self._node_types, [node.type for node in list(dataset_sample.values())] + self._node_types, + [node.type for node in list(dataset_sample.values())], # ty: ignore[unresolved-attribute] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. ) def test_can_load_non_loopy_data(self): @@ -107,13 +108,14 @@ def preprocess_raw_sample_fn( ) datasets_map[condensed_node_type_str] = tf_dataset - dataset = CombinedIterableDatasets(iterable_dataset_map=datasets_map) + dataset = CombinedIterableDatasets(iterable_dataset_map=datasets_map) # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. dataset_iter = iter(dataset) for _ in range(10): dataset_sample = next(dataset_iter) self.assertEqual(self._node_types, list(dataset_sample.keys())) self.assertEqual( - self._node_types, [node.type for node in list(dataset_sample.values())] + self._node_types, + [node.type for node in list(dataset_sample.values())], # ty: ignore[unresolved-attribute] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. ) with self.assertRaises(RuntimeError): dataset_sample = next(dataset_iter) diff --git a/tests/unit/test_assets/test_dataset_test.py b/tests/unit/test_assets/test_dataset_test.py index 2a6d18643..8c3672965 100644 --- a/tests/unit/test_assets/test_dataset_test.py +++ b/tests/unit/test_assets/test_dataset_test.py @@ -84,8 +84,8 @@ def test_with_default_edge_indices(self) -> None: # Verify node counts (5 users, 5 stories in default graph) node_ids = dataset.node_ids assert isinstance(node_ids, dict) - self.assertEqual(node_ids[USER].shape[0], 5) - self.assertEqual(node_ids[STORY].shape[0], 5) + self.assertEqual(node_ids[USER].shape[0], 5) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + self.assertEqual(node_ids[STORY].shape[0], 5) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. # Verify feature info is None when no features provided self.assertIsNone(dataset.node_feature_info) @@ -119,8 +119,8 @@ def test_with_custom_parameters(self) -> None: # Verify node counts from custom edge indices node_ids = dataset.node_ids assert isinstance(node_ids, dict) - self.assertEqual(node_ids[USER].shape[0], 3) - self.assertEqual(node_ids[STORY].shape[0], 3) + self.assertEqual(node_ids[USER].shape[0], 3) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + self.assertEqual(node_ids[STORY].shape[0], 3) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. # Verify feature dimension from custom features expected_feature_info = { @@ -185,9 +185,9 @@ def test_basic_dataset_with_splits(self) -> None: assert isinstance(test_node_ids, dict) # Verify split sizes - self.assertEqual(train_node_ids[USER].shape[0], 3) - self.assertEqual(val_node_ids[USER].shape[0], 1) - self.assertEqual(test_node_ids[USER].shape[0], 1) + self.assertEqual(train_node_ids[USER].shape[0], 3) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + self.assertEqual(val_node_ids[USER].shape[0], 1) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + self.assertEqual(test_node_ids[USER].shape[0], 1) # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. def test_missing_positive_labels_raises_error(self) -> None: """Test that missing positive labels for split nodes raises an error.""" diff --git a/tests/unit/types_tests/graph_test.py b/tests/unit/types_tests/graph_test.py index fa68667f0..999db4299 100644 --- a/tests/unit/types_tests/graph_test.py +++ b/tests/unit/types_tests/graph_test.py @@ -303,7 +303,8 @@ def test_treat_supervision_edges_as_graph_edges( self.assertEqual(graph_tensors.edge_index.keys(), expected_edge_index.keys()) for edge_type, expected_tensor in expected_edge_index.items(): torch.testing.assert_close( - graph_tensors.edge_index[edge_type], expected_tensor + graph_tensors.edge_index[edge_type], # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + expected_tensor, ) def test_select_label_edge_types(self):