In [None]:
# Clean up existing installations
!pip uninstall -y torch torch_xla torchvision torchaudio

# Install PyTorch 2.1.0 and matching XLA version
!pip install --no-cache-dir torch==2.1.0
!pip install --no-cache-dir "torch_xla[tpu]>=2.1.0" -f https://storage.googleapis.com/libtpu-releases/index.html

# Install supporting packages
!pip install torchvision==0.16.0
!pip install -q transformers==4.35.2 datasets==2.14.5 seqeval==1.2.2
!pip install -q pandas pyarrow

# Configure environment
import os
os.environ['XLA_USE_BF16'] = "1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = "100000000"

print("\nInstallation complete. Please restart the runtime now and run the verification code.")

# Check installation status
!python3 -c "import torch; import torch_xla; print(f'PyTorch version: {torch.__version__}'); print(f'XLA version: {torch_xla.__version__}')"

Found existing installation: torch 2.0.0
Uninstalling torch-2.0.0:
  Successfully uninstalled torch-2.0.0
Found existing installation: torch-xla 2.5.1
Uninstalling torch-xla-2.5.1:
  Successfully uninstalled torch-xla-2.5.1
[0mCollecting torch==2.1.0
  Downloading torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl.metadata (25 kB)
Collecting triton==2.1.0 (from torch==2.1.0)
  Downloading triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)
