# Environment Setup and Package Installation

In [8]:
!pip install -q segmentation_models_pytorch python-dotenv

## Set up Torch Hub Checkpoints for SE-Net

In [None]:
# Remove the existing checkpoints directory and create a new directory for torch hub checkpoints
!rm -r /root/.cache/torch/hub/checkpoints/
!mkdir -p /root/.cache/torch/hub/checkpoints/
# Copy pre-trained SE-Net weights to checkpoints directory
!cp /kaggle/input/se-net-pretrained-imagenet-weights/* /root/.cache/torch/hub/checkpoints/

rm: cannot remove '/root/.cache/torch/hub/checkpoints/': No such file or directory


# Import Libraries

In [9]:
import torch as tc
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import os,sys,cv2
from torch.cuda.amp import autocast # autocast module for automatic mixed-precision training
import matplotlib.pyplot as plt
import albumentations as A
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DataParallel # DataParallel for parallel training
from glob import glob
import random

# Config for Model training

In [10]:
class CFG:
    # Set random seed for reproducibility
    seed = 42
    
    # Model-related configurations
    target_size = 1
    model_name = 'Unet'
    backbone = 'se_resnext50_32x4d'
    in_chans = 1
    
    tile_size = 512  # Size of the tiles used for input image cropping
    stride = tile_size // 2  # Stride for overlapping tiles during image processing

    # Training-related configurations
    image_size = 512
    input_size = 512
    train_batch_size = 24
    valid_batch_size = train_batch_size * 2
    epochs = 20
    lr = 8e-5
    chopping_percentile = 1e-3
    rotate_p = 0.5
    
    # Augmentation configurations
    train_aug_list = [
        A.Rotate(limit=270, p=rotate_p),
        A.RandomScale(scale_limit=(0.8, 1.25), interpolation=cv2.INTER_CUBIC, p=0.2),
        A.RandomCrop(input_size, input_size, p=1),
        A.Flip(p=0.5),
        A.RandomGamma(p=rotate_p * 2 / 3),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.2),
        A.GaussianBlur(p=0.1),
        A.MotionBlur(p=0.05),
        A.GridDistortion(num_steps=5, distort_limit=0.5, p=0.1),
        A.ElasticTransform(p=0.05, border_mode=cv2.BORDER_REFLECT_101, alpha_affine=4, sigma=6.0, alpha=120),
        ToTensorV2(transpose_mask=True),
    ]
    train_aug = A.Compose(train_aug_list)
    
    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]
    valid_aug = A.Compose(valid_aug_list)
    

def set_seed(seed):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    tc.manual_seed(seed)
    tc.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    tc.backends.cudnn.deterministic = True
    tc.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)


set_seed(CFG.seed)
print(f"Seed set for reproducibility: {CFG.seed}")

Seed set for reproducibility: 42


# Model

In [11]:
class SegmentationModel(nn.Module):
    def __init__(self, CFG, weight=None):
        super().__init__()
        
        # Create the segmentation model using Unet architecture
        self.model = smp.Unet(
            encoder_name=CFG.backbone,
            encoder_weights=weight,
            in_channels=CFG.in_chans,
            classes=CFG.target_size,
            activation=None,
        )

    def forward(self, image):
        """
        Forward pass of the model without applying activation.
        """
        output = self.model(image)
        # output = output.squeeze(-1) 
        return output[:, 0]  # .sigmoid()


def build_model(weight="imagenet"):
    """
    Build and return the SegmentationModel.
    """
    from dotenv import load_dotenv
    load_dotenv()

    print('model_name', CFG.model_name)
    print('backbone', CFG.backbone)

    model = SegmentationModel(CFG, weight)

    return model.cuda()

# Additional Functions

In [12]:
def min_max_normalization(x:tc.Tensor)->tc.Tensor:
    """input.shape=(batch,f1,...)"""
    shape=x.shape
    if x.ndim>2:
        x=x.reshape(x.shape[0],-1)

    min_=x.min(dim=-1,keepdim=True)[0]
    max_=x.max(dim=-1,keepdim=True)[0]
    if min_.mean()==0 and max_.mean()==1:
        return x.reshape(shape)

    x=(x-min_)/(max_-min_+1e-9)
    return x.reshape(shape)

def norm_with_clip(x:tc.Tensor,smooth=1e-5):
    dim=list(range(1,x.ndim))
    mean=x.mean(dim=dim,keepdim=True)
    std=x.std(dim=dim,keepdim=True)
    x=(x-mean)/(std+smooth)
    x[x>5]=(x[x>5]-5)*1e-3 +5
    x[x<-3]=(x[x<-3]+3)*1e-3-3
    return x

def add_noise(x:tc.Tensor,max_randn_rate=0.1,randn_rate=None,x_already_normed=False):
    """input.shape=(batch,f1,f2,...) output's var will be normalizate  """
    ndim=x.ndim-1
    if x_already_normed:
        x_std=tc.ones([x.shape[0]]+[1]*ndim,device=x.device,dtype=x.dtype)
        x_mean=tc.zeros([x.shape[0]]+[1]*ndim,device=x.device,dtype=x.dtype)
    else:
        dim=list(range(1,x.ndim))
        x_std=x.std(dim=dim,keepdim=True)
        x_mean=x.mean(dim=dim,keepdim=True)
    if randn_rate is None:
        randn_rate=max_randn_rate*np.random.rand()*tc.rand(x_mean.shape,device=x.device,dtype=x.dtype)
    cache=(x_std**2+(x_std*randn_rate)**2)**0.5
    #https://blog.csdn.net/chaosir1991/article/details/106960408

    return (x-x_mean+tc.randn(size=x.shape,device=x.device,dtype=x.dtype)*randn_rate*x_std)/(cache+1e-7)

# Custom Dataset

### Why Padding?
**Handling Non-divisible Dimensions:**

- Padding is crucial for images with dimensions not divisible by tile_size. It ensures consistency in processing and prevents issues when dividing images into tiles.

- Applying padding ensures that each tile, generated during image processing, has a consistent size without any remainder. This is essential for uniformity in subsequent operations.


*Example*:
```
pad0 would be calculated as (16 - 30 % 16) % 16, resulting in 2.
pad1 would be calculated as (16 - 25 % 16) % 16, resulting in 7.
```
**Explanation**
- `pad0` and `pad1` are calculated using double modulo.
- The inner modulo checks if padding is needed (non-zero remainder).
- The outer modulo limits the padding value to be less than self.tile_size, ensuring controlled padding.
- Resulting in padded dimensions that are multiples of self.tile_size, maintaining processing stability.

In [13]:
class CustomDataset(Dataset):
    def __init__(self, paths, is_label, do_sort=True):
        """
        Initialize the CustomDataset.

        Args:
            paths (list): List of file paths.
            is_label (bool): Flag indicating if the dataset contains labels.
            do_sort (bool, optional): Flag to sort file paths. Defaults to True.
        """
        self.paths = paths
        self.tile_size = CFG.tile_size
        
        # Sort paths if required
        if do_sort:
            self.paths.sort()
        self.is_label = is_label

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        img = cv2.imread(self.paths[index], cv2.IMREAD_GRAYSCALE)
        img = tc.from_numpy(img)

        # Compute padding for both dimensions
        pad0 = (self.tile_size - img.shape[0] % self.tile_size) % self.tile_size
        pad1 = (self.tile_size - img.shape[1] % self.tile_size) % self.tile_size

        # Apply replication padding
        img = np.pad(img, [(0, pad0), (0, pad1)], mode='edge')
        img = tc.from_numpy(img)

        if self.is_label:
            img = (img != 0).to(tc.uint8) * 255
        else:
            img = img.to(tc.uint8)

        return img

### Load data with clipping

- This function loads data using a `CustomDataset`, then performs thresholding and normalization on the data if it is not a label.
- Clip extreme values in the array, both the sides

In [14]:
def load_data(paths, is_label=False, do_sort=True):
    data_loader = CustomDataset(paths, is_label, do_sort)
    data_loader = DataLoader(data_loader, batch_size=16, num_workers=0)
    x = tc.cat([batch for batch in tqdm(data_loader)], dim=0)
    if not is_label:
        TH = x.reshape(-1).numpy()
        index = -int(len(TH) * CFG.chopping_percentile)
        TH: int = np.partition(TH, index)[index]
        x[x > TH] = int(TH)

        TH = x.reshape(-1).numpy()
        index = -int(len(TH) * CFG.chopping_percentile)
        TH: int = np.partition(TH, -index)[-index]
        x[x < TH] = int(TH)

        x = (min_max_normalization(x.to(tc.float16)[None])[0]*255).to(tc.uint8)
    return x

# Loss Functions

### Dice Loss PyTorch

In [15]:
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        # comment out if your model contains a sigmoid or equivalent activation layer
        inputs = inputs.sigmoid()

        # flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.*intersection + smooth) / \
            (inputs.sum() + targets.sum() + smooth)

        return 1 - dice

### Composite loss functions

**The function `select_loss_function(epoch)` dynamically adapts the loss function based on the current training epoch.**
- During early epochs (<= 15), Dice loss is prioritized for balanced training.
- In intermediate epochs (15 < epoch <= 17), the model transitions to Balanced BCE-Lovasz Loss for accuracy.
- Beyond epoch 17, Weighted BCE-Lovasz-Tversky Loss is used for reducing False positives.

In [16]:
bce_loss = nn.BCEWithLogitsLoss()
lovasz_loss = smp.losses.LovaszLoss(mode='binary', per_image=False)
tversky_loss = smp.losses.TverskyLoss(mode='binary', log_loss=False, from_logits=True)
dice_loss = DiceLoss()


def balanced_bce_lovasz_loss(output, target):
    return 0.5 * bce_loss(output, target) + 0.5 * lovasz_loss(output, target)

def weighted_bce_lovasz_tversky_loss(output, target):
    return 0.25 * bce_loss(output, target) + 0.25 * lovasz_loss(output, target) + 0.5 * tversky_loss(output, target)

# Define a function to get the appropriate loss based on the epoch
def select_loss_function(epoch):
    if epoch <= 15:
        return dice_loss
    elif 15 < epoch <= 17:
        return balanced_bce_lovasz_loss
    else:
        return weighted_bce_lovasz_tversky_loss

# Scores

In [17]:
def dice_coef(y_pred: tc.Tensor, y_true: tc.Tensor, thr=0.5, dim=(-1, -2), epsilon=0.001):
    y_pred = y_pred.sigmoid()
    y_true = y_true.to(tc.float32)
    y_pred = (y_pred > thr).to(tc.float32)
    
    intersection = (y_true * y_pred).sum(dim=dim)
    denominator = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    
    dice = ((2 * intersection + epsilon) / (denominator + epsilon)).mean()
    
    return dice

# Function to Load Training Dataset

**What is Tiling?**
- Tiling involves breaking down a large image into smaller, non-overlapping or overlapping, tiles.
- The stride determines the step size for moving through the image to extract each subsequent tile.

**Why i used Tiling?**
- **Edge Consistency:** Overlapping tiles maintain contextual information at edges, ensuring accurate segmentation by preventing the model from overlooking crucial details near tile borders.

- **Adaptation to Varied Resolution:** Tiling allows segmentation models to handle images with diverse resolutions. This adaptability ensures effective segmentation across different scales, optimizing performance in varied scenarios.

In [18]:
def load_train_dataset(x: list, y: list):
    image_size = CFG.image_size
    in_chans = CFG.in_chans

    train_x = []
    train_y = []

    for i in range(len(x)):
        print('dataset ', i)
        x_data = x[i]
        y_data = y[i]

        for index in range(x_data.shape[0] - in_chans + 1):
            x_slice = x_data[index:index+in_chans, :, :]
            y_slice = y_data[index:index+in_chans, :, :]

            x1_list = list(range(0, x_slice.shape[2] - CFG.tile_size + 1, CFG.stride))
            y1_list = list(range(0, x_slice.shape[1] - CFG.tile_size + 1, CFG.stride))

            for y1 in y1_list:
                for x1 in x1_list:
                    y2 = y1 + CFG.tile_size
                    x2 = x1 + CFG.tile_size
                    tile_x = x_slice[:, y1:y2, x1:x2]
                    tile_y = y_slice[:, y1:y2, x1:x2]

                    train_x.append(tile_x)
                    train_y.append(tile_y)

    return train_x, train_y

# Dataset Class with Data Augmentation

### Train Dataset Class with Data Augmentation and Tiled crop

In [19]:
class Custom_Train_Dataset(Dataset):
    def __init__(self, x: list, y: list, arg=False):
        super(Kaggld_Dataset, self).__init__()
        self.x = x  
        self.y = y 
        self.in_chans = CFG.in_chans
        self.arg = arg
        if arg:
            self.transform = CFG.train_aug
        else:
            self.transform = CFG.valid_aug

    def __len__(self) -> int:
        return len(self.x)

    def __getitem__(self, index):
        x = self.x[index]
        y = self.y[index]

        # Transform
        data = self.transform(image=x.numpy().transpose(1, 2, 0), mask=y.numpy().squeeze(0))

        x = data['image']
        y = data['mask'] >= 127

        if self.arg:
            i=np.random.randint(4)

            x=x.rot90(i,dims=(1,2))
            y=y.rot90(i,dims=(0,1))

            for i in range(2):
                if np.random.randint(2):
                    x=x.flip(dims=(i,))
                    if i>=1:
                        y=y.flip(dims=(i-1,))

        return x, y  # (uint8, uint8)

### Validation Dataset Class with Center Crop

In [20]:
class Custom_Val_Dataset(Dataset):
    def __init__(self,x:list,y:list):
        super(Dataset,self).__init__()
        self.x=x#list[(C,H,W),...]
        self.y=y#list[(C,H,W),...]
        self.image_size=CFG.image_size
        self.in_chans=CFG.in_chans
        self.transform=CFG.valid_aug

    def __len__(self) -> int:
        return sum([y.shape[0]-self.in_chans for y in self.y])

    def __getitem__(self,index):
        i=0
        for x in self.x:
            if index>x.shape[0]-self.in_chans:
                index-=x.shape[0]-self.in_chans
                i+=1
            else:
                break
        x=self.x[i]
        y=self.y[i]

        x_index= (x.shape[1]-self.image_size)//2
        y_index= (x.shape[2]-self.image_size)//2 
        x=x[index:index+self.in_chans, x_index:x_index+self.image_size, y_index:y_index+self.image_size]
        y=y[index+self.in_chans//2, x_index:x_index+self.image_size, y_index:y_index+self.image_size]
        data = self.transform(image=x.numpy().transpose(1,2,0), mask=y.numpy())

        x = data['image']
        y = data['mask']>=127

        return x,y#(uint8,uint8)

# Data Preparation

*Uncomment the kidney 2 loading section if using kidney 2 for training*

In [None]:
# Load kidney 1 dense data
train_x = []
train_y = []
root_path = "/kaggle/input/blood-vessel-segmentation/"
paths = [root_path + "/train/kidney_1_dense"]

for i, path in enumerate(paths):
    if path == root_path + "/train/kidney_3_dense":
        continue
    print('Loading kidney 1 images')
    x = load_data(glob(f"{path}/images/*.tif"), is_label=False)
    print('Loading kidney 1 labels')
    y = load_data(glob(f"{path}/labels/*.tif"), is_label=True)
    train_x.append(x)
    train_y.append(y)

    # Augmentation
    train_x.append(x.permute(1, 2, 0))
    train_y.append(y.permute(1, 2, 0))
    train_x.append(x.permute(2, 0, 1))
    train_y.append(y.permute(2, 0, 1))

# Load kidney 3 data
kidney_3s_label = sorted(glob(root_path + "/train/kidney_3_sparse/labels/*"))
kidney_3d_label = sorted(glob(root_path + "/train/kidney_3_dense/labels/*"))
kidney_3d_img = sorted(glob(root_path + "/train/kidney_3_sparse/images/*"))

# Loading kidney 3 images with specific exclusion from sparse
kidney_3_label = kidney_3s_label[0: 496] + kidney_3d_label + kidney_3s_label[997:]
kidney_3_img = kidney_3d_img

print('Loading kidney 3 images')
kid3_x = load_data(kidney_3_img, is_label=False, do_sort=False)
print('Loading kidney 3 labels')
kid3_y = load_data(kidney_3_label, is_label=True, do_sort=False)
train_x.append(kid3_x)
train_y.append(kid3_y)

train_x.append(kid3_x.permute(1, 2, 0))
train_y.append(kid3_y.permute(1, 2, 0))
train_x.append(kid3_x.permute(2, 0, 1))
train_y.append(kid3_y.permute(2, 0, 1))

# Uncomment the following section if using kidney 2 for training
# kid2_img = sorted(glob(root_path + "/train/kidney_2/images/*"))[900:]
# kid2_label = sorted(glob(root_path + "/train/kidney_2/labels/*"))[900:]
# print('Loading kidney 2 images')
# kid2_x = load_data(kid2_img, is_label=False, do_sort=False)
# print('Loading kidney 2 labels')
# kid2_y = load_data(kid2_label, is_label=True, do_sort=False)
# train_x.append(kid2_x)
# train_y.append(kid2_y)
# train_x.append(kid2_x.permute(1, 2, 0))
# train_y.append(kid2_y.permute(1, 2, 0))
# train_x.append(kid2_x.permute(2, 0, 1))
# train_y.append(kid2_y.permute(2, 0, 1))

# Load validation data from kidney 2
val_img = sorted(glob(root_path + "/train/kidney_2/images/*"))[900:]
val_label = sorted(glob(root_path + "/train/kidney_2/labels/*"))[900:]

print('Loading Validation data X')
val_x = load_data(val_img, is_label=False, do_sort=False)
print('Loading Validation data y')
val_y = load_data(val_label, is_label=True, do_sort=False)

Loading kidney 1 images


100%|██████████| 143/143 [02:01<00:00,  1.18it/s]


Loading kidney 1 labels


100%|██████████| 143/143 [01:08<00:00,  2.08it/s]


Loading kidney 3 images


 80%|████████  | 52/65 [01:11<00:17,  1.34s/it]

**Load padded images and mask**

In [None]:
train_image, train_mask = load_train_dataset(train_x, train_y)

dataset  0

dataset  1

dataset  2

dataset  3

dataset  4

dataset  5


## Scheduler

reference: https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py

In [None]:
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
import torch

class GradualWarmupScheduler(_LRScheduler):
    """ Gradually warm-up(increasing) learning rate in optimizer.
    Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
        total_epoch: target learning rate is reached at total_epoch, gradually
        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
    """

    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        self.multiplier = multiplier
        if self.multiplier < 1.:
            raise ValueError('multiplier should be greater thant or equal to 1.')
        self.total_epoch = total_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        super(GradualWarmupScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

    def step_ReduceLROnPlateau(self, metrics, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch if epoch != 0 else 1  # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
        if self.last_epoch <= self.total_epoch:
            warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
            for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
                param_group['lr'] = lr
        else:
            if epoch is None:
                self.after_scheduler.step(metrics, None)
            else:
                self.after_scheduler.step(metrics, epoch - self.total_epoch)

    def step(self, epoch=None, metrics=None):
        if type(self.after_scheduler) != ReduceLROnPlateau:
            if self.finished and self.after_scheduler:
                if epoch is None:
                    self.after_scheduler.step(None)
                else:
                    self.after_scheduler.step(epoch - self.total_epoch)
                self._last_lr = self.after_scheduler.get_last_lr()
            else:
                return super(GradualWarmupScheduler, self).step(epoch)
        else:
            self.step_ReduceLROnPlateau(metrics, epoch)

# Setup Training

In [None]:
tc.backends.cudnn.enabled = True
tc.backends.cudnn.benchmark = True

train_dataset = Kaggld_Dataset(train_image, train_mask, arg=True)
train_dataset = DataLoader(train_dataset, batch_size=CFG.train_batch_size, num_workers=2, shuffle=True, pin_memory=True)
# tc.save(train_dataset, '/content/drive/MyDrive/001_Projects/008_SenNet/SenNet/dataloader/DataLoader_train_k1k3_512.pth')

val_dataset = Val_Dataset([val_x], [val_y])
val_dataset = DataLoader(val_dataset, batch_size=CFG.valid_batch_size, num_workers=2, shuffle=False, pin_memory=True)
# tc.save(val_dataset, '/content/drive/MyDrive/001_Projects/008_SenNet/SenNet/dataloader/DataLoader_val_k2_512.pth')

model = build_model()
# loss_fc = DiceLoss()
# loss_fc = WeightedBCEandDiceLoss(alpha=1, beta=3)
optimizer = tc.optim.AdamW(model.parameters(), lr=CFG.lr)
scaler = tc.cuda.amp.GradScaler()


model_name Unet

backbone se_resnext50_32x4d


# Training

In [None]:
import gc
tc.cuda.empty_cache()
gc.collect()
with tc.no_grad():
    tc.cuda.empty_cache()

In [None]:
def calculate_metrics(predict, truth):
    p = (predict > 0.5)
    t = (truth > 0.5)
    hit = (p * t).sum()
    fp = (p * (1 - t)).sum()
    t_sum = t.sum()
    p_sum = p.sum()
    return hit, fp, t_sum, p_sum

In [None]:
device = tc.device("cuda" if tc.cuda.is_available() else "cpu")
model.cuda()

best_val_score = float('-inf')
best_val_hit_rate = float('-inf')
best_val_false_positives = float('inf')

# Training loop1
for epoch in range(CFG.epochs):
    # Training phase
    model.train()
    loss_fc=get_loss(epoch)
    train_time = tqdm(train_dataset, desc=f"Epoch {epoch}")
    train_loss = 0
    train_scores = 0
    train_total_hit = 0
    train_total_fp = 0
    train_total_t_sum = 0
    train_total_p_sum = 0

    for i, (x, y) in enumerate(train_time):
        x = x.to(device, dtype=tc.float32)
        y = y.to(device, dtype=tc.float32)
        x = norm_with_clip(x.reshape(-1, *x.shape[2:])).reshape(x.shape)
        x = add_noise(x, max_randn_rate=0.5, x_already_normed=True)

        with autocast():
            pred = model(x)
            # pred = pred.squeeze(1)
            loss = loss_fc(pred, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        scheduler.step()
        score = dice_coef(pred.detach(), y)
        train_loss = (train_loss * i + loss.item()) / (i + 1)
        train_scores = (train_scores * i + score) / (i + 1)

        # Calculate training metrics
        train_hit, train_fp, train_t_sum, train_p_sum = calculate_metrics((pred.detach() > 0.5).cpu().numpy(), y.cpu().numpy())
        train_total_hit += train_hit
        train_total_fp += train_fp
        train_total_t_sum += train_t_sum
        train_total_p_sum += train_p_sum

        train_time.set_description(f"epoch:{epoch}, train_loss:{train_loss:.4f}, train_score:{train_scores:.4f}, lr{optimizer.param_groups[0]['lr']:.4e}, train_hit_rate:{train_total_hit / train_total_t_sum:.4f}, train_false_positives:{train_total_fp / train_total_p_sum:.4f}")
        train_time.update()
        del loss, pred

    train_time.close()

    # Validation phase
    model.eval()
    val_time = tqdm(val_dataset, desc=f"Epoch {epoch}")
    val_loss = 0
    val_scores = 0
    val_total_hit = 0
    val_total_fp = 0
    val_total_t_sum = 0
    val_total_p_sum = 0

    for i, (x, y) in enumerate(val_time):
        x = x.to(device, dtype=tc.float32)
        y = y.to(device, dtype=tc.float32)
        x = norm_with_clip(x.reshape(-1, *x.shape[2:])).reshape(x.shape)

        with autocast():
            with tc.no_grad():
                pred = model(x)
                # pred = pred.squeeze(1)
                loss = loss_fc(pred, y)

        score = dice_coef(pred.detach(), y)
        val_loss = (val_loss * i + loss.item()) / (i + 1)
        val_scores = (val_scores * i + score) / (i + 1)

        # Calculate validation metrics
        val_hit, val_fp, val_t_sum, val_p_sum = calculate_metrics((pred.detach() > 0.5).cpu().numpy(), y.cpu().numpy())
        val_total_hit += val_hit
        val_total_fp += val_fp
        val_total_t_sum += val_t_sum
        val_total_p_sum += val_p_sum

        val_time.set_description(f"epoch:{epoch}, val_loss:{val_loss:.4f}, val_score:{val_scores:.4f}, val_hit_rate:{val_total_hit / val_total_t_sum:.4f}, val_false_positives:{val_total_fp / val_total_p_sum:.4f}")
        val_time.update()

    val_time.close()

    # Update best model based on validation score
    if val_scores > best_val_score:
        best_val_score = val_scores
        tc.save(model.state_dict(), f"/SenNet/models/{CFG.backbone}_tile_best_val_score_model.pt")

    # Update best model based on validation false positives
    val_false_positives_rate = val_total_fp / val_total_p_sum
    if val_false_positives_rate < best_val_false_positives:
        best_val_false_positives = val_false_positives_rate
        tc.save(model.state_dict(), f"/SenNet/SenNet/models/{CFG.backbone}_tile_best_false_positives_model.pt")

epoch:0, train_loss:0.2891, train_score:0.8124, lr4.8210e-05, train_hit_rate:0.8000, train_false_positives:0.3553: 100%|██████████| 3598/3598 [40:17<00:00,  1.49it/s]

epoch:0, val_loss:0.0730, val_score:0.9410, val_hit_rate:0.9651, val_false_positives:0.0091: 100%|██████████| 21/21 [00:21<00:00,  1.02s/it]

epoch:1, train_loss:0.2089, train_score:0.8621, lr3.6697e-04, train_hit_rate:0.8688, train_false_positives:0.0662: 100%|██████████| 3598/3598 [39:52<00:00,  1.50it/s]

epoch:1, val_loss:0.0663, val_score:0.9452, val_hit_rate:0.9698, val_false_positives:0.0095: 100%|██████████| 21/21 [00:12<00:00,  1.67it/s]

epoch:2, train_loss:0.1893, train_score:0.8725, lr4.3303e-04, train_hit_rate:0.8831, train_false_positives:0.0629:  49%|████▊     | 1747/3598 [19:22<20:31,  1.50it/s]


KeyboardInterrupt: 

In [None]:
tc.save(model.state_dict(), f"SenNet/SenNet/models/{CFG.backbone}_tile_last_epoch.pt")

In [None]:
model.state_dict()