In [None]:
from google.colab import drive

# This will prompt for authorization.
drive.mount('/content/drive')


# !ls "/content/drive/My Drive"

Mounted at /content/drive


In [None]:
import os

TRAIN_DIR = "/content/drive/MyDrive/train"
VAL_DIR   = "/content/drive/MyDrive/validation"

train_files = sorted([os.path.join(TRAIN_DIR, f) for f in os.listdir(TRAIN_DIR) if f.endswith(".npz")])
val_files   = sorted([os.path.join(VAL_DIR, f)   for f in os.listdir(VAL_DIR)   if f.endswith(".npz")])

print(len(train_files), len(val_files))


200 20


In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset

class SeismicPatchDataset(Dataset):
    def __init__(self, files, patch_size=64, fault_ratio=0.5):
        self.files = files
        self.ps = patch_size
        self.fr = fault_ratio

    def __len__(self):
        return len(self.files) * 20

    def __getitem__(self, idx):
        file = self.files[idx % len(self.files)]
        data = np.load(file)

        seis = data["seis"]
        fault = data["fault"]

        fault = (fault > 0.5).astype(np.float32)

        ps = self.ps
        margin = ps // 2

        D, H, W = seis.shape # Get actual dimensions of the seismic data

        # Calculate safe lower and upper bounds for random center selection
        # These bounds ensure that `center - margin` >= 0 and `center + margin <= Dimension`
        # If a dimension is too small for a full patch, adjust bounds to prevent ValueError in randint.
        lower_bound_d, upper_bound_d = margin, D - margin
        if upper_bound_d <= lower_bound_d: # If dimension is too small for a full patch
            lower_bound_d = 0
            upper_bound_d = D # Allow selection anywhere, patch will be padded later

        lower_bound_h, upper_bound_h = margin, H - margin
        if upper_bound_h <= lower_bound_h:
            lower_bound_h = 0
            upper_bound_h = H

        lower_bound_w, upper_bound_w = margin, W - margin
        if upper_bound_w <= lower_bound_w:
            lower_bound_w = 0
            upper_bound_w = W

        # --- choose patch center ---
        cx, cy, cz = 0, 0, 0 # Initialize, will be overwritten

        if np.random.rand() < self.fr and fault.sum() > 0:
            # fault-biased sampling
            fault_idx = np.argwhere(fault == 1)

            # Filter fault_idx to only include points from which a full patch can be extracted
            valid_fault_idx = []
            for fx, fy, fz in fault_idx:
                if (fx >= margin and fx < D - margin and
                    fy >= margin and fy < H - margin and
                    fz >= margin and fz < W - margin):
                    valid_fault_idx.append((fx, fy, fz))

            if valid_fault_idx:
                cx, cy, cz = valid_fault_idx[np.random.randint(len(valid_fault_idx))]
            else:
                # Fallback to random background if no valid fault_idx within bounds
                if upper_bound_d <= lower_bound_d: cx = D // 2
                else: cx = np.random.randint(lower_bound_d, upper_bound_d)

                if upper_bound_h <= lower_bound_h: cy = H // 2
                else: cy = np.random.randint(lower_bound_h, upper_bound_h)

                if upper_bound_w <= lower_bound_w: cz = W // 2
                else: cz = np.random.randint(lower_bound_w, upper_bound_w)
        else:
            # random background
            if upper_bound_d <= lower_bound_d: cx = D // 2
            else: cx = np.random.randint(lower_bound_d, upper_bound_d)

            if upper_bound_h <= lower_bound_h: cy = H // 2
            else: cy = np.random.randint(lower_bound_h, upper_bound_h)

            if upper_bound_w <= lower_bound_w: cz = W // 2
            else: cz = np.random.randint(lower_bound_w, upper_bound_w)

        # --- extract patch ---
        # Get actual slice boundaries. These might result in smaller patches if the image is small
        slice_d_start = max(0, cx - margin)
        slice_d_end   = min(D, cx + margin)
        slice_h_start = max(0, cy - margin)
        slice_h_end   = min(H, cy + margin)
        slice_w_start = max(0, cz - margin)
        slice_w_end   = min(W, cz + margin)

        seis_patch = seis[slice_d_start:slice_d_end, slice_h_start:slice_h_end, slice_w_start:slice_w_end]
        fault_patch = fault[slice_d_start:slice_d_end, slice_h_start:slice_h_end, slice_w_start:slice_w_end]

        # Pad if the extracted patch is smaller than the target patch_size
        if seis_patch.shape != (ps, ps, ps):
            pad_d = ps - seis_patch.shape[0]
            pad_h = ps - seis_patch.shape[1]
            pad_w = ps - seis_patch.shape[2]

            # Pad with zeros to ensure output patch size is ps x ps x ps
            seis_patch = np.pad(seis_patch, ((0, pad_d), (0, pad_h), (0, pad_w)), mode='constant')
            fault_patch = np.pad(fault_patch, ((0, pad_d), (0, pad_h), (0, pad_w)), mode='constant')

        # --- torch tensors ---
        seis_patch = torch.tensor(seis_patch, dtype=torch.float32).unsqueeze(0)
        fault_patch = torch.tensor(fault_patch, dtype=torch.float32).unsqueeze(0)

        return seis_patch, fault_patch

