!pip install albumentations
!pip install segmentation-models-pytorch

In [1]:
# Standard library
import os
import gc
import time
import math
import random
import logging
import warnings
#logging.basicConfig(level=logging.ERROR)

# Scientific computing
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt

# Deep learning (PyTorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.parametrizations import weight_norm
from sklearn.metrics import pairwise_distances

# Computer vision
import cv2
import torchvision
from PIL import Image
from torchvision import transforms

# Miscellaneous utilities
from tqdm import tqdm
import timm
import einops
from einops import rearrange
from sklearn.model_selection import KFold
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
import kagglehub

  data = fetch_version_info()


In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#DEVICE = torch.device('cpu')
NUM_EPOCHS = 5
BATCH_SIZE = 5
LEARNING_RATE = 0.001
IMAGE_SIZE = 224
NUM_CLASSES = 3
SEED = 42
N_FOLDS = 5

In [3]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [4]:
path = kagglehub.dataset_download("khanhpt1999/data-1")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/data-1


img = Image.open("/kaggle/input/test-images/image_2025-05-05_142003010.png").convert("RGB")
print("PIL image size:", img.size, "mode:", img.mode)

transform = transforms.Compose([
    transforms.Resize((224, 224)),                 
    transforms.ToTensor(),                
    transforms.Normalize( 
        mean=[0.5,0.5,0.5],
        std=[0.5,0.5,0.5]
    ),
])

inpt: torch.Tensor = transform(img)
print("Tensor shape:", inpt.shape, "dtype:", inpt.dtype, "range:", inpt.min(),"→",inpt.max())

inpt = inpt.unsqueeze(0)
print(f"Shape of Input: {inpt.shape}")

In [5]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=8, dropout=0.1, img_size=64):  # Added img_size parameter
        super().__init__()
        assert embed_dim % num_heads == 0
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.img_size = img_size  # Store image size
        
        # Linear projections for Q, K, V
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        # FIXED: Relative position embeddings for actual image size
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * img_size - 1) * (2 * img_size - 1), num_heads)
        )
        
        # FIXED: Initialize relative position bias for actual image size
        coords_h = torch.arange(img_size)
        coords_w = torch.arange(img_size)
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += img_size - 1
        relative_coords[:, :, 1] += img_size - 1
        relative_coords[:, :, 0] *= 2 * img_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)
        
    def forward(self, x):
        B, C, H, W = x.shape
        N = H * W
        
        # Reshape to sequence format [B, N, C]
        x = x.reshape(B, C, N).transpose(1, 2)
        
        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # Add relative position bias (only if image size matches)
        if H == self.img_size and W == self.img_size:
            relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].reshape(
                H * W, H * W, -1)
            relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
            attn = attn + relative_position_bias.unsqueeze(0)
        
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dim)
        x = self.proj(x)
        
        # Reshape back to spatial format [B, C, H, W]
        x = x.transpose(1, 2).view(B, C, H, W)
        
        return x

