In [None]:
!pip install -q torch_xla[tpu] -f https://storage.googleapis.com/libtorch-xla-releases/wheels/tpuvm/colab.html

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# from sklearn.cluster import KMeans
from tqdm import tqdm
from torch.utils.data import Subset

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

"""Dataset part"""

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=10000, shuffle=True)
val_loader = DataLoader(test_dataset, batch_size=20, shuffle=False)


mean_image = 0.0
total_samples = 0
denom = 0.0

for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, desc="Loading CIFAR-10")):
    batch_samples = inputs.size(0)
    mean_image += inputs.sum(dim=0)  # sum over batch dimension -> shape (C, H, W)
    total_samples += batch_samples

mean_image /= total_samples
mu_flat = mean_image.view(1, -1)

for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, desc="Computing denominator")):
    # inputs: shape (B, C, H, W)
    batch_flat = inputs.view(inputs.size(0), -1)

    # Compute squared distance to mean for each sample in batch
    dists_squared = ((batch_flat - mu_flat) ** 2).sum(dim=1)

    # Sum
    denom += dists_squared.sum().item()

q_values = torch.empty(total_samples) # Our q(x)
start_idx = 0

for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, desc="Computing q(x) for all dataset indices")):
    batch_size = inputs.size(0)
    end_idx = start_idx + batch_size

    # Compute squared distances to the mean
    batch_flat = inputs.view(batch_size, -1)
    dists_squared = ((batch_flat - mu_flat) ** 2).sum(dim=1)

    # Compute q(x)
    q_batch = 0.5 * (1 / total_samples) + 0.5 * (dists_squared / denom)
    q_values[start_idx:end_idx] = q_batch

    start_idx = end_idx

Loading CIFAR-10: 100%|██████████| 5/5 [00:09<00:00,  1.99s/it]
Computing denominator: 100%|██████████| 5/5 [00:10<00:00,  2.13s/it]
Computing q(x) for all dataset indices: 100%|██████████| 5/5 [00:10<00:00,  2.12s/it]


In [None]:
# 1 / q(x)
sampling_probs = (1.0 / q_values)
sampling_probs /= sampling_probs.sum()  # normalize to sum to 1

m = 20000  # TODO Use the general way later
sample_indices = torch.multinomial(sampling_probs, num_samples=m, replacement=False)

coreset = Subset(train_dataset, sample_indices.tolist())
coreset_loader = torch.utils.data.DataLoader(coreset, batch_size=2048, shuffle=False)

In [None]:
"""Training model part"""

# Use MPS if available (for Macs), otherwise fallback
# device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
device = xm.xla_device()
print(f"Using device: {device}")

# Load ResNet18

model = resnet18(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=300)

# Training loop
def train(model, train_loader, epochs=300):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        para_loader = pl.MpDeviceLoader(train_loader, device)
        loop = tqdm(para_loader, desc=f"Epoch [{epoch+1}/{epochs}]")

        for inputs, targets in loop:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            xm.optimizer_step(optimizer)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            loop.set_postfix(loss=running_loss/(total/inputs.size(0)), acc=100.*correct/total)

        scheduler.step()
    return model

# Validation loop
def validate(model, val_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        para_loader = pl.MpDeviceLoader(val_loader, device)
        for inputs, targets in para_loader:
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = 100. * correct / total
    print(f"Validation Accuracy: {acc:.2f}%")

# Main entry
model = train(model, coreset_loader, epochs=35)
validate(model, val_loader)

Using device: xla:0


Epoch [1/35]: 100%|██████████| 10/10 [01:02<00:00,  6.27s/it, acc=20.4, loss=1.71]
Epoch [2/35]: 100%|██████████| 10/10 [00:04<00:00,  2.42it/s, acc=36.4, loss=1.37]
Epoch [3/35]: 100%|██████████| 10/10 [00:04<00:00,  2.19it/s, acc=45.3, loss=1.19]
Epoch [4/35]: 100%|██████████| 10/10 [00:04<00:00,  2.38it/s, acc=51.5, loss=1.05]
Epoch [5/35]: 100%|██████████| 10/10 [00:04<00:00,  2.31it/s, acc=58.2, loss=0.92]
Epoch [6/35]: 100%|██████████| 10/10 [00:04<00:00,  2.38it/s, acc=64, loss=0.805]
Epoch [7/35]: 100%|██████████| 10/10 [00:04<00:00,  2.35it/s, acc=69.3, loss=0.701]
Epoch [8/35]: 100%|██████████| 10/10 [00:04<00:00,  2.20it/s, acc=71.7, loss=0.653]
Epoch [9/35]: 100%|██████████| 10/10 [00:04<00:00,  2.35it/s, acc=73.5, loss=0.603]
Epoch [10/35]: 100%|██████████| 10/10 [00:04<00:00,  2.39it/s, acc=78.2, loss=0.508]
Epoch [11/35]: 100%|██████████| 10/10 [00:04<00:00,  2.33it/s, acc=83.8, loss=0.397]
Epoch [12/35]: 100%|██████████| 10/10 [00:04<00:00,  2.37it/s, acc=87.4, loss=0.3

Validation Accuracy: 52.75%
