# Advanced augmentation (Mixup & CutMix)

In [None]:
# --- Mixup and CutMix Functions ---
def mixup_data(x, y, alpha=1.0):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def cutmix_data(x, y, alpha=1.0):
    '''Returns cutmixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    shuffled_x = x[index, :]
    y_a, y_b = y, y[index]

    # Generate cutmix mask
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = shuffled_x[:, :, bbx1:bbx2, bby1:bby2]
    
    # Adjust lambda to match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, y_a, y_b, lam

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

# Train loop

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, device, p=0.5, mixup_alpha=1.0, cutmix_alpha=1.0, mixup_ratio=0.5, verbose=0):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0
    if verbose == 1:
        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    else:
        progress_bar = enumerate(train_loader)
    for batch_idx, batch in progress_bar:
        data, target = batch

        # Move data to device
        data = data.to(device)
        target = target.to(device)

        # --- Decide Whether to Apply Mix Augmentation ---
        if np.random.rand() < p:
            # --- Decide Whether to Use Mixup or CutMix ---
            if np.random.rand() < (1-mixup_ratio):
                # Apply CutMix
                data, targets_a, targets_b, lam = cutmix_data(data, target, cutmix_alpha)
            else:
                # Apply Mixup
                data, targets_a, targets_b, lam = mixup_data(data, target, mixup_alpha)

            # Forward pass (same for both Mixup and CutMix)
            outputs = model(data)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Calculate metrics (for mixed targets)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_predictions += target.size(0)
            correct_predictions += lam * (predicted == targets_a).sum().item() + (1 - lam) * (predicted == targets_b).sum().item()
        else:
            # --- Regular Training (No Mix Augmentation) ---
            # Reset gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(data)
            loss = criterion(outputs, target)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Calculate metrics (for regular targets)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_predictions += target.size(0)
            correct_predictions += (predicted == target).sum().item()

        if verbose == 1:
            # Update progress bar
            progress_bar.set_description(f"Train Loss: {loss.item():.4f}")

    # Calculate epoch metrics
    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = correct_predictions / total_predictions

    return epoch_loss, epoch_accuracy