In [98]:
# First clone the repository
!git clone https://github.com/BloodAxe/Kaggle-Salt.git /kaggle/working/Kaggle-Salt

import sys
sys.path.append("/kaggle/working/Kaggle-Salt")

# standard libraries
import os
from typing import Optional
from collections import OrderedDict
from datetime import datetime
import json
import random
from lib.dataset import medium_augmentations, drop_some, normalize_image
import cv2
from albumentations.pytorch import ToTensorV2

# data processing libraries
import numpy as np
import pandas as pd
from skimage.morphology import remove_small_objects, remove_small_holes
from sklearn.utils import check_random_state
from sklearn.model_selection import train_test_split

# PyTorch and related
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.backends import cudnn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid

# Albumentations for augmentations
import albumentations as A
import albumentations.augmentations.functional as AF

# progress bar
from tqdm import tqdm

#custom modules from the cloned repo
try:
    from models.modules.abn import ABN, ACT_RELU
    from models.resnext import try_index
    import lib.augmentations as AA
    from lib import torch_augmentation_functional as TAF
    from lib import train_utils as U
    from lib.common import find_in_dir, is_sorted
    from nnn import ssim_cv
except ImportError as e:
    print(f"Error importing custom modules: {e}")
    print("Please ensure the repository is properly cloned and paths are set")

#augmentation pipeline
base_aug = medium_augmentations()

#Define a standalone z-score normalize function
def zscore_image(image, **kwargs):
    return normalize_image(image)

#rain_transform
train_transform = A.Compose([
    # per-image z-score normalization
    A.Lambda(image=zscore_image),
    # medium-strength augmentations
    *base_aug.transforms,
    # final normalize + to-tensor
    A.Normalize(mean=0.5, std=0.224, max_pixel_value=255.0),
    ToTensorV2()
])
print("Contents of input directory:")
print(os.listdir('/kaggle/input/tgs-salt-identification-challenge'))

fatal: destination path '/kaggle/working/Kaggle-Salt' already exists and is not an empty directory.
Error importing custom modules: cannot import name 'torch_augmentation_functional' from 'lib' (unknown location)
Please ensure the repository is properly cloned and paths are set
Contents of input directory:
['depths.csv', 'sample_submission.csv', 'train.zip', 'competition_data.zip', 'test.zip', 'train.csv', 'flamingo.zip']


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
torch.cuda.device_count()

In [None]:
DATA_ROOT = 'data'
N_FOLDS = 5
ORIGINAL_SIZE = 101


def all_train_ids() -> np.ndarray:
    """
    Return all train ids
    :return: Numpy array of ids
    """
    return np.array(sorted([id_from_fname(fname) for fname in find_in_dir(os.path.join(DATA_ROOT, 'train', 'images'))]))


def all_test_ids() -> np.ndarray:
    """
    Return all test ids
    :return: Numpy array of ids
    """
    return np.array(sorted([id_from_fname(fname) for fname in find_in_dir(os.path.join(DATA_ROOT, 'test', 'images'))]))


def read_train_image(sample_id) -> np.ndarray:
    return cv2.imread(os.path.join(DATA_ROOT, 'train', 'images', '%s.png' % sample_id), cv2.IMREAD_GRAYSCALE)


def read_test_image(sample_id) -> np.ndarray:
    return cv2.imread(os.path.join(DATA_ROOT, 'test', 'images', '%s.png' % sample_id), cv2.IMREAD_GRAYSCALE)


def read_train_mask(sample_id) -> np.ndarray:
    mask = cv2.imread(os.path.join(DATA_ROOT, 'train', 'masks', '%s.png' % sample_id), cv2.IMREAD_GRAYSCALE)
    mask = (mask > 0).astype(np.uint8)
    return mask


def read_train_images(ids) -> np.ndarray:
    """
    Reads train images. Returns numpy array of shape [N;H;W], where N is number of images, H - height, W - width.
    Images read as np.uint8 type with range in [0.255]
    :param ids: List of image ids.
    :return: Numpy array
    """
    if not is_sorted(ids):
        raise ValueError('Array ids must be sorted')

    images = [read_train_image(sample_id) for sample_id in ids]
    images = np.array(images, dtype=np.uint8)
    return images


