<a href="https://colab.research.google.com/github/Ak-N-G/conferenceproject/blob/main/conferenceproject.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Conference Project: Image Classification Models

**Models:** Vision Transformer, Swin Transformer, ConvNet, Compressed Vision Transformer

**Workflow:** Dataset loading, layer visualization, model training, saving graphs/outputs, evaluating metrics, running on a custom dataset.

**Deadline:**
- Abstract: 15th August
- Paper Completion: 30th August


## 1. Environment Setup & Library Installation

In [None]:
!pip install torch torchvision timm transformers scikit-learn pandas matplotlib seaborn openpyxl

## 2. Mount Google Drive (for saving weights, graphs, etc.)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 3. Imports & Global Configs

In [None]:
import os
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import timm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import numpy as np
from tqdm import tqdm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64

# Directories
RESULTS_DIR = '/content/drive/MyDrive/ConferenceProject/'
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(os.path.join(RESULTS_DIR, 'weights'), exist_ok=True)
os.makedirs(os.path.join(RESULTS_DIR, 'graphs'), exist_ok=True)

## 4. Dataset Preparation

In [None]:
# Example: CIFAR-10, replace with your custom dataset if needed
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Required for ViT/Swin
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

train_data = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_data = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Class labels
CLASSES = train_data.classes

### Visualize Some Sample Images

In [None]:
plt.figure(figsize=(10,5))
for i, (img, label) in enumerate(train_loader):
  if i == 1:
    break
  for j in range(8):
    plt.subplot(2,4,j+1)
    plt.imshow((img[j].permute(1,2,0) * 0.5 + 0.5).numpy())
    plt.title(CLASSES[label[j]])
    plt.axis('off')
plt.tight_layout()
plt.show()

## 5. Helper Functions (Training, Evaluation, Plotting, Saving)

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=10, model_name='model'):
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    train_loss, val_loss = [], []
    train_acc, val_acc = [], []
    best_val_acc = 0.0
    best_weights = None

    for epoch in range(num_epochs):
        model.train()
        epoch_loss, correct, total = 0, 0, 0
        for imgs, labels in tqdm(train_loader):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        train_loss.append(epoch_loss/len(train_loader))
        train_acc.append(correct/total)

        # Validation
        model.eval()
        val_loss_epoch, correct, total = 0, 0, 0
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_loss_epoch += loss.item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        val_loss.append(val_loss_epoch/len(val_loader))
        val_acc.append(correct/total)

        # Save best model
        if val_acc[-1] > best_val_acc:
            best_val_acc = val_acc[-1]
            best_weights = model.state_dict()
            torch.save(best_weights, os.path.join(RESULTS_DIR, 'weights', f'{model_name}_best.pt'))
        scheduler.step()
        print(f"Epoch {epoch+1}: Train Acc={train_acc[-1]:.3f} Val Acc={val_acc[-1]:.3f}")

    # Plot
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1)
    plt.plot(train_loss, label='Train Loss')
    plt.plot(val_loss, label='Val Loss')
    plt.legend(); plt.title('Loss')
    plt.subplot(1,2,2)
    plt.plot(train_acc, label='Train Acc')
    plt.plot(val_acc, label='Val Acc')
    plt.legend(); plt.title('Accuracy')
    plt.savefig(os.path.join(RESULTS_DIR, 'graphs', f'{model_name}_loss_acc.png'))
    plt.show()
    return best_weights

