In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class FWIForward(nn.Module):
    '''Forward modeling

    Args:
        ctx: dictionary that contains parameters for forward modeling, see FWM() for details 
        device: torch device
        sample_temporal: timestep interval
        sample_spatial: percentage of # of receivers
        normalize: whether denormalize velocity map and return normalized seismic data
        v_denorm_func: denormalization function for velocity map
        s_norm_func: normalization function for seismic data
    '''
    def __init__(self, ctx, device, sample_temporal=1, sample_spatial=1.0, normalize=True, v_denorm_func=None, s_norm_func=None):
        super(FWIForward, self).__init__()
        self.device = device
        self.normalize = normalize
        if normalize:
            self.v_denorm_func = v_denorm_func
            self.s_norm_func = s_norm_func
        self.sample_temporal = sample_temporal
        # Compute the locations of sources and receivers
        if 'sx' not in ctx.keys():
            ctx['sx'] = np.linspace(0, ctx['n_grid'] - 1, num=ctx['ns']) * ctx['dx']
        else:
            ctx['sx'] = np.array(ctx['sx']) * ctx['dx']
        if 'gx' not in ctx.keys():
            ctx['gx'] = np.linspace(0, ctx['n_grid'] - 1, num=int(sample_spatial * ctx['ng'])) * ctx['dx']
        else:
            ctx['gx'] = np.array(ctx['gx']) * ctx['dx']
        self.ctx = ctx

    # Source func
    def ricker(self, f, dt, nt):
        nw = 2.2/f/dt
        nw = 2*np.floor(nw/2)+1
        nc = np.floor(nw/2)
        k = np.arange(nw)
        
        alpha = (nc-k)*f*dt*np.pi
        beta = alpha ** 2
        w0 = (1-beta*2)*np.exp(-beta)
        w = np.zeros(nt)
        w[:len(w0)] = w0
        return w

    # Absorbing boundary condition
    # UPFWI paper Equation 16
    # Collino & Tsogka (2001) paper Equation 20&21
    def get_Abc(self, vp, nbc, dx):
        dimrange = 1.0*torch.unsqueeze(torch.arange(nbc, device=self.device), dim=-1)
        damp = torch.zeros_like(vp, device=self.device, requires_grad=False) 
        
        velmin,_ = torch.min(vp.view(vp.shape[0],-1), dim=-1, keepdim=False)

        a = (nbc-1)*dx       
        kappa = 3.0 * velmin * np.log(1e7) / (2.0 * a)
        kappa = torch.unsqueeze(kappa,dim=0)
        kappa = torch.repeat_interleave(kappa, nbc, dim=0)
        
        damp1d = kappa * (dimrange*dx/a) ** 2
        damp1d = damp1d.permute(1,0).unsqueeze(1)
        
        damp[:,:,:nbc, :] = torch.repeat_interleave(torch.flip(damp1d,dims=[-1]).unsqueeze(-1), vp.shape[-1], dim=-1) 
        damp[:,:,-nbc:,:] = torch.repeat_interleave(damp1d.unsqueeze(-1), vp.shape[-1], dim=-1) 
        damp[:,:,:, :nbc] = torch.repeat_interleave(torch.flip(damp1d,dims=[-1]).unsqueeze(-2), vp.shape[-2], dim=-2) 
        damp[:,:,:,-nbc:] = torch.repeat_interleave(damp1d.unsqueeze(-2), vp.shape[-2], dim=-2) 
        return damp

    # Adjust index after adding abc
    def adj_sr(self, sx,sz,gx,gz,dx,nbc):
        isx = np.around(sx/dx)+nbc
        isz = np.around(sz/dx)+nbc
        
        igx = np.around(gx/dx)+nbc
        igz = np.around(gz/dx)+nbc
        return isx.astype('int'),int(isz),igx.astype('int'),int(igz)

    def FWM(self, v, nbc, dx, nt, dt, f, sx, sz, gx, gz, **kwargs):
        ''' Forward modeling
        2nd-rder central FD in time, 4th-order central FD in space

        Args:
            v (tensor): velocity map, (N, 1, Depth, Width)
            nbc (int): # of grids for boundary condition
            dx (int): grid size
            dt (int): time interval
            f (int): source peak frequency
            sx (numpy array): source location 
            sz (int): source depth
            gx (numpy array): receiver location 
            gz (int): receiver depth
        '''
        src = self.ricker(f, dt, nt)
        alpha = (v*dt/dx) ** 2

        abc = self.get_Abc(v, nbc, dx)
        kappa = abc*dt

        c1 = -2.5
        c2 = 4.0/3.0
        c3 = -1.0/12.0 

        temp1 = 2+2*c1*alpha-kappa
        temp2 = 1-kappa
        beta_dt = (v*dt) ** 2
        
        ns = len(sx)
        #for i in range(ns):
        #    grid_idx = int(np.around(sx[i] / dx))
        #    print(f"source{i+1}: {grid_idx}")
        isx,isz,igx,igz = self.adj_sr(sx,sz,gx,gz,dx,nbc)
        seis = []
        p1 = torch.zeros((v.shape[0], ns, v.shape[2], v.shape[3]), device=self.device, requires_grad=True)
        p0 = torch.zeros((v.shape[0], ns, v.shape[2], v.shape[3]), device=self.device, requires_grad=True)
        p  = torch.zeros((v.shape[0], ns, v.shape[2], v.shape[3]), device=self.device, requires_grad=True)
        for i in range(nt):
            p = (temp1*p1 - temp2*p0 + alpha * 
                (c2*(torch.roll(p1, 1, dims = -2) + torch.roll(p1, -1, dims = -2) + torch.roll(p1, 1, dims = -1)+ torch.roll(p1, -1, dims = -1))
                +c3*(torch.roll(p1, 2, dims = -2) + torch.roll(p1, -2, dims = -2) + torch.roll(p1, 2, dims = -1)+ torch.roll(p1, -2, dims = -1))
                ))
            for loc in range(ns):
                p[:,loc,isz,isx[loc]] = p[:,loc,isz,isx[loc]] + beta_dt[:,0,isz,isx[loc]] * src[i]
            if i % self.sample_temporal == 0:
                seis.append(torch.unsqueeze(p[:, :, [igz]*len(igx), igx], dim=2))
            p0=p1
            p1=p
        return torch.cat(seis, dim=2)


    def forward(self, v):
        if self.normalize:
            v = self.v_denorm_func(v)
        v_pad = F.pad(v, (self.ctx['nbc'],) * 4, mode='replicate')
        s = self.FWM(v_pad, **self.ctx)
        return self.s_norm_func(s) if self.normalize else s


