In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast
import uuid

In [3]:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

print(f"Number of TPU cores available: {xm.xrt_world_size()}")
print(f"Current TPU ordinal: {xm.get_ordinal()}")

E0000 00:00:1747417818.089569    2990 common_lib.cc:621] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:232


Number of TPU cores available: 1
Current TPU ordinal: 0


In [4]:
# TPU-specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

In [5]:
# Data Loading and Preprocessing
base_path = '/kaggle/input/data'
image_dirs = [f'images_{str(i).zfill(3)}' for i in range(1, 13)]
data_entry_csv_path = os.path.join(base_path, 'Data_Entry_2017.csv')

# Load metadata
data_entry_df = pd.read_csv(data_entry_csv_path)
data_entry_df = data_entry_df.loc[:, ~data_entry_df.columns.str.contains('^Unnamed')]
print(f"Total data entries: {len(data_entry_df)}")

class SSLCustomDataset(Dataset):
    def __init__(self, df, image_dirs, base_path, target_size=(224, 224)):
        self.df = df
        self.image_dirs = image_dirs
        self.base_path = base_path
        self.target_size = target_size
        
        self.image_path_map = self._build_image_path_map()
        self.valid_indices = [i for i in range(len(df)) if self.image_path_map.get(df.iloc[i]['Image Index'])]
        print(f"Filtered dataset to {len(self.valid_indices)} valid images")

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomResizedCrop(target_size, scale=(0.5, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.05)
            ], p=0.5),
            transforms.RandomGrayscale(p=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
    def _build_image_path_map(self):
        path_map = {}
        for dir_name in self.image_dirs:
            dir_path = os.path.join(self.base_path, dir_name, 'images')
            if os.path.exists(dir_path):
                for img_file in os.listdir(dir_path):
                    path_map[img_file] = os.path.join(dir_path, img_file)
        return path_map
    
    def __len__(self):
        return len(self.valid_indices)
    

    def __getitem__(self, idx):
        idx = self.valid_indices[idx]
        row = self.df.iloc[idx]
        img_path = self.image_path_map[row['Image Index']]
        
        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f"Failed to load image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        img1 = self.transform(img)
        img2 = self.transform(img)
        
        return img1, img2


Total data entries: 112120


In [6]:
# Model Definitions
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.resnet = models.resnet18(weights='IMAGENET1K_V1')
        self.resnet.fc = nn.Identity()

    def forward(self, x):
        return self.resnet(x)

class MLPHead(nn.Module):
    def __init__(self, in_dim=512, hidden_dim=512, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim, bias=False),
            nn.BatchNorm1d(out_dim, affine=False)
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        z = self.net(x)
        return F.normalize(z, dim=1)
        
class SSLModel(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.projector = MLPHead()

    def forward(self, x1, x2):
        h1 = self.encoder(x1)
        h2 = self.encoder(x2)
        z1 = self.projector(h1)
        z2 = self.projector(h2)
        return z1, z2
        
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.1):  
        super().__init__()
        self.temperature = temperature
        
    def forward(self, z1, z2):
        batch_size = z1.size(0)
        features = torch.cat([z1, z2], dim=0)
        similarity = torch.mm(features, features.t()) / self.temperature
        
        pos_sim = torch.diag(similarity, batch_size)
        if torch.rand(1).item() < 0.01:
            xm.master_print(f"Positive similarity: {pos_sim.mean().item():.4f}")
        
        mask = torch.eye(2 * batch_size, device=similarity.device)
        mask = 1 - mask
        similarity = similarity * mask
        
        labels = torch.zeros(2 * batch_size, device=similarity.device, dtype=torch.int64)
        labels[0:batch_size] = torch.arange(batch_size, 2*batch_size)
        labels[batch_size:2*batch_size] = torch.arange(0, batch_size)
        
        loss = F.cross_entropy(similarity, labels)
        return loss

In [13]:
# Training function adapted for TPU
def train():
    device = xm.xla_device()
    # Create dataset and dataloader
    batch_size = 128 * xm.xrt_world_size()  # Scale batch size with number of cores
    dataset = SSLCustomDataset(data_entry_df, image_dirs, base_path)
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=4,
        drop_last=True)
    
    # Wrap with parallel loader
    device_loader = pl.MpDeviceLoader(dataloader, device)

    # Instantiate model
    model = SSLModel(Encoder()).to(device)
    criterion = NTXentLoss(temperature=0.1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-3 * xm.xrt_world_size())  # Scale learning rate
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)
    scaler = GradScaler()

    num_epochs = 5
    accum_steps = 16
    print_freq = 200
    
    for epoch in range(num_epochs):
        model.train()
        sampler.set_epoch(epoch)
        total_loss = 0
        batch_count = 0
        
        for batch_idx, (img1, img2) in enumerate(tqdm(device_loader, 
                                                    desc=f"Epoch {epoch + 1}/{num_epochs}",
                                                    disable=not xm.is_master_ordinal())):
            with autocast():
                z1, z2 = model(img1, img2)
                loss = criterion(z1, z2) / accum_steps

            scaler.scale(loss).backward()
            
            if (batch_idx + 1) % print_freq == 0 and xm.is_master_ordinal():
                # Removed cosine similarity computation and printing
                xm.master_print(f"Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}, Loss: {loss.item() * accum_steps:.4f}")
            
            if (batch_idx + 1) % accum_steps == 0 or (batch_idx + 1) == len(dataloader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                xm.mark_step()  # Important for TPU execution

            total_loss += loss.item() * accum_steps
            batch_count += 1

        scheduler.step()
        avg_loss = total_loss / batch_count
        xm.master_print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

        if xm.is_master_ordinal():
            torch.save(model.state_dict(), f"ssl_model_epoch_{epoch + 1}.pth")

    if xm.is_master_ordinal():
        final_save_path = 'ssl_model_final.pth'
        torch.save({
            'model_state_dict': model.state_dict(),
            'encoder_architecture': 'resnet18',
            'projector_dims': [512, 512, 128],
        }, final_save_path)
        print(f"Training complete! Model saved to {final_save_path}")


In [14]:
# Start training
def _mp_fn(rank, flags):
    # Initialize TPU inside the spawned process
    device = xm.xla_device()
    print(f"TPU {rank}: Device {device} initialized")
    train()
# Launch TPU training
FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs= 1, start_method='fork')