def evaluate_and_export(model, data_loader, model_name):
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    with torch.no_grad():
        for imgs, labels in data_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            outputs = model(imgs)
            probs = torch.softmax(outputs, dim=1)
            preds = outputs.argmax(dim=1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    print(f"Accuracy={acc:.3f} Precision={prec:.3f} Recall={recall:.3f} F1={f1:.3f}")

    # Export probability matrix to Excel
    df = pd.DataFrame(all_probs, columns=CLASSES)
    df.to_excel(os.path.join(RESULTS_DIR, f'{model_name}_prob_matrix.xlsx'), index=False)

    # Plot Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=CLASSES, yticklabels=CLASSES)
    plt.xlabel('Predicted'); plt.ylabel('True'); plt.title('Confusion Matrix')
    plt.savefig(os.path.join(RESULTS_DIR, 'graphs', f'{model_name}_confusion.png'))
    plt.show()

## 6. Vision Transformer (ViT)


In [None]:
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=len(CLASSES)).to(DEVICE)
best_vit_weights = train_model(vit, train_loader, test_loader, num_epochs=10, model_name='vit')
vit.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'vit_best.pt')))
evaluate_and_export(vit, test_loader, 'vit')
# To visualize attention: see timm/transformers/vit docs (optional, advanced)


## 7. Swin Transformer


In [None]:
# Reduce batch size for Swin Transformer training on CIFAR-10 to mitigate potential VRAM issues
swin = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=len(CLASSES)).to(DEVICE)
best_swin_weights = train_model(swin, torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2), torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2), num_epochs=10, model_name='swin')
swin.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'swin_best.pt')))
evaluate_and_export(swin, torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2), 'swin')

## 8. ConvNet (ResNet as Example)

In [None]:
resnet = torchvision.models.resnet18(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features, len(CLASSES))
resnet = resnet.to(DEVICE)
best_resnet_weights = train_model(resnet, train_loader, test_loader, num_epochs=10, model_name='resnet')
resnet.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'resnet_best.pt')))
evaluate_and_export(resnet, test_loader, 'resnet')

## 9. Compressed Vision Transformer (TinyViT/Slim ViT/Any Efficient Transformer)

In [None]:
# Example with a tiny ViT variant (replace with any lighter ViT available)
compressed_vit = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=len(CLASSES)).to(DEVICE)
best_compressed_weights = train_model(compressed_vit, train_loader, test_loader, num_epochs=10, model_name='compressed_vit')
compressed_vit.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'compressed_vit_best.pt')))
evaluate_and_export(compressed_vit, test_loader, 'compressed_vit')

## 10. Visualizing Feature Maps (ConvNet Example)


In [None]:
def visualize_conv_features(model, img_tensor):
    activations = {}
    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook

    layer = model.layer1[0]
    hook_handle = layer.register_forward_hook(get_activation('conv1'))
    model.eval()
    with torch.no_grad():
        _ = model(img_tensor.unsqueeze(0).to(DEVICE))
    act = activations['conv1'].cpu()[0]

    plt.figure(figsize=(12,6))
    for i in range(8):
        plt.subplot(2,4,i+1)
        plt.imshow(act[i].detach().numpy(), cmap='viridis')
        plt.axis('off')
    plt.tight_layout()
    plt.show()
    hook_handle.remove()

# Example: Visualize features for one training image
sample_img, _ = train_data[0]
visualize_conv_features(resnet, sample_img)

## 11. Custom Dataset: How to Use Your Own Data

Replace the dataset loading section above with code to read images from `/custom_dataset/train/class1/`, `/custom_dataset/val/class1/`, etc. using `datasets.ImageFolder`.

In [None]:
# Sample loader for your custom dataset (adjust path/structure as per your folders)
CUSTOM_TRAIN = '/content/drive/MyDrive/ConferenceProject/custom_dataset/train'
CUSTOM_VAL = '/content/drive/MyDrive/ConferenceProject/custom_dataset/val'
train_custom = datasets.ImageFolder(CUSTOM_TRAIN, transform=transform)
val_custom = datasets.ImageFolder(CUSTOM_VAL, transform=transform)
train_loader_custom = torch.utils.data.DataLoader(train_custom, batch_size=BATCH_SIZE, shuffle=True)
val_loader_custom = torch.utils.data.DataLoader(val_custom, batch_size=BATCH_SIZE, shuffle=False)

# Repeat training/evaluation with custom dataloader, e.g., replacing train_loader, test_loader above

# 12. Using the Animal Image Classification Dataset