def read_test_images(ids) -> np.ndarray:
    """
    Reads test images. Returns numpy array of shape [N;H;W], where N is number of images, H - height, W - width
    Images read as np.uint8 type with range in [0.255]
    :param ids: List of image ids.
    :return: Numpy array
    """
    if not is_sorted(ids):
        raise ValueError('Array ids must be sorted')

    images = [read_test_image(sample_id) for sample_id in ids]
    images = np.array(images, dtype=np.uint8)
    return images


def read_train_masks(ids) -> np.ndarray:
    """
    Reads train masks. Returns numpy array of shape [N;H;W], where N is number of images, H - height, W - width
    :param ids: List of image ids.
    :return: Numpy array with values {0,1}
    """
    if not is_sorted(ids):
        raise ValueError('Array ids must be sorted')

    images = [read_train_mask(sample_id) for sample_id in ids]
    images = np.array(images, dtype=np.uint8)
    return images


def read_depths(ids):
    if not is_sorted(ids):
        raise ValueError('Array ids must be sorted')

    df = pd.read_csv(os.path.join(DATA_ROOT, 'depths.csv'))
    df['z'] = df['z'].astype(np.float32)
    df['z'] = df['z'] / df['z'].max()

    depths = []
    for sample_id in ids:
        z = df[df['id'] == sample_id].iloc[0]['z']
        depths.append(z)
    return np.array(depths)


def get_selection_mask(ids: np.ndarray, query: np.ndarray):
    if not is_sorted(ids):
        raise ValueError('Array ids must be sorted')

    if not is_sorted(query):
        raise ValueError('Array subset must be sorted')

    if not np.in1d(query, ids, assume_unique=True).all():
        raise ValueError("Some elements of subset are not in ids")

    # mask2 = []
    # for sample_id in subset:
    #     index = np.argwhere(ids == sample_id)[0, 0]
    #     mask2.append(index)
    # mask2 = np.array(mask2)
    # return mask2

    mask = np.array([sample_id in query for sample_id in ids])
    return mask


def drop_some(images, masks, drop_black=True, drop_vstrips=False, drop_empty=False, drop_few=None) -> np.ndarray:
    skips = []

    dropped_blacks = 0
    dropped_vstrips = 0
    dropped_few = 0
    dropped_empty = 0

    for image, mask in zip(images, masks):
        should_keep = True

        if drop_black and is_black(image, mask):
            should_keep = False
            dropped_blacks += 1

        if drop_vstrips and is_vertical_strips(image, mask):
            should_keep = False
            dropped_vstrips += 1

        if drop_few and is_salt_less_than(image, mask, int(drop_few)) and not is_salt_less_than(image, mask, 1):
            should_keep = False
            dropped_few += 1

        if drop_empty and is_salt_less_than(image, mask, 1):
            should_keep = False
            dropped_empty += 1

        skips.append(should_keep)

    print(f'Dropped {dropped_blacks} black images; {dropped_vstrips} vertical strips; {dropped_empty}  empty masks; {dropped_few} few-pixel salt')
    return np.array(skips)


def cumsum(img, axis=0) -> np.ndarray:
    """
    https://www.kaggle.com/bguberfain/unet-with-depth#360485
    For what I know about seismic imaging, the cumsum on the depth axis will (remotely) approximate an inversion operation* (that is, convert from interface transition to interface properly).
    :param axis:
    :param img: Single-channel image
    :return:
    """
    x_mean = img.mean()
    x_csum = (np.float32(img) - x_mean).cumsum(axis=axis)
    x_csum -= x_csum.mean()
    x_csum /= max(1e-3, x_csum.std())
    return x_csum


def id_from_fname(fname):
    return os.path.splitext(os.path.basename(fname))[0]