Downloading torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m218.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (89.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.2/89.2 MB[0m [31m207.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: triton, torch
  Attempting uninstall: triton
    Found existing instal

Looking in links: https://storage.googleapis.com/libtpu-releases/index.html
Collecting torch_xla>=2.1.0 (from torch_xla[tpu]>=2.1.0)
  Downloading torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (17 kB)
Downloading torch_xla-2.5.1-cp310-cp310-manylinux_2_28_x86_64.whl (90.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.6/90.6 MB[0m [31m209.7 MB/s[0m eta [36m0:00:00[0m
[?25hTraceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/pkg_resources/__init__.py", line 3070, in _dep_map
    return self.__dep_map
  File "/usr/local/lib/python3.10/dist-packages/pip/_vendor/pkg_resources/__init__.py", line 2863, in __getattr__
    raise AttributeError(attr)
AttributeError: _DistInfoDistribution__dep_map

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/pip/_internal/cli/base_command.py", line 179, in exc_logging_wr

In [1]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

def verify_tpu_setup():
    try:
        print(f"PyTorch version: {torch.__version__}")
        print(f"PyTorch XLA version: {torch_xla.__version__}")

        # Initialize TPU device
        device = xm.xla_device()
        print(f"XLA device: {device}")

        # Run a simple test
        input_tensor = torch.randn(3, 3)
        device_tensor = input_tensor.to(device)
        result = device_tensor @ device_tensor
        xm.mark_step()

        print("\nTest computation result:")
        print(result)

        print("\nTPU setup successful!")
        return True

    except Exception as e:
        print(f"\nError during TPU setup: {str(e)}")
        print("\nTroubleshooting steps:")
        print("1. Verify TPU runtime is selected in Runtime -> Change runtime type")
        print("2. Make sure you've restarted the runtime after installation")
        print("3. Try the following command to check TPU availability:")
        print("   !python -c 'import torch_xla; print(torch_xla.__version__)'")
        return False

# Additional TPU system information
def print_tpu_info():
    print("\nTPU System Information:")
    print(f"TPU Runtime Version: {os.environ.get('TPU_RUNTIME_VERSION', 'Not available')}")
    print(f"XRT TPU Config: {os.environ.get('XRT_TPU_CONFIG', 'Not available')}")
    if xm.xrt_world_size() > 1:
        print(f"Number of TPU cores: {xm.xrt_world_size()}")

# Run verification
verify_tpu_setup()
try:
    print_tpu_info()
except:
    print("Could not retrieve additional TPU information")

ModuleNotFoundError: No module named 'torch_xla'

In [None]:
"""
Azerbaijani Named Entity Recognition (NER) Pipeline - TPU Version
License: CC BY-NC-ND 4.0
"""

import os
import multiprocessing
import logging
import warnings
import json
import ast
from datetime import datetime
from typing import List, Dict, Tuple, Optional
from pathlib import Path
import sys
import traceback
import numpy as np
import pandas as pd

# TPU-specific imports
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
from torch.utils.data import DataLoader

from datasets import Dataset, DatasetDict, Features, Sequence, Value
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from seqeval.metrics import f1_score, precision_score, recall_score

# Disable wandb
os.environ["WANDB_DISABLED"] = "true"

# Filter warnings
warnings.filterwarnings('ignore', category=FutureWarning)
warnings.filterwarnings('ignore', category=UserWarning)

# TPU-specific constants
TPU_CORES = 8
OPTIMAL_NUM_WORKERS = 4  # Adjusted for TPU

# Set up logging
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO,
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('training.log')
    ]
)

# Entity type definitions
ENTITY_TYPES = {
    0: "O",           # Outside any named entity
    1: "PERSON",      # Names of individuals
    2: "LOCATION",    # Geographical locations
    3: "ORGANISATION",# Names of companies, institutions
    4: "DATE",        # Dates or periods
    5: "TIME",        # Times of the day
    6: "MONEY",       # Monetary values
    7: "PERCENTAGE",  # Percentage values
    8: "FACILITY",    # Buildings, airports, etc.
    9: "PRODUCT",     # Products and goods
    10: "EVENT",      # Events and occurrences
    11: "ART",        # Artworks, titles of books, songs
    12: "LAW",        # Legal documents
    13: "LANGUAGE",   # Languages
    14: "GPE",        # Countries, cities, states
    15: "NORP",       # Nationalities or religious or political groups
    16: "ORDINAL",    # Ordinal numbers
    17: "CARDINAL",   # Cardinal numbers
    18: "DISEASE",    # Diseases and medical conditions
    19: "CONTACT",    # Contact information
    20: "ADAGE",      # Proverbs, sayings
    21: "QUANTITY",   # Measurements and quantities
    22: "MISCELLANEOUS", # Miscellaneous entities
    23: "POSITION",   # Professional or social positions
    24: "PROJECT"     # Names of projects or programs
}

def get_training_args(output_dir: str) -> TrainingArguments:
    """Get TPU-optimized training arguments"""
    return TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="steps",
        eval_steps=100,
        learning_rate=2e-5,
        per_device_train_batch_size=32,  # Increased for TPU
        per_device_eval_batch_size=32,   # Increased for TPU
        num_train_epochs=5,
        weight_decay=0.01,
        push_to_hub=False,
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        logging_dir=os.path.join(output_dir, 'logs'),
        logging_steps=50,
        report_to="none",
        save_strategy="steps",
        save_steps=100,
        save_total_limit=2,
        warmup_steps=500,
        fp16=False,  # TPU doesn't need fp16
        dataloader_num_workers=OPTIMAL_NUM_WORKERS,
        group_by_length=True,
        gradient_accumulation_steps=1,  # Adjusted for TPU
        max_grad_norm=1.0,
        tpu_num_cores=TPU_CORES,
        optim="adamw_torch",
    )

def setup_tpu():
    """Initialize TPU device"""
    try:
        device = xm.xla_device()
        logging.info(f"TPU device initialized: {device}")
        return device
    except Exception as e:
        logging.error(f"Failed to initialize TPU: {str(e)}")
        raise


class AzerbaijaniNERPipeline:
    def __init__(self, model_name: str = "bert-base-multilingual-cased", output_dir: str = "az-ner-model"):
        """Initialize the TPU-enabled NER pipeline"""
        self.model_name = model_name
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            model_max_length=512,
            padding_side='right'
        )

        # Initialize label mappings
        self.label2id = {entity_type: idx for idx, entity_type in ENTITY_TYPES.items()}
        self.id2label = {idx: entity_type for idx, entity_type in ENTITY_TYPES.items()}

        # Set up TPU device
        self.device = setup_tpu()
        logging.info(f"Using device: {self.device}")

        # Initialize statistics tracking
        self.stats = {
            'total_examples': 0,
            'valid_examples': 0,
            'invalid_examples': 0,
            'tag_distribution': {idx: 0 for idx in ENTITY_TYPES.keys()}
        }

    def process_row(self, row: Dict) -> Optional[Dict]:
        """Process a single data row"""
        try:
            self.stats['total_examples'] += 1

            tokens = row['tokens']
            tags = row['ner_tags']

            if isinstance(tokens, str):
                tokens = ast.literal_eval(tokens)
            if isinstance(tags, str):
                tags = ast.literal_eval(tags)

            if len(tokens) != len(tags):
                self.stats['invalid_examples'] += 1
                return None

            cleaned_tags = []
            for tag in tags:
                tag = int(tag) if isinstance(tag, (int, str)) and str(tag).isdigit() else 0
                if tag not in ENTITY_TYPES:
                    tag = 0
                cleaned_tags.append(tag)
                self.stats['tag_distribution'][tag] += 1

            self.stats['valid_examples'] += 1
            return {
                'tokens': tokens,
                'ner_tags': cleaned_tags
            }
        except Exception as e:
            self.stats['invalid_examples'] += 1
            return None

    def load_dataset(self, data_path: str) -> DatasetDict:
        """Load and prepare the dataset with TPU optimization"""
        logging.info(f"Loading dataset from {data_path}")

        self.stats = {
            'total_examples': 0,
            'valid_examples': 0,
            'invalid_examples': 0,
            'tag_distribution': {idx: 0 for idx in ENTITY_TYPES.keys()}
        }

        try:
            df = pd.read_parquet(data_path)
            logging.info(f"Loaded {len(df)} rows from {data_path}")

            # Process rows in parallel for TPU
            with multiprocessing.Pool(OPTIMAL_NUM_WORKERS) as pool:
                processed_data = list(filter(None, pool.map(self.process_row, df.to_dict('records'))))

            dataset = Dataset.from_pandas(
                pd.DataFrame(processed_data),
                features=Features({
                    'tokens': Sequence(Value('string')),
                    'ner_tags': Sequence(Value('int64'))
                })
            )

            train_test = dataset.train_test_split(test_size=0.2, seed=42)
            test_valid = train_test['test'].train_test_split(test_size=0.5, seed=42)

            dataset_dict = DatasetDict({
                'train': train_test['train'],
                'validation': test_valid['train'],
                'test': test_valid['test']
            })

            self._log_statistics(dataset_dict)
            return dataset_dict

        except Exception as e:
            logging.error(f"Error loading dataset: {str(e)}")
            raise

    def _log_statistics(self, dataset_dict: DatasetDict):
        """Log dataset statistics"""
        logging.info("\nDataset Statistics:")
        logging.info(f"Total examples processed: {self.stats['total_examples']}")
        logging.info(f"Valid examples: {self.stats['valid_examples']}")
        logging.info(f"Invalid examples: {self.stats['invalid_examples']}")

        logging.info("\nTag Distribution:")
        for tag_id, count in self.stats['tag_distribution'].items():
            logging.info(f"{ENTITY_TYPES[tag_id]}: {count}")

        logging.info("\nDataset Splits:")
        for split, ds in dataset_dict.items():
            logging.info(f"{split} set size: {len(ds)}")

    def tokenize_and_align_labels(self, examples: Dict) -> Dict:
        """Tokenize and align labels with TPU optimization"""
        tokenized_inputs = self.tokenizer(
            examples["tokens"],
            truncation=True,
            is_split_into_words=True,
            max_length=512,  # TPU optimized length
            padding="max_length"
        )

        labels = []
        for i, label in enumerate(examples["ner_tags"]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            previous_word_idx = None
            label_ids = []

            for word_idx in word_ids:
                if word_idx is None:
                    label_ids.append(-100)
                elif word_idx != previous_word_idx:
                    label_ids.append(label[word_idx])
                else:
                    label_ids.append(-100)
                previous_word_idx = word_idx

            labels.append(label_ids)

        tokenized_inputs["labels"] = labels
        return tokenized_inputs

    def compute_metrics(self, eval_preds: Tuple[np.ndarray, np.ndarray]) -> Dict[str, float]:
        """Compute evaluation metrics with TPU support"""
        predictions, labels = eval_preds
        predictions = np.argmax(predictions, axis=2)

        true_predictions = [
            [ENTITY_TYPES[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        true_labels = [
            [ENTITY_TYPES[l] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]

        metrics = {
            "precision": precision_score(true_labels, true_predictions),
            "recall": recall_score(true_labels, true_predictions),
            "f1": f1_score(true_labels, true_predictions)
        }

        for entity_type in set(ENTITY_TYPES.values()) - {'O'}:
            entity_preds = [[p == entity_type for p in pred] for pred in true_predictions]
            entity_labels = [[l == entity_type for l in label] for label in true_labels]

            try:
                metrics[f"{entity_type}_f1"] = f1_score(entity_labels, entity_preds)
            except:
                metrics[f"{entity_type}_f1"] = 0.0

        # Sync metrics across TPU cores
        metrics = {k: xm.mesh_reduce('metrics', v, np.mean) for k, v in metrics.items()}
        return metrics

    def train(self, dataset_dict: DatasetDict) -> Trainer:
        """Train the model with TPU optimization"""
        logging.info("Initializing model for TPU training...")

        model = AutoModelForTokenClassification.from_pretrained(
            self.model_name,
            num_labels=len(ENTITY_TYPES),
            id2label=self.id2label,
            label2id=self.label2id
        ).to(self.device)

        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logging.info(f"Model parameters: {total_params:,} (trainable: {trainable_params:,})")

        tokenized_datasets = dataset_dict.map(
            self.tokenize_and_align_labels,
            batched=True,
            remove_columns=dataset_dict["train"].column_names,
            num_proc=OPTIMAL_NUM_WORKERS,
            load_from_cache_file=False
        )

        training_args = get_training_args(str(self.output_dir))

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_datasets["train"],
            eval_dataset=tokenized_datasets["validation"],
            tokenizer=self.tokenizer,
            data_collator=DataCollatorForTokenClassification(
                self.tokenizer,
                pad_to_multiple_of=8
            ),
            compute_metrics=self.compute_metrics,
            callbacks=[
                EarlyStoppingCallback(
                    early_stopping_patience=3,
                    early_stopping_threshold=0.01
                )
            ]
        )

        try:
            # TPU-specific training loop
            def train_func(rank):
                train_result = trainer.train()
                xm.mark_step()
                return train_result

            train_result = xmp.spawn(train_func, nprocs=TPU_CORES, start_method='fork')
            metrics = train_result[0].metrics

            trainer.save_metrics("train", metrics)
            xm.save(model.state_dict(), str(self.output_dir / "pytorch_model.bin"))
            self.tokenizer.save_pretrained(str(self.output_dir))

            logging.info(f"Training metrics: {metrics}")

        except Exception as e:
            logging.error(f"Training error: {str(e)}")
            raise

        return trainer

def main():
    """Main function to run the TPU pipeline"""
    if not xm.xla_device_hw() == 'TPU':
        raise RuntimeError("TPU device not found. Please ensure you're running in a TPU runtime.")

    start_time = datetime.now()

    try:
        pipeline = AzerbaijaniNERPipeline()
        data_path = "train-00000-of-00001.parquet"
        dataset_dict = pipeline.load_dataset(data_path)

        def training_loop():
            trainer = pipeline.train(dataset_dict)
            return trainer

        trainer = xmp.spawn(training_loop, nprocs=TPU_CORES, start_method='fork')[0]

        def evaluation_loop():
            test_results = trainer.evaluate(
                dataset_dict["test"].map(
                    pipeline.tokenize_and_align_labels,
                    batched=True,
                    remove_columns=dataset_dict["test"].column_names
                )
            )
            xm.mark_step()
            return test_results

        test_results = xmp.spawn(evaluation_loop, nprocs=TPU_CORES, start_method='fork')[0]

        test_results["timestamp"] = datetime.now().isoformat()
        test_results["training_duration"] = str(datetime.now() - start_time)

        results_path = pipeline.output_dir / f"test_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(results_path, "w", encoding="utf-8") as f:
            json.dump(test_results, f, indent=2, ensure_ascii=False)

        logging.info(f"Training completed successfully.")
        logging.info(f"Results saved to {results_path}")
        logging.info(f"Total training time: {datetime.now() - start_time}")

    except Exception as e:
        logging.error(f"Pipeline error: {str(e)}", exc_info=True)
        raise

if __name__ == "__main__":
    main()