In [None]:
# Animal Image Classification Dataset



import kagglehub

# Download the dataset
path = kagglehub.dataset_download("borhanitrash/animal-image-classification-dataset")

print(f"Dataset downloaded to: {path}")

# Update the custom dataset paths to the downloaded data
CUSTOM_TRAIN = os.path.join(path, 'animal_image_classification_dataset', 'train')
CUSTOM_VAL = os.path.join(path, 'animal_image_classification_dataset', 'val')

# Verify the updated paths exist
if os.path.exists(CUSTOM_TRAIN):
    print(f"Custom train path updated to: {CUSTOM_TRAIN}")
else:
    print(f"Error: Custom train path not found at {CUSTOM_TRAIN}")

if os.path.exists(CUSTOM_VAL):
    print(f"Custom validation path updated to: {CUSTOM_VAL}")
else:
    print(f"Error: Custom validation path not found at {CUSTOM_VAL}")

# Reload the custom datasets with the new paths
train_custom = datasets.ImageFolder(CUSTOM_TRAIN, transform=transform)
val_custom = datasets.ImageFolder(CUSTOM_VAL, transform=transform)

train_loader_custom = torch.utils.data.DataLoader(train_custom, batch_size=BATCH_SIZE, shuffle=True)
val_loader_custom = torch.utils.data.DataLoader(val_custom, batch_size=BATCH_SIZE, shuffle=False)

print(f"Number of training images in custom dataset: {len(train_custom)}")
print(f"Number of validation images in custom dataset: {len(val_custom)}")

# Update the CLASSES variable based on the custom dataset
CLASSES_CUSTOM = train_custom.classes
print(f"Classes in custom dataset: {CLASSES_CUSTOM}")

# 13. Dataset Download (ImageNet from Kaggle)

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("mayurmadnani/imagenet-dataset")

print("Path to dataset files:", path)
# Example: Replace with the actual download URL and desired output path
!wget -O /content/drive/MyDrive/ImageNet/imagenet_localization_train.tar.gz "path"

In [None]:
# Define the path to your ImageNet dataset on Google Drive
IMAGENET_TRAIN_DIR = '/content/drive/MyDrive/path/to/imagenet/train'
IMAGENET_VAL_DIR = '/content/drive/MyDrive/path/to/imagenet/val'

# Verify the paths exist
if not os.path.exists(IMAGENET_TRAIN_DIR):
    print(f"Error: ImageNet training directory not found at {IMAGENET_TRAIN_DIR}")
if not os.path.exists(IMAGENET_VAL_DIR):
    print(f"Error: ImageNet validation directory not found at {IMAGENET_VAL_DIR}")

# Create ImageNet datasets and data loaders
# Note: Loading the full ImageNet can take a significant amount of time and memory.
# You might need to adjust num_workers or other DataLoader parameters based on your Colab instance.
try:
    imagenet_train_data = datasets.ImageFolder(IMAGENET_TRAIN_DIR, transform=transform)
    imagenet_val_data = datasets.ImageFolder(IMAGENET_VAL_DIR, transform=transform)

    imagenet_train_loader = torch.utils.data.DataLoader(imagenet_train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    imagenet_val_loader = torch.utils.data.DataLoader(imagenet_val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # Update the CLASSES variable for ImageNet
    IMAGENET_CLASSES = imagenet_train_data.classes
    print(f"Number of training images in ImageNet: {len(imagenet_train_data)}")
    print(f"Number of validation images in ImageNet: {len(imagenet_val_data)}")
    print(f"Classes in ImageNet: {IMAGENET_CLASSES}")

except Exception as e:
    print(f"An error occurred while loading the ImageNet dataset: {e}")
    print("Please ensure the paths are correct and the dataset is organized in the specified structure.")

## 14. Vision Transformer (ViT) - Training on ImageNet

In [None]:
# Initialize Vision Transformer model
# You might need to adjust the model size or architecture depending on your resources
vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=len(IMAGENET_CLASSES)).to(DEVICE)

# Train the ViT model on ImageNet
# You might need to adjust num_epochs, learning rate, and other training parameters for ImageNet
# Consider adding more frequent checkpointing in train_model for long training
best_vit_weights = train_model(vit, imagenet_train_loader, imagenet_val_loader, num_epochs=10, model_name='vit_imagenet')

# Load the best weights for evaluation
vit.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'vit_imagenet_best.pt')))