In [None]:
from torch.utils.data import DataLoader

train_ds = SeismicPatchDataset(train_files, patch_size=64)
val_ds   = SeismicPatchDataset(val_files,   patch_size=64, fault_ratio=0.5)

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=1, shuffle=False, num_workers=0)


In [None]:
# x, y = next(iter(train_loader))
# print(x.shape, y.shape)
# print("Fault voxels:", y.sum().item())


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)


In [None]:
class UNet3D(nn.Module):
    def __init__(self, base_ch=16):
        super().__init__()

        self.enc1 = DoubleConv(1, base_ch)
        self.enc2 = DoubleConv(base_ch, base_ch*2)
        self.enc3 = DoubleConv(base_ch*2, base_ch*4)
        self.enc4 = DoubleConv(base_ch*4, base_ch*8)

        self.pool = nn.MaxPool3d(2)

        self.bottleneck = DoubleConv(base_ch*8, base_ch*16)

        self.up4 = nn.ConvTranspose3d(base_ch*16, base_ch*8, 2, stride=2)
        self.dec4 = DoubleConv(base_ch*16, base_ch*8)

        self.up3 = nn.ConvTranspose3d(base_ch*8, base_ch*4, 2, stride=2)
        self.dec3 = DoubleConv(base_ch*8, base_ch*4)

        self.up2 = nn.ConvTranspose3d(base_ch*4, base_ch*2, 2, stride=2)
        self.dec2 = DoubleConv(base_ch*4, base_ch*2)

        self.up1 = nn.ConvTranspose3d(base_ch*2, base_ch, 2, stride=2)
        self.dec1 = DoubleConv(base_ch*2, base_ch)

        self.out = nn.Conv3d(base_ch, 1, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))

        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))

        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))

        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.out(d1)


In [None]:
import torch.nn.functional as F

def dice_loss(pred, target, eps=1e-6):
    pred = torch.sigmoid(pred)
    num = 2 * (pred * target).sum()
    den = pred.sum() + target.sum() + eps
    return 1 - num / den

def combined_loss(pred, target):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    dice = dice_loss(pred, target)
    return bce + dice

In [None]:
import torch
from torch.amp import autocast, GradScaler

device = "cuda"

model = UNet3D(base_ch=16).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

scaler = GradScaler("cuda")


In [None]:
def train_one_epoch(model, loader):
    model.train()
    total_loss = 0

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        with autocast("cuda"):
            pred = model(x)
            loss = combined_loss(pred, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    return total_loss / len(loader)


In [None]:
@torch.no_grad()
def validate(model, loader):
    model.eval()
    total_loss = 0

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        pred = model(x)
        loss = combined_loss(pred, y)
        total_loss += loss.item()

    return total_loss / len(loader)


In [None]:
epochs = 15

for epoch in range(epochs):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = validate(model, val_loader)

    print(f"Epoch {epoch+1:03d} | Train: {train_loss:.4f} | Val: {val_loss:.4f}")


Epoch 001 | Train: 0.9923 | Val: 0.6356
Epoch 002 | Train: 0.4955 | Val: 0.4833
Epoch 003 | Train: 0.4200 | Val: 0.4574
Epoch 004 | Train: 0.3959 | Val: 0.4498
Epoch 005 | Train: 0.3785 | Val: 0.4364
Epoch 006 | Train: 0.3650 | Val: 0.4224
Epoch 007 | Train: 0.3549 | Val: 0.4293
Epoch 008 | Train: 0.3447 | Val: 0.4288
Epoch 009 | Train: 0.3389 | Val: 0.4230
Epoch 010 | Train: 0.3316 | Val: 0.4123
Epoch 011 | Train: 0.3239 | Val: 0.4106
Epoch 012 | Train: 0.3203 | Val: 0.4034
Epoch 013 | Train: 0.3142 | Val: 0.4214
Epoch 014 | Train: 0.3097 | Val: 0.4045
Epoch 015 | Train: 0.3039 | Val: 0.3940


In [None]:
save_path = "/content/drive/MyDrive/unet3d_fault.pth"

torch.save({
    "model_state_dict": model.state_dict(),
    "epoch": epoch,
}, save_path)