TPU 0: Device xla:0 initialized
Filtered dataset to 112120 valid images


Epoch 1/5:   2%|▏         | 14/875 [00:20<08:35,  1.67it/s] 

Positive similarity: 3.2744


Epoch 1/5:   3%|▎         | 26/875 [00:38<27:17,  1.93s/it]

Positive similarity: 7.7182


Epoch 1/5:   9%|▉         | 80/875 [01:35<06:13,  2.13it/s]

Positive similarity: 8.7845


Epoch 1/5:  10%|▉         | 84/875 [01:43<13:56,  1.06s/it]

Positive similarity: 8.9183


Epoch 1/5:  11%|█         | 93/875 [01:53<10:03,  1.30it/s]

Positive similarity: 8.7356


Epoch 1/5:  23%|██▎       | 199/875 [03:55<05:05,  2.21it/s]

Epoch 1/5, Batch 200, Loss: 0.4092


Epoch 1/5:  24%|██▍       | 213/875 [04:13<08:39,  1.27it/s]

Positive similarity: 9.1403


Epoch 1/5:  35%|███▍      | 302/875 [06:11<06:46,  1.41it/s]

Positive similarity: 9.1510


Epoch 1/5:  46%|████▌     | 399/875 [08:25<04:19,  1.83it/s]

Epoch 1/5, Batch 400, Loss: 0.3014


Epoch 1/5:  56%|█████▋    | 494/875 [10:38<04:19,  1.47it/s]

Positive similarity: 9.2466


Epoch 1/5:  68%|██████▊   | 599/875 [13:02<02:23,  1.93it/s]

Epoch 1/5, Batch 600, Loss: 0.1250


Epoch 1/5:  83%|████████▎ | 728/875 [15:58<01:03,  2.33it/s]

Positive similarity: 9.3909