# Evaluate the ViT model on ImageNet
# Evaluation on the full ImageNet validation set might take time
evaluate_and_export(vit, imagenet_val_loader, 'vit_imagenet')

# To visualize attention: see timm/transformers/vit docs (optional, advanced)

## 15. Swin Transformer - Training on ImageNet

In [None]:
# Initialize Swin Transformer model
# You might need to adjust the model size or architecture depending on your resources
swin = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=len(IMAGENET_CLASSES)).to(DEVICE)

# Train the Swin Transformer model on ImageNet
# You might need to adjust num_epochs, learning rate, and other training parameters for ImageNet
# Consider adding more frequent checkpointing in train_model for long training
# Reduced BATCH_SIZE to 32 to mitigate VRAM issues
train_loader_swin = torch.utils.data.DataLoader(imagenet_train_data, batch_size=32, shuffle=True, num_workers=2)
val_loader_swin = torch.utils.data.DataLoader(imagenet_val_data, batch_size=32, shuffle=False, num_workers=2)

best_swin_weights = train_model(swin, train_loader_swin, val_loader_swin, num_epochs=10, model_name='swin_imagenet')

# Load the best weights for evaluation
swin.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'swin_imagenet_best.pt')))

# Evaluate the Swin Transformer model on ImageNet
# Evaluation on the full ImageNet validation set might take time
evaluate_and_export(swin, val_loader_swin, 'swin_imagenet')

## 16. ConvNet (ResNet as Example) - Training on ImageNet

In [None]:
# Initialize ConvNet model (ResNet18 as example)
# You might consider a larger ResNet for better performance on ImageNet
resnet = torchvision.models.resnet18(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features, len(IMAGENET_CLASSES))
resnet = resnet.to(DEVICE)

# Train the ResNet model on ImageNet
# You might need to adjust num_epochs, learning rate, and other training parameters for ImageNet
# Consider adding more frequent checkpointing in train_model for long training
best_resnet_weights = train_model(resnet, imagenet_train_loader, imagenet_val_loader, num_epochs=10, model_name='resnet_imagenet')

# Load the best weights for evaluation
resnet.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'resnet_imagenet_best.pt')))

# Evaluate the ResNet model on ImageNet
# Evaluation on the full ImageNet validation set might take time
evaluate_and_export(resnet, imagenet_val_loader, 'resnet_imagenet')

## 17. Compressed Vision Transformer (TinyViT/Slim ViT/Any Efficient Transformer) - Training on ImageNet

In [None]:
# Example with a tiny ViT variant (replace with any lighter ViT available)
# You might need to adjust the model size or architecture depending on your resources
compressed_vit = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=len(IMAGENET_CLASSES)).to(DEVICE)

# Train the compressed ViT model on ImageNet
# You might need to adjust num_epochs, learning rate, and other training parameters for ImageNet
# Consider adding more frequent checkpointing in train_model for long training
best_compressed_weights = train_model(compressed_vit, imagenet_train_loader, imagenet_val_loader, num_epochs=10, model_name='compressed_vit_imagenet')

# Load the best weights for evaluation
compressed_vit.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'compressed_vit_imagenet_best.pt')))

# Evaluate the compressed ViT model on ImageNet
# Evaluation on the full ImageNet validation set might take time
evaluate_and_export(compressed_vit, imagenet_val_loader, 'compressed_vit_imagenet')

## 18. Visualizing Feature Maps (ConvNet Example) - After ImageNet Training

