In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from einops import rearrange
from tqdm import tqdm
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

class PatchEmbedding(nn.Module): 
    def __init__(self, img_size=256, patch_size=16, in_channels=3, embed_dim=512):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x) 
        x = rearrange(x, 'b c h w -> b (h w) c') 
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim * 4)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x.permute(1, 0, 2)  
        skip_connections = []
        for layer in self.layers:
            x = layer(x)
            skip_connections.append(x.permute(1, 0, 2))  
        x = self.norm(x)
        return x.permute(1, 0, 2), skip_connections  

class SpatialChannelAttention(nn.Module):
    def __init__(self, channels, reduction_ratio=8):
        super().__init__()
        # Channel attention
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels//reduction_ratio, 1),
            nn.ReLU(),
            nn.Conv2d(channels//reduction_ratio, channels, 1),
            nn.Sigmoid()
        )
        
        # Spatial attention
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )

    def forward(self, x):
        ca = self.channel_attention(x)
        x_ca = x * ca
        
        sa_input = torch.cat([torch.mean(x_ca, dim=1, keepdim=True), 
                             torch.max(x_ca, dim=1, keepdim=True)[0]], dim=1)
        sa = self.spatial_attention(sa_input)
        x_sa = x_ca * sa
        
        return x_sa

class SSMDecoderWithAttention(nn.Module):
    def __init__(self, embed_dim=512, out_channels=3):
        super().__init__()
        self.embed_dim = embed_dim

        
        self.skip_adjust1 = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim//2, kernel_size=1),
            SpatialChannelAttention(embed_dim//2)
        )
        self.skip_adjust2 = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim//4, kernel_size=1),
            SpatialChannelAttention(embed_dim//4)
        )
        self.skip_adjust3 = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim//8, kernel_size=1),
            SpatialChannelAttention(embed_dim//8)
        )
        
    
        self.up1 = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim//2, kernel_size=3, padding=1),
            nn.BatchNorm2d(embed_dim//2),
            nn.GELU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.up2 = nn.Sequential(
            nn.Conv2d(embed_dim//2, embed_dim//4, kernel_size=3, padding=1),
            nn.BatchNorm2d(embed_dim//4),
            nn.GELU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
        self.up3 = nn.Sequential(
            nn.Conv2d(embed_dim//4, embed_dim//8, kernel_size=3, padding=1),
            nn.BatchNorm2d(embed_dim//8),
            nn.GELU(),
            nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        )
        
        self.final_conv = nn.Conv2d(embed_dim//8, out_channels, kernel_size=1)

    def forward(self, x, skip_connections):
        B, N, C = x.shape
        h = w = int(N**0.5)
        x = x.permute(0, 2, 1).view(B, C, h, w)  

       
        skip1 = F.interpolate(skip_connections[-1].permute(0, 2, 1).view(B, C, h, w), scale_factor=2, mode='bilinear', align_corners=False)
        skip2 = F.interpolate(skip_connections[-2].permute(0, 2, 1).view(B, C, h, w), scale_factor=4, mode='bilinear', align_corners=False)
        skip3 = F.interpolate(skip_connections[-3].permute(0, 2, 1).view(B, C, h, w), scale_factor=16, mode='bilinear', align_corners=False)

        # Decoder stages with skip connections
        x = self.up1(x) + self.skip_adjust1(skip1)
        x = self.up2(x) + self.skip_adjust2(skip2)
        x = self.up3(x) + self.skip_adjust3(skip3)

        return torch.sigmoid(self.final_conv(x))

class HybridMamba(nn.Module):
    def __init__(self, img_size=256, patch_size=16, in_channels=3, embed_dim=512, num_layers=3):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.transformer = TransformerEncoder(embed_dim, num_heads=8, num_layers=num_layers)
        self.decoder = SSMDecoderWithAttention(embed_dim)

    def forward(self, x):
        x = self.patch_embed(x)  # (B, 256, embed_dim)
        x, skip_connections = self.transformer(x)  # (B, 256, embed_dim), skips
        x = self.decoder(x, skip_connections)      # (B, 3, 256, 256)
        return x  

In [2]:
model = HybridMamba().to(device)
dummy = torch.randn(1, 3, 512, 512).to(device)
out = model(dummy)
print("Output shape:", out.shape)

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Output shape: torch.Size([1, 3, 512, 512])


In [3]:
# ==================== TRAINING SCRIPT ====================
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.png', '.tif'))])
        self.mask_files = sorted([f for f in os.listdir(mask_dir) if f.endswith(('.png', '.tif'))])
        
        # Verify matching pairs
        assert len(self.image_files) == len(self.mask_files), "Mismatched number of images and masks"
        for img, msk in zip(self.image_files, self.mask_files):
            assert img.split('.')[0] == msk.split('.')[0].replace('_mask', ''), \
                f"Mismatched pairs: {img} vs {msk}"

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        
        image = Image.open(img_path).convert('L')  # Grayscale
        mask = Image.open(mask_path).convert('L')
        
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
            
        mask = (mask > 0.5).float()  # Binarize mask
        
        return image, mask

class HybridReconstructionLoss(nn.Module):
    def __init__(self, weights={'mse': 0.5, 'ssim': 0.3, 'dice': 0.2}, eps=1e-6):
        super().__init__()
        self.weights = weights
        self.eps = eps
        self.mse_loss = nn.MSELoss()
        
    def forward(self, pred, target):
        # Clamp predictions to valid range
        pred = torch.clamp(pred, 0, 1)
        
        # 1. MSE Loss (Pixel-level accuracy)
        mse = self.mse_loss(pred, target)
        
        # 2. SSIM Loss (Structural similarity)
        ssim_loss = 1 - ssim(pred, target, data_range=1.0, size_average=True)
        
        # 3. Dice Loss (Region overlap)
        intersection = (pred * target).sum()
        dice_loss = 1 - (2. * intersection + self.eps) / (pred.sum() + target.sum() + self.eps)
        
        # Weighted combination
        total_loss = (self.weights['mse'] * mse +
                     self.weights['ssim'] * ssim_loss +
                     self.weights['dice'] * dice_loss)
        
        return total_loss
        
def train():
    # Config
    config = {
        'image_dir': 'augmented_images',
        'mask_dir': 'augmented_masks',
        'batch_size': 10,
        'num_epochs': 200,
        'lr': 1e-4,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'save_path': 'advanced_mamba_L.pth',
        'state_file': 'training_state.pth'
    }

    # Initialize model, optimizer, loss
    model = HybridMamba().to(config['device'])
    optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    criterion = HybridReconstructionLoss()

    # Initialize tracking variables
    start_epoch = 0
    best_loss = float('inf')
    loss_history = {
        'total': [],
        'mse': [],
        'ssim': [],
        'dice': []
    }

    if os.path.exists(config['state_file']):
        print(f"Loading checkpoint from {config['state_file']}...")
        checkpoint = torch.load(config['state_file'], map_location=config['device'])
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        loss_history = checkpoint['loss_history']
        best_loss = checkpoint['best_loss']
        start_epoch = checkpoint['epoch'] + 1  # Start from next epoch
        print(f"Resumed from epoch {start_epoch}")

    # Data Transforms and DataLoader
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])
    dataset = SegmentationDataset(config['image_dir'], config['mask_dir'], transform)
    loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0)

    # Training loop
    for epoch in range(start_epoch, config['num_epochs']):
        model.train()
        epoch_losses = {'total': 0, 'mse': 0, 'ssim': 0, 'dice': 0}

        pbar = tqdm(loader, desc=f'Epoch {epoch+1}/{config["num_epochs"]}')
        for images, masks in pbar:
            images, masks = images.to(config['device']), masks.to(config['device'])
            optimizer.zero_grad()

            outputs = model(images)
            pred = torch.clamp(outputs, 0, 1)

            mse = criterion.mse_loss(pred, masks)
            ssim_loss = 1 - ssim(pred, masks, data_range=1.0, size_average=True)
            intersection = (pred * masks).sum()
            dice_loss = 1 - (2. * intersection + criterion.eps) / (pred.sum() + masks.sum() + criterion.eps)
            total_loss = (criterion.weights['mse'] * mse +
                          criterion.weights['ssim'] * ssim_loss +
                          criterion.weights['dice'] * dice_loss)

            total_loss.backward()
            optimizer.step()

            epoch_losses['total'] += total_loss.item()
            epoch_losses['mse'] += mse.item()
            epoch_losses['ssim'] += ssim_loss.item()
            epoch_losses['dice'] += dice_loss.item()

            pbar.set_postfix({
                'loss': total_loss.item(),
                'mse': mse.item(),
                'ssim': ssim_loss.item(),
                'dice': dice_loss.item()
            })

        # Average losses
        for key in epoch_losses:
            epoch_losses[key] /= len(loader)
            loss_history[key].append(epoch_losses[key])

        # Print summary
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"Total Loss: {epoch_losses['total']:.4f}")
        print(f"MSE Loss: {epoch_losses['mse']:.4f}")
        print(f"SSIM Loss: {epoch_losses['ssim']:.4f}")
        print(f"Dice Loss: {epoch_losses['dice']:.4f}")

        # Save best model
        if epoch_losses['total'] < best_loss:
            best_loss = epoch_losses['total']
            torch.save(model.state_dict(), config['save_path'])
            print(f"Saved new best model with loss: {best_loss:.4f}")

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'loss_history': loss_history,
            'best_loss': best_loss
        }, config['state_file'])

if __name__ == '__main__':
    train()


Loading checkpoint from training_state.pth...
Resumed from epoch 102


Epoch 103/200: 100%|██████████| 607/607 [19:34<00:00,  1.94s/it, loss=0.00986, mse=0.00174, ssim=0.00845, dice=0.0323] 



Epoch 103 Summary:
Total Loss: 0.0170
MSE Loss: 0.0028
SSIM Loss: 0.0088
Dice Loss: 0.0648


Epoch 104/200: 100%|██████████| 607/607 [19:30<00:00,  1.93s/it, loss=0.0976, mse=0.049, ssim=0.0772, dice=0.249]       



Epoch 104 Summary:
Total Loss: 0.0227
MSE Loss: 0.0041
SSIM Loss: 0.0111
Dice Loss: 0.0866


Epoch 105/200:  44%|████▍     | 268/607 [08:38<10:56,  1.94s/it, loss=0.00785, mse=0.00132, ssim=0.00714, dice=0.0252] 


KeyboardInterrupt: 

In [1]:
import torch

# Load the checkpoint
checkpoint = torch.load('training_state.pth', map_location='cuda')  # or map_location='cuda' if needed

# Extract all saved items
epoch = checkpoint['epoch']
model_state = checkpoint['model_state']
optimizer_state = checkpoint['optimizer_state']
loss_history = checkpoint['loss_history']
best_loss = checkpoint['best_loss']

# Print or inspect
print(f"Epoch: {epoch}")
print(f"Best Loss: {best_loss}")
print(f"Loss History: {loss_history.keys()}")


Epoch: 103
Best Loss: 0.01350906465158075
Loss History: dict_keys(['total', 'mse', 'ssim', 'dice'])


In [2]:
# Load the saved training state
checkpoint = torch.load('training_state.pth', map_location='cuda')

# Extract loss history
loss_history = checkpoint.get('loss_history', {})


In [3]:
# Print available loss types
print("Loss types:", list(loss_history.keys()))

# Example: Print loss values for each epoch for each type
for loss_type, losses in loss_history.items():
    print(f"\n{loss_type.upper()} Losses:")
    for epoch, loss_value in enumerate(losses):
        print(f"Epoch {epoch + 1}: {loss_value:.4f}")


Loss types: ['total', 'mse', 'ssim', 'dice']

TOTAL Losses:
Epoch 1: 0.2078
Epoch 2: 0.1621
Epoch 3: 0.1506
Epoch 4: 0.1427
Epoch 5: 0.1368
Epoch 6: 0.1319
Epoch 7: 0.1250
Epoch 8: 0.1186
Epoch 9: 0.1144
Epoch 10: 0.1088
Epoch 11: 0.1081
Epoch 12: 0.1010
Epoch 13: 0.0964
Epoch 14: 0.0930
Epoch 15: 0.0908
Epoch 16: 0.0866
Epoch 17: 0.0830
Epoch 18: 0.0821
Epoch 19: 0.0786
Epoch 20: 0.0753
Epoch 21: 0.0707
Epoch 22: 0.0681
Epoch 23: 0.0689
Epoch 24: 0.0645
Epoch 25: 0.0638
Epoch 26: 0.0604
Epoch 27: 0.0560
Epoch 28: 0.0563
Epoch 29: 0.0529
Epoch 30: 0.0563
Epoch 31: 0.0517
Epoch 32: 0.0563
Epoch 33: 0.0506
Epoch 34: 0.0491
Epoch 35: 0.0471
Epoch 36: 0.0464
Epoch 37: 0.0468
Epoch 38: 0.0407
Epoch 39: 0.0425
Epoch 40: 0.0414
Epoch 41: 0.0410
Epoch 42: 0.0396
Epoch 43: 0.0373
Epoch 44: 0.0347
Epoch 45: 0.0341
Epoch 46: 0.0377
Epoch 47: 0.0361
Epoch 48: 0.0328
Epoch 49: 0.0351
Epoch 50: 0.0316
Epoch 51: 0.0332
Epoch 52: 0.0339
Epoch 53: 0.0340
Epoch 54: 0.0286
Epoch 55: 0.0309
Epoch 56: 0.02