def harder_augmentations(target_size, border_mode=cv2.BORDER_CONSTANT):
    border_mode = U.get_border_mode(border_mode)
    aug = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomBrightness(p=0.5),
        A.RandomGamma(gamma_limit=(80, 120), p=0.5),
        A.OneOf([A.IAAAdditiveGaussianNoise(), A.GaussNoise()], p=0.3),
        A.OneOf([A.MotionBlur(p=.1), A.MedianBlur(blur_limit=3, p=0.1), A.Blur(blur_limit=3, p=0.1)], p=0.2),
        A.OneOf([
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=10, p=0.5, border_mode=border_mode),
            A.RandomSizedCrop((int(target_size * 0.8), target_size), target_size, target_size, p=0.5),
        ], p=0.5),
        A.OneOf([
            AA.AxisShear(sx=0.1, sy=0.1, border_mode=border_mode, p=0.5),
            A.ElasticTransform(alpha=0.5, sigma=0.5, alpha_affine=10, border_mode=border_mode),
            A.GridDistortion(border_mode=border_mode),
            A.IAAPerspective(p=0.3),
        ], p=0.3),
        A.Cutout()
    ])
    return aug


def hard_augmentations(target_size, border_mode=cv2.BORDER_CONSTANT):
    border_mode = U.get_border_mode(border_mode)
    aug = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.OneOf([A.RandomBrightness(p=0.5), A.RandomGamma(gamma_limit=(80, 120), p=0.5)]),
        A.OneOf([A.IAAAdditiveGaussianNoise(), A.GaussNoise()], p=0.3),
        A.OneOf([A.MotionBlur(p=.1), A.MedianBlur(blur_limit=3, p=0.1), A.Blur(blur_limit=3, p=0.1)], p=0.2),
        A.ShiftScaleRotate(shift_limit=0.075, scale_limit=0.075, rotate_limit=10, p=0.5, border_mode=border_mode),
        A.OneOf([
            A.ElasticTransform(alpha=0.5, sigma=0.5, alpha_affine=10, border_mode=border_mode, p=0.1),
            A.GridDistortion(border_mode=border_mode, p=0.1),
        ], p=0.3),
        A.Cutout()
    ])

    return aug


def medium_augmentations(target_size, border_mode=cv2.BORDER_CONSTANT):
    border_mode = U.get_border_mode(border_mode)
    aug = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomBrightness(p=0.25),
        A.RandomGamma(gamma_limit=(80, 120), p=0.25),
        AA.RandomContrastGray(p=0.25),
        A.ShiftScaleRotate(shift_limit=0.10, scale_limit=0.10, rotate_limit=5, p=0.5, border_mode=border_mode),
        A.OneOf([
            A.ElasticTransform(alpha=0.5, sigma=0.5, alpha_affine=10, border_mode=border_mode, p=0.1),
            A.GridDistortion(border_mode=border_mode, p=0.1),
            A.NoOp(p=0.5)
        ]),
        A.OneOf([A.IAAAdditiveGaussianNoise(), A.GaussNoise()], p=0.3),
    ])
    return aug


def light_augmentations(target_size, border_mode=cv2.BORDER_CONSTANT):
    border_mode = U.get_border_mode(border_mode)
    aug = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.10, scale_limit=0.10, rotate_limit=0, p=0.5, border_mode=border_mode),
        A.OneOf([
            A.ElasticTransform(alpha=0.5, sigma=0.5, alpha_affine=10, border_mode=border_mode, p=0.1),
            A.GridDistortion(border_mode=border_mode, p=0.1),
            A.NoOp(p=0.5)
        ]),
    ])
    return aug


def flip_augmentations(target_size, border_mode=cv2.BORDER_CONSTANT):
    border_mode = U.get_border_mode(border_mode)
    aug = A.HorizontalFlip(p=0.5)
    return aug


def none_augmentations(target_size, border_mode=cv2.BORDER_CONSTANT):
    return A.NoOp(p=1)


class DatasetResizePad:
    def __init__(self, resize_size, target_size, border_mode=cv2.BORDER_CONSTANT, border_fill=0, interpolation=cv2.INTER_LINEAR, **kwargs):
        self.resize_size = resize_size
        self.target_size = target_size
        self.border_mode = border_mode
        self.border_fill = border_fill
        self.interpolation = interpolation

        self.t_forward = A.Compose([
            A.Resize(resize_size, resize_size, interpolation=interpolation),
            A.PadIfNeeded(min_height=target_size, min_width=target_size, border_mode=border_mode)])

    def forward(self, **kwargs):
        return self.t_forward(**kwargs)

    def backward(self, x):
        if isinstance(x, torch.Tensor):
            x = TAF.central_crop(x, self.resize_size, self.resize_size)
        elif isinstance(x, np.ndarray):
            x = AF.center_crop(x, self.resize_size, self.resize_size).ascontiguousarray()
        return x

    def __repr__(self):
        return f'DatasetResizePad(resize_size={self.resize_size}, target_size={self.target_size}, border_mode={self.border_mode}, border_fill={self.border_fill}, interpolation={self.interpolation})'