In [None]:
# Visualize feature maps for the ResNet model trained on ImageNet
# Get a sample image from the ImageNet validation dataset
if 'imagenet_val_data' in locals() and len(imagenet_val_data) > 0:
    sample_img_imagenet, _ = imagenet_val_data[0]
    print("Visualizing feature maps for the ResNet model trained on ImageNet...")
    visualize_conv_features(resnet, sample_img_imagenet)
else:
    print("ImageNet validation data not available or empty. Cannot visualize feature maps.")

# Task
Write Python code to train a Swin Transformer model on a TPU, optimized for the TPU architecture. The code should include TPU environment setup, data parallelism using `pytorch/xla`, mixed precision training with `bfloat16`, and adaptations for XLA compatibility.

## Tpu environment setup

### Subtask:
Add code to detect and initialize the TPU environment using `pytorch/xla`.


**Reasoning**:
The subtask requires setting up the TPU environment and initializing it using `pytorch/xla`. This involves importing necessary modules, setting environment variables, and defining a main function to be launched on TPU cores. This can be done in a single code block.



In [None]:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import os

# Define a main training function to be executed on each TPU core
def train_tpu(index, args):
    device = xm.get_device()
    print(f"TPU is available. Using device: {device}")

# Main execution block
if __name__ == '__main__':
    # Set environment variables for bfloat16 support and tensor allocation size
    os.environ['XLA_USE_BF16'] = '1'
    os.environ['XLA_TENSOR_ALLOC_MAXSIZE'] = '100000000'

    # Launch the main training function across available TPU cores
    xmp.spawn(train_tpu, args=({},), nprocs=8) # Assuming 8 TPU cores are available


## Data parallelism with xla

### Subtask:
Modify the data loaders and training loop to distribute the data and model across TPU cores using XLA utilities.


**Reasoning**:
The subtask requires modifying the existing `train_model` and `evaluate_and_export` functions to work with XLA devices and then updating the main execution block to use these modified functions within `xmp.spawn`. This involves adding a device argument to the functions, moving data to the device, replacing `optimizer.step()` with `xm.optimizer_step`, and moving results back to the CPU for metric calculations. I will implement these changes in the functions first and then show how to use them in the `xmp.spawn` block.



In [None]:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel as pctl
import torch_xla.distributed.xla_multiprocessing as xmp


def train_model_xla(model, train_loader, val_loader, device, num_epochs=10, model_name='model'):
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    train_loss, val_loss = [], []
    train_acc, val_acc = [], []
    best_val_acc = 0.0
    best_weights = None

    for epoch in range(num_epochs):
        model.train()
        epoch_loss, correct, total = 0, 0, 0
        # Wrap the data loader for XLA
        para_train_loader = xmp.MpSerialDataLoader(train_loader, device)
        for imgs, labels in tqdm(para_train_loader):
            imgs, labels = imgs.to(device), labels.to(device) # Move data to XLA device
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            # Use xm.optimizer_step for XLA
            xm.optimizer_step(optimizer)

            epoch_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        train_loss.append(epoch_loss/len(train_loader))
        train_acc.append(correct/total)

        # Validation
        model.eval()
        val_loss_epoch, correct, total = 0, 0, 0
        # Wrap the data loader for XLA
        para_val_loader = xmp.MpSerialDataLoader(val_loader, device)
        with torch.no_grad():
            for imgs, labels in para_val_loader:
                imgs, labels = imgs.to(device), labels.to(device) # Move data to XLA device
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_loss_epoch += loss.item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        val_loss.append(val_loss_epoch/len(val_loader))
        val_acc.append(correct/total)

        # Save best model (only on master process)
        if xm.is_master_ordinal():
            if val_acc[-1] > best_val_acc:
                best_val_acc = val_acc[-1]
                best_weights = model.state_dict()
                torch.save(best_weights, os.path.join(RESULTS_DIR, 'weights', f'{model_name}_best.pt'))

        scheduler.step()
        if xm.is_master_ordinal():
             print(f"Epoch {epoch+1}: Train Acc={train_acc[-1]:.3f} Val Acc={val_acc[-1]:.3f}")

    # Plot (only on master process)
    if xm.is_master_ordinal():
        plt.figure(figsize=(10,4))
        plt.subplot(1,2,1)
        plt.plot(train_loss, label='Train Loss')
        plt.plot(val_loss, label='Val Loss')
        plt.legend(); plt.title('Loss')
        plt.subplot(1,2,2)
        plt.plot(train_acc, label='Train Acc')
        plt.plot(val_acc, label='Val Acc')
        plt.legend(); plt.title('Accuracy')
        plt.savefig(os.path.join(RESULTS_DIR, 'graphs', f'{model_name}_loss_acc.png'))
        plt.show()

    # Broadcast the best weights to all devices
    if best_weights is not None:
        best_weights = xm.mesh_reduce('best_weights', best_weights, torch.save)

    return best_weights