Epoch 1/5:  85%|████████▍ | 740/875 [16:19<02:46,  1.24s/it]

Positive similarity: 9.3665


Epoch 1/5:  85%|████████▍ | 742/875 [16:19<01:33,  1.42it/s]

Positive similarity: 9.4453


Epoch 1/5:  91%|█████████▏| 799/875 [17:35<00:39,  1.92it/s]

Epoch 1/5, Batch 800, Loss: 0.0886


Epoch 1/5: 100%|██████████| 875/875 [19:14<00:00,  1.32s/it]


Epoch 1/5, Average Loss: 0.4438


Epoch 2/5:  23%|██▎       | 199/875 [03:48<04:55,  2.29it/s]

Epoch 2/5, Batch 200, Loss: 0.0774


Epoch 2/5:  23%|██▎       | 203/875 [03:56<15:06,  1.35s/it]

Positive similarity: 9.4385


Epoch 2/5:  27%|██▋       | 233/875 [04:32<29:25,  2.75s/it]

Positive similarity: 9.4773


Epoch 2/5:  34%|███▎      | 295/875 [05:39<04:27,  2.17it/s]

Positive similarity: 9.5065


Epoch 2/5:  39%|███▉      | 345/875 [06:43<24:12,  2.74s/it]

Positive similarity: 9.4292


Epoch 2/5:  46%|████▌     | 399/875 [07:39<03:39,  2.17it/s]

Epoch 2/5, Batch 400, Loss: 0.0673


Epoch 2/5:  50%|█████     | 441/875 [08:32<18:52,  2.61s/it]

Positive similarity: 9.4498


Epoch 2/5:  55%|█████▍    | 478/875 [09:11<03:54,  1.69it/s]

Positive similarity: 9.4767


Epoch 2/5:  55%|█████▍    | 479/875 [09:11<03:09,  2.09it/s]

Positive similarity: 9.4630


Epoch 2/5:  64%|██████▍   | 559/875 [10:42<02:19,  2.27it/s]

Positive similarity: 9.4101


Epoch 2/5:  67%|██████▋   | 587/875 [11:17<06:52,  1.43s/it]

Positive similarity: 9.4942


Epoch 2/5:  68%|██████▊   | 599/875 [11:27<02:07,  2.16it/s]

Epoch 2/5, Batch 600, Loss: 0.0715


Epoch 2/5:  88%|████████▊ | 771/875 [14:48<02:23,  1.38s/it]

Positive similarity: 9.4690


Epoch 2/5:  91%|█████████▏| 799/875 [15:16<00:34,  2.23it/s]

Epoch 2/5, Batch 800, Loss: 0.0607


Epoch 2/5:  94%|█████████▎| 819/875 [15:43<01:14,  1.33s/it]

Positive similarity: 9.4693


Epoch 2/5: 100%|██████████| 875/875 [16:41<00:00,  1.15s/it]


Epoch 2/5, Average Loss: 0.0738


Epoch 3/5:  16%|█▌        | 137/875 [02:46<31:26,  2.56s/it]

Positive similarity: 9.5516


Epoch 3/5:  19%|█▊        | 163/875 [03:13<16:25,  1.38s/it]

Positive similarity: 9.4889


Epoch 3/5:  23%|██▎       | 199/875 [03:50<05:12,  2.17it/s]

Epoch 3/5, Batch 200, Loss: 0.0598


Epoch 3/5:  23%|██▎       | 205/875 [03:59<08:36,  1.30it/s]

Positive similarity: 9.4556


Epoch 3/5:  25%|██▌       | 223/875 [04:18<04:57,  2.19it/s]

Positive similarity: 9.4853


Epoch 3/5:  31%|███       | 268/875 [05:13<10:20,  1.02s/it]

Positive similarity: 9.5762


Epoch 3/5:  46%|████▌     | 399/875 [07:39<03:36,  2.20it/s]

Epoch 3/5, Batch 400, Loss: 0.0539


Epoch 3/5:  68%|██████▊   | 599/875 [11:28<02:07,  2.17it/s]

Epoch 3/5, Batch 600, Loss: 0.0556