def get_prepare_fn(name, border_mode=cv2.BORDER_DEFAULT, **kwargs):
    border_mode = U.get_border_mode(border_mode)
    if name is None or name == 'None':
        return None
    if name == '128':
        return DatasetResizePad(resize_size=128, target_size=128, border_mode=border_mode, **kwargs)
    if name == '224':
        return DatasetResizePad(resize_size=224, target_size=224, border_mode=border_mode, **kwargs)
    if name == '256':
        return DatasetResizePad(resize_size=256, target_size=256, border_mode=border_mode, **kwargs)
    if name == '128pad':
        return DatasetResizePad(resize_size=ORIGINAL_SIZE, target_size=128, border_mode=border_mode, **kwargs)
    if name == '224pad':
        return DatasetResizePad(resize_size=ORIGINAL_SIZE * 2, target_size=224, border_mode=border_mode, **kwargs)
    if name == '256pad':
        return DatasetResizePad(resize_size=ORIGINAL_SIZE * 2, target_size=256, border_mode=border_mode, **kwargs)

    raise ValueError('Unsupported prepare fn')


class ImageAndMaskDataset(Dataset):
    """
    Creates a dataset object.
    :param images - List of images
    :param masks - List of masks
    :param depths - List of depths

    """

    def __init__(self, ids: np.ndarray, images: np.ndarray, masks: Optional[np.ndarray], depths: np.ndarray,
                 prepare_fn: DatasetResizePad = None,
                 normalize=A.Normalize(mean=0.5, std=0.224, max_pixel_value=255.0),
                 augment=None):

        if not is_sorted(ids):
            raise ValueError('Array ids must be sorted')

        self.ids = ids
        self.images = images
        self.masks = masks
        self.depths = depths
        self.augment = augment
        self.normalize = normalize
        self.num_channels = 1
        self.resize_fn = prepare_fn

        if prepare_fn is not None:
            self.images = np.array([self.resize_fn.forward(image=x)['image'] for x in self.images])
            if self.masks is not None:
                self.masks = np.array([self.resize_fn.forward(image=x, mask=x)['mask'] for x in self.masks])

    def __getitem__(self, index):

        data = {'image': self.images[index].copy()}
        if self.masks is not None:
            data['mask'] = self.masks[index].copy()

        if self.augment is not None:
            data = self.augment(**data)

        data = self.normalize(**data)

        image = np.expand_dims(data['image'], 0)
        image = torch.from_numpy(image).float()

        sample = {
            'index': index,
            'id': self.ids[index],
            'image': image,
            'depth': self.depths[index],
        }

        mask = data.get('mask', None)
        if mask is not None:
            if not np.isin(mask, [0, 1]).all():
                raise RuntimeError(f'A mask after augmentation contains values other than {{0;1}}: {np.unique(mask)}')

            # BCE problem, so float target
            mask_class = np.array((mask > 0).any(), dtype=np.float32)
            mask_class = np.expand_dims(mask_class, 0)

            mask = np.expand_dims(mask, 0)
            mask = torch.from_numpy(mask).float()

            sample['mask'] = mask
            sample['class'] = mask_class

        return sample

    def channels(self):
        return self.num_channels

    def __len__(self):
        return len(self.images)


def get_folds_vector(kind, images, masks, depths, n_folds=N_FOLDS, random_state=None):
    n = len(depths)
    folds = np.array(list(range(n_folds)) * n)[:n]

    rnd = check_random_state(random_state)

    if kind == 'coverage' or kind == 'area':
        coverage = np.array([cv2.countNonZero(x) for x in masks], dtype=np.int)
        sorted_indexes = np.argsort(coverage)
    elif kind == 'depth':
        sorted_indexes = np.argsort(depths)
    elif kind == 'resolution':
        resolution = np.array([cv2.Laplacian(image, cv2.CV_32F, borderType=cv2.BORDER_REFLECT101).std() for image in images])
        sorted_indexes = np.argsort(resolution)
    else:
        sorted_indexes = list(range(n))
        rnd.shuffle(sorted_indexes)

    return folds[sorted_indexes]