def evaluate_and_export_xla(model, data_loader, device, model_name, classes):
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    # Wrap the data loader for XLA
    para_data_loader = xmp.MpSerialDataLoader(data_loader, device)

    with torch.no_grad():
        for imgs, labels in para_data_loader:
            imgs, labels = imgs.to(device), labels.to(device) # Move data to XLA device
            outputs = model(imgs)
            probs = torch.softmax(outputs, dim=1)
            preds = outputs.argmax(dim=1)
            all_labels.extend(labels.cpu().numpy()) # Move back to CPU
            all_preds.extend(preds.cpu().numpy())   # Move back to CPU
            all_probs.extend(probs.cpu().numpy())   # Move back to CPU

    # Collect results from all cores and aggregate on master
    all_labels = np.concatenate(xm.mesh_reduce('eval_labels', np.array(all_labels), np.concatenate))
    all_preds = np.concatenate(xm.mesh_reduce('eval_preds', np.array(all_preds), np.concatenate))
    all_probs = np.concatenate(xm.mesh_reduce('eval_probs', np.array(all_probs), np.concatenate))


    if xm.is_master_ordinal():
        acc = accuracy_score(all_labels, all_preds)
        prec = precision_score(all_labels, all_preds, average='macro')
        recall = recall_score(all_labels, all_preds, average='macro')
        f1 = f1_score(all_labels, all_preds, average='macro')
        print(f"Accuracy={acc:.3f} Precision={prec:.3f} Recall={recall:.3f} F1={f1:.3f}")

        # Export probability matrix to Excel
        df = pd.DataFrame(all_probs, columns=classes)
        df.to_excel(os.path.join(RESULTS_DIR, f'{model_name}_prob_matrix.xlsx'), index=False)

        # Plot Confusion Matrix
        cm = confusion_matrix(all_labels, all_preds)
        plt.figure(figsize=(8,6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
        plt.xlabel('Predicted'); plt.ylabel('True'); plt.title('Confusion Matrix')
        plt.savefig(os.path.join(RESULTS_DIR, 'graphs', f'{model_name}_confusion.png'))
        plt.show()


# Main execution block with xmp.spawn
def _mp_fn(index, flags):
    device = xm.get_device()
    print(f"TPU is available. Using device: {device}")

    # Assuming imagenet_train_data, imagenet_val_data, and IMAGENET_CLASSES are defined in the global scope
    # Or passed via the flags dictionary

    # Initialize Vision Transformer model on the TPU device
    vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=len(IMAGENET_CLASSES)).to(device)

    # Create XLA data loaders
    imagenet_train_loader_xla = torch.utils.data.DataLoader(imagenet_train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    imagenet_val_loader_xla = torch.utils.data.DataLoader(imagenet_val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


    # Train the ViT model on ImageNet using the XLA-compatible training function
    best_vit_weights = train_model_xla(vit, imagenet_train_loader_xla, imagenet_val_loader_xla, device, num_epochs=1, model_name='vit_imagenet_xla') # Reduced epochs for testing

    # Load the best weights for evaluation (broadcasted from master)
    if best_vit_weights is not None:
        vit.load_state_dict(best_vit_weights)
    elif xm.is_master_ordinal(): # If training failed on master, try loading from saved file
         try:
             vit.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'vit_imagenet_xla_best.pt')))
         except FileNotFoundError:
             print("Best weights not found on master and broadcasting failed.")


    # Evaluate the ViT model on ImageNet using the XLA-compatible evaluation function
    evaluate_and_export_xla(vit, imagenet_val_loader_xla, device, 'vit_imagenet_xla', IMAGENET_CLASSES)