Epoch 3/5:  85%|████████▍ | 743/875 [14:14<00:59,  2.20it/s]

Positive similarity: 9.6012


Epoch 3/5:  91%|█████████▏| 799/875 [15:18<00:34,  2.19it/s]

Epoch 3/5, Batch 800, Loss: 0.0500


Epoch 3/5: 100%|██████████| 875/875 [16:45<00:00,  1.15s/it]


Epoch 3/5, Average Loss: 0.0540


Epoch 4/5:   7%|▋         | 61/875 [01:14<10:03,  1.35it/s] 

Positive similarity: 9.5679


Epoch 4/5:  23%|██▎       | 199/875 [03:48<05:08,  2.19it/s]

Epoch 4/5, Batch 200, Loss: 0.0475


Epoch 4/5:  30%|██▉       | 262/875 [05:01<05:46,  1.77it/s]

Positive similarity: 9.4946


Epoch 4/5:  43%|████▎     | 376/875 [07:10<03:03,  2.72it/s]

Positive similarity: 9.5175


Epoch 4/5:  46%|████▌     | 399/875 [07:37<03:40,  2.16it/s]

Positive similarity: 9.5629
Epoch 4/5, Batch 400, Loss: 0.0409


Epoch 4/5:  58%|█████▊    | 511/875 [09:47<02:48,  2.16it/s]

Positive similarity: 9.5960


Epoch 4/5:  68%|██████▊   | 591/875 [11:17<02:09,  2.19it/s]

Positive similarity: 9.5614


Epoch 4/5:  68%|██████▊   | 599/875 [11:26<02:05,  2.19it/s]

Epoch 4/5, Batch 600, Loss: 0.0435


Epoch 4/5:  74%|███████▎  | 644/875 [12:20<03:52,  1.01s/it]

Positive similarity: 9.5426


Epoch 4/5:  89%|████████▉ | 780/875 [14:52<01:36,  1.02s/it]

Positive similarity: 9.5514


Epoch 4/5:  91%|█████████▏| 799/875 [15:11<00:35,  2.17it/s]

Epoch 4/5, Batch 800, Loss: 0.0442


Epoch 4/5:  94%|█████████▍| 822/875 [15:39<00:31,  1.68it/s]

Positive similarity: 9.5485


Epoch 4/5:  96%|█████████▌| 840/875 [15:59<00:13,  2.60it/s]

Positive similarity: 9.5262


Epoch 4/5:  97%|█████████▋| 851/875 [16:17<00:32,  1.37s/it]

Positive similarity: 9.5741


Epoch 4/5: 100%|██████████| 875/875 [16:39<00:00,  1.14s/it]


Epoch 4/5, Average Loss: 0.0456


Epoch 5/5:   3%|▎         | 28/875 [00:39<14:57,  1.06s/it] 

Positive similarity: 9.5745


Epoch 5/5:  23%|██▎       | 199/875 [03:53<05:08,  2.19it/s]

Epoch 5/5, Batch 200, Loss: 0.0453


Epoch 5/5:  23%|██▎       | 201/875 [04:02<30:01,  2.67s/it]

Positive similarity: 9.5783


Epoch 5/5:  32%|███▏      | 276/875 [05:26<10:11,  1.02s/it]

Positive similarity: 9.5925


Epoch 5/5:  46%|████▌     | 399/875 [07:44<03:39,  2.17it/s]

Epoch 5/5, Batch 400, Loss: 0.0460


Epoch 5/5:  68%|██████▊   | 599/875 [11:34<02:06,  2.18it/s]

Epoch 5/5, Batch 600, Loss: 0.0378


Epoch 5/5:  72%|███████▏  | 628/875 [12:11<04:14,  1.03s/it]

Positive similarity: 9.5025


Epoch 5/5:  91%|█████████▏| 799/875 [15:26<00:35,  2.15it/s]

Epoch 5/5, Batch 800, Loss: 0.0416


Epoch 5/5: 100%|██████████| 875/875 [16:52<00:00,  1.16s/it]


Epoch 5/5, Average Loss: 0.0427
Training complete! Model saved to ssl_model_final.pth