def get_train_test_split_for_fold(stratify, fold, ids):
    folds = pd.read_csv(os.path.join('data', f'folds_by_{stratify}.csv'))
    folds = np.array([folds[folds['id'] == id].iloc[0]['fold'] for id in ids])
    return folds != fold, folds == fold


def fix_mask(mask):
    """
    Tries to 'fix' a mask by filling gaps and removing single-pixel noise
    :param mask:
    :return:
    """
    mask = mask.astype(np.bool)
    mask = remove_small_holes(mask, area_threshold=12, connectivity=1)
    mask = remove_small_objects(mask, min_size=12, connectivity=1)

    # kernel = np.ones((5, 5), dtype=np.uint8)
    # mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, borderType=cv2.BORDER_REFLECT101)
    # mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, borderType=cv2.BORDER_REFLECT101)
    return mask.astype(np.uint8)


def fix_masks(masks, train_ids):
    changed_masks = []
    changed_ids = []

    for image_id, mask in zip(train_ids, masks):
        new_mask = fix_mask(mask)
        changed_masks.append(new_mask)

        if not np.array_equal(new_mask, mask):
            changed_ids.append(image_id)

    masks = np.array(changed_masks)
    return masks, changed_ids


def is_black(image, mask):
    return image.sum() == 0


def is_vertical_strips(image, mask):
    colsum = np.sum(mask, axis=0)
    uniq = np.unique(colsum)
    return len(uniq) == 2 and uniq.min() == 0 and uniq.max() == mask.shape[0]


def is_salt_less_than(image, mask, threshold):
    return mask.sum() < threshold


def is_salt_greater_than(image, mask, threshold):
    return mask.sum() > threshold


AUGMENTATION_MODES = {
    'harder': harder_augmentations,
    'hard': hard_augmentations,
    'medium': medium_augmentations,
    'light': light_augmentations,
    'flip': flip_augmentations,
    'none': none_augmentations,
}

In [None]:

@clipped
def random_contrast_gray(img, alpha):
    gray = ((1.0 - alpha) / img.size) * np.sum(img)
    return alpha * img + gray


class RandomContrastGray(A.ImageOnlyTransform):
    """Randomly change contrast of the input image.

    Args:
        limit ((float, float) or float): factor range for changing contrast. If limit is a single float, the range
            will be (-limit, limit). Default: 0.2.
        p (float): probability of applying the transform. Default: 0.5.

    Targets:
        image

    Image types:
        uint8, float32
    """

    def __init__(self, limit=.2, p=.5):
        super(RandomContrastGray, self).__init__(p)
        self.limit = A.to_tuple(limit)

    def apply(self, img, alpha=0.2, **params):
        return random_contrast_gray(img, alpha)

    def get_params(self):
        return {'alpha': 1.0 + random.uniform(self.limit[0], self.limit[1])}


class AxisShear(A.DualTransform):
    def __init__(self, sx=0.1, sy=0.1, p=0.5, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT):
        super(AxisShear, self).__init__(p)
        self.sx = sx
        self.sy = sy
        self.interpolation = interpolation
        self.border_mode = border_mode

    def get_params(self):
        return {"cx": random.uniform(0, 1),
                "cy": random.uniform(0, 1),
                "sx": random.uniform(- self.sx, self.sx),
                "sy": random.uniform(- self.sx, self.sx)}

    def apply(self, img, cx=0.5, cy=0.5, sx=0, sy=0, interpolation=cv2.INTER_LINEAR, **params):
        center = np.eye(3, 3)
        center[0, 2] = cx * img.shape[1]
        center[1, 2] = cy * img.shape[0]

        inv_center = np.eye(3, 3)
        inv_center[0, 2] = -center[0, 2]
        inv_center[1, 2] = -center[1, 2]

        shear = np.eye(3, 3)
        shear[0, 1] = sx
        shear[1, 0] = sy

        m = np.matmul(np.matmul(center, shear), inv_center)
        return cv2.warpAffine(img, m[:2, ...], dsize=(img.shape[1], img.shape[0]), flags=interpolation, borderMode=self.border_mode)