In [20]:
import matplotlib.pyplot as plt
from scripts.data_utils import data_trans
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

ctx_5 = {
'n_grid': 70, # Number of grid points in one dimension (depth or width)
'nt': 1000, # Reduced number of time steps for simplicity
'dx': 10, # Grid spacing (meters)
'nbc': 120, # Number of boundary condition layers
'dt': 1e-3, # Time interval (seconds)
'f': 15, # Source frequency (Hz)
'sz': 10, # Source depth
'gz': 10, # Receiver depth
'ng': 70, # Number of receivers
'ns': 5
}

ctx_6 = {
'n_grid': 70, # Number of grid points in one dimension (depth or width)
'nt': 1000, # Reduced number of time steps for simplicity
'dx': 10, # Grid spacing (meters)
'nbc': 120, # Number of boundary condition layers
'dt': 1e-3, # Time interval (seconds)
'f': 15, # Source frequency (Hz)
'sz': 10, # Source depth
'gz': 10, # Receiver depth
'ng': 70, # Number of receivers
'ns': 6, # Number of sources
'sx': [10, 15, 20, 34, 52, 69] # Source locations
}

ctx_10 = {
'n_grid': 70, # Number of grid points in one dimension (depth or width)
'nt': 1000, # Reduced number of time steps for simplicity
'dx': 10, # Grid spacing (meters)
'nbc': 120, # Number of boundary condition layers
'dt': 1e-3, # Time interval (seconds)
'f': 15, # Source frequency (Hz)
'sz': 10, # Source depth
'gz': 10, # Receiver depth
'ng': 70, # Number of receivers
'ns': 10#, # Number of sources
#'sx': [10, 15, 20, 34, 52, 69], # Source locations
}

fwi_forward_10 = FWIForward(ctx_10, device, normalize=False, v_denorm_func=data_trans.v_denormalize, s_norm_func=data_trans.s_normalize_none)
fwi_forward_6 = FWIForward(ctx_6, device, normalize=False, v_denorm_func=data_trans.v_denormalize, s_norm_func=data_trans.s_normalize_none)
fwi_forward_5 = FWIForward(ctx_5, device, normalize=False, v_denorm_func=data_trans.v_denormalize, s_norm_func=data_trans.s_normalize_none)