class SelfAttentionConv(nn.Module):
    """Self-attention layer that can replace convolutions"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, num_heads=8):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.num_heads = num_heads
        
        # If channels change, use 1x1 conv for projection
        if in_channels != out_channels:
            self.channel_proj = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
        else:
            self.channel_proj = nn.Identity()
            
        # Self-attention components
        self.attention = MultiHeadSelfAttention(out_channels, num_heads)
        
        # Learnable scale parameter for gradual integration
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        # Project channels if needed
        identity = self.channel_proj(x)
        
        # Apply self-attention
        attended = self.attention(identity)
        
        # Gradual integration with learnable scale
        out = self.gamma * attended + identity
        
        # Handle stride if needed
        if self.stride > 1:
            out = F.avg_pool2d(out, kernel_size=self.stride, stride=self.stride)
            
        return out

In [6]:
class ElectricForceModel(nn.Module):

    def __init__(
        self,
        imsize = 224,
        v0 = 0,
        layers = 3,
        loops = 3,
        debug = False,
        beta = 0.8,
        use_self_attention = False
    ):
        super().__init__()

        self.register_buffer('v0', torch.tensor([v0], dtype=torch.float32))
        self.register_buffer('beta', torch.tensor([beta], dtype=torch.float32))

        self.adaptive_epsilon = nn.Parameter(torch.tensor([1.0]))
        self.adaptive_mass = nn.Parameter(torch.tensor([1.0]))
        self.adaptive_time = nn.Parameter(torch.tensor([0.1]))
        
        with torch.no_grad():
            self.adaptive_epsilon.data.fill_(1.0)
            self.adaptive_mass.data.fill_(1.0)
            self.adaptive_time.data.fill_(0.1)

        self.imsize = imsize
        self.layers = layers
        self.loops = loops
        self.debug = debug
        self.use_self_attention = use_self_attention

        if self.use_self_attention:
            self.input_attention = MultiHeadSelfAttention(3, num_heads=3, img_size=imsize)
            self.attention_scale = nn.Parameter(torch.zeros(1))
        
        if self.use_self_attention:
            self.mapping_to_vector_space = SelfAttentionConv(3, 2, num_heads=2)
            self.mapping_to_ipt = SelfAttentionConv(2, 3, num_heads=3)
        else:
            self.mapping_to_vector_space = nn.Conv2d(3, 2, 1, 1, 0)
            self.mapping_to_ipt = nn.Conv2d(2, 3, 1, 1, 0)
            
        self.space_norm = nn.Tanh()
        self.conv1 = nn.Conv2d(3, 3, 1, 1, 0)
        self.conv2 = nn.Conv2d(3, 3, 1, 1, 0)
         

    def compute_distance(self, ipt, chunk_size=512):
        
        """
        Compute the distance of each point with respect to every point in zeta

        Input:
            - ipt: Input Image | [B, 3, self.imsize, self.imsize]

        Outpt:
            - zeta: The coordinates map for each pixel in image (x, y) | [B, 2, self.imsize, self.imsize]
            - x: x-coordinates of every point in zeta | [B, N]
            - y: y-coordinates of every point in zeta | [B, N]
            - inverse_sq_dist: 1/r^2, the inverse squared distance of each point with respect to every point in zeta | [B, N, N]
            - inverse_dist: 1/r, the inverse distance of each point with respect to every point in zeta | [B, N, N]
            
        How the Function Works:
            - Find zeta using a point-wise-convolutional-layer, the number of channel output will be the number of dimensions in vector space
                            (2 -> (x, y) == Vector Space
                             3 -> (x, y, z) == 3D Space
                             ...
                             ...
                            )
            - Find x and y by slicing zeta along the 0th and 1st of the 1st Dimension
            - Calculate the Euclidean Distance between each point in zeta
            - inverse_sq_dist is D**(-2) and inverse_dist is D**(-1)
        """
                
        zeta = self.mapping_to_vector_space(ipt)
        zeta = self.space_norm(zeta) # [B, 2, self.imsize, self.imsize]
        if self.debug:
            print(f"zeta Shape: {zeta.shape}")

        B, C, H, W = zeta.shape
        N = H*W

        pts = zeta.reshape(B, 2, N) # [B, 2, N]
        if self.debug:
            print(f"pts Shape: {pts.shape}")

        x, y = pts[:, 0], pts[:, 1] # [B, N], [B, N]
        if self.debug:
            print(f"x Shape: {x.shape}")
            print(f"y Shape: {y.shape}")

        zflat = pts.permute(0, 2, 1) # [B, N, 2]
        if self.debug:
            print(f"zflat Shape: {zflat.shape}")

        if N <= chunk_size:
            D = torch.cdist(zflat, zflat, p=2) # [B, N, N]
            D = torch.where(D==0, float('inf'), D) # [B, N, N]
        else:
            if self.debug:
                print(f"Using chunking with chunk_size={chunk_size} for N={N}")
        
            D = torch.zeros(B, N, N, device=zflat.device, dtype=zflat.dtype) # [B, N, N]
            
            for i in range(0, N, chunk_size):
                end_i = min(i + chunk_size, N)
                chunk_points = zflat[:, i:end_i, :] # [B, chunk_size, 2]
                distances_chunk = torch.cdist(chunk_points, zflat, p=2) # [B, chunk_size, N]

                diag_mask = torch.zeros_like(distances_chunk, dtype=torch.bool)
                for idx in range(end_i - i):
                    diag_mask[:, idx, i + idx] = True
                    
                distances_chunk = torch.where(diag_mask, float('inf'), distances_chunk)
                D[:, i:end_i, :] = distances_chunk # [B, chunk_size, N]

        D = torch.where(D == 0, float('inf'), D)
        if self.debug:
            print(f"D Shape: {D.shape}") # [B, N, N]
            
        #inverse_sq_dist = D.pow(-2) # [B, N, N]
        #inverse_dist = D.pow(-1) # [B, N, N]
        #if self.debug:
            #print(f"inverse_sq_dist Shape: {inverse_sq_dist.shape}") # [B, N, N]
            #print(f"inverse_dist Shape: {inverse_dist.shape}") # [B, N, N]
            
        if self.debug:
            print("-"*59)
        
        return zeta, x, y, D #inverse_sq_dist, #inverse_dist

    def compute_angle(self, x, y, chunk_size=512):
        
        """
        Compute angle of forces acted on each point with respect to the x-axis
        Input: 
            - x: x-coordinates of every point in zeta | [B, N]
            - y: y-coordinates of every point in zeta | [B, N]

        Output:
            - theta_rad: Radian of angle of forces with respect to the x-axis | [B, N, N]

        How the Function Works:
            - Applying Inverse Function of tan to find the angle | tan = dy/dx
        """
        
        B, N = x.shape
        
        if self.debug:
            print(f"Running compute_angle...")
            print(f" ")
            
        if N <= chunk_size:
            theta_rad = torch.atan2( # [B, N, N]
                y.unsqueeze(2) - y.unsqueeze(1), # [B, N, 1] - [B, 1, N] -> [B, N, N]
                x.unsqueeze(2) - x.unsqueeze(1), # [B, N, 1] - [B, 1, N] -> [B, N, N]
            )
        else:
            if self.debug:
                print(f"Using chunking for angle computation with chunk_size={chunk_size} for N={N}")
            
            theta_rad = torch.zeros(B, N, N, device=x.device, dtype=x.dtype) # [B, N, N]
            
            for i in range(0, N, chunk_size):
                end_i = min(i + chunk_size, N)
                
                x_chunk = x[:, i:end_i] # [B, chunk_size]
                y_chunk = y[:, i:end_i] # [B, chunk_size]
                
                dy = y_chunk.unsqueeze(2) - y.unsqueeze(1) # [B, chunk_size, 1] - [B, 1, N] -> [B, chunk_size, N]
                dx = x_chunk.unsqueeze(2) - x.unsqueeze(1) # [B, chunk_size, 1] - [B, 1, N] -> [B, chunk_size, N]
                
                theta_chunk = torch.atan2(dy, dx) # [B, chunk_size, N]
                theta_rad[:, i:end_i, :] = theta_chunk # [B, chunk_size, N]
                
        theta_deg = theta_rad * (180.0 / math.pi) # [B, N, N]
        if self.debug:
            print(f"theta_rad Shape: {theta_rad.shape}")
            print(f"theta_deg Shape: {theta_deg.shape}")
            
        if self.debug:
            print("-"*59)
        return theta_rad

    def compute_electric_force(self, ipt, theta_rad, inverse_sq_dist, chunk_size = 512):
        
        """
        Compute the Net Electric Force acted on each point
        
        Input:
            - ipt: Input Image | [B, 3, self.imsize, self.imsize]
            - theta_rad: The angle of forces acted on each point with respect to the x-axis | [B, N, N]
            - inverse_sq_dist: 1/r^2, the inverse squared distance of each point with respect to every point in zeta | [B, N, N]
    
        Output:
            - Fx: Force acted on each point in the x-axis | [B, N]
            - Fy: Force acted on each point in the y-axis | [B, N]
    
        How the Function Works:
            - F_E = qE = qi (k_E sum(qj/r^2))
                + Since all variables are vectors, we split them into x and y axis, respectively
        """
        
        k_E = 1.0 / (4.0 * math.pi *torch.abs(self.adaptive_epsilon))
        
        B, C, H, W = ipt.shape
        N = H*W
        
        ipt_reshaped = ipt.reshape(B, N, C) # [B, N, C]
        
        if self.debug:
            print(f"Running compute_electric_force...")
            print(" ")
    
        if N <= chunk_size:
            # Small enough - use original approach
            qi_qj = ipt_reshaped.unsqueeze(2) * ipt_reshaped.unsqueeze(1) # [B, N, 1, C] * [B, 1, N, C] -> [B, N, N, C]
            qi_qj = torch.mean(qi_qj, dim = 3) # [B, N, N]
            if self.debug:
                print(f"qi_qj Shape: {qi_qj.shape}") # [B, N, N]
    
            Fx = k_E * qi_qj * inverse_sq_dist * torch.cos(theta_rad) # [B, N, N]
            Fy = k_E * qi_qj * inverse_sq_dist * torch.sin(theta_rad) # [B, N, N]
    
            Fx = torch.sum(Fx, dim = 2) # [B, N]
            Fy = torch.sum(Fy, dim = 2) # [B, N]
        else:
            # CHUNKED BROADCASTING - avoid creating large [B, N, N, C] tensor
            if self.debug:
                print(f"Using chunked broadcasting for force computation with chunk_size={chunk_size} for N={N}")
            
            Fx = torch.zeros(B, N, device=ipt.device, dtype=ipt.dtype) # [B, N]
            Fy = torch.zeros(B, N, device=ipt.device, dtype=ipt.dtype) # [B, N]
            
            for i in range(0, N, chunk_size):
                end_i = min(i + chunk_size, N)
                
                # Get chunk of qi values
                qi_chunk = ipt_reshaped[:, i:end_i, :] # [B, chunk_size, C]
                
                # Initialize qi_qj chunk for this row chunk
                qi_qj_chunk = torch.zeros(B, end_i - i, N, device=ipt.device, dtype=ipt.dtype) # [B, chunk_size, N]
                
                # CHUNKED COMPUTATION of qi_qj to avoid large broadcasting
                for j in range(0, N, chunk_size):
                    end_j = min(j + chunk_size, N)
                    
                    # Get chunk of qj values  
                    qj_chunk = ipt_reshaped[:, j:end_j, :] # [B, chunk_size_j, C]
                    
                    # Small broadcasting: [B, chunk_i, 1, C] * [B, 1, chunk_j, C] -> [B, chunk_i, chunk_j, C]
                    qi_expanded = qi_chunk.unsqueeze(2) # [B, chunk_i, 1, C]
                    qj_expanded = qj_chunk.unsqueeze(1) # [B, 1, chunk_j, C]
                    qi_qj_small = qi_expanded * qj_expanded # [B, chunk_i, chunk_j, C]
                    qi_qj_small = torch.mean(qi_qj_small, dim=3) # [B, chunk_i, chunk_j]
                    
                    # Store in the chunk
                    qi_qj_chunk[:, :, j:end_j] = qi_qj_small # [B, chunk_i, chunk_j]
                    
                    # Cleanup small tensors
                    del qi_expanded, qj_expanded, qi_qj_small
                
                # Get corresponding chunks for force computation
                inverse_sq_dist_chunk = inverse_sq_dist[:, i:end_i, :] # [B, chunk_size, N]
                theta_rad_chunk = theta_rad[:, i:end_i, :] # [B, chunk_size, N]
                
                # Compute forces for this chunk
                Fx_chunk = k_E * qi_qj_chunk * inverse_sq_dist_chunk * torch.cos(theta_rad_chunk) # [B, chunk_size, N]
                Fy_chunk = k_E * qi_qj_chunk * inverse_sq_dist_chunk * torch.sin(theta_rad_chunk) # [B, chunk_size, N]
                
                # Sum across interactions
                Fx[:, i:end_i] = torch.sum(Fx_chunk, dim=2) # [B, chunk_size]
                Fy[:, i:end_i] = torch.sum(Fy_chunk, dim=2) # [B, chunk_size]
                
                # Cleanup chunk tensors
                del qi_chunk, qi_qj_chunk, inverse_sq_dist_chunk, theta_rad_chunk, Fx_chunk, Fy_chunk
        
        if self.debug:
            print(f"Fx Shape: {Fx.shape}") # [B, N]
            print(f"Fy Shape: {Fy.shape}") # [B, N]
            print("-"*59)
            
        return Fx, Fy

    def compute_vector_translation(self, Fx, Fy, zeta):
        """
        Move each point in the vector space corresponding to forces acted on it.

        Input:
            - Fx: Force acted on each point in the x-axis | [B, N]
            - Fy: Force acted on each point in the y-axis | [B, N]
            - zeta: Coordinates of each point (x, y) | [B, 2, self.imsize, self.imsize]

        Output:
            - new_zeta: Fully updated Coordinates of each point (x, y) | [B, 2, self.imsize, self.imsize]

        How the Function Works:
            - Displacement = Velocity x time + Initial Position -> D =  vt + x0
            - v = Acceleration x time + Initial Velocity = at + v0
            - Acceleration = Force/Mass = F/m
                + Since all variables are vectors, we split them into x and y axis, respectively
        """
        
        if self.debug:
            print(f"Starting compute_vector_translation...")
            print(" ")
            
        ax = Fx / torch.abs(self.adaptive_mass) # [B, N]
        ay = Fy / torch.abs(self.adaptive_mass) # [B, N]
        if self.debug:
            print(f"ax Shape: {ax.shape}")
            print(f"ay Shape: {ay.shape}")
            
        B, N = Fx.shape
        v0 = self.v0.expand(B, N) # [B, N]
        
        if self.debug:
            print(f"v0 Shape: {v0.shape}")
            
        time_step = torch.abs(self.adaptive_time)
        vx = v0 + ax * time_step # [B, N]
        vy = v0 + ay * time_step # [B, N]
        vx = vx.reshape(B, self.imsize, self.imsize) # [B, self.imsize, self.imsize]
        vy = vy.reshape(B, self.imsize, self.imsize) # [B, self.imsize, self.imsize]
        if self.debug:
            print(f"vx Shape: {vx.shape}")
            print(f"vy Shape: {vy.shape}")
            
        if self.debug:
            print(f"zeta Shape: {zeta.shape}")
            print(f"zeta[:, 0] Shape: {zeta[:, 0].shape}")
            print(f"zeta[:, 1] Shape: {zeta[:, 1].shape}")
        new_zeta = zeta.clone()
        new_zeta[:, 0] = zeta[:, 0] + vx # [B, 2, self.imsize, self.imsize]
        new_zeta[:, 1] = zeta[:, 1] + vy # [B, 2, self.imsize, self.imsize]

        
        if self.debug:
            print(f"Updated zeta Shape: {new_zeta.shape}")
            print(f"-"*59)
            
        return new_zeta

    def update_input(self, ipt, zeta):
        
        """
        Update Input by Unmapping zeta

        Input:
            - ipt: Input Image | [B, 3, self.imsize, self.imsize]
            - zeta: Coordinates of each point (x, y) | [B, 2, self.imsize, self.imsize]

        Output:
            - new_ipt: New Input Image | [B, 3, self.imsize, self.imsize]

        How the Function Works:
            - You got the new zeta Coordinates
            - You unmap it using a point-wise-convolutional-layer
            - New Input will be the summation of the unmapped zeta and the Input Image
        """
        if self.debug:
            print(f"Running update_input...")
            
        map_to_ipt = self.mapping_to_ipt(zeta) # [B, 3, self.imsize, self.imsize]
        new_ipt = map_to_ipt*self.beta + (1/self.loops)*ipt*(1-self.beta) # [B, 3, self.imsize, self.imsize]
        return new_ipt

    def compute_energy_of_system(self, ipt, inverse_dist, chunk_size = 512):
    
        """
        Compute total Energy of the System
    
        Input:
            - ipt: Final Input Image | [B, 3, self.imsize, self.imsize]
            - inverse_dist: 1/r, the inverse distance of each point with respect to every point in zeta | [B, N, N]
        
        Output:
            - V: Total Energy of Final Input Image | [B]
    
        How the Function Works:
            - V = sum(qi * Vj) = qi (k_E sum(qj/r^2))
        """
        
        k_E = 1.0 / (4.0 * math.pi * torch.abs(self.adaptive_epsilon))
        
        B, C, H, W = ipt.shape
        N = H*W
        
        ipt_reshaped = ipt.reshape(B, N, C) # [B, N, C]
        
        if self.debug:
            print(f"Running compute_energy_of_system...")
            print(" ")
    
        if N <= chunk_size:
            # Small enough - use original approach
            qi_qj = ipt_reshaped.unsqueeze(2) * ipt_reshaped.unsqueeze(1) # [B, N, 1, C] * [B, 1, N, C] -> [B, N, N, C]
            qi_qj = torch.mean(qi_qj, dim = 3) # [B, N, N]
            if self.debug:
                print(f"qi_qj Shape: {qi_qj.shape}") # [B, N, N]
    
            V = k_E * qi_qj * inverse_dist # [B, N, N]
            V = torch.sum(V, dim = 2) # [B, N]
            V = torch.sum(V, dim = 1) # [B]
        else:
            # CHUNKED BROADCASTING - avoid creating large [B, N, N, C] tensor
            if self.debug:
                print(f"Using chunked broadcasting for energy computation with chunk_size={chunk_size} for N={N}")
            
            V = torch.zeros(B, device=ipt.device, dtype=ipt.dtype) # [B]
            
            for i in range(0, N, chunk_size):
                end_i = min(i + chunk_size, N)
                
                # Get chunk of qi values
                qi_chunk = ipt_reshaped[:, i:end_i, :] # [B, chunk_size, C]
                
                # Initialize qi_qj chunk for this row chunk
                qi_qj_chunk = torch.zeros(B, end_i - i, N, device=ipt.device, dtype=ipt.dtype) # [B, chunk_size, N]
                
                # CHUNKED COMPUTATION of qi_qj to avoid large broadcasting
                for j in range(0, N, chunk_size):
                    end_j = min(j + chunk_size, N)
                    
                    # Get chunk of qj values
                    qj_chunk = ipt_reshaped[:, j:end_j, :] # [B, chunk_size_j, C]
                    
                    # Small broadcasting: [B, chunk_i, 1, C] * [B, 1, chunk_j, C] -> [B, chunk_i, chunk_j, C]
                    qi_expanded = qi_chunk.unsqueeze(2) # [B, chunk_i, 1, C]
                    qj_expanded = qj_chunk.unsqueeze(1) # [B, 1, chunk_j, C]
                    qi_qj_small = qi_expanded * qj_expanded # [B, chunk_i, chunk_j, C]
                    qi_qj_small = torch.mean(qi_qj_small, dim=3) # [B, chunk_i, chunk_j]
                    
                    # Store in the chunk
                    qi_qj_chunk[:, :, j:end_j] = qi_qj_small # [B, chunk_i, chunk_j]
                    
                    # Cleanup small tensors
                    del qi_expanded, qj_expanded, qi_qj_small
                
                # Get corresponding distance chunk
                inverse_dist_chunk = inverse_dist[:, i:end_i, :] # [B, chunk_size, N]
                
                # Compute energy for this chunk
                V_chunk = k_E * qi_qj_chunk * inverse_dist_chunk # [B, chunk_size, N]
                V_chunk = torch.sum(V_chunk, dim=2) # [B, chunk_size]
                V_chunk = torch.sum(V_chunk, dim=1) # [B]
                
                # Accumulate energy
                V += V_chunk # [B]
                
                # Cleanup chunk tensors
                del qi_chunk, qi_qj_chunk, inverse_dist_chunk, V_chunk
                
        if self.debug:
            print(f"V Shape: {V.shape}") # [B]
            print("-"*59)
            
        return V

    def aggregate_to_one_function(self, ipt, chunk_size):
    
        k_E = 1.0 / (4.0 * math.pi * torch.abs(self.adaptive_epsilon))
        epsilon_dist_sq = 1e-12 # For squared distances to avoid division by zero
    
        if self.debug:
            print(f"Running aggregate_to_one_function (ULTRA EFFICIENT)")
            print(f"  ipt Shape: {ipt.shape}")
    
        # 1. Compute zeta (spatial coordinates)
        # These are computed once and reused. Their memory is relatively small.
        zeta_raw = self.mapping_to_vector_space(ipt)
        zeta = self.space_norm(zeta_raw)  # [B, 2, H, W]
        B, C_zeta, H, W = zeta.shape
        N = H * W
    
        if self.debug: print(f"  zeta Shape: {zeta.shape}")
    
        # Global accumulators for total forces and system potential
        Fx_total = torch.zeros(B, N, device=ipt.device, dtype=ipt.dtype)
        Fy_total = torch.zeros(B, N, device=ipt.device, dtype=ipt.dtype)
        #V_total_system = torch.zeros(B, device=ipt.device, dtype=ipt.dtype) # Scalar potential for the whole system
    
        # Flattened coordinates and charges (properties from ipt)
        # These are [B, N, Dims]
        pts_zeta_coords_all = zeta.reshape(B, 2, N)
        x_coords_all = pts_zeta_coords_all[:, 0, :] # [B, N]
        y_coords_all = pts_zeta_coords_all[:, 1, :] # [B, N]
        zflat_for_cdist_all = pts_zeta_coords_all.permute(0, 2, 1) # [B, N, 2] (for torch.cdist)
    
        B_ipt, C_ipt, H_ipt, W_ipt = ipt.shape
        ipt_flat_charges_all = ipt.reshape(B_ipt, C_ipt, N).permute(0, 2, 1) # [B, N, C_ipt]
    
        # Outer loop: Iterate over TARGET chunks (points experiencing the force)
        for i_start in range(0, N, chunk_size):
            i_end = min(i_start + chunk_size, N)
            current_i_chunk_size = i_end - i_start
    
            if self.debug: print(f"  Processing TARGET chunk: points {i_start} to {i_end-1}")
    
            # Data for current TARGET chunk_i (small tensors)
            target_coords_i_cdist = zflat_for_cdist_all[:, i_start:i_end, :]     # [B, current_i_chunk_size, 2]
            target_x_i = x_coords_all[:, i_start:i_end]                         # [B, current_i_chunk_size]
            target_y_i = y_coords_all[:, i_start:i_end]                         # [B, current_i_chunk_size]
            target_charges_qi = ipt_flat_charges_all[:, i_start:i_end, :]       # [B, current_i_chunk_size, C_ipt]
    
            # Accumulators for forces/potential ON this i-th target chunk
            Fx_on_i_accum = torch.zeros(B, current_i_chunk_size, device=ipt.device, dtype=ipt.dtype)
            Fy_on_i_accum = torch.zeros(B, current_i_chunk_size, device=ipt.device, dtype=ipt.dtype)
            #V_on_i_accum = torch.zeros(B, current_i_chunk_size, device=ipt.device, dtype=ipt.dtype) # Potential energy for each point in i-chunk
    
            # Inner loop: Iterate over SOURCE chunks (points exerting the force)
            for j_start in range(0, N, chunk_size):
                j_end = min(j_start + chunk_size, N)
                current_j_chunk_size = j_end - j_start
    
                # Data for current SOURCE chunk_j (small tensors)
                source_coords_j_cdist = zflat_for_cdist_all[:, j_start:j_end, :]   # [B, current_j_chunk_size, 2]
                source_x_j = x_coords_all[:, j_start:j_end]                       # [B, current_j_chunk_size]
                source_y_j = y_coords_all[:, j_start:j_end]                       # [B, current_j_chunk_size]
                source_charges_qj = ipt_flat_charges_all[:, j_start:j_end, :]     # [B, current_j_chunk_size, C_ipt]
    
                # --- Compute interactions between current chunk_i and current chunk_j ---
                # All tensors created here are [B, current_i_chunk_size, current_j_chunk_size]
    
                # A. Distances (r_ij)
                # distances_ij: [B, current_i_chunk_size, current_j_chunk_size]
                distances_ij = torch.cdist(target_coords_i_cdist, source_coords_j_cdist, p=2)
    
                # B. Angles (theta_ij of vector from j to i)
                # dx_ji = x_i - x_j ; dy_ji = y_i - y_j
                dx_ji_sub = target_x_i.unsqueeze(2) - source_x_j.unsqueeze(1)
                dy_ji_sub = target_y_i.unsqueeze(2) - source_y_j.unsqueeze(1)
                theta_ji_sub = torch.atan2(dy_ji_sub, dx_ji_sub) # Angle of vector R_ji (from source j to target i)
    
                # C. Charge interaction term (qi_qj_effective)
                qi_expanded = target_charges_qi.unsqueeze(2)             # [B, size_i, 1, C_ipt]
                qj_expanded = source_charges_qj.unsqueeze(1)             # [B, 1, size_j, C_ipt]
                qi_qj_channels = qi_expanded * qj_expanded               # [B, size_i, size_j, C_ipt]
                qi_qj_scalar_sub = torch.mean(qi_qj_channels, dim=3)   # [B, size_i, size_j]
    
                # D. Handle self-interaction if target and source chunks overlap
                if i_start == j_start: # Only diagonal sub-chunks can have self-interaction
                    diag_sub_mask = torch.eye(current_i_chunk_size, current_j_chunk_size,
                                              device=ipt.device, dtype=torch.bool)
                    # Ensure broadcasting if one chunk is smaller at the edge
                    min_dim_self = min(current_i_chunk_size, current_j_chunk_size)
                    diag_sub_mask_final = diag_sub_mask[:min_dim_self, :min_dim_self]

                    distances_ij_corrected = distances_ij.clone()
                    distances_ij_corrected[:, :min_dim_self, :min_dim_self] = torch.where(
                        diag_sub_mask_final,
                        float('inf'),
                        distances_ij_corrected[:, :min_dim_self, :min_dim_self]
                    )
                    distances_ij = distances_ij_corrected
    
    
                # E. Potential Energy V_ij = k_E * (qi_qj)_ij / r_ij
                
                #inv_dist_ij = torch.pow(distances_ij, -1) # Handles inf correctly -> 0.0
                #V_ij_sub_values = k_E * qi_qj_scalar_sub * inv_dist_ij
    
                # F. Force Components F_ij = (k_E * (qi_qj)_ij / r_ij^2) * unit_vector_ji
                # F_radial_signed = k_E * qi_qj_scalar_sub * (1/r_ij^2)
                # (1/r_ij^2) can be inv_dist_ij.pow(2) or 1.0 / (distances_ij.pow(2) + epsilon_dist_sq)
                distances_ij_inf = torch.where(distances_ij == 0, float('inf'), distances_ij)
                inv_sq_dist_ij = torch.pow(distances_ij_inf, -2) # Add epsilon for stability if r can be 0 for non-self
                #inv_sq_dist_ij = torch.where(torch.isinf(distances_ij), torch.zeros_like(inv_sq_dist_ij), inv_sq_dist_ij) # Ensure inf dist -> 0 force
    
                F_radial_signed_ij = k_E * qi_qj_scalar_sub * inv_sq_dist_ij
                
                normalization_factor = torch.pow(torch.tensor([N], dtype=torch.float32, device=ipt.device), 2)
                
                Fx_ij_sub = F_radial_signed_ij * torch.cos(theta_ji_sub)/normalization_factor
                Fy_ij_sub = F_radial_signed_ij * torch.sin(theta_ji_sub)/normalization_factor
    
                # G. Accumulate forces and potential for the current target_i_chunk
                Fx_on_i_accum = torch.sum(Fx_ij_sub, dim=2)
                Fy_on_i_accum = torch.sum(Fy_ij_sub, dim=2)

                Fx_on_i_accum = self.space_norm(Fx_on_i_accum)
                Fy_on_i_accum = self.space_norm(Fy_on_i_accum)
                #V_on_i_accum += torch.sum(V_ij_sub_values, dim=2) # Sum potential contributions for each point in i_chunk
    
                # H. Explicitly delete intermediate sub-chunk tensors
                del distances_ij, dx_ji_sub, dy_ji_sub, theta_ji_sub, qi_expanded, qj_expanded
                del qi_qj_channels, qi_qj_scalar_sub, #inv_dist_ij, V_ij_sub_values
                del inv_sq_dist_ij, F_radial_signed_ij, Fx_ij_sub, Fy_ij_sub
                if torch.cuda.is_available(): torch.cuda.empty_cache() # Aggressive cache clearing
    
            # Store results for the fully processed i-th target chunk
            Fx_total[:, i_start:i_end] = Fx_on_i_accum
            Fy_total[:, i_start:i_end] = Fy_on_i_accum
            #V_total_system += torch.sum(V_on_i_accum, dim=1) # Sum potential over points in i-chunk, then add to system total
    
            # Explicitly delete i-chunk accumulators
            del target_coords_i_cdist, target_x_i, target_y_i, target_charges_qi
            del Fx_on_i_accum, Fy_on_i_accum, #V_on_i_accum
            if torch.cuda.is_available(): torch.cuda.empty_cache()
    
        if self.debug:
            print(f"  aggregate_to_one_function (ULTRA EFFICIENT) finished.")
            print(f"  Fx_total: {Fx_total}")
            print(f"  Fy_total: {Fy_total}")
            #if V_total_system.numel() > 0: print(f"  V_total_system: {V_total_system.item()}")
            print("-" * 59)
    
        return Fx_total, Fy_total, zeta, #V_total_system # Make sure zeta is returned

    def forward(self, ipt, chunk_size=1024):
        
        ipt = self.conv1(ipt)
        ipt = self.space_norm(ipt)
        
        assert ipt.shape[2] and ipt.shape[3] == self.imsize, "Input Shape's Height and Width must match parameter imsize"
        for layer in range(self.layers):
            
            if self.use_self_attention:
                attended_ipt = self.input_attention(ipt)
                ipt = self.attention_scale * attended_ipt + ipt
            else:
                ipt = self.conv2(ipt)
                ipt = self.space_norm(ipt)
            if self.debug:
                print(f"After self-attention ipt Shape: {ipt.shape}")
                
            if self.debug:
                print(f"Running Layer {layer + 1}/{self.layers}")
                print(f"-"*59)
            for loop in range(self.loops):
                if self.debug:
                    print(f"Starting {loop + 1}/{self.loops} for compute_electric_force...")
                    print(" ")
                #zeta, x, y, D = self.compute_distance(ipt, chunk_size)
                #theta_rad = self.compute_angle(x, y, chunk_size)
                #Fx, Fy = self.compute_electric_force(ipt, theta_rad, D.pow(-2), chunk_size)
                Fx, Fy, zeta = self.aggregate_to_one_function(ipt, chunk_size)
                zeta = self.compute_vector_translation(Fx, Fy, zeta)
    
            zeta = self.space_norm(zeta)
            ipt = self.update_input(ipt, zeta)
            
        ipt = self.space_norm(ipt)
        Fx, Fy, zeta = self.aggregate_to_one_function(ipt, chunk_size)
        #V = torch.abs(V)
        
        return ipt, zeta, #V

In [7]:
model = ElectricForceModel(debug = True)
model = model.to(DEVICE)
model = model.to(torch.bfloat16)

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    output, zeta = model(inpt.to(DEVICE).to(torch.bfloat16), chunk_size=512)

In [8]:
class SegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Đọc ảnh và mask
        image = np.array(Image.open(self.image_paths[idx]).convert("RGB"))
        mask = np.array(Image.open(self.mask_paths[idx]).convert("L"))
        
        mask = np.where(mask < 85, 0, np.where(mask < 170, 1, 2))
        
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
        mask = mask.long() if isinstance(mask, torch.Tensor) else torch.from_numpy(mask).long()
        #print(mask.shape)
        return image, mask

In [9]:
def compute_multi_iou(outputs, masks, num_classes=3):
    probs = torch.softmax(outputs, dim=1)
    preds = torch.argmax(probs, dim=1)
    iou_per_class = []
    for cls in range(num_classes):
        pred_cls = (preds == cls)
        target_cls = (masks == cls)
        intersection = (pred_cls & target_cls).sum().float()
        union = (pred_cls | target_cls).sum().float()
        iou = (intersection / (union + 1e-8)).item()
        iou_per_class.append(iou)
    return np.mean(iou_per_class)

def compute_dice_score(outputs, masks, num_classes=3):
    probs = torch.softmax(outputs, dim=1)
    preds = torch.argmax(probs, dim=1)
    dice_per_class = []
    for cls in range(num_classes):
        pred_cls = (preds == cls)
        target_cls = (masks == cls)
        intersection = (pred_cls & target_cls).sum().float()
        dice = (2. * intersection / (pred_cls.sum() + target_cls.sum() + 1e-8)).item()
        dice_per_class.append(dice)
    return np.mean(dice_per_class)

def compute_precision(outputs, masks, num_classes=3):
    probs = torch.softmax(outputs, dim=1)
    preds = torch.argmax(probs, dim=1)
    precision_per_class = []
    for cls in range(num_classes):
        pred_cls = (preds == cls)
        target_cls = (masks == cls)
        true_positive = (pred_cls & target_cls).sum().float()
        predicted_positive = pred_cls.sum().float()
        precision = (true_positive / (predicted_positive + 1e-8)).item()
        precision_per_class.append(precision)
    return np.mean(precision_per_class)

def compute_accuracy(outputs, masks):
    probs = torch.softmax(outputs, dim=1)
    preds = torch.argmax(probs, dim=1)
    correct = (preds == masks).sum().float()
    total = torch.numel(masks)
    return (correct / total).item()


def post_process(mask, kernel_size=3):
    import cv2
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    
    cleaned = cv2.morphologyEx(mask.numpy().astype(np.uint8), 
                              cv2.MORPH_OPEN, kernel)

    cleaned = cv2.morphologyEx(cleaned, 
                              cv2.MORPH_CLOSE, kernel)
    
    return torch.from_numpy(cleaned)

In [10]:
class Mixup:
    def __init__(self, alpha=0.4, p=0.5):
        self.alpha = alpha
        self.p = p
        
    def __call__(self, images, masks):
        if np.random.rand() < self.p:
            batch_size = len(images) if isinstance(images, list) else 1
            if batch_size > 1:
                indices = np.random.permutation(batch_size)
                img2, mask2 = images[indices], masks[indices]
                
                lam = np.random.beta(self.alpha, self.alpha)

                mixed_img = lam * images + (1 - lam) * img2
                mixed_mask = lam * masks + (1 - lam) * mask2
                
                return mixed_img, mixed_mask
        return images, masks

In [11]:
def get_train_transform():
    return A.Compose([
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
        ToTensorV2(),
    ])

def get_valid_transform():
    return A.Compose([
        A.Resize(IMAGE_SIZE, IMAGE_SIZE),
        A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
        ToTensorV2(),
    ])

def visualize_predictions(model, dataloader, device, num_samples=4):
    model.eval()
    images, masks, preds = [], [], []
    
    with torch.no_grad():
        for img, mask in dataloader:
            img = img.to(device)
            output = model(img)
            pred = torch.argmax(torch.softmax(output, dim=1), dim=1)
            
            for i in range(min(img.shape[0], num_samples - len(images))):
                if len(images) < num_samples:
                    images.append(img[i].cpu())
                    masks.append(mask[i].cpu())
                    preds.append(pred[i].cpu())
            
            if len(images) >= num_samples:
                break
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples * 5))
    
    for i in range(num_samples):
        img = images[i].permute(1, 2, 0).numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        
        axes[i, 0].imshow(img)
        axes[i, 0].set_title('Original Image')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(masks[i], cmap='jet')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(preds[i], cmap='jet')
        axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()
    return fig

def train_one_epoch(model, dataloader, optimizer, criterion, device, tau):
    model.train()
    train_loss = 0.0
    
    pbar = tqdm(dataloader, desc='Training')
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs, zeta= model(images)
        loss = criterion(outputs, masks)#*(1-tau) + V*tau
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return train_loss / len(dataloader)

# Validation function
def validate(model, dataloader, criterion, device, tau):
    model.eval()
    val_loss = 0.0
    iou_scores, dsc_scores, precision_scores, acc_scores = [], [], [], []
    
    with torch.no_grad():
        for images, masks in tqdm(dataloader, desc='Validation'):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs, zeta= model(images)
            #outputs = convert_model_output_to_segmentation(outputs)
            loss = criterion(outputs, masks)#*(1-tau) + V*tau
            
            val_loss += loss.item()
            
            iou_scores.append(compute_multi_iou(outputs, masks))
            dsc_scores.append(compute_dice_score(outputs, masks))
            precision_scores.append(compute_precision(outputs, masks))
            acc_scores.append(compute_accuracy(outputs, masks))
    
    mean_iou = np.mean(iou_scores)
    mean_dsc = np.mean(dsc_scores)
    mean_precision = np.mean(precision_scores)
    mean_acc = np.mean(acc_scores)
    
    return val_loss / len(dataloader), mean_iou, mean_dsc, mean_precision, mean_acc

def train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, num_epochs, device, fold=None):
    best_metrics = {'mean_iou': 0.0, 'mean_dsc': 0.0}
    history = {
        'train_loss': [], 'val_loss': [],
        'iou': [], 'dsc': [], 'precision': [], 'accuracy': []
    }
    
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, tau = 0.2)
        history['train_loss'].append(train_loss)
        
        val_loss, mean_iou, mean_dsc, mean_precision, mean_acc = validate(model, val_loader, criterion, device, tau = 0.2)
        history['val_loss'].append(val_loss)
        history['iou'].append(mean_iou)
        history['dsc'].append(mean_dsc)
        history['precision'].append(mean_precision)
        history['accuracy'].append(mean_acc)
        
        scheduler.step(mean_dsc)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        print(f"Accuracy: {mean_acc:.4f}, Mean IoU: {mean_iou:.4f}, Mean DSC/F1: {mean_dsc:.4f}, Mean Precision: {mean_precision:.4f}")
        
        if mean_dsc > best_metrics['mean_dsc']:
            best_metrics['mean_dsc'] = mean_dsc
            best_metrics['mean_iou'] = mean_iou
            best_metrics['mean_precision'] = mean_precision
            best_metrics['mean_acc'] = mean_acc
            
            model_name = f"best_model_fold{fold}.pth" if fold is not None else "best_model.pth"
            torch.save(model.state_dict(), model_name)
            print(f"Saved best model with DSC: {mean_dsc:.4f}")
        
        print(f"Best DSC: {best_metrics['mean_dsc']:.4f}, Best IoU: {best_metrics['mean_iou']:.4f}")
        print("-" * 50)
    
    return model, history, best_metrics

def find_optimal_k(all_images, all_masks, k_values=[3, 5, 7, 10], num_epochs=10, tau = 0.2):

    k_scores = []
    k_variances = []
    train_times = []
    
    for k in k_values:
        fold_scores = []
        
        kf = KFold(n_splits=k, shuffle=True, random_state=SEED)
        start_time = time.time()
        
        for fold, (train_idx, val_idx) in enumerate(kf.split(all_images)):
            print(f"  Fold {fold+1}/{k}")
            
            train_img = [all_images[i] for i in train_idx]
            train_mask = [all_masks[i] for i in train_idx]
            val_img = [all_images[i] for i in val_idx]
            val_mask = [all_masks[i] for i in val_idx]
            
            # Tạo dataset
            train_dataset = SegmentationDataset(train_img, train_mask, get_train_transform())
            val_dataset = SegmentationDataset(val_img, val_mask, get_valid_transform())
            
            # Tạo dataloader
            train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
            
            # Khởi tạo model mới cho mỗi fold
            model = ElectricForceModel().to(DEVICE)
            
            # Setup optimizer và loss
            optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
            criterion = nn.CrossEntropyLoss()
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3)
            
            best_metrics = {'mean_dsc': 0.0}
            
            for epoch in range(num_epochs):
                model.train()
                for images, masks in train_loader:
                    images, masks = images.to(DEVICE), masks.to(DEVICE)
                    optimizer.zero_grad()
                    outputs, zeta= model(images)
                    loss = criterion(outputs, masks)#*(1-tau) + V*tau
                    loss.backward()
                    optimizer.step()
                
                model.eval()
                dsc_scores = []
                with torch.no_grad():
                    for images, masks in val_loader:
                        images, masks = images.to(DEVICE), masks.to(DEVICE)
                        outputs, zeta= model(images)
                        #outputs = conconvert_model_output_to_segmentation(outputs)
                        dsc_scores.append(compute_dice_score(outputs, masks))
                
                mean_dsc = np.mean(dsc_scores)
                scheduler.step(mean_dsc)
                
                if mean_dsc > best_metrics['mean_dsc']:
                    best_metrics['mean_dsc'] = mean_dsc
            
            # Lưu Dice score của fold này
            fold_scores.append(best_metrics['mean_dsc'])
            print(f"    Dice score: {best_metrics['mean_dsc']:.4f}")
        
        train_time = time.time() - start_time
        train_times.append(train_time)
        
        mean_score = np.mean(fold_scores)
        score_variance = np.var(fold_scores)
        
        k_scores.append(mean_score)
        k_variances.append(score_variance)
        
        print(f"K={k}, Điểm trung bình: {mean_score:.4f}, Phương sai: {score_variance:.6f}")
        print(f"Thời gian training: {train_time:.2f} giây")
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    
    ax1.plot(k_values, k_scores, 'o-', linewidth=2, color='blue')
    ax1.set_xlabel('Số lượng Fold (K)')
    ax1.set_ylabel('Điểm DSC trung bình')
    ax1.set_title('Hiệu suất theo giá trị K')
    ax1.grid(True)
    
    # Vẽ phương sai
    ax2.plot(k_values, k_variances, 'o-', linewidth=2, color='orange')
    ax2.set_xlabel('Số lượng Fold (K)')
    ax2.set_ylabel('Phương sai')
    ax2.set_title('Phương sai theo giá trị K')
    ax2.grid(True)
    
    # Vẽ thời gian training
    ax3.plot(k_values, train_times, 'o-', linewidth=2, color='green')
    ax3.set_xlabel('Số lượng Fold (K)')
    ax3.set_ylabel('Thời gian training (giây)')
    ax3.set_title('Thời gian training theo giá trị K')
    ax3.grid(True)
    
    plt.tight_layout()
    plt.savefig('k_fold_elbow_analysis.png')
    plt.show()
    
    optimal_k = find_elbow_point(k_values, k_scores, k_variances)

    
    return optimal_k

def find_elbow_point(k_values, k_scores, k_variances):
    """
    Xác định điểm elbow dựa trên phần trăm cải thiện và phương sai
    """
    
    improvements = []
    for i in range(1, len(k_values)):
        improvement = (k_scores[i] - k_scores[i-1]) / k_scores[i-1] * 100
        improvements.append(improvement)
    
    for i in range(len(improvements)):
        if improvements[i] < 1.0:
            return k_values[i+1]
    
    var_index = np.argmin(k_variances)
    return k_values[var_index]

def run_training():
    print(f"Using device: {DEVICE}")
    
    root_dir = Path('/kaggle/input/data-1/content')
    all_images = sorted(list((root_dir / 'images').glob('*.bmp')))
    all_masks = sorted(list((root_dir / 'masks').glob('*.png')))
    
    all_images = sorted(all_images, key=lambda x: x.stem)
    all_masks = sorted(all_masks, key=lambda x: x.stem)
    
    print(f"Total images: {len(all_images)}")
    print(f"Total masks: {len(all_masks)}")
    
    optimal_k = find_optimal_k(
        all_images, all_masks, 
        k_values=[3, 5, 7, 10], 
        num_epochs=10
    )
    
    kf = KFold(n_splits=optimal_k, shuffle=True, random_state=SEED)
    fold_metrics = []
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(all_images)):
        print(f"Training Fold {fold+1}/{optimal_k}")

        train_img = [all_images[i] for i in train_idx]
        train_mask = [all_masks[i] for i in train_idx]
        val_img = [all_images[i] for i in val_idx]
        val_mask = [all_masks[i] for i in val_idx]
        
        # Create datasets
        train_dataset = SegmentationDataset(train_img, train_mask, get_train_transform())
        val_dataset = SegmentationDataset(val_img, val_mask, get_valid_transform())
        
        # Create dataloaders
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
        
        # Initialize model
        model = ElectricForceModel().to(DEVICE)
        
        # Setup optimizer and loss
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
        criterion = nn.CrossEntropyLoss()
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5, factor=0.5)
        
        # Train model
        model, history, best_metrics = train_model(
            model, 
            train_loader, 
            val_loader,
            optimizer,
            criterion,
            scheduler,
            NUM_EPOCHS,
            DEVICE,
            fold=fold
        )
        
        # Save metrics
        fold_metrics.append(best_metrics)
        
        # Visualize predictions
        visualize_predictions(model, val_loader, DEVICE)
        
        # Plot learning curves
        plt.figure(figsize=(12, 8))
        
        plt.subplot(2, 2, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.legend()
        plt.title('Loss Curves')
        
        plt.subplot(2, 2, 2)
        plt.plot(history['iou'], label='IoU')
        plt.plot(history['dsc'], label='DSC/F1')
        plt.legend()
        plt.title('IoU and DSC Curves')
        
        plt.subplot(2, 2, 3)
        plt.plot(history['precision'], label='Precision')
        plt.legend()
        plt.title('Precision Curve')
        
        plt.subplot(2, 2, 4)
        plt.plot(history['accuracy'], label='Accuracy')
        plt.legend()
        plt.title('Accuracy Curve')
        
        plt.tight_layout()
        plt.savefig(f'learning_curves_fold{fold}.png')
        plt.show()
    
    # Print overall metrics
    print("\nAverage metrics across all folds:")
    avg_iou = np.mean([metrics['mean_iou'] for metrics in fold_metrics])
    avg_dsc = np.mean([metrics['mean_dsc'] for metrics in fold_metrics])
    avg_precision = np.mean([metrics['mean_precision'] for metrics in fold_metrics])
    avg_acc = np.mean([metrics['mean_acc'] for metrics in fold_metrics])
    
    print(f"Avg IoU: {avg_iou:.4f}")
    print(f"Avg DSC/F1: {avg_dsc:.4f}")
    print(f"Avg Precision: {avg_precision:.4f}")
    print(f"Avg Accuracy: {avg_acc:.4f}")

# Run the training
if __name__ == "__main__":
    run_training()

In [13]:
from torch.amp import autocast, GradScaler
import time
from pathlib import Path
import json
from datetime import datetime
import logging

# Setup logging
def setup_logging():
    """Setup detailed logging for training"""
    log_dir = Path("training_logs")
    log_dir.mkdir(exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = log_dir / f"training_{timestamp}.log"
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def train_one_epoch(model, dataloader, optimizer, criterion, device, epoch, logger):
    """Train for one epoch without gradient scaling"""
    model.train()
    
    epoch_loss = 0.0
    epoch_steps = 0
    batch_losses = []
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1} - Training')
    
    for batch_idx, (images, masks) in enumerate(pbar):
        step_start_time = time.time()
        
        # Move data to device and cast
        images = images.to(device).to(torch.bfloat16)
        masks = masks.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass with autocast
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            outputs, zeta = model(images, chunk_size=512)
            loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        # Track metrics
        batch_loss = loss.item()
        epoch_loss += batch_loss
        batch_losses.append(batch_loss)
        epoch_steps += 1
        
        step_time = time.time() - step_start_time
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{batch_loss:.4f}',
            'avg_loss': f'{epoch_loss/epoch_steps:.4f}',
            'step_time': f'{step_time:.2f}s'
        })
    
    return epoch_loss / len(dataloader), batch_losses


def validate_one_epoch(model, dataloader, criterion, device, epoch, logger):
    """Validate for one epoch with detailed metrics"""
    model.eval()
    
    val_loss = 0.0
    iou_scores, dsc_scores, precision_scores, acc_scores = [], [], [], []
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1} - Validation')
    
    with torch.no_grad():
        for batch_idx, (images, masks) in enumerate(pbar):
            # Move data to device
            images = images.to(device).to(torch.bfloat16)
            masks = masks.to(device)
            
            # Forward pass with mixed precision
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                outputs, zeta = model(images, chunk_size=512)
                loss = criterion(outputs, masks)
            
            val_loss += loss.item()
            
            # Compute metrics
            iou = compute_multi_iou(outputs, masks)
            dsc = compute_dice_score(outputs, masks)
            precision = compute_precision(outputs, masks)
            accuracy = compute_accuracy(outputs, masks)
            
            iou_scores.append(iou)
            dsc_scores.append(dsc)
            precision_scores.append(precision)
            acc_scores.append(accuracy)
            
            # Update progress bar
            pbar.set_postfix({
                'val_loss': f'{loss.item():.4f}',
                'iou': f'{iou:.4f}',
                'dsc': f'{dsc:.4f}'
            })
    
    # Calculate averages
    avg_val_loss = val_loss / len(dataloader)
    mean_iou = np.mean(iou_scores)
    mean_dsc = np.mean(dsc_scores)
    mean_precision = np.mean(precision_scores)
    mean_acc = np.mean(acc_scores)
    
    # Calculate standard deviations
    std_iou = np.std(iou_scores)
    std_dsc = np.std(dsc_scores)
    std_precision = np.std(precision_scores)
    std_acc = np.std(acc_scores)
    
    logger.info(f"Epoch {epoch+1} Validation Summary:")
    logger.info(f"  Val_Loss={avg_val_loss:.4f}")
    logger.info(f"  IoU: {mean_iou:.4f} ± {std_iou:.4f}")
    logger.info(f"  DSC: {mean_dsc:.4f} ± {std_dsc:.4f}")
    logger.info(f"  Precision: {mean_precision:.4f} ± {std_precision:.4f}")
    logger.info(f"  Accuracy: {mean_acc:.4f} ± {std_acc:.4f}")
    
    return avg_val_loss, mean_iou, mean_dsc, mean_precision, mean_acc

def save_checkpoint(model, optimizer, scaler, epoch, metrics, filename):
    """Save model checkpoint with training state"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'metrics': metrics,
        'timestamp': datetime.now().isoformat()
    }
    torch.save(checkpoint, filename)