class AxisScale(A.DualTransform):
    def __init__(self, sx=0.1, sy=0.1, p=0.5, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT):
        super(AxisScale, self).__init__(p)
        self.sx = sx
        self.sy = sy
        self.interpolation = interpolation
        self.border_mode = border_mode

    def get_params(self):
        return {"cx": random.uniform(0, 1),
                "cy": random.uniform(0, 1),
                "sx": random.uniform(1 - self.sx, 1 + self.sx),
                "sy": random.uniform(1 - self.sx, 1 + self.sx)}

    def apply(self, img, cx=0.5, cy=0.5, sx=1, sy=1, interpolation=cv2.INTER_LINEAR, **params):
        center = np.eye(3, 3)
        center[0, 2] = cx * img.shape[1]
        center[1, 2] = cy * img.shape[0]

        inv_center = np.eye(3, 3)
        inv_center[0, 2] = -center[0, 2]
        inv_center[1, 2] = -center[1, 2]

        scale = np.eye(3, 3)
        scale[0, 0] = sx
        scale[1, 1] = sy

        m = np.matmul(np.matmul(center, scale), inv_center)
        return cv2.warpAffine(img, m[:2, ...], dsize=(img.shape[1], img.shape[0]), flags=interpolation, borderMode=self.border_mode)

In [None]:


def test_ssim_normalization():
    assert ssim_cv(np.zeros((101, 101), dtype=np.uint8),
                   np.zeros((101, 101), dtype=np.uint8)) == 1.0

    assert ssim_cv(np.ones((101, 101), dtype=np.uint8) * 255,
                   np.ones((101, 101), dtype=np.uint8) * 255) == 1.0

    assert ssim_cv(np.zeros((101, 101), dtype=np.uint8),
                   np.ones((101, 101), dtype=np.uint8) * 255) < 0.0001

    assert ssim_cv(np.ones((101, 101), dtype=np.uint8) * 127,
                   np.ones((101, 101), dtype=np.uint8) * 255) > 0.5

    one_black = np.ones((101, 101), dtype=np.uint8) * 255
    one_black[1, 1] = 0

    assert ssim_cv(one_black,
                   np.ones((101, 101), dtype=np.uint8) * 255) > 0.99

In [82]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import device

# Clone and set up repository
repo_path = '/kaggle/working/Kaggle-Salt'
if not os.path.exists(repo_path):
    !git clone https://github.com/BloodAxe/Kaggle-Salt.git {repo_path}

# Add to Python path
sys.path.insert(0, '/kaggle/working')
sys.path.insert(0, repo_path)

# First try importing the compiled version
try:
    from models.modules.abn import ABN
    from models.modules.functions import inplace_abn, inplace_abn_sync
    print("Successfully imported compiled ABN modules")
except ImportError as e:
    print(f"Import failed: {e}. Using fallback implementation...")
    
    # Fallback implementation
    class ABN(nn.Module):
        def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                    activation="leaky_relu", slope=0.01):
            super().__init__()
            self.bn = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum, affine=affine)
            self.activation = activation
            self.slope = slope
            self.eps = eps
            self.momentum = momentum
            
        def forward(self, x):
            x = self.bn(x)
            if self.activation == "leaky_relu":
                return F.leaky_relu(x, negative_slope=self.slope)
            return x
    
    def inplace_abn(x, weight, bias, running_mean, running_var, training, momentum, eps, activation, slope):
        return ABN(x.size(1), eps, momentum, True, activation, slope)(x)
    
    inplace_abn_sync = inplace_abn

# Implement InPlaceABN with proper attribute inheritance
class InPlaceABN(ABN):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 activation="leaky_relu", slope=0.01):
        super().__init__(num_features, eps, momentum, affine, activation, slope)
        
    def forward(self, x):
        if hasattr(self, 'bn'):  # Fallback case
            return super().forward(x)
        else:  # Compiled case
            return inplace_abn(
                x, 
                self.bn.weight, 
                self.bn.bias, 
                self.bn.running_mean, 
                self.bn.running_var,
                self.training,
                self.bn.momentum,
                self.bn.eps,
                self.activation,
                self.slope
            )