# The following block should be executed to run the training on TPU
# if __name__ == '__main__':
#     os.environ['XLA_USE_BF16'] = '1'
#     os.environ['XLA_TENSOR_ALLOC_MAXSIZE'] = '100000000'
#
#     # You would typically load your dataset here before spawning processes
#     # For example:
#     # IMAGENET_TRAIN_DIR = '/content/drive/MyDrive/path/to/imagenet/train'
#     # IMAGENET_VAL_DIR = '/content/drive/MyDrive/path/to/imagenet/val'
#     # imagenet_train_data = datasets.ImageFolder(IMAGENET_TRAIN_DIR, transform=transform)
#     # imagenet_val_data = datasets.ImageFolder(IMAGENET_VAL_DIR, transform=transform)
#     # IMAGENET_CLASSES = imagenet_train_data.classes
#     # BATCH_SIZE = 64 # Define your batch size
#     # RESULTS_DIR = '/content/drive/MyDrive/ConferenceProject/' # Define your results directory
#
#     xmp.spawn(_mp_fn, args=({},), nprocs=8) # Assuming 8 TPU cores are available


## Mixed precision training

### Subtask:
Implement mixed precision training using `bfloat16` for improved performance on TPUs.


**Reasoning**:
Implement mixed precision training using bfloat16 by adding `torch.cuda.amp.autocast(dtype=torch.bfloat16)` to the training and evaluation loops in the `train_model_xla` and `evaluate_and_export_xla` functions.



In [None]:
import torch_xla.amp as xamp # Import for XLA AMP

def train_model_xla(model, train_loader, val_loader, device, num_epochs=10, model_name='model'):
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    train_loss, val_loss = [], []
    train_acc, val_acc = [], []
    best_val_acc = 0.0
    best_weights = None

    scaler = xamp.GradScaler() # Initialize GradScaler for XLA AMP

    for epoch in range(num_epochs):
        model.train()
        epoch_loss, correct, total = 0, 0, 0
        para_train_loader = xmp.MpSerialDataLoader(train_loader, device)
        for imgs, labels in tqdm(para_train_loader):
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()

            with xamp.autocast(dtype=torch.bfloat16): # Enable bfloat16 autocasting
                outputs = model(imgs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward() # Scale the loss and perform backward pass
            scaler.step(optimizer) # Update optimizer
            scaler.update() # Update scaler

            epoch_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        train_loss.append(epoch_loss/len(train_loader))
        train_acc.append(correct/total)

        # Validation
        model.eval()
        val_loss_epoch, correct, total = 0, 0, 0
        para_val_loader = xmp.MpSerialDataLoader(val_loader, device)
        with torch.no_grad():
            with xamp.autocast(dtype=torch.bfloat16): # Enable bfloat16 autocasting for validation
                for imgs, labels in para_val_loader:
                    imgs, labels = imgs.to(device), labels.to(device)
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)

            val_loss_epoch += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        val_loss.append(val_loss_epoch/len(val_loader))
        val_acc.append(correct/total)

        if xm.is_master_ordinal():
            if val_acc[-1] > best_val_acc:
                best_val_acc = val_acc[-1]
                best_weights = model.state_dict()
                torch.save(best_weights, os.path.join(RESULTS_DIR, 'weights', f'{model_name}_best.pt'))

        scheduler.step()
        if xm.is_master_ordinal():
             print(f"Epoch {epoch+1}: Train Acc={train_acc[-1]:.3f} Val Acc={val_acc[-1]:.3f}")

    if xm.is_master_ordinal():
        plt.figure(figsize=(10,4))
        plt.subplot(1,2,1)
        plt.plot(train_loss, label='Train Loss')
        plt.plot(val_loss, label='Val Loss')
        plt.legend(); plt.title('Loss')
        plt.subplot(1,2,2)
        plt.plot(train_acc, label='Train Acc')
        plt.plot(val_acc, label='Val Acc')
        plt.legend(); plt.title('Accuracy')
        plt.savefig(os.path.join(RESULTS_DIR, 'graphs', f'{model_name}_loss_acc.png'))
        plt.show()

    if best_weights is not None:
        best_weights = xm.mesh_reduce('best_weights', best_weights, torch.save)

    return best_weights