Using device: cuda


In [14]:
import os
### here we select different velocity models
family = ["CF","CV","FF","FV"]

for i in range(1,10):
    partition = []
    for f in family:
        full_path = f"data/velocity_model/{f}_test.npy"
        vm_sample = np.load(full_path)
        os.makedirs(f"dataset_{i+1}/velocity_model", exist_ok=True)
        np.save(f"dataset_{i+1}/velocity_model/{f}.npy", vm_sample[i*10:(i+1)*10])

In [23]:
### Creating the 10 source scenario ###
import os
for i in range(1,10):
    os.makedirs(f"dataset_{i+1}/seismic_data/10_sources", exist_ok=True)

    for stylies in ['CF','CV','FF','FV']:
        vm_sample = np.load(f"dataset_{i+1}/velocity_model/{stylies}.npy")
        print(f"Processing style: {stylies}, shape: {vm_sample.shape}")
        seismic_data = []
        for j in range(vm_sample.shape[0]):
            s_data = fwi_forward_10(torch.tensor(vm_sample[j:j+1,:,:,:], dtype=torch.float32, device=device))
            s_data_numpy = s_data.cpu().detach().numpy()
            seismic_data.append(s_data_numpy)
        seismic_data = np.concatenate(seismic_data, axis=0)
        print(f"Seismic data shapes: {seismic_data.shape}")
        np.save(f"dataset_{i+1}/seismic_data/10_sources/{stylies}.npy", seismic_data)

Processing style: CF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: CV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: FF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: FV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: CF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: CV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: FF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: FV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: CF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: CV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: FF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 10, 1000, 70)
Processing style: FV, shape: (10, 1, 70, 70)
Seismic data shapes:

In [24]:
### Creating the 6 source scenario ###
import os
for i in range(1,10):
    os.makedirs(f"dataset_{i+1}/seismic_data/6_sources", exist_ok=True)

    for stylies in ['CF','CV','FF','FV']:
        vm_sample = np.load(f"dataset_{i+1}/velocity_model/{stylies}.npy")
        print(f"Processing style: {stylies}, shape: {vm_sample.shape}")
        seismic_data = []
        for j in range(vm_sample.shape[0]):
            s_data = fwi_forward_6(torch.tensor(vm_sample[j:j+1,:,:,:], dtype=torch.float32, device=device))
            s_data_numpy = s_data.cpu().detach().numpy()
            seismic_data.append(s_data_numpy)
        seismic_data = np.concatenate(seismic_data, axis=0)
        print(f"Seismic data shapes: {seismic_data.shape}")
        np.save(f"dataset_{i+1}/seismic_data/6_sources/{stylies}.npy", seismic_data)

Processing style: CF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: CV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: FF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: FV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: CF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: CV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: FF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: FV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: CF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: CV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: FF, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 1000, 70)
Processing style: FV, shape: (10, 1, 70, 70)
Seismic data shapes: (10, 6, 10

In [None]:
### Creating the 2C dataset ###
import os
for i in range(1,10):
    os.makedirs(f"dataset_{i+1}/seismic_data/2C/client1", exist_ok=True)
    os.makedirs(f"dataset_{i+1}/seismic_data/2C/client2", exist_ok=True)

    for stylies in ['CF','CV','FF','FV']:
        print(f"Processing style: {stylies}")
        vm_sample = np.load(f"dataset_{i+1}/velocity_model/{stylies}.npy")
        seismic_data_client1 = []
        seismic_data_client2 = []
        for j in range(vm_sample.shape[0]):
            s_data = fwi_forward_6(torch.tensor(vm_sample[j:j+1,:,:,:], dtype=torch.float32, device=device))
            s_data_numpy = s_data.cpu().detach().numpy()
            seismic_data_client1.append(s_data_numpy[:,0:3,:,:35])
            seismic_data_client2.append(s_data_numpy[:,3:6,:,35:])
        seismic_data_client1 = np.concatenate(seismic_data_client1, axis=0)
        seismic_data_client2 = np.concatenate(seismic_data_client2, axis=0)
        print(f"Seismic data shapes: client1 {seismic_data_client1.shape}, client2 {seismic_data_client2.shape}")
        np.save(f"dataset_{i+1}/seismic_data/2C/client1/{stylies}.npy", seismic_data_client1)
        np.save(f"dataset_{i+1}/seismic_data/2C/client2/{stylies}.npy", seismic_data_client2)

Processing style: CF
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: CV
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: FF
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: FV
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: CF
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: CV
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: FF
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: FV
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: CF
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: CV
Seismic data shapes: client1 (10, 3, 1000, 35), client2 (10, 3, 1000, 35)
Processing style: FF
Seismic data shapes: client1 

In [None]:
### Creating the 3A dataset ###
import os
for i in range(1,10):
    os.makedirs(f"dataset_{i+1}/seismic_data/3A/client1", exist_ok=True)
    os.makedirs(f"dataset_{i+1}/seismic_data/3A/client2", exist_ok=True)
    os.makedirs(f"dataset_{i+1}/seismic_data/3A/client3", exist_ok=True)

    for stylies in ['CF','CV','FF','FV']:
        print(f"Processing style: {stylies}")
        vm_sample = np.load(f"dataset_{i+1}/velocity_model/{stylies}.npy")
        seismic_data_client1 = []
        seismic_data_client2 = []
        seismic_data_client3 = []
        for j in range(vm_sample.shape[0]):
            s_data = fwi_forward_10(torch.tensor(vm_sample[j:j+1,:,:,:], dtype=torch.float32, device=device))
            s_data_numpy = s_data.cpu().detach().numpy()
            seismic_data_client1.append(s_data_numpy[:,0:4,:,:24])
            seismic_data_client2.append(s_data_numpy[:,3:7,:,24:47])
            seismic_data_client3.append(s_data_numpy[:,6:10,:,47:70])
        seismic_data_client1 = np.concatenate(seismic_data_client1, axis=0)
        seismic_data_client2 = np.concatenate(seismic_data_client2, axis=0)
        seismic_data_client3 = np.concatenate(seismic_data_client3, axis=0)
        print(f"Seismic data shapes: client1 {seismic_data_client1.shape}, \
                client2 {seismic_data_client2.shape}, \
                client3 {seismic_data_client3.shape}")
        np.save(f"dataset_{i+1}/seismic_data/3A/client1/{stylies}.npy", seismic_data_client1)
        np.save(f"dataset_{i+1}/seismic_data/3A/client2/{stylies}.npy", seismic_data_client2)
        np.save(f"dataset_{i+1}/seismic_data/3A/client3/{stylies}.npy", seismic_data_client3)

Processing style: CF
Seismic data shapes: client1 (10, 4, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 4, 1000, 23)
Processing style: CV
Seismic data shapes: client1 (10, 4, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 4, 1000, 23)
Processing style: FF
Seismic data shapes: client1 (10, 4, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 4, 1000, 23)
Processing style: FV
Seismic data shapes: client1 (10, 4, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 4, 1000, 23)
Processing style: CF
Seismic data shapes: client1 (10, 4, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 4, 1000, 23)
Processing style: CV
Seismic data shapes: client1 (10, 4, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 4, 1000, 23)
Processing style: FF
Seismic data shapes: client1 (10, 4, 1000, 24),        

In [None]:
### Creating the 3B dataset ###
import os
import numpy as np
for i in range(1,10):
    os.makedirs(f"dataset_{i+1}/seismic_data/3B/client1", exist_ok=True)
    os.makedirs(f"dataset_{i+1}/seismic_data/3B/client2", exist_ok=True)
    os.makedirs(f"dataset_{i+1}/seismic_data/3B/client3", exist_ok=True)

    for stylies in ['CF','CV','FF','FV']:
        print(f"Processing style: {stylies}")
        vm_sample = np.load(f"dataset_{i+1}/velocity_model/{stylies}.npy")
        seismic_data_client1 = []
        seismic_data_client2 = []
        seismic_data_client3 = []
        for j in range(vm_sample.shape[0]):
            s_data = fwi_forward_10(torch.tensor(vm_sample[j:j+1,:,:,:], dtype=torch.float32, device=device))
            s_data_numpy = s_data.cpu().detach().numpy()
            seismic_data_client1.append(s_data_numpy[:,0:3,:,:24])
            seismic_data_client2.append(s_data_numpy[:,3:7,:,24:47])
            seismic_data_client3.append(s_data_numpy[:,7:10,:,47:70])
        seismic_data_client1 = np.concatenate(seismic_data_client1, axis=0)
        seismic_data_client2 = np.concatenate(seismic_data_client2, axis=0)
        seismic_data_client3 = np.concatenate(seismic_data_client3, axis=0)
        print(f"Seismic data shapes: client1 {seismic_data_client1.shape}, \
                client2 {seismic_data_client2.shape}, \
                client3 {seismic_data_client3.shape}")
        np.save(f"dataset_{i+1}/seismic_data/3B/client1/{stylies}.npy", seismic_data_client1)
        np.save(f"dataset_{i+1}/seismic_data/3B/client2/{stylies}.npy", seismic_data_client2)
        np.save(f"dataset_{i+1}/seismic_data/3B/client3/{stylies}.npy", seismic_data_client3)

Processing style: CF
Seismic data shapes: client1 (10, 3, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 3, 1000, 23)
Processing style: CV
Seismic data shapes: client1 (10, 3, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 3, 1000, 23)
Processing style: FF
Seismic data shapes: client1 (10, 3, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 3, 1000, 23)
Processing style: FV
Seismic data shapes: client1 (10, 3, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 3, 1000, 23)
Processing style: CF
Seismic data shapes: client1 (10, 3, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 3, 1000, 23)
Processing style: CV
Seismic data shapes: client1 (10, 3, 1000, 24),                 client2 (10, 4, 1000, 23),                 client3 (10, 3, 1000, 23)
Processing style: FF
Seismic data shapes: client1 (10, 3, 1000, 24),        

KeyboardInterrupt: 

In [None]:
### Creating the 2A dataset ###
import os
for i in range(1,10):
    os.makedirs(f"dataset_{i+1}/seismic_data/2A/client1", exist_ok=True)
    os.makedirs(f"dataset_{i+1}/seismic_data/2A/client2", exist_ok=True)

    for stylies in ['CF','CV','FF','FV']:
        print(f"Processing style: {stylies}")
        vm_sample = np.load(f"dataset_{i+1}/velocity_model/{stylies}.npy")
        seismic_data_client1 = []
        seismic_data_client2 = []
        for j in range(vm_sample.shape[0]):
            s_data = fwi_forward_5(torch.tensor(vm_sample[j:j+1,:,:,:], dtype=torch.float32, device=device))
            s_data_numpy = s_data.cpu().detach().numpy()
            # 2A: client1 uses sources [:3], receivers [:35]; client2 uses [3:5], receivers [35:]
            seismic_data_client1.append(s_data_numpy[:,0:3,:,:35])
            seismic_data_client2.append(s_data_numpy[:,3:5,:,35:])
        seismic_data_client1 = np.concatenate(seismic_data_client1, axis=0)
        seismic_data_client2 = np.concatenate(seismic_data_client2, axis=0)
        print(f"Seismic data shapes: client1 {seismic_data_client1.shape}, client2 {seismic_data_client2.shape}")
        np.save(f"dataset_{i+1}/seismic_data/2A/client1/{stylies}.npy", seismic_data_client1)
        np.save(f"dataset_{i+1}/seismic_data/2A/client2/{stylies}.npy", seismic_data_client2)


In [None]:
### Creating the 2B dataset ###
import os
for i in range(1,10):
    os.makedirs(f"dataset_{i+1}/seismic_data/2B/client1", exist_ok=True)
    os.makedirs(f"dataset_{i+1}/seismic_data/2B/client2", exist_ok=True)

    for stylies in ['CF','CV','FF','FV']:
        print(f"Processing style: {stylies}")
        vm_sample = np.load(f"dataset_{i+1}/velocity_model/{stylies}.npy")
        seismic_data_client1 = []
        seismic_data_client2 = []
        for j in range(vm_sample.shape[0]):
            s_data = fwi_forward_5(torch.tensor(vm_sample[j:j+1,:,:,:], dtype=torch.float32, device=device))
            s_data_numpy = s_data.cpu().detach().numpy()
            # 2B: client1 uses sources [:3], receivers [:35]; client2 uses [2:5], receivers [35:]
            seismic_data_client1.append(s_data_numpy[:,0:3,:,:35])
            seismic_data_client2.append(s_data_numpy[:,2:5,:,35:])
        seismic_data_client1 = np.concatenate(seismic_data_client1, axis=0)
        seismic_data_client2 = np.concatenate(seismic_data_client2, axis=0)
        print(f"Seismic data shapes: client1 {seismic_data_client1.shape}, client2 {seismic_data_client2.shape}")
        np.save(f"dataset_{i+1}/seismic_data/2B/client1/{stylies}.npy", seismic_data_client1)
        np.save(f"dataset_{i+1}/seismic_data/2B/client2/{stylies}.npy", seismic_data_client2)