class InPlaceABNSync(ABN):
    def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1,
                 affine=True, activation="leaky_relu", slope=0.01):
        super().__init__(num_features, eps, momentum, affine, activation, slope)
        self.devices = devices or [0]
        
    def forward(self, x):
        if hasattr(self, 'bn'):  # Fallback case
            return super().forward(x)
        else:  # Compiled case
            return inplace_abn_sync(
                x,
                self.bn.weight,
                self.bn.bias,
                self.bn.running_mean,
                self.bn.running_var,
                {},
                self.training,
                self.bn.momentum,
                self.bn.eps,
                self.activation,
                self.slope
            )

# Test
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Test both implementations
    for cls in [InPlaceABN, InPlaceABNSync]:
        print(f"\nTesting {cls.__name__}:")
        layer = cls(64).to(device)
        x = torch.randn(16, 64, 32, 32).to(device)
        out = layer(x)
        print("Output shape:", out.shape)
        print("First values:", out[0, 0, 0, :5])

Import failed: /root/.cache/torch_extensions/py311_cu124/inplace_abn/inplace_abn.so: cannot open shared object file: No such file or directory. Using fallback implementation...
Using device: cpu

Testing InPlaceABN:
Output shape: torch.Size([16, 64, 32, 32])
First values: tensor([-0.0021, -0.0063, -0.0087,  0.1946, -0.0033], grad_fn=<SliceBackward0>)

Testing InPlaceABNSync:
Output shape: torch.Size([16, 64, 32, 32])
First values: tensor([ 1.0436, -0.0130,  1.5882,  0.4891, -0.0073], grad_fn=<SliceBackward0>)


In [83]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.models import resnet34
from torchvision.models.resnet import ResNet34_Weights

def get_my_encoder(pretrained=True):
    weights = ResNet34_Weights.IMAGENET1K_V1 if pretrained else None
    model = resnet34(weights=weights)
    # Remove last fully connected layer and avgpool
    return nn.Sequential(*list(model.children())[:-2])

# Define decoder
def get_my_decoder(mid_channels=256):
    return nn.Sequential(
        nn.Conv2d(mid_channels, 128, kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        nn.Upsample(scale_factor=2),
        nn.Conv2d(128, 64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.Conv2d(64, 1, kernel_size=1)
    )

class SaltUNetWithAux(nn.Module):
    def __init__(self, encoder, decoder, mid_channels):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
        # Auxiliary classifier
        self.aux_classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(mid_channels, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        features = self.encoder(x)
        mask_pred = self.decoder(features)
        aux_pred = self.aux_classifier(features)
        return mask_pred, aux_pred

def lovasz(pred, target):
    # Placeholder - use actual Lovasz implementation if needed
    return F.binary_cross_entropy_with_logits(pred, target)

encoder = get_my_encoder(pretrained=True)
decoder = get_my_decoder()
model = SaltUNetWithAux(encoder, decoder, mid_channels=256)  # Removed .cuda()

# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Example training loop
def train_epoch(model, trainloader):
    model.train()
    for imgs, masks in trainloader:
        # Removed .cuda() calls
        imgs, masks = imgs.float(), masks.float()
        
        # Forward pass
        mask_pred, aux_pred = model(imgs)
        
        # Losses
        loss_mask = lovasz(mask_pred.squeeze(1), masks)
        salt_present = (masks.view(masks.size(0), -1).sum(dim=1) > 0).float()
        loss_aux = F.binary_cross_entropy(aux_pred.squeeze(1), salt_present)
        
        # Combined loss
        total_loss = loss_mask + 0.5 * loss_aux
        
        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        print(f"Seg Loss: {loss_mask.item():.4f}, Aux Loss: {loss_aux.item():.4f}")

#

In [84]:


class _ConvBatchNormReLU(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            relu=True,
    ):
        super(_ConvBatchNormReLU, self).__init__()
        self.add_module(
            "conv",
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation,
                bias=False,
            ),
        )
        self.add_module(
            "bn",
            nn.BatchNorm2d(
                num_features=out_channels, eps=1e-5, momentum=0.999, affine=True
            ),
        )

        if relu:
            self.add_module("relu", nn.ReLU())

    def forward(self, x):
        return super(_ConvBatchNormReLU, self).forward(x)


class _ASPPModule(nn.Module):
    """Atrous Spatial Pyramid Pooling with image pool"""

    def __init__(self, in_channels, out_channels, pyramids):
        super(_ASPPModule, self).__init__()
        self.stages = nn.Module()
        self.stages.add_module(
            "c0", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1)
        )
        for i, (dilation, padding) in enumerate(zip(pyramids, pyramids)):
            self.stages.add_module(
                "c{}".format(i + 1),
                _ConvBatchNormReLU(in_channels, out_channels, 3, 1, padding, dilation),
            )
        self.imagepool = nn.Sequential(
            OrderedDict(
                [
                    ("pool", nn.AdaptiveAvgPool2d(1)),
                    ("conv", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1)),
                ]
            )
        )

    def forward(self, x):
        h = self.imagepool(x)
        h = [F.interpolate(h, size=x.shape[2:], mode="bilinear")]
        for stage in self.stages.children():
            h += [stage(x)]
        h = torch.cat(h, dim=1)
        return h


class ASPP(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 hidden_channels=256,
                 dilations=(12, 24, 36),
                 abn_block=ABN,
                 activation=ACT_RELU,
                 pooling_size=None):
        super(ASPP, self).__init__()
        self.pooling_size = pooling_size

        self.map_convs = nn.ModuleList([
            nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
            nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
            nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
            nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
        ])
        self.map_bn = abn_block(hidden_channels * 4, activation=activation)

        self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
        self.global_pooling_bn = abn_block(hidden_channels, activation=activation)

        self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
        self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
        self.red_bn = abn_block(out_channels, activation=activation)

        self.reset_parameters(self.map_bn.activation, self.map_bn.slope)

    def reset_parameters(self, activation, slope):
        gain = nn.init.calculate_gain(activation, slope)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data, gain)
                if hasattr(m, "bias") and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, ABN):
                if hasattr(m, "weight") and m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if hasattr(m, "bias") and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        # Map convolutions
        out = torch.cat([m(x) for m in self.map_convs], dim=1)
        out = self.map_bn(out)
        out = self.red_conv(out)

        # Global pooling
        pool = self._global_pooling(x)
        pool = self.global_pooling_conv(pool)
        pool = self.global_pooling_bn(pool)  # Removed the stray | character
        pool = self.pool_red_conv(pool)
        if self.training or self.pooling_size is None:
            pool = pool.repeat(1, 1, x.size(2), x.size(3))

        out += pool
        out = self.red_bn(out)
        return out

    def _global_pooling(self, x):
        if self.training or self.pooling_size is None:
            pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
            pool = pool.view(x.size(0), x.size(1), 1, 1)
        else:
            pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
                          min(try_index(self.pooling_size, 1), x.shape[3]))
            padding = (
                (pooling_size[1] - 1) // 2,
                (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
                (pooling_size[0] - 1) // 2,
                (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
            )

            pool = functional.avg_pool2d(x, pooling_size, stride=1)
            pool = functional.pad(pool, pad=padding, mode="replicate")
        return pool

In [None]:
!pip install tensorboardX

In [None]:
try:
    from tensorboardX import SummaryWriter
except ImportError:
    import sys
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "tensorboardX"])
    from tensorboardX import SummaryWriter


try:
    from lib import dataset as D
    from lib import train_utils as U
    from lib.common import count_parameters, is_sorted, compute_mask_class, to_numpy
    from lib.metrics import JaccardIndex, AverageMeter, PixelAccuracy, threshold_mining, do_kaggle_metric
    from lib.train_utils import logit_to_prob
    from test import generate_model_submission
except ImportError as e:
    print(f"Error importing custom modules: {e}")
    print("Please ensure your custom modules are in the Python path")

tqdm.monitor_interval = 0  # Workaround for https://github.com/tqdm/tqdm/issues/481

def main():
    # Rest of your code remains the same...
    pass

if __name__ == '__main__':
    cudnn.benchmark = True
    main()