def evaluate_and_export_xla(model, data_loader, device, model_name, classes):
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    para_data_loader = xmp.MpSerialDataLoader(data_loader, device)

    with torch.no_grad():
        with xamp.autocast(dtype=torch.bfloat16): # Enable bfloat16 autocasting for evaluation
            for imgs, labels in para_data_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                probs = torch.softmax(outputs, dim=1)
                preds = outputs.argmax(dim=1)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

    all_labels = np.concatenate(xm.mesh_reduce('eval_labels', np.array(all_labels), np.concatenate))
    all_preds = np.concatenate(xm.mesh_reduce('eval_preds', np.array(all_preds), np.concatenate))
    all_probs = np.concatenate(xm.mesh_reduce('eval_probs', np.array(all_probs), np.concatenate))


    if xm.is_master_ordinal():
        acc = accuracy_score(all_labels, all_preds)
        prec = precision_score(all_labels, all_preds, average='macro')
        recall = recall_score(all_labels, all_preds, average='macro')
        f1 = f1_score(all_labels, all_preds, average='macro')
        print(f"Accuracy={acc:.3f} Precision={prec:.3f} Recall={recall:.3f} F1={f1:.3f}")

        df = pd.DataFrame(all_probs, columns=classes)
        df.to_excel(os.path.join(RESULTS_DIR, f'{model_name}_prob_matrix.xlsx'), index=False)

        cm = confusion_matrix(all_labels, all_preds)
        plt.figure(figsize=(8,6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
        plt.xlabel('Predicted'); plt.ylabel('True'); plt.title('Confusion Matrix')
        plt.savefig(os.path.join(RESULTS_DIR, 'graphs', f'{model_name}_confusion.png'))
        plt.show()

# The following block should be executed to run the training on TPU
# def _mp_fn(index, flags):
#     device = xm.get_device()
#     print(f"TPU is available. Using device: {device}")
#
#     # Load dataset and define loaders and classes here
#     # ...
#
#     vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=len(IMAGENET_CLASSES)).to(device)
#
#     imagenet_train_loader_xla = torch.utils.data.DataLoader(imagenet_train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
#     imagenet_val_loader_xla = torch.utils.data.DataLoader(imagenet_val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
#
#     best_vit_weights = train_model_xla(vit, imagenet_train_loader_xla, imagenet_val_loader_xla, device, num_epochs=1, model_name='vit_imagenet_xla')
#
#     if best_vit_weights is not None:
#         vit.load_state_dict(best_vit_weights)
#     elif xm.is_master_ordinal():
#          try:
#              vit.load_state_dict(torch.load(os.path.join(RESULTS_DIR, 'weights', 'vit_imagenet_xla_best.pt')))
#          except FileNotFoundError:
#              print("Best weights not found on master and broadcasting failed.")
#
#     evaluate_and_export_xla(vit, imagenet_val_loader_xla, device, 'vit_imagenet_xla', IMAGENET_CLASSES)
#
# # if __name__ == '__main__':
# #     os.environ['XLA_USE_BF16'] = '1'
# #     os.environ['XLA_TENSOR_ALLOC_MAXSIZE'] = '100000000'
# #     xmp.spawn(_mp_fn, args=({},), nprocs=8)