<a href="https://colab.research.google.com/github/ayagup/stablediffusion/blob/main/hf_tpu_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.runtime as xr
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoConfig,
    get_linear_schedule_with_warmup
)
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.metrics import accuracy_score
import time

# Check TPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"XLA version: {torch_xla.__version__}")

device = xm.xla_device()
print(f"Using device: {device}")

try:
    world_size = xr.world_size()
    print(f"Number of TPU cores: {world_size}")
except:
    print("World size not available, but TPU is working")


In [None]:

# Custom dataset for text classification
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }


In [None]:

# Generate synthetic text data for training
def create_synthetic_text_data(num_samples=1000):
    """Create synthetic text classification data"""

    # Simple sentiment-like data
    positive_words = ['good', 'great', 'excellent', 'amazing', 'wonderful', 'fantastic', 'love', 'perfect']
    negative_words = ['bad', 'terrible', 'awful', 'horrible', 'hate', 'worst', 'disappointing', 'poor']
    neutral_words = ['okay', 'fine', 'average', 'normal', 'standard', 'typical', 'regular', 'common']

    texts = []
    labels = []

    for _ in range(num_samples):
        # Generate random sentences
        label = np.random.choice([0, 1, 2])  # 0: negative, 1: neutral, 2: positive

        if label == 0:  # negative
            words = np.random.choice(negative_words, size=np.random.randint(3, 8))
            text = f"This is {' '.join(words)} and not recommended."
        elif label == 1:  # neutral
            words = np.random.choice(neutral_words, size=np.random.randint(3, 8))
            text = f"This seems {' '.join(words)} to me."
        else:  # positive
            words = np.random.choice(positive_words, size=np.random.randint(3, 8))
            text = f"This is {' '.join(words)} and highly recommended!"

        texts.append(text)
        labels.append(label)

    return texts, labels


In [None]:

# Download and setup model
def setup_model_and_tokenizer(model_name="distilbert-base-uncased", num_labels=3):
    """Download and setup model and tokenizer from Hugging Face"""

    print(f"Downloading model: {model_name}")

    # Download tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Download model configuration
    config = AutoConfig.from_pretrained(model_name)
    config.num_labels = num_labels

    # Download model
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        config=config
    )

    print(f"✓ Model and tokenizer downloaded successfully")
    print(f"Model config: {config}")

    return model, tokenizer, config


In [None]:

# Training function
def train_model():
    """Train the Hugging Face model on TPU"""

    # Setup model and tokenizer
    model, tokenizer, config = setup_model_and_tokenizer(
        model_name="distilbert-base-uncased",
        num_labels=3
    )

    # Move model to TPU
    model = model.to(device)
    print("✓ Model moved to TPU")

    # Create synthetic dataset
    print("Creating synthetic dataset...")
    train_texts, train_labels = create_synthetic_text_data(800)
    val_texts, val_labels = create_synthetic_text_data(200)

    print(f"Training samples: {len(train_texts)}")
    print(f"Validation samples: {len(val_texts)}")
    print(f"Sample text: {train_texts[0]}")
    print(f"Sample label: {train_labels[0]}")

    # Create datasets
    train_dataset = TextDataset(train_texts, train_labels, tokenizer)
    val_dataset = TextDataset(val_texts, val_labels, tokenizer)

    # Create data loaders
    batch_size = 8
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Create TPU-optimized data loaders
    train_loader = pl.ParallelLoader(train_loader, [device])
    val_loader = pl.ParallelLoader(val_loader, [device])

    # Setup optimizer and scheduler
    num_epochs = 3
    num_training_steps = len(train_loader.per_device_loader(device)) * num_epochs

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps
    )

    print("Starting training...")

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        correct_predictions = 0
        total_predictions = 0

        print(f"\nEpoch {epoch + 1}/{num_epochs}")

        for batch_idx, batch in enumerate(train_loader.per_device_loader(device)):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Forward pass
            optimizer.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            logits = outputs.logits

            # Backward pass
            loss.backward()
            xm.optimizer_step(optimizer)  # TPU-optimized step
            scheduler.step()

            # Calculate accuracy
            predictions = torch.argmax(logits, dim=-1)
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)
            total_loss += loss.item()

            if batch_idx % 20 == 0:
                current_acc = correct_predictions / total_predictions if total_predictions > 0 else 0
                print(f"Batch {batch_idx}, Loss: {loss.item():.4f}, Accuracy: {current_acc:.4f}")

        # Epoch summary
        avg_loss = total_loss / len(train_loader.per_device_loader(device))
        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
        print(f"Epoch {epoch + 1} - Avg Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

        # Validation
        if (epoch + 1) % 1 == 0:
            val_accuracy = evaluate_model(model, val_loader, device)
            print(f"Validation Accuracy: {val_accuracy:.4f}")

    print("✓ Training completed!")
    return model, tokenizer


In [None]:

# Evaluation function
def evaluate_model(model, val_loader, device):
    """Evaluate model on validation set"""
    model.eval()
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader.per_device_loader(device):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=-1)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_predictions)
    return accuracy


In [None]:

# Inference function
def perform_inference(model, tokenizer, texts):
    """Perform inference on new texts"""
    model.eval()

    label_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
    results = []

    print("\n=== Performing Inference ===")

    with torch.no_grad():
        for text in texts:
            # Tokenize input
            encoding = tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=128,
                return_tensors='pt'
            )

            # Move to device
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(predictions, dim=-1).item()
            confidence = predictions[0][predicted_class].item()

            result = {
                'text': text,
                'predicted_label': label_map[predicted_class],
                'confidence': confidence,
                'all_scores': predictions.cpu().numpy()[0]
            }
            results.append(result)

            print(f"Text: '{text}'")
            print(f"Prediction: {label_map[predicted_class]} (confidence: {confidence:.4f})")
            print("-" * 50)

    return results


In [None]:

# Main execution function
def main():
    try:
        print("=== Hugging Face Model Training on TPU ===")

        # Train the model
        trained_model, tokenizer = train_model()

        # Test inference with sample texts
        test_texts = [
            "This movie is absolutely wonderful and amazing!",
            "The product is terrible and disappointing.",
            "It's an okay product, nothing special.",
            "I love this book, it's fantastic!",
            "The service was bad and horrible.",
            "This is a normal and average experience."
        ]

        # Perform inference
        inference_results = perform_inference(trained_model, tokenizer, test_texts)

        print("\n=== Inference Results Summary ===")
        for i, result in enumerate(inference_results):
            print(f"{i+1}. '{result['text']}' -> {result['predicted_label']} ({result['confidence']:.3f})")

        print("\n✓ Hugging Face PyTorch TPU example completed successfully!")

        return trained_model, tokenizer, inference_results

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()


In [None]:

# Test basic TPU functionality first
def test_basic_tpu():
    """Test basic TPU functionality"""
    print("=== Testing Basic TPU Operations ===")

    # Basic tensor operations
    x = torch.randn(3, 3, device=device)
    y = torch.randn(3, 3, device=device)
    z = torch.matmul(x, y)

    print(f"✓ Basic tensor operations work on TPU")
    print(f"Device: {device}")
    print(f"Result shape: {z.shape}")

    return True


In [None]:
test_basic_tpu()

In [None]:

if __name__ == "__main__":
    # Test TPU first
    test_basic_tpu()

    # Run main training and inference
    model, tokenizer, results = main()