In [1]:
RUN_TRAIN = True # bfloat16 or float32 recommended
RUN_VALID = True
RUN_TEST  = False

import torch
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
    raise RuntimeError("Requires >= 2 GPUs with CUDA enabled.")

try: 
    import monai
except: 
    !pip install --no-deps monai -q

[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m2.7/2.7 MB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h

In [2]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_key = user_secrets.get_secret("wandb_key")

# ConvNeXt Baseline Notebook

This notebook builds on [ConvNeXt - Full Resolution Baseline](https://www.kaggle.com/code/brendanartley/convnext-full-resolution-baseline) notebook. 


In [25]:
%%writefile _cfg.py

from types import SimpleNamespace
import torch

cfg= SimpleNamespace()
cfg.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg.local_rank = 0
cfg.seed = 123
cfg.subsample = None # None=full dataset, int32=first n rows with respect to "dataset"
cfg.subsample_fraction = 0.3  # Use % of data from each dataset family, or None 

cfg.backbone = "convnext_small.fb_in22k_ft_in1k"
cfg.ema = True
cfg.ema_decay = 0.99

cfg.epochs = 5
cfg.batch_size = 32
cfg.batch_size_val = 32

cfg.early_stopping = {"patience": 3, "streak": 0}
cfg.logging_steps = 100

cfg.use_wandb = True
cfg.wandb_project = "seismic-inversion"

# ADDED NEW LINES:
cfg.run_name = "convnext_0.3_[-2]layer"

cfg.mixed_precision = False  # Set to True for A100, False for T4
cfg.precision_dtype = torch.float32  # torch.float16 or torch.bfloat16
cfg.use_grad_scaler = False  # Automatically set based on mixed_precision

cfg.compile_model = True

# GPU Logic: Auto-detect based on environment
cfg.distributed = True  # Will be overridden by auto-detection

# 72x72 vs Full Resolution Mode
cfg.use_72x72_mode = False  # Set to True for 72x72 training, False for full resolution
cfg.data_path_72x72 = "/kaggle/input/openfwi-preprocessed-72x72/openfwi_72x72/"

# 256x72 unusual preprocessing 
cfg.use_smart_256x72_mode = False

cfg.smoothness_loss = False

# Augmentation settings
cfg.use_augmentations = False  # Enable/disable training augmentations
cfg.aug_noise_prob = 0.4      # Probability of noise augmentation
cfg.aug_noise_level = 0.005    # Noise level as fraction of signal std
cfg.aug_scale_prob = 0.3      # Probability of scale augmentation  
cfg.aug_scale_range = (0.99, 1.01)  # Scale range (min, max)

# TTA settings  
cfg.enhanced_tta = False       # Enhanced TTA vs basic flip TTA
cfg.tta_noise_level = 0.005   # TTA noise level as fraction of signal std
cfg.tta_scale_values = [0.98, 1.02]  # Fixed scale values for TTA

# Decoder layer selection
cfg.decoder_layer_index = -2  # -1 = last layer (x[-1]), -2 = second-to-last (x[-2]), etc.
cfg.use_feature_fusion = False  # True = fuse multiple layers, False = single layer

# Feature fusion settings (only used if use_feature_fusion = True)
cfg.fusion_layers = [-1, -2]   # Which layers to combine
cfg.fusion_weights = [0.7, 0.3]  # Weights for each layer (must match fusion_layers length)



Overwriting _cfg.py


### Dataset
 

The input dataset changes to [openfwi_float16_1](https://www.kaggle.com/datasets/egortrushin/open-wfi-1) and [openfwi_float16_2](https://www.kaggle.com/datasets/egortrushin/open-wfi-2) datasets.

In [26]:
%%writefile _dataset.py

import os
import glob

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch

from _cfg import cfg

import torch
import torch.nn.functional as F

def smart_preprocess_256x72(x):
    """
    Smart preprocessing: Early arrivals focus + minimal padding to 72
    
    Args:
        x: numpy array, shape (sources, time, receivers) - single sample
    
    Returns:
        numpy array, shape (sources, 256, 72)
    """
    
    x_tensor = torch.from_numpy(x).float()
    
    
    time_samples = x_tensor.shape[1]
    crop_samples = min(512, time_samples)
    x_cropped = x_tensor[:, :crop_samples, :]  # (sources, 512, receivers)
    
    
    sources, _, receivers = x_cropped.shape
    x_reshaped = x_cropped.unsqueeze(1)  
    
    # Downsample to 256√ó70 first (preserve original spatial resolution)
    x_downsampled = F.interpolate(x_reshaped, size=(256, 70), mode='area')
    x_downsampled = x_downsampled.squeeze(1)  
    
    x_padded = F.pad(x_downsampled, (1, 1, 0, 0), mode='replicate')  
        
    return x_padded.numpy().astype(np.float16)

def inputs_files_to_output_files(input_files):
    """Convert input file paths to output file paths"""
    return [
        f.replace('/seis', '/vel').replace('/data', '/model')
        for f in input_files
    ]

def get_72x72_data_files(data_path):
    """Get data files for 72x72 mode"""
    # All filenames
    all_inputs = [
        f for f in glob.glob(data_path + "/*/*.npy")
        if ('/seis' in f) or ('/data' in f)
    ]
    all_outputs = inputs_files_to_output_files(all_inputs)
    assert all([x != y for x,y in zip(all_inputs, all_outputs)])

    # Validation filenames (same split as HGNet)
    val_fpaths= [
        'CurveFault_A/seis2_1_0.npy', 'CurveFault_A/seis2_1_1.npy', 
        'CurveFault_B/seis6_1_0.npy', 'CurveFault_B/seis6_1_1.npy', 
        'CurveVel_A/data1.npy', 'CurveVel_A/data10.npy', 
        'CurveVel_B/data1.npy', 'CurveVel_B/data10.npy', 
        'FlatFault_A/seis2_1_0.npy', 'FlatFault_A/seis2_1_1.npy', 
        'FlatFault_B/seis6_1_0.npy', 'FlatFault_B/seis6_1_1.npy', 
        'FlatVel_A/data1.npy', 'FlatVel_A/data10.npy', 
        'FlatVel_B/data1.npy', 'FlatVel_B/data10.npy', 
        'Style_A/data1.npy', 'Style_A/data10.npy', 
        'Style_B/data1.npy', 'Style_B/data10.npy',
    ]

    train_inputs, train_outputs = [], []
    valid_inputs, valid_outputs = [], []

    # Iterate and split files
    for a, b in zip(all_inputs, all_outputs):
        to_val = False
        
        for c in val_fpaths:
            if c in a:
                to_val = True

        if to_val:
            valid_inputs.append(a)
            valid_outputs.append(b)
        else:
            train_inputs.append(a)
            train_outputs.append(b)

    return train_inputs, train_outputs, valid_inputs, valid_outputs

class CustomDataset(torch.utils.data.Dataset):
    def __init__(
        self, 
        cfg,
        mode = "train", 
    ):
        self.cfg = cfg
        self.mode = mode
        
        if cfg.use_72x72_mode:
            self.data, self.labels, self.records = self.load_72x72_metadata()
        else:
            self.data, self.labels, self.records = self.load_metadata()

    def load_metadata(self):
        """Original full resolution data loading"""
        # Select rows
        df= pd.read_csv("/kaggle/input/openfwi-preprocessed-72x72/folds.csv")
        
        if hasattr(self.cfg, 'subsample_fraction') and self.cfg.subsample_fraction is not None:
            # Sample by fraction within each dataset family
            df = df.groupby(["dataset", "fold"]).apply(
                lambda x: x.sample(frac=self.cfg.subsample_fraction, random_state=self.cfg.seed)
            ).reset_index(drop=True)
        elif self.cfg.subsample is not None:
            # Original fixed-number sampling
            df= df.groupby(["dataset", "fold"]).head(self.cfg.subsample)

        if self.mode == "train":
            df= df[df["fold"] != 0]
        else:
            df= df[df["fold"] == 0]

        data = []
        labels = []
        records = []
        mmap_mode = "r"

        for idx, row in tqdm(df.iterrows(), total=len(df), disable=self.cfg.local_rank != 0):
            row= row.to_dict()

           
            # Original full dataset paths
            p1 = os.path.join("/kaggle/input/open-wfi-1/openfwi_float16_1/", row["data_fpath"])
            p2 = os.path.join("/kaggle/input/open-wfi-1/openfwi_float16_1/", row["data_fpath"].split("/")[0], "*", row["data_fpath"].split("/")[-1])
            p3 = os.path.join("/kaggle/input/open-wfi-2/openfwi_float16_2/", row["data_fpath"])
            p4 = os.path.join("/kaggle/input/open-wfi-2/openfwi_float16_2/", row["data_fpath"].split("/")[0], "*", row["data_fpath"].split("/")[-1])
            farr = glob.glob(p1) + glob.glob(p2) + glob.glob(p3) + glob.glob(p4)
        
            # Map to lbl fpath
            farr= farr[0]
            flbl= farr.replace('seis', 'vel').replace('data', 'model')
            
            # Load
            arr= np.load(farr, mmap_mode=mmap_mode)
            lbl= np.load(flbl, mmap_mode=mmap_mode)

            # Append
            data.append(arr)
            labels.append(lbl)
            records.append(row["dataset"])

        return data, labels, records

    def load_72x72_metadata(self):
        """Load 72x72 preprocessed data"""
        train_inputs, train_outputs, valid_inputs, valid_outputs = get_72x72_data_files(self.cfg.data_path_72x72)
        
        if self.mode == "train":
            input_files = train_inputs
            output_files = train_outputs
        else:
            input_files = valid_inputs
            output_files = valid_outputs

        data = []
        labels = []
        records = []

        for input_file, output_file in tqdm(zip(input_files, output_files), total=len(input_files), disable=self.cfg.local_rank != 0):
            # Load data
            arr = np.load(input_file, mmap_mode='r')
            lbl = np.load(output_file, mmap_mode='r')
            
            # Extract dataset name from path
            dataset_name = input_file.split("/")[-2]
            
            data.append(arr)
            labels.append(lbl)
            records.append(dataset_name)

        return data, labels, records

    def __getitem__(self, idx):
        row_idx = idx // 500
        col_idx = idx % 500
    
        d = self.records[row_idx]
        x = self.data[row_idx][col_idx, ...]
        y = self.labels[row_idx][col_idx, ...]
    
        # Augmentations - before preprocessing
        if self.mode == "train":
            # Flip augmentation (50% chance)
            if np.random.random() < 0.5:
                if self.cfg.use_72x72_mode:
                    x = x[::-1, :, ::-1]
                    y = y[:, ::-1]
                elif self.cfg.use_smart_256x72_mode:
                    x = x[::-1, :, ::-1]
                    y = y[:, ::-1]
                else:
                    x = x[::-1, :, ::-1]
                    y = y[..., ::-1]
                    
            if self.cfg.use_augmentations:
            
                # Configurable noise augmentation
                if np.random.random() < self.cfg.aug_noise_prob:
                    noise_level = self.cfg.aug_noise_level * np.std(x)
                    noise = np.random.normal(0, noise_level, x.shape).astype(x.dtype)
                    x = x + noise
                    
                    # Safety check for NaN/inf
                    if np.isnan(x).any() or np.isinf(x).any():
                        print(f"Warning: NaN/inf detected in noise augmentation, reverting")
                        x = x - noise
                
                # Configurable scale augmentation
                if np.random.random() < self.cfg.aug_scale_prob:
                    scale = np.random.uniform(self.cfg.aug_scale_range[0], self.cfg.aug_scale_range[1])
                    x = x * scale
                    
                    # Safety check for NaN/inf
                    if np.isnan(x).any() or np.isinf(x).any():
                        print(f"Warning: NaN/inf detected in scale augmentation, reverting")
                        x = x / scale
    
        x = x.copy()
        y = y.copy()
    
        if self.cfg.use_smart_256x72_mode:
            x = smart_preprocess_256x72(x)
        
        elif self.cfg.use_72x72_mode:
            pass
    
        if not self.cfg.mixed_precision:
            x = x.astype(np.float32)
            y = y.astype(np.float32)
        
        return x, y
    
    def __len__(self, ):
        return len(self.records) * 500

Overwriting _dataset.py


# Model

The model uses the `ConvNeXt` backbone from timm. See more info on this backbone [here](https://huggingface.co/timm/convnext_small.fb_in22k_ft_in1k) and the original paper [here](https://arxiv.org/abs/2201.03545). We modify the stem to aggressively downsample the height and we replace normalization layers with `InstanceNorm2d`.

### Encoder

For the unet, we typically want the encoder to downsample by a factor of 2x at each stage. This works best when the input is square so that we can use as little padding as possible. In the original notebook, we did this by interpolating the input data. This worked okay, but we lost a lot of detail as a result. Here we rely on the stem to downsample using convolutions. See the `update_stem()` function for more details. 

#### Normalization

Most CNNs use `BatchNorm2D`, which relies on batch statistics when computing normalization. However, replacing this with a batch-independent normalization layer like `InstanceNorm2D` or `LayerNorm` can improve convergence speed and stabilize validation performance. ConvNeXt uses `LayerNorm` by default, but we use `InstanceNorm2D` instead.

Since normalization is now independent of batch statistics, smaller batch sizes can be used without a drop in performance. For example, a batch-size of 16 uses ~9GB of vRAM during training.

### Decoder

The decoder is mostly the same. One small change is that we did not use intermediate convolutions here.

In [27]:
%%writefile _model.py
from copy import deepcopy
from types import MethodType

import torch
import torch.nn as nn
import torch.nn.functional as F

import timm
from timm.models.convnext import ConvNeXtBlock

from monai.networks.blocks import UpSample, SubpixelUpsample

####################
## EMA + Ensemble ##
####################

class ModelEMA(nn.Module):
    def __init__(self, model, decay=0.99, device=None):
        super().__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)


class EnsembleModel(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = nn.ModuleList(models).eval()

    def forward(self, x):
        output = None
        
        for m in self.models:
            logits= m(x)
            
            if output is None:
                output = logits
            else:
                output += logits
                
        output /= len(self.models)
        return output
        

#############
## Decoder ##
#############

class ConvBnAct2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding: int = 0,
        stride: int = 1,
        norm_layer: nn.Module = nn.Identity,
        act_layer: nn.Module = nn.ReLU,
    ):
        super().__init__()

        self.conv= nn.Conv2d(
            in_channels, 
            out_channels,
            kernel_size,
            stride=stride, 
            padding=padding, 
            bias=False,
        )
        self.norm = norm_layer(out_channels) if norm_layer != nn.Identity else nn.Identity()
        self.act= act_layer(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        return x


class SCSEModule2d(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.Tanh(),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(
            nn.Conv2d(in_channels, 1, 1), 
            nn.Sigmoid(),
            )

    def forward(self, x):
        return x * self.cSE(x) + x * self.sSE(x)

class Attention2d(nn.Module):
    def __init__(self, name, **params):
        super().__init__()
        if name is None:
            self.attention = nn.Identity(**params)
        elif name == "scse":
            self.attention = SCSEModule2d(**params)
        else:
            raise ValueError("Attention {} is not implemented".format(name))

    def forward(self, x):
        return self.attention(x)

class DecoderBlock2d(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = None,
        intermediate_conv: bool = False,
        upsample_mode: str = "deconv",
        scale_factor: int = 2,
    ):
        super().__init__()

        # Upsample block
        if upsample_mode == "pixelshuffle":
            self.upsample= SubpixelUpsample(
                spatial_dims= 2,
                in_channels= in_channels,
                scale_factor= scale_factor,
            )
        else:
            self.upsample = UpSample(
                spatial_dims= 2,
                in_channels= in_channels,
                out_channels= in_channels,
                scale_factor= scale_factor,
                mode= upsample_mode,
            )

        if intermediate_conv:
            k= 3
            c= skip_channels if skip_channels != 0 else in_channels
            self.intermediate_conv = nn.Sequential(
                ConvBnAct2d(c, c, k, k//2),
                ConvBnAct2d(c, c, k, k//2),
                )
        else:
            self.intermediate_conv= None

        self.attention1 = Attention2d(
            name= attention_type, 
            in_channels= in_channels + skip_channels,
            )

        self.conv1 = ConvBnAct2d(
            in_channels + skip_channels,
            out_channels,
            kernel_size= 3,
            padding= 1,
            norm_layer= norm_layer,
        )

        self.conv2 = ConvBnAct2d(
            out_channels,
            out_channels,
            kernel_size= 3,
            padding= 1,
            norm_layer= norm_layer,
        )
        self.attention2 = Attention2d(
            name= attention_type, 
            in_channels= out_channels,
            )

    def forward(self, x, skip=None):
        x = self.upsample(x)

        if self.intermediate_conv is not None:
            if skip is not None:
                skip = self.intermediate_conv(skip)
            else:
                x = self.intermediate_conv(x)

        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class UnetDecoder2d(nn.Module):
    """
    Unet decoder.
    Source: https://arxiv.org/abs/1505.04597
    """
    def __init__(
        self,
        encoder_channels: tuple[int],
        skip_channels: tuple[int] = None,
        decoder_channels: tuple = (256, 128, 64, 32),
        scale_factors: tuple = (2,2,2,2),
        norm_layer: nn.Module = nn.Identity,
        attention_type: str = None,
        intermediate_conv: bool = False,
        upsample_mode: str = "deconv",
    ):
        super().__init__()
        
        if len(encoder_channels) == 4:
            decoder_channels= decoder_channels[1:]
        self.decoder_channels= decoder_channels
        
        if skip_channels is None:
            skip_channels= list(encoder_channels[1:]) + [0]
        
        # STORE skip_channels for use in forward method
        self.skip_channels = skip_channels
        
        # Build decoder blocks
        in_channels= [encoder_channels[0]] + list(decoder_channels[:-1])
        self.blocks = nn.ModuleList()
        for i, (ic, sc, dc) in enumerate(zip(in_channels, skip_channels, decoder_channels)):
            self.blocks.append(
                DecoderBlock2d(
                    ic, sc, dc, 
                    norm_layer= norm_layer,
                    attention_type= attention_type,
                    intermediate_conv= intermediate_conv,
                    upsample_mode= upsample_mode,
                    scale_factor= scale_factors[i],
                    )
            )
    
    def forward(self, feats: list[torch.Tensor]):
        res= [feats[0]]
        feats= feats[1:]
        
        for i, b in enumerate(self.blocks):
            if i < len(self.skip_channels) and self.skip_channels[i] > 0:
                skip = feats[i] if i < len(feats) else None
            else:
                skip = None  # Force skip to None when skip_channels=0
                
            res.append(
                b(res[-1], skip=skip),
            )
            
        return res

class SegmentationHead2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        scale_factor: tuple[int] = (2,2),
        kernel_size: int = 3,
        mode: str = "nontrainable",
    ):
        super().__init__()
        self.conv= nn.Conv2d(
            in_channels, out_channels, kernel_size= kernel_size,
            padding= kernel_size//2
        )
        self.upsample = UpSample(
            spatial_dims= 2,
            in_channels= out_channels,
            out_channels= out_channels,
            scale_factor= scale_factor,
            mode= mode,
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.upsample(x)
        return x

#################################
## PERFECT SEGMENTATION HEAD  ##
#################################

class PerfectSegmentationHead2d(nn.Module):
    """
    Perfect segmentation head: 64√ó64 ‚Üí exactly 70√ó70
    Uses learnable upsampling for precise output size
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        # Feature processing
        self.conv1 = nn.Conv2d(in_channels, in_channels//2, 3, padding=1)
        self.norm1 = nn.InstanceNorm2d(in_channels//2, affine=True)
        self.act1 = nn.GELU()
        
        self.conv2 = nn.Conv2d(in_channels//2, out_channels, 3, padding=1)
        
        # Using ConvTranspose2d with exact kernel/padding for 70√ó70
        self.upsample = nn.ConvTranspose2d(
            out_channels, out_channels,
            kernel_size=7, stride=1, padding=0  
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act1(x)
        x = self.conv2(x)
        
        x = self.upsample(x)  
        
        return x
        

#############
## Encoder ##
#############

def _convnext_block_forward(self, x):
    shortcut = x
    x = self.conv_dw(x)

    if self.use_conv_mlp:
        x = self.norm(x)
        x = self.mlp(x)
    else:
        x = self.norm(x)
        x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
        x = self.mlp(x)
        x = x.permute(0, 3, 1, 2)
        x = x.contiguous()

    if self.gamma is not None:
        x = x * self.gamma.reshape(1, -1, 1, 1)

    x = self.drop_path(x) + self.shortcut(shortcut)
    return x


class Net(nn.Module):
    
    def __init__(
        self,
        backbone: str,
        pretrained: bool = True,
        use_72x72_mode: bool = False,
        use_smart_256x72_mode: bool = False,
    ):
        super().__init__()
        
        self.use_72x72_mode = use_72x72_mode
        self.use_smart_256x72_mode = use_smart_256x72_mode
        
        # Encoder
        self.backbone= timm.create_model(
            backbone,
            in_chans= 5,
            pretrained= pretrained,
            features_only= True,
            drop_path_rate=0.0,
            )
        ecs= [_["num_chs"] for _ in self.backbone.feature_info][::-1]

        
        # Decoder - DIFFERENT for each mode
        if use_smart_256x72_mode:
            self.decoder= UnetDecoder2d(
                encoder_channels= ecs,
                scale_factors= (2, 2, 2, 2),  
            )
            
            self.seg_head = PerfectSegmentationHead2d(
                in_channels=self.decoder.decoder_channels[-1],
                out_channels=1,
            )
            
            print("[Properly Engineered 256√ó72] Using PerfectSegmentationHead2d for exact 70√ó70 output")
            
        elif use_72x72_mode:
            # Modified decoder with different scale factors for 72x72
            # The fundamental issue is that 72√ó72 is too small for this U-Net architecture. 
            # We're trying to fit a small input into a model designed for large images. 
            # For 72√ó72, a simpler decoder without skip connections was tried, but unsuccessful.
            self.decoder= UnetDecoder2d(
                encoder_channels= ecs,
                skip_channels= [0, 0, 0, 0],
                scale_factors= (3, 3, 2, 2),  
            )
            
            self.seg_head= SegmentationHead2d(
                in_channels= self.decoder.decoder_channels[-1],
                out_channels= 1,
                scale_factor= 1,
            )
        else:
            # Original decoder for full resolution
            self.decoder= UnetDecoder2d(
                encoder_channels= ecs,
                scale_factors= (2, 2, 2, 2),  
            )
    
            self.seg_head= SegmentationHead2d(
                in_channels= self.decoder.decoder_channels[-1],
                out_channels= 1,
                scale_factor= 1,
            )

        
        
        # Stem modifications
        if use_smart_256x72_mode:
            self._update_stem_256x72_properly_engineered(backbone)
        elif use_72x72_mode:
            self._update_stem_72x72(backbone)
        else:
            self._update_stem_full_res(backbone)
        
        self.replace_activations(self.backbone, log=True)
        self.replace_norms(self.backbone, log=True)
        self.replace_forwards(self.backbone, log=True)

        decoder_channels = self.decoder.decoder_channels
        self.channel_adapters = nn.ModuleDict({
            '-1': nn.Identity(),
            '-2': nn.Conv2d(decoder_channels[-2], decoder_channels[-1], 1) if len(decoder_channels) >= 2 else nn.Identity(),
            '-3': nn.Conv2d(decoder_channels[-3], decoder_channels[-1], 1) if len(decoder_channels) >= 3 else nn.Identity(),
            '-4': nn.Conv2d(decoder_channels[-4], decoder_channels[-1], 1) if len(decoder_channels) >= 4 else nn.Identity(),
        })
        
    
    def _update_stem_256x72_properly_engineered(self, backbone):
        """
        PROPERLY ENGINEERED: 256√ó72 ‚Üí 32√ó32 for perfect 70√ó70 output
        Works backwards from target to ensure exact dimensions
        """
        if backbone.startswith("convnext"):
            original_stem = self.backbone.stem_0
            original_weight = original_stem.weight
            original_bias = original_stem.bias
            out_channels = original_weight.shape[0]
            
            # Calculate exact padding needed
            # Target progression: 256√ó72 ‚Üí 280√ó96 ‚Üí 70√ó48 ‚Üí 33√ó48 ‚Üí 32√ó40 ‚Üí 32√ó32
            
            conv1 = nn.Conv2d(
                5, out_channels, 
                kernel_size=6, stride=(4, 2), padding=1
            )
            # 280√ó96 ‚Üí 70√ó48
            
            conv2 = nn.Conv2d(
                out_channels, out_channels,
                kernel_size=(5, 3), stride=(2, 1), padding=(0, 1)
            )
            # 70√ó48 ‚Üí 33√ó48
            
            conv3 = nn.Conv2d(
                out_channels, out_channels, 
                kernel_size=(2, 9), stride=(1, 1), padding=0
            )
            # 33√ó48 ‚Üí 32√ó40
            
            self.backbone.stem_0 = nn.Sequential(
                nn.ReflectionPad2d((12, 12, 12, 12)),  
                
                conv1,  # 280√ó96 ‚Üí 70√ó48
                nn.LayerNorm([out_channels, 70, 48]),
                nn.GELU(),
                
                conv2,  # 70√ó48 ‚Üí 33√ó48  
                nn.LayerNorm([out_channels, 33, 48]),
                nn.GELU(),
                
                conv3,  # 33√ó48 ‚Üí 32√ó40
                nn.LayerNorm([out_channels, 32, 40]),
                nn.GELU(),
                
                nn.AdaptiveAvgPool2d((32, 32))  # 32√ó40 ‚Üí 32√ó32
            )
            
            with torch.no_grad():
                conv1.weight.data.normal_(0, 0.02)
                conv1.bias.data.zero_()

                if original_weight.shape[2] <= 6 and original_weight.shape[3] <= 6:
                    conv1.weight[:, :, :original_weight.shape[2], :original_weight.shape[3]].copy_(original_weight)
                    conv1.bias.copy_(original_bias)
                    
                conv2.weight.data.normal_(0, 0.02)
                conv2.bias.data.zero_()
                conv3.weight.data.normal_(0, 0.02) 
                conv3.bias.data.zero_()
            
            print(f"[PROPERLY ENGINEERED] Complete stem flow:")
            print(f"  256√ó72 ‚Üí ReflectionPad(280√ó96) ‚Üí Conv1(70√ó48) ‚Üí Conv2(33√ó48) ‚Üí Conv3(32√ó40) ‚Üí Pool(32√ó32)")
            print(f"  RESULT: Perfect 32√ó32 ‚Üí Perfect skip connections ‚Üí Perfect 70√ó70 output!")
            
        else:
            raise ValueError("Properly engineered stem not implemented for this backbone.")

    def _update_stem_72x72(self, backbone):
        """Proper stem modifications for 72x72 mode"""
        if backbone.startswith("convnext"):
            self.backbone.stem_0.stride = (2, 2)  # 72√∑2 = 36
            self.backbone.stem_0.padding = (1, 1)  # Standard padding
        else:
            raise ValueError("Custom striding not implemented for 72x72 mode.")

    def _update_stem_full_res(self, backbone):
        """Original aggressive stem modifications for full resolution - EXACTLY AS ORIGINAL"""
        if backbone.startswith("convnext"):

            # Update stride
            self.backbone.stem_0.stride = (4, 1)
            self.backbone.stem_0.padding = (0, 2)

            # Duplicate stem layer (to downsample height)
            with torch.no_grad():
                w = self.backbone.stem_0.weight
                new_conv= nn.Conv2d(w.shape[0], w.shape[0], kernel_size=(4, 4), stride=(4, 1), padding=(0, 1))
                new_conv.weight.copy_(w.repeat(1, (128//w.shape[1])+1, 1, 1)[:, :new_conv.weight.shape[1], :, :])
                new_conv.bias.copy_(self.backbone.stem_0.bias)

            self.backbone.stem_0= nn.Sequential(
                nn.ReflectionPad2d((1,1,80,80)),
                self.backbone.stem_0,
                new_conv,
            )

        else:
            raise ValueError("Custom striding not implemented.")
        pass

    def replace_activations(self, module, log=False):
        if log and self.use_smart_256x72_mode:
            print(f"[256√ó72 Mode] Replacing all activations with GELU...")
        elif log and self.use_72x72_mode:
            print(f"[72x72 Mode] Replacing all activations with GELU...")
        elif log:
            print(f"[Full Res Mode] Replacing all activations with GELU...")
        
        for name, child in module.named_children():
            if isinstance(child, (
                nn.ReLU, nn.LeakyReLU, nn.Mish, nn.Sigmoid, 
                nn.Tanh, nn.Softmax, nn.Hardtanh, nn.ELU, 
                nn.SELU, nn.PReLU, nn.CELU, nn.GELU, nn.SiLU,
            )):
                setattr(module, name, nn.GELU())
            else:
                self.replace_activations(child)

    def replace_norms(self, mod, log=False):
        if log and self.use_smart_256x72_mode:
            print(f"[256√ó72 Mode] Replacing all norms with InstanceNorm...")
        elif log and self.use_72x72_mode:
            print(f"[72x72 Mode] Replacing all norms with InstanceNorm...")
        elif log:
            print(f"[Full Res Mode] Replacing all norms with InstanceNorm...")
            
        for name, c in mod.named_children():
            n_feats= None
            if isinstance(c, (nn.BatchNorm2d, nn.InstanceNorm2d)):
                n_feats= c.num_features
            elif isinstance(c, (nn.GroupNorm,)):
                n_feats= c.num_channels
            elif isinstance(c, (nn.LayerNorm,)):
                n_feats= c.normalized_shape[0]

            if n_feats is not None:
                new = nn.InstanceNorm2d(
                    n_feats,
                    affine=True,
                    )
                setattr(mod, name, new)
            else:
                self.replace_norms(c)

    def replace_forwards(self, mod, log=False):
        if log and self.use_smart_256x72_mode:
            print(f"[256√ó72 Mode] Replacing forward functions...")
        elif log and self.use_72x72_mode:
            print(f"[72x72 Mode] Replacing forward functions...")
        elif log:
            print(f"[Full Res Mode] Replacing forward functions...")
            
        for name, c in mod.named_children():
            if isinstance(c, ConvNeXtBlock):
                c.forward = MethodType(_convnext_block_forward, c)
            else:
                self.replace_forwards(c)
        

    def get_decoder_features(self, decoder_outputs):
        """Get features from decoder based on config"""
        from _cfg import cfg
        
        if cfg.use_feature_fusion:
            # Use learnable weights if available, otherwise use config weights
            if hasattr(self, 'fusion_weights'):
                weights = torch.softmax(self.fusion_weights, dim=0)
            else:
                weights = cfg.fusion_weights
            
            target_size = decoder_outputs[-1].shape[-2:]
            fused_features = None
            
            for i, layer_idx in enumerate(cfg.fusion_layers):
                features = decoder_outputs[layer_idx]
                
                # Channel adaptation
                adapter_key = str(layer_idx)
                if adapter_key in self.channel_adapters:
                    features = self.channel_adapters[adapter_key](features)
                
                # Spatial interpolation if needed
                if features.shape[-2:] != target_size:
                    features = F.interpolate(features, size=target_size, mode='bilinear', align_corners=False)
                
                weight = weights[i] if hasattr(self, 'fusion_weights') else weights[i]
                
                if fused_features is None:
                    fused_features = weight * features
                else:
                    fused_features += weight * features
            
            return fused_features
        else:
            # Single layer selection
            features = decoder_outputs[cfg.decoder_layer_index]
            
            # Channel adaptation
            adapter_key = str(cfg.decoder_layer_index)
            if adapter_key in self.channel_adapters:
                features = self.channel_adapters[adapter_key](features)
            
            # Spatial interpolation if needed
            target_size = decoder_outputs[-1].shape[-2:]
            if features.shape[-2:] != target_size:
                features = F.interpolate(features, size=target_size, mode='bilinear', align_corners=False)
            
            return features

    def proc_flip(self, x_in):
        x_in= torch.flip(x_in, dims=[-3, -1])
        x= self.backbone(x_in)
        x= x[::-1]
    
        # Decoder
        x= self.decoder(x)
        features_to_use = self.get_decoder_features(x)
        x_seg= self.seg_head(features_to_use)
        
        # NO CROPPING for 256√ó72 mode - perfect 70√ó70 output!
        if not self.use_smart_256x72_mode:
            x_seg= x_seg[..., 1:-1, 1:-1]  # Only crop for other modes
        
        x_seg= torch.flip(x_seg, dims=[-1])
        x_seg= x_seg * 1500 + 3000
        return x_seg
    
    def forward(self, batch):
        x= batch
        
        # Encoder
        x_in = x
        x= self.backbone(x)
        x= x[::-1]
    
        # Decoder
        x= self.decoder(x)
        features_to_use = self.get_decoder_features(x)
        x_seg= self.seg_head(features_to_use)
        
        if not self.use_smart_256x72_mode:
            x_seg= x_seg[..., 1:-1, 1:-1]  # Only crop for other modes
        
        x_seg= x_seg * 1500 + 3000
    
        if self.training:
            return x_seg
        else:
            from _cfg import cfg
            
            if not cfg.enhanced_tta:
                # Basic TTA (your original)
                p1 = self.proc_flip(x_in)
                x_seg = torch.mean(torch.stack([x_seg, p1]), dim=0)
                return x_seg
            else:
                # Enhanced TTA - collect multiple predictions
                predictions = [x_seg]  # Original prediction
                
                p1 = self.proc_flip(x_in)
                predictions.append(p1)
                
                # 2. Noise TTA variants
                for _ in range(2):
                    noise = torch.randn_like(x_in) * cfg.tta_noise_level * torch.std(x_in)
                    x_noise = x_in + noise
                    
                    x_aug = self.backbone(x_noise)
                    x_aug = x_aug[::-1]
                    x_aug = self.decoder(x_aug)
                    features_aug = self.get_decoder_features(x_aug)
                    x_seg_aug = self.seg_head(features_aug)
                    
                    if not self.use_smart_256x72_mode:
                        x_seg_aug = x_seg_aug[..., 1:-1, 1:-1]
                    
                    x_seg_aug = x_seg_aug * 1500 + 3000
                    predictions.append(x_seg_aug)
                
                # 3. Scale TTA variants  
                for scale in cfg.tta_scale_values:
                    x_scaled = x_in * scale
                    
                    x_aug = self.backbone(x_scaled)
                    x_aug = x_aug[::-1]
                    x_aug = self.decoder(x_aug)
                    features_aug = self.get_decoder_features(x_aug)
                    x_seg_aug = self.seg_head(features_aug)
                    
                    if not self.use_smart_256x72_mode:
                        x_seg_aug = x_seg_aug[..., 1:-1, 1:-1]
                    
                    x_seg_aug = x_seg_aug * 1500 + 3000
                    predictions.append(x_seg_aug)
                
                x_seg = torch.mean(torch.stack(predictions), dim=0)
                return x_seg

Overwriting _model.py


### Utils

Same as starter notebook.

In [28]:
%%writefile _utils.py

import datetime

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

Overwriting _utils.py


# Train

Same as starter notebook.

In [32]:
%%writefile _train.py

import os
import time 
import random
import numpy as np
from tqdm import tqdm
from contextlib import nullcontext

import torch
import torch.nn as nn
from torch.amp import autocast, GradScaler

# Conditional imports for distributed
try:
    import torch.distributed as dist
    from torch.utils.data import DistributedSampler
    from torch.nn.parallel import DistributedDataParallel
    DIST_AVAILABLE = True
except:
    DIST_AVAILABLE = False

from _cfg import cfg
from _dataset import CustomDataset
from _model import ModelEMA, Net
from _utils import format_time

from torch.optim.lr_scheduler import CosineAnnealingLR

try:
    import wandb
except ImportError:
    wandb = None

try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    wandb_key = user_secrets.get_secret("wandb_key")
    if wandb is not None and wandb_key:
        wandb.login(key=wandb_key)
except:
    pass

def set_seed(seed=1234):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

def setup(rank, world_size):
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    return

def cleanup():
    dist.barrier()
    dist.destroy_process_group()
    return



def setup_gpu_mode(cfg):
    """Auto-detect GPU configuration and set distributed mode"""
    import os
    
    # Check if running with torchrun (distributed)
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        cfg.distributed = True
        cfg.local_rank = int(os.environ["RANK"])
        cfg.world_size = int(os.environ["WORLD_SIZE"])
        print(f"[Distributed] Rank: {cfg.local_rank}, World size: {cfg.world_size}")
    else:
        cfg.distributed = False
        cfg.local_rank = 0
        cfg.world_size = 1
        available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
        if available_gpus == 0:
            print("[Single GPU] No GPUs available, using CPU")
            cfg.device = torch.device("cpu")
        else:
            print(f"[Single GPU] Using 1 GPU out of {available_gpus} available")
            cfg.device = torch.device("cuda:0")

def main_single_gpu(cfg):
    """Single GPU training"""
    print("="*25)
    print("Running on single GPU")
    print("="*25)
    
    # Print mode info
    mode_str = "72x72" if cfg.use_72x72_mode else "Full Resolution"
    print(f"Training Mode: {mode_str}")
    
    if cfg.use_wandb and wandb is not None:
        run_name = f"{cfg.run_name}_{cfg.seed}_single_gpu"
        if cfg.use_72x72_mode:
            run_name += "_72x72"
        wandb.init(
            project=cfg.wandb_project,
            config=vars(cfg),
            name=run_name
        )
    
    # Datasets
    print("Loading data..")
    train_ds = CustomDataset(cfg=cfg, mode="train")
    train_dl = torch.utils.data.DataLoader(
        train_ds, 
        batch_size=cfg.batch_size, 
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
    )
    
    valid_ds = CustomDataset(cfg=cfg, mode="valid")
    valid_dl = torch.utils.data.DataLoader(
        valid_ds, 
        batch_size=cfg.batch_size_val, 
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
    )
    
    # Model
    model = Net(backbone=cfg.backbone, use_72x72_mode=cfg.use_72x72_mode, use_smart_256x72_mode=cfg.use_smart_256x72_mode)

    model = model.to(cfg.device)
    
    # Disable compilation for 72x72 mode
    compile_enabled = cfg.compile_model
    if compile_enabled:
        model = torch.compile(model, mode='default')
    
    if cfg.ema:
        print("Initializing EMA model..")
        ema_model = ModelEMA(model, decay=cfg.ema_decay, device=cfg.device)
    else:
        ema_model = None

    
    
    # Training setup
    criterion = nn.L1Loss()
    
    compile_enabled = cfg.compile_model
    if compile_enabled:
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, fused=True)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    scheduler = CosineAnnealingLR(optimizer, T_max=cfg.epochs)

    # Define autocast context
    if cfg.mixed_precision:
        autocast_context = autocast(device_type='cuda', dtype=cfg.precision_dtype)
    else:
        autocast_context = nullcontext()
    
    if cfg.use_grad_scaler and cfg.mixed_precision:
        scaler = GradScaler()
    else:
        scaler = None
    
    print("="*25)
    print("Starting training...")
    print("="*25)
    
    best_loss = 1_000_000
    val_loss = 1_000_000
    
    for epoch in range(1, cfg.epochs+1):
        tstart = time.time()
        
        # Train
        model.train()
        total_loss = []
        
        for i, (x, y) in enumerate(train_dl):
            x = x.to(cfg.device)
            y = y.to(cfg.device)
            
            with autocast_context:
                logits = model(x)
            
            def smoothness_loss(pred_vel, alpha=0.1):
                grad_x = torch.abs(pred_vel[:, :, :, 1:] - pred_vel[:, :, :, :-1])
                grad_y = torch.abs(pred_vel[:, :, 1:, :] - pred_vel[:, :, :-1, :])
                return alpha * (grad_x.mean() + grad_y.mean())
            
            if cfg.smoothness_loss:
                mae_loss = criterion(logits, y)
                smooth_loss = smoothness_loss(logits, alpha=0.1)
                loss = mae_loss + smooth_loss
            else:
                loss = criterion(logits, y)
            
            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
                optimizer.step()
                optimizer.zero_grad()
            
            total_loss.append(loss.item())
            
            if ema_model is not None:
                ema_model.update(model)
            
            if len(total_loss) >= cfg.logging_steps or i == 0:
                train_loss = np.mean(total_loss)
                total_loss = []
                print(f"Epoch {epoch}: Train MAE: {train_loss:.2f} Val MAE: {val_loss:.2f} "
                      f"Time: {format_time(time.time() - tstart)} Step: {i+1}/{len(train_dl)}")
            
            if cfg.use_wandb and wandb is not None and i % 50 == 0:
                wandb.log({
                    "train_loss": loss.item(),
                    "epoch": epoch,
                    "step": epoch * len(train_dl) + i
                })
        
        # Validation
        model.eval()
        val_logits = []
        val_targets = []
        
        with torch.no_grad():
            for x, y in tqdm(valid_dl):
                x = x.to(cfg.device)
                y = y.to(cfg.device)
                
                with autocast_context:
                    if ema_model is not None:
                        out = ema_model.module(x)
                    else:
                        out = model(x)
                
                val_logits.append(out.cpu())
                val_targets.append(y.cpu())
            
            val_logits = torch.cat(val_logits, dim=0)
            val_targets = torch.cat(val_targets, dim=0)
            val_loss = criterion(val_logits, val_targets).item()

        
        
        if cfg.use_wandb and wandb is not None and cfg.local_rank == 0:
            wandb.log({
                "val_loss": val_loss,
                "epoch": epoch,
                "best_loss": best_loss,
                "learning_rate": scheduler.get_last_lr()[0],
            })
            
        scheduler.step() 
        # Save best model
        if val_loss < best_loss:
            print(f"New best: {best_loss:.2f} -> {val_loss:.2f}")
            print("Saved weights..")
            best_loss = val_loss
            
            if ema_model is not None:
                torch.save(ema_model.module.state_dict(), f'best_model_{cfg.seed}.pt')
            else:
                torch.save(model.state_dict(), f'best_model_{cfg.seed}.pt')
            
            cfg.early_stopping["streak"] = 0
        else:
            cfg.early_stopping["streak"] += 1
            if cfg.early_stopping["streak"] > cfg.early_stopping["patience"]:
                print("Ending training (early_stopping).")
                break
    
    if cfg.use_wandb and wandb is not None:
        wandb.finish()

def main_multi_gpu(cfg):
    """Multi GPU training"""
    
    # Print mode info (only rank 0)
    if cfg.local_rank == 0:
        mode_str = "72x72" if cfg.use_72x72_mode else "Full Resolution"
        print(f"Training Mode: {mode_str}")
    
    if cfg.use_wandb and wandb is not None and cfg.local_rank == 0:
        run_name = f"{cfg.run_name}_{cfg.seed}"
        if cfg.use_72x72_mode:
            run_name += "_72x72"
        wandb.init(
            project=cfg.wandb_project,
            config=vars(cfg),
            name=run_name
        )

    # ========== Datasets / Dataloaders ==========
    if cfg.local_rank == 0:
        print("="*25)
        print("Loading data..")
    train_ds = CustomDataset(cfg=cfg, mode="train")
    sampler= DistributedSampler(
        train_ds, 
        num_replicas=cfg.world_size, 
        rank=cfg.local_rank,
    )
    train_dl = torch.utils.data.DataLoader(
        train_ds, 
        sampler= sampler,
        batch_size= cfg.batch_size, 
        num_workers= 4,
    )
    
    valid_ds = CustomDataset(cfg=cfg, mode="valid")
    sampler= DistributedSampler(
        valid_ds, 
        num_replicas=cfg.world_size, 
        rank=cfg.local_rank,
    )
    valid_dl = torch.utils.data.DataLoader(
        valid_ds, 
        sampler= sampler,
        batch_size= cfg.batch_size_val, 
        num_workers= 4,
    )

    # ========== Model / Optim ==========
    model = Net(backbone=cfg.backbone, use_72x72_mode=cfg.use_72x72_mode, use_smart_256x72_mode=cfg.use_smart_256x72_mode)

    model= model.to(cfg.local_rank)
    
    # Disable compilation for 72x72 mode
    compile_enabled = cfg.compile_model
    if compile_enabled:
        model = torch.compile(model, mode='default')
    if cfg.ema:
        if cfg.local_rank == 0:
            print("Initializing EMA model..")
        ema_model = ModelEMA(
            model, 
            decay=cfg.ema_decay, 
            device=cfg.local_rank,
        )
    else:
        ema_model = None
    model= DistributedDataParallel(
        model, 
        device_ids=[cfg.local_rank],
        find_unused_parameters=True
        )
    
    criterion = nn.L1Loss()
    compile_enabled = cfg.compile_model
    if compile_enabled:
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, fused=True)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    scheduler = CosineAnnealingLR(optimizer, T_max=cfg.epochs)

    if cfg.mixed_precision and cfg.use_grad_scaler:
        scaler = GradScaler()
    else:
        scaler = None

    # ========== Training ==========
    if cfg.local_rank == 0:
        print("="*25)
        print("Give me warp {}, Mr. Sulu.".format(cfg.world_size))
        print("="*25)
    
    best_loss= 1_000_000
    val_loss= 1_000_000

    for epoch in range(0, cfg.epochs+1):
        if epoch != 0:
            tstart= time.time()
            train_dl.sampler.set_epoch(epoch)
    
            # Train loop
            model.train()
            total_loss = []
            for i, (x, y) in enumerate(train_dl):
                x = x.to(cfg.local_rank)
                y = y.to(cfg.local_rank)
        
                if cfg.mixed_precision:
                    autocast_context = autocast(device_type='cuda', dtype=cfg.precision_dtype)
                else:
                    autocast_context = nullcontext()
                
                with autocast_context:
                    logits = model(x)

                def smoothness_loss(pred_vel, alpha=0.1):
                        grad_x = torch.abs(pred_vel[:, :, :, 1:] - pred_vel[:, :, :, :-1])
                        grad_y = torch.abs(pred_vel[:, :, 1:, :] - pred_vel[:, :, :-1, :])
                        return alpha * (grad_x.mean() + grad_y.mean())

                if cfg.smoothness_loss:
                    # With:
                    mae_loss = criterion(logits, y)
                    smooth_loss = smoothness_loss(logits, alpha=0.1)
                    loss = mae_loss + smooth_loss
                else:
                    loss = criterion(logits, y)

        
                if scaler is not None:
                    scaler.scale(loss).backward()
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
                    optimizer.step()
                    optimizer.zero_grad()
    
                total_loss.append(loss.item())
                
                if ema_model is not None:
                    ema_model.update(model)
                    
                if cfg.local_rank == 0 and (len(total_loss) >= cfg.logging_steps or i == 0):
                    train_loss = np.mean(total_loss)
                    total_loss = []
                    print("Epoch {}:     Train MAE: {:.2f}     Val MAE: {:.2f}     Time: {}     Step: {}/{}".format(
                        epoch, 
                        train_loss,
                        val_loss,
                        format_time(time.time() - tstart),
                        i+1, 
                        len(train_dl)+1, 
                    ))

                if cfg.use_wandb and wandb is not None and cfg.local_rank == 0 and i % 50 == 0:
                    wandb.log({
                        "train_loss": loss.item(),
                        "epoch": epoch,
                        "step": epoch * len(train_dl) + i
                    })
    
        # ========== Valid ==========
        model.eval()
        val_logits = []
        val_targets = []
        
        if cfg.mixed_precision:
            autocast_context = autocast(device_type='cuda', dtype=cfg.precision_dtype)
        else:
            autocast_context = nullcontext()
            
        with torch.no_grad():
            for x, y in tqdm(valid_dl, disable=cfg.local_rank != 0):
                x = x.to(cfg.local_rank)
                y = y.to(cfg.local_rank)
    
                with autocast_context:
                    if ema_model is not None:
                        out = ema_model.module(x)
                    else:
                        out = model(x)

                val_logits.append(out.cpu())
                val_targets.append(y.cpu())

            val_logits= torch.cat(val_logits, dim=0)
            val_targets= torch.cat(val_targets, dim=0)
                
            loss = criterion(val_logits, val_targets).item()

        # Gather loss
        v = torch.tensor([loss], device=cfg.local_rank)
        torch.distributed.all_reduce(v, op=dist.ReduceOp.SUM)
        val_loss = (v[0] / cfg.world_size).item()
        
        if cfg.use_wandb and wandb is not None and cfg.local_rank == 0:
            wandb.log({
                "val_loss": val_loss,
                "epoch": epoch,
                "best_loss": best_loss,
                "learning_rate": scheduler.get_last_lr()[0],
            })

        scheduler.step()
    
        # ========== Weights / Early stopping ==========
        stop_train = torch.tensor([0], device=cfg.local_rank)
        if cfg.local_rank == 0:
            es= cfg.early_stopping
            if val_loss < best_loss:
                print("New best: {:.2f} -> {:.2f}".format(best_loss, val_loss))
                print("Saved weights..")
                best_loss = val_loss
                if ema_model is not None:
                    torch.save(ema_model.module.state_dict(), f'best_model_{cfg.seed}.pt')
                else:
                    torch.save(model.state_dict(), f'best_model_{cfg.seed}.pt')
        
                es["streak"] = 0
            else:
                es= cfg.early_stopping
                es["streak"] += 1
                if es["streak"] > es["patience"]:
                    print("Ending training (early_stopping).")
                    stop_train = torch.tensor([1], device=cfg.local_rank)
        
        # Exits training on all ranks
        dist.broadcast(stop_train, src=0)
        if stop_train.item() == 1:
            return

    if cfg.use_wandb and wandb is not None and cfg.local_rank == 0:
        wandb.finish()

    return

if __name__ == "__main__":
    
    set_seed(cfg.seed)
    
    # Auto-detect GPU configuration
    setup_gpu_mode(cfg)
    
    if cfg.distributed:
        # Multi-GPU setup
        setup(cfg.local_rank, cfg.world_size)
        
        _, total = torch.cuda.mem_get_info(device=cfg.local_rank)
        time.sleep(cfg.local_rank)
        print(f"Rank: {cfg.local_rank}, World size: {cfg.world_size}, GPU memory: {total / 1024**3:.2f}GB", flush=True)
        time.sleep(cfg.world_size - cfg.local_rank)
        
        set_seed(cfg.seed + cfg.local_rank)  # Different seed per GPU
        
        main_multi_gpu(cfg)
        cleanup()
    else:
        # Single GPU setup
        main_single_gpu(cfg)

Overwriting _train.py


In [None]:
if RUN_TRAIN:
    print("Starting training..")
    !OMP_NUM_THREADS=1 torchrun --nproc_per_node=2 _train.py
    # single‚ÄêGPU:
    #!python _train.py

Starting training..
2025-06-30 16:14:49.452285: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-30 16:14:49.452418: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751300089.479634    6203 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751300089.479634    6202 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751300089.487307    6202 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
E0000 00:00:1751300089.487453   

### Pretrained Models

Next, we load in 2x pretrained models. These models were trained with a batch size of 16 for 50 epochs.

In [31]:
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F

from _cfg import cfg
from _model import Net, EnsembleModel

if RUN_VALID or RUN_TEST:

    # Load pretrained models
    models = []
    for f in sorted(glob.glob("/kaggle/working/*.pt")):
        print("Loading: ", f)
        m = Net(
            backbone=cfg.backbone,
            pretrained=False,
            use_72x72_mode=cfg.use_72x72_mode, 
            use_smart_256x72_mode=cfg.use_smart_256x72_mode
        )

        state_dict= torch.load(f, map_location=cfg.device, weights_only=True)
        state_dict= {k.removeprefix("_orig_mod."):v for k,v in state_dict.items()} # Remove torch.compile() prefix

        m.load_state_dict(state_dict)
        if cfg.compile_model:
            m = torch.compile(m, mode='reduce-overhead')
        models.append(m)
    
    # Combine
    model = EnsembleModel(models)
    model = model.to(cfg.device)
    model = model.eval()
    print("n_models: {:_}".format(len(models)))

    for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f"NaN detected in {name}")
            break

Loading:  /kaggle/working/best_model_123.pt


UnboundLocalError: cannot access local variable 'decoder_channels' where it is not associated with a value

# Valid

Next, we score the ensemble on the validation set.

In [None]:
from tqdm import tqdm
import numpy as np

import torch
import torch.nn as nn
from torch.amp import autocast

from _dataset import CustomDataset
import copy


from _cfg import cfg


if RUN_VALID:
    cfg_full_val = copy.deepcopy(cfg)
    cfg_full_val.subsample_fraction = None  # Use full validation set
    
    valid_ds = CustomDataset(cfg=cfg_full_val, mode="valid")

    
    sampler = torch.utils.data.SequentialSampler(valid_ds)
    valid_dl = torch.utils.data.DataLoader(
        valid_ds, 
        sampler= sampler,
        batch_size= cfg.batch_size_val, 
        num_workers= 4,
    )

    # Valid loop
    criterion = nn.L1Loss()
    val_logits = []
    val_targets = []
    
    with torch.no_grad():
        for x, y in tqdm(valid_dl):
            x = x.to(cfg.device)
            y = y.to(cfg.device)

            if torch.isnan(x).any() or torch.isnan(y).any():
                print("NaN in input data")
                continue
    
            # To this:
            if cfg.mixed_precision:
                with autocast(device_type='cuda', dtype=cfg.precision_dtype):
                    out = model(x)
            else:
                out = model(x)
    
            val_logits.append(out.cpu())
            val_targets.append(y.cpu())
    
        val_logits= torch.cat(val_logits, dim=0)
        val_targets= torch.cat(val_targets, dim=0)
    
        total_loss= criterion(val_logits, val_targets).item()
    
    # Dataset Scores
    ds_idxs= np.array([valid_ds.records])
    ds_idxs= np.repeat(ds_idxs, repeats=500)
    
    print("="*25)
    with torch.no_grad():    
        for idx in sorted(np.unique(ds_idxs)):
    
            # Mask
            mask = ds_idxs == idx
            logits_ds = val_logits[mask]
            targets_ds = val_targets[mask]
    
            # Score predictions
            loss = criterion(val_logits[mask], val_targets[mask]).item()
            print("{:15} {:.2f}".format(idx, loss))
    print("="*25)
    print("Val MAE: {:.2f}".format(total_loss))
    print("="*25)

# Test

Finally, we make predictions on the test data.

In [None]:
import torch

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, test_files):
        self.test_files = test_files

    def __len__(self):
        return len(self.test_files)

    def __getitem__(self, i):
        test_file = self.test_files[i]
        test_stem = test_file.split("/")[-1].split(".")[0]
        x = np.load(test_file)
        
        # Add the same dtype conversion as CustomDataset
        if not cfg.mixed_precision:
            x = x.astype(np.float32)
        
        return x, test_stem

In [None]:
import csv
import time
import glob
from tqdm import tqdm
import numpy as np
import pandas as pd

from _utils import format_time


if RUN_TEST:

    def _preprocess(x):
        x = F.interpolate(x, size=(70, 70), mode='area')
        x = F.pad(x, (1,1,1,1), mode='replicate')
        return x
    

    ss= pd.read_csv("/kaggle/input/waveform-inversion/sample_submission.csv")    
    row_count = 0
    t0 = time.time()
    
    test_files = sorted(glob.glob("/kaggle/input/open-wfi-test/test/*.npy"))
    x_cols = [f"x_{i}" for i in range(1, 70, 2)]
    fieldnames = ["oid_ypos"] + x_cols
    
    test_ds = TestDataset(test_files)
    test_dl = torch.utils.data.DataLoader(
        test_ds, 
        sampler=torch.utils.data.SequentialSampler(test_ds),
        batch_size=cfg.batch_size_val, 
        num_workers=4,
    )
    
    with open("submission.csv", "wt", newline="") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        with torch.inference_mode():
            with torch.autocast(cfg.device.type):
                for inputs, oids_test in tqdm(test_dl, total=len(test_dl)):
                    if not cfg.mixed_precision:
                        inputs = inputs.to(cfg.device, dtype=torch.float32)  # Explicit dtype conversion
                    
                    if cfg.use_72x72_mode:
                        inputs = _preprocess(inputs)  # Added 72x72 support
                    outputs = model(inputs)
            
                            
                    y_preds = outputs[:, 0].cpu().numpy()
                    
                    for y_pred, oid_test in zip(y_preds, oids_test):
                        for y_pos in range(70):
                            row = dict(zip(x_cols, [y_pred[y_pos, x_pos] for x_pos in range(1, 70, 2)]))
                            row["oid_ypos"] = f"{oid_test}_y_{y_pos}"
            
                            writer.writerow(row)
                            row_count += 1

                            # Clear buffer
                            if row_count % 100_000 == 0:
                                csvfile.flush()
    
    t1 = format_time(time.time() - t0)
    print(f"Inference Time: {t1}")

We can also view a few samples to make sure things look reasonable.