<a href="https://colab.research.google.com/github/Shaobin675/Path_in_ML_model_training/blob/main/102_introduce_DDP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
'''
data_factory.py
a modular architecture. one module handles the data factory,
another the model's brain, and a third coordinates the distributed workers.
'''

def get_data_loaders(dataset, batch_size, world_size, rank):
    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=8,       # High number for fast augmentation
        pin_memory=True,     # Necessary for fast GPU transfer
        prefetch_factor=2    # Pre-loads next batches while GPU works
    )
    return loader, sampler

In [None]:
import torch
from torch.utils.data import Dataset
import cv2

class XRayDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # Load image (OpenCV is C++ optimized)
        img = cv2.imread(self.paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Apply high-concurrency augmentation
        if self.transform:
            augmented = self.transform(image=img)
            img = augmented['image']

        return img # Returns a Tensor [3, 224, 224]

In [1]:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
"""
trainer.py
Provides a DDPTrainer wrapper that encapsulates PyTorch DistributedDataParallel (DDP)
initialization and epoch execution logic for multi-GPU training.
"""
class DDPTrainer:
    """
    Wraps a model with DistributedDataParallel (DDP) and manages distributed epoch execution.
    Assumes torch.distributed.init_process_group() has been called prior to instantiation.
    """
    def __init__(self, model, gpu_id, lr=1e-4):
        self.gpu_id = gpu_id
        self.model = DDP(model.to(gpu_id), device_ids=[gpu_id])

        # uses AdamW for stability
        self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)
        self.criterion = nn.CosineEmbeddingLoss() # Common for learning embeddings

    def train_step(self, images):
        '''
        Update the model's weights to get smarter.
        '''
        self.optimizer.zero_grad()

        # Forward pass
        embeddings = self.model(images)

        # we might compare an image to its augmented version.
        # Here we just show the mechanical backward pass:
        loss = embeddings.sum() # Simplified placeholder for the loss math

        loss.backward()         # DDP syncs gradients here
        self.optimizer.step()
        return loss.item()

    def save_checkpoint(self, epoch, path):
        # Only save on the 'Master' node (Rank 0) to avoid file corruption
        if self.gpu_id == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.module.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
            }, f"{path}/checkpoint_{epoch}.pt")

    def generate_embeddings(self, loader):
        '''
        Use the already learned weights to describe an image.
        '''
        self.model.eval()
        all_vecs = []

        with torch.no_grad():
            for batch in loader:
                # Move augmented images to GPU
                batch = batch.to(self.gpu_id, non_blocking=True)

                # The 'Embedding' is the output of our model
                embeddings = self.model(batch)

                # Move back to CPU for storage
                all_vecs.append(embeddings.cpu())

        return torch.cat(all_vecs)


In [None]:
import torch.nn as nn
import torchvision.models as models
#model.py
class XRayEncoder(nn.Module):
    def __init__(self, model_name='resnet50'):
        super().__init__()
        # Flexible backbone selection
        if model_name == 'resnet50':
            base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

        # Remove the classification head
        self.feature_extractor = nn.Sequential(*list(base.children())[:-1])

    def forward(self, x):
        features = self.feature_extractor(x)
        return torch.flatten(features, 1) # Returns the 1D Embedding

In [None]:
import os
import glob
import torch.distributed as dist
from model import XRayEncoder
from trainer import DDPTrainer
from data_factory import get_data_loaders
from augment import XRayAugmenter
from data_loader import XRayDataset

def main():
    # 1. Initialize Distributed Backend
    dist.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    # 2. Prepare Data (Google-scale: thousands of paths)
    all_image_paths = glob.glob("/path/to/big_data/*.jpg")
    augmenter = XRayAugmenter(img_size=224)
    dataset = XRayDataset(all_image_paths, transform=augmenter.train_transform)

    loader, sampler = get_data_loaders(
        dataset,
        batch_size=64,
        world_size=world_size,
        rank=gpu_id
    )

    # 3. Initialize Model and Trainer
    model = XRayEncoder()
    trainer = DDPTrainer(model, gpu_id)

    # 4. Training Loop
    for epoch in range(10):
        sampler.set_epoch(epoch) # Critical for shuffling in DDP
        for batch_idx, images in enumerate(loader):
            images = images.to(gpu_id, non_blocking=True)
            loss = trainer.train_step(images)

            if gpu_id == 0 and batch_idx % 10 == 0:
                print(f"Epoch {epoch} | Batch {batch_idx} | Loss: {loss:.4f}")

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

In [None]:
import torch
#config.py
class Config:
    # üåç Distributed Settings
    WORLD_SIZE = torch.cuda.device_count()
    BACKEND = "nccl"

    # üß¨ Hyperparameters
    BATCH_SIZE = 64
    IMG_SIZE = 224
    LEARNING_RATE = 1e-4
    EPOCHS = 50

    # üìÇ Paths
    DATA_DIR = "./data/xrays/"
    CHECKPOINT_DIR = "./checkpoints/"