def plot_training_curves(history, save_path):
    """Plot and save training curves"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Loss curves
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    axes[0, 0].set_title('Loss Curves')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # IoU curve
    axes[0, 1].plot(epochs, history['iou'], 'g-', label='IoU')
    axes[0, 1].set_title('IoU Curve')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('IoU')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # DSC curve
    axes[0, 2].plot(epochs, history['dsc'], 'purple', label='DSC/F1')
    axes[0, 2].set_title('DSC/F1 Curve')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('DSC')
    axes[0, 2].legend()
    axes[0, 2].grid(True)
    
    # Precision curve
    axes[1, 0].plot(epochs, history['precision'], 'orange', label='Precision')
    axes[1, 0].set_title('Precision Curve')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Precision')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Accuracy curve
    axes[1, 1].plot(epochs, history['accuracy'], 'brown', label='Accuracy')
    axes[1, 1].set_title('Accuracy Curve')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    # Learning rate curve (if available)
    if 'learning_rate' in history:
        axes[1, 2].plot(epochs, history['learning_rate'], 'cyan', label='Learning Rate')
        axes[1, 2].set_title('Learning Rate')
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].set_ylabel('LR')
        axes[1, 2].legend()
        axes[1, 2].grid(True)
    else:
        axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def train_model(model, train_loader, val_loader, optimizer, criterion, scheduler, 
                num_epochs, device, save_dir="checkpoints"):
    """Main training loop with comprehensive logging"""
    
    # Setup logging
    logger = setup_logging()
    
    # Create save directory
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)
    
    # Initialize mixed precision scaler
    #scaler = GradScaler('cuda')

    # Convert model to bfloat16 for memory efficiency
    model = model.to(torch.bfloat16)
    logger.info("Model converted to bfloat16 for memory efficiency")
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'iou': [],
        'dsc': [],
        'precision': [],
        'accuracy': [],
        'learning_rate': [],
        'epoch_times': []
    }
    
    # Best metrics tracking
    best_metrics = {
        'best_dsc': 0.0,
        'best_iou': 0.0,
        'best_epoch': 0
    }
    
    logger.info(f"Starting training for {num_epochs} epochs")
    logger.info(f"Device: {device}")
    logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    logger.info(f"Training batches: {len(train_loader)}")
    logger.info(f"Validation batches: {len(val_loader)}")
    
    total_start_time = time.time()
    
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        
        logger.info(f"{'='*50}")
        logger.info(f"Starting Epoch {epoch+1}/{num_epochs}")
        logger.info(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
        
        # Training phase
        train_loss, batch_losses = train_one_epoch(
            model, train_loader, optimizer, criterion, device, epoch, logger
        )
        
        # Validation phase
        val_loss, mean_iou, mean_dsc, mean_precision, mean_acc = validate_one_epoch(
            model, val_loader, criterion, device, epoch, logger
        )
        
        # Update learning rate scheduler
        scheduler.step(mean_dsc)
        
        # Record metrics
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['iou'].append(mean_iou)
        history['dsc'].append(mean_dsc)
        history['precision'].append(mean_precision)
        history['accuracy'].append(mean_acc)
        history['learning_rate'].append(optimizer.param_groups[0]['lr'])
        
        epoch_time = time.time() - epoch_start_time
        history['epoch_times'].append(epoch_time)
        
        # Log epoch summary
        logger.info(f"Epoch {epoch+1} Summary:")
        logger.info(f"  Train Loss: {train_loss:.4f}")
        logger.info(f"  Val Loss: {val_loss:.4f}")
        logger.info(f"  IoU: {mean_iou:.4f}")
        logger.info(f"  DSC: {mean_dsc:.4f}")
        logger.info(f"  Precision: {mean_precision:.4f}")
        logger.info(f"  Accuracy: {mean_acc:.4f}")
        logger.info(f"  Epoch Time: {epoch_time:.2f}s")
        
        # Save best model
        if mean_dsc > best_metrics['best_dsc']:
            best_metrics['best_dsc'] = mean_dsc
            best_metrics['best_iou'] = mean_iou
            best_metrics['best_epoch'] = epoch + 1
            
            # Save best model checkpoint
            save_checkpoint(
                model, optimizer, scaler, epoch,
                {'dsc': mean_dsc, 'iou': mean_iou},
                save_dir / "best_model.pth"
            )
            
            logger.info(f"🎉 New best model saved! DSC: {mean_dsc:.4f}, IoU: {mean_iou:.4f}")
        
        # Save regular checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            save_checkpoint(
                model, optimizer, scaler, epoch,
                {'dsc': mean_dsc, 'iou': mean_iou},
                save_dir / f"checkpoint_epoch_{epoch+1}.pth"
            )
            logger.info(f"Checkpoint saved at epoch {epoch+1}")
        
        # Save training history
        with open(save_dir / "training_history.json", 'w') as f:
            json.dump(history, f, indent=2)
    
    total_time = time.time() - total_start_time
    
    # Final summary
    logger.info(f"{'='*50}")
    logger.info("Training Completed!")
    logger.info(f"Total Training Time: {total_time:.2f}s ({total_time/3600:.2f}h)")
    logger.info(f"Average Epoch Time: {np.mean(history['epoch_times']):.2f}s")
    logger.info(f"Best DSC: {best_metrics['best_dsc']:.4f} at epoch {best_metrics['best_epoch']}")
    logger.info(f"Best IoU: {best_metrics['best_iou']:.4f}")
    
    # Plot and save training curves
    plot_training_curves(history, save_dir / "training_curves.png")
    
    return model, history, best_metrics

def run_training():
    """Main function to run training"""
    print(f"Using device: {DEVICE}")
    
    # Load data
    root_dir = Path('/kaggle/input/data-1/content')
    all_images = sorted(list((root_dir / 'images').glob('*.bmp')))
    all_masks = sorted(list((root_dir / 'masks').glob('*.png')))
    
    all_images = sorted(all_images, key=lambda x: x.stem)
    all_masks = sorted(all_masks, key=lambda x: x.stem)
    
    print(f"Total images: {len(all_images)}")
    print(f"Total masks: {len(all_masks)}")
    
    # Train/validation split (80/20)
    split_idx = int(0.8 * len(all_images))
    train_images = all_images[:split_idx]
    train_masks = all_masks[:split_idx]
    val_images = all_images[split_idx:]
    val_masks = all_masks[split_idx:]
    
    print(f"Training samples: {len(train_images)}")
    print(f"Validation samples: {len(val_images)}")
    
    # Create datasets
    train_dataset = SegmentationDataset(train_images, train_masks, get_train_transform())
    val_dataset = SegmentationDataset(val_images, val_masks, get_valid_transform())
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    # Initialize model
    model = ElectricForceModel().to(DEVICE).to(torch.bfloat16)
    
    # Setup optimizer and loss
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', patience=5, factor=0.5, verbose=True
    )
    
    # Train model
    model, history, best_metrics = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        scheduler=scheduler,
        num_epochs=NUM_EPOCHS,
        device=DEVICE,
        save_dir="training_results"
    )
    
    # Final evaluation on validation set
    print("\nFinal Evaluation:")
    print(f"Best DSC: {best_metrics['best_dsc']:.4f}")
    print(f"Best IoU: {best_metrics['best_iou']:.4f}")
    print(f"Best Epoch: {best_metrics['best_epoch']}")

# Run the training
if __name__ == "__main__":
    run_training()

Using device: cuda
Total images: 300
Total masks: 300
Training samples: 240
Validation samples: 60


Epoch 1 - Training:  19%|█▉        | 9/48 [25:13<1:49:16, 168.12s/it, loss=nan, avg_loss=nan, step_time=157.16s]      


KeyboardInterrupt: 