In [1]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = "./cache"

import cv2
import torch
from transformers import SwinModel
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from transformers.models.swin.modeling_swin import (
    SwinLayer,
    SwinSelfAttention,
    SwinSelfOutput,
    SwinIntermediate,
    SwinOutput
)
from transformers import SwinConfig
import numpy as np

from tqdm import tqdm
import matplotlib.pyplot as plt



2025-05-19 19:44:12.155484: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-19 19:44:12.157647: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-05-19 19:44:12.197600: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
import inspect

print(inspect.signature(SwinSelfAttention.__init__))
print(inspect.signature(SwinConfig.__init__))

In [5]:
import pathlib

In [6]:
TRAIN_PATH = "ISIC2018_Task1-2_Training_Input"
MASK_PATH = "ISIC2018_Task1_Training_GroundTruth"

## Model Constructing

### Test now(2025.5.19) still in test but can run now

In [77]:
def window_partition(x, window_size, H, W):
    B, L, C = x.shape
    print(f'B: {B}, L: {L}, C: {C}')
    x = x.view(B, H, W, C)
    x = x.unfold(1, window_size, window_size).unfold(2, window_size, window_size)
    print(f"x_after unfold: {x.shape}")
    x = x.contiguous().view(-1, window_size*window_size, C)
    print(f"x_after view: {x.shape}")
    return x

In [78]:
def window_reverse(windows, window_size, H, W, B):
    num_windows = (H // window_size) * (W // window_size)
    x = windows.view(B, num_windows, window_size * window_size, -1)
    print(f"x_after view: {x.shape}")
    x = x.view(B, H//window_size, W//window_size, window_size, window_size, -1)
    print(f"x_after view: {x.shape}")
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    print(f"x_after permute: {x.shape}")
    return x.view(B, H*W, -1)

In [79]:
from transformers.models.swin.modeling_swin import SwinSelfAttention, SwinSelfOutput, SwinIntermediate, SwinOutput
import torch

class SimpleSwinBlock(torch.nn.Module):
    def __init__(self, config, dim, num_heads):
        super().__init__()
        self.config = config
        self.dim = dim
        self.num_heads = num_heads
        # self.attn = SwinSelfAttention(config, dim, num_heads, window_size)
        self.self_output = SwinSelfOutput(config, dim)
        self.intermediate = SwinIntermediate(config, dim)
        self.output = SwinOutput(config, dim)
        self.norm1 = torch.nn.LayerNorm(dim, eps=config.layer_norm_eps)
        self.norm2 = torch.nn.LayerNorm(dim, eps=config.layer_norm_eps)

    def forward(self, x, H, W):
        max_window = 7
        window_size = min(max_window, H, W)
        for ws in range(window_size, 0, -1):
            if H % ws == 0 and W % ws == 0:
                window_size = ws
                break
        print(f'H: {H}, W: {W}, window_size: {window_size}')
        attn = SwinSelfAttention(self.config, self.dim, self.num_heads, window_size)
        shortcut = x
        B = shortcut.shape[0]
        print(f'shortcut.shape: {shortcut.shape}')
        # (batch_size, seq_len, dim)
        x = self.norm1(x)
        print(f"self.norm1(x).shape: {x.shape}")
        # (batch_size, seq_len, dim)
        x_windows = window_partition(x, window_size, H, W)
        print(f"x_windows.shape: {x_windows.shape}")
        x_windows = attn(x_windows)
        if isinstance(x_windows, tuple):
            x_windows = x_windows[0]
        x = window_reverse(x_windows, window_size, H, W, B)
        print(f"x.shape: {x.shape}")
        # (batch_size, seq_len, dim)
        x = self.self_output(x, shortcut)
        print(f"self_output.shape: {x.shape}")
        # (batch_size, seq_len, dim)
        shortcut2 = x
        x = self.norm2(x)
        x = self.intermediate(x)
        print(f"intermediate.shape: {x.shape}")
        #(batch_size, seq_len, 4*dim)
        x = self.output(x)
        print(f"output.shape: {x.shape}")
        x = x + shortcut2
        print(f"x.shape: {x.shape}")
        #(batch_size, seq_len, dim)
        return x

In [80]:
class PatchExpand(torch.nn.Module):
    def __init__(self, input_dim, output_dim, scale=2):
        super().__init__()
        self.proj = torch.nn.Linear(input_dim, output_dim * scale * scale)
        self.scale = scale
        self.output_dim = output_dim

    def forward(self, x):
        B, L, C = x.shape
        H = W = int(L ** 0.5)
        x = self.proj(x)  # (B, L, output_dim*scale*scale)
        x = x.view(B, H, W, self.scale, self.scale, self.output_dim)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
        # (B, H, scale, W, scale, output_dim)
        x = x.view(B, H * self.scale, W * self.scale, self.output_dim)
        return x.view(B, -1, self.output_dim)
        # (B, L*scale*scale, output_dim)

In [81]:
class SwinTransformerDecoder(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        # 对应编码器输出的通道和分辨率
        self.dims = [768, 384, 192, 96]
        self.heads = [32, 16, 8, 4]
        self.config = config

        self.up_blocks = torch.nn.ModuleList()
        for i in range(3):
            self.up_blocks.append(PatchExpand(self.dims[i], self.dims[i+1], scale=2))
            self.up_blocks.append(torch.nn.Linear(self.dims[i+1]*2, self.dims[i+1]))  # 融合skip
            self.up_blocks.append(SimpleSwinBlock(config, self.dims[i+1], self.heads[i+1]))

        self.final_proj = torch.nn.Linear(self.dims[-1], 1)

    def forward(self, features):
        x = features[0]  # 7x7x768
        skips = features[1:]  # [14x14x384, 28x28x192, 56x56x96]
        for i in range(3):
            x = self.up_blocks[i*3](x)  # PatchExpand
            skip = skips[i]
            # 上采样skip到x的空间分辨率
            if x.shape[1] != skip.shape[1]:
                B, L_skip, C_skip = skip.shape
                H_skip = W_skip = int(L_skip ** 0.5)
                H_x = W_x = int(x.shape[1] ** 0.5)
                skip_ = skip.view(B, H_skip, W_skip, C_skip).permute(0, 3, 1, 2)
                skip_ = torch.nn.functional.interpolate(skip_, size=(H_x, W_x), mode='bilinear', align_corners=False)
                skip = skip_.permute(0, 2, 3, 1).reshape(B, H_x * W_x, C_skip)
            x = torch.cat([x, skip], dim=-1)
            x = self.up_blocks[i*3+1](x)
            B, L, C = x.shape
            H = W = int(L ** 0.5)
            x = self.up_blocks[i*3+2](x, H, W)
        # 后续同前
        B, L, C = x.shape
        H = W = int(L ** 0.5)
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)
        x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        x = x.permute(0, 2, 3, 1).reshape(B, -1, C)
        x = self.final_proj(x)
        x = x.view(B, 224, 224, 1).permute(0, 3, 1, 2)
        x = torch.sigmoid(x)
        return x

In [82]:
from transformers import SwinModel, SwinConfig

class SwinUNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224", output_hidden_states=True)
        self.decoder = SwinTransformerDecoder(self.encoder.config)
        for name, param in self.encoder.named_parameters():
            if "layers.2" not in name and "layers.3" not in name:
                param.requires_grad = False

    def forward(self, x):
        outputs = self.encoder(x)
        hs = outputs.hidden_states
        print(f'hs[0].shape: {hs[0].shape}')
        print(f'hs[1].shape: {hs[1].shape}')
        print(f'hs[2].shape: {hs[2].shape}')
        print(f'hs[3].shape: {hs[3].shape}')
        print(f'hs[4].shape: {hs[4].shape}')
        # [7x7x768, 14x14x384, 28x28x192, 56x56x96]
        features = []
        for i in [3, 2, 1, 0]:
            feat = hs[i]
            B, L, C = feat.shape
            H = W = int(L ** 0.5)
            features.append(feat.view(B, H, W, C).reshape(B, -1, C))
        return self.decoder(features)

In [83]:
model = SwinUNet()
dummy_input = torch.randn(2, 3, 224, 224)
output = model(dummy_input)
print('输出shape:', output.shape)  # 应为 [2, 1, 224, 224]

hs[0].shape: torch.Size([2, 3136, 96])
hs[1].shape: torch.Size([2, 784, 192])
hs[2].shape: torch.Size([2, 196, 384])
hs[3].shape: torch.Size([2, 49, 768])
hs[4].shape: torch.Size([2, 49, 768])
H: 14, W: 14, window_size: 7
shortcut.shape: torch.Size([2, 196, 384])
self.norm1(x).shape: torch.Size([2, 196, 384])
B: 2, L: 196, C: 384
x_after unfold: torch.Size([2, 2, 2, 384, 7, 7])
x_after view: torch.Size([8, 49, 384])
x_windows.shape: torch.Size([8, 49, 384])
x_after view: torch.Size([2, 4, 49, 384])
x_after view: torch.Size([2, 2, 2, 7, 7, 384])
x_after permute: torch.Size([2, 14, 14, 384])
x.shape: torch.Size([2, 196, 384])
self_output.shape: torch.Size([2, 196, 384])
intermediate.shape: torch.Size([2, 196, 1536])
output.shape: torch.Size([2, 196, 384])
x.shape: torch.Size([2, 196, 384])
H: 28, W: 28, window_size: 7
shortcut.shape: torch.Size([2, 784, 192])
self.norm1(x).shape: torch.Size([2, 784, 192])
B: 2, L: 784, C: 192
x_after unfold: torch.Size([2, 4, 4, 192, 7, 7])
x_after view:

### Old attempt(failed)

In [62]:
class PatchExpand(torch.nn.Module):
    def __init__(self, dim, dim_scale=2):
        super().__init__()
        self.dim = dim
        self.expand = torch.nn.Linear(dim, dim*dim_scale**2, bias=False)
        self.dim_scale = dim_scale

    def forward(self, x):
        '''
        x: (B, L, C)
        output: (B, L, C//4)
        '''
        B, L, C = x.shape
        H = W = int(L**0.5)
        x = self.expand(x)  # (B,L,4C)
        x = x.view(B, H, W, C*self.dim_scale**2)
        x = x.reshape(B, H*self.dim_scale, W*self.dim_scale, C)
        return x.reshape(B, (H*self.dim_scale)**2, C)

In [63]:
encoder_config = SwinConfig(
    image_size = 224,
    patch_size = 4,
    depths = [2, 2, 18, 2],
    num_heads = [4, 8, 16, 32],
    embed_dim = 96
)

class DecoderConfig(SwinConfig):
    def __init__(
            self,
            decoder_depths:list,
            hidden_size:int,
            encoder_resolution:list,
            qkv_bias:bool = True,
            attn_drop_rate:float = 0.0,
            proj_drop_rate:float = 0.0,
            **kwargs
    ):
        super().__init__(
            qkv_bias=qkv_bias,
            attn_drop_rate=attn_drop_rate,
            proj_drop_rate=proj_drop_rate,
            **kwargs
        )
        self.decoder_depths = decoder_depths
        self.hidden_size = hidden_size
        self.encoder_resolution = encoder_resolution

decoder_config = DecoderConfig(
    decoder_depths=[2, 2, 2, 2],
    hidden_size=96,
    encoder_resolution=[
        (96, (56, 56)),
        (192, (28, 28)),
        (384, (14, 14)),
        (768, (7, 7))
    ],
    qkv_bias=True,
    attn_drop_rate=0.1,
    proj_drop_rate=0.1,
    num_heads=[32, 16, 8, 4],
    window_size=7,
    embed_dim=96
)

In [64]:
class SwinTransformerBlock(torch.nn.Module):
    def __init__(self, config: DecoderConfig, dim, input_resolution, num_heads, shift_size=0):
        super().__init__()
        self.attention = SwinSelfAttention(
            config, dim, num_heads, window_size=7
        )
        self.self_output = SwinSelfOutput(config, dim)
        self.intermediate = SwinIntermediate(config, dim)
        self.output = SwinOutput(config, dim)
        self.layernorm_before = torch.nn.LayerNorm(dim, eps=config.layer_norm_eps)
        self.layernorm_after = torch.nn.LayerNorm(dim, eps=config.layer_norm_eps)
    
    def forward(self, hidden_states):
        input = hidden_states
        hidden_states = self.layernorm_before(hidden_states)
        attention_output = self.attention(hidden_states)
        attention_output = self.self_output(attention_output, input)
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        layer_output = self.layernorm_after(layer_output + input)
        return layer_output

        

In [65]:
class TransformerUNetDecoder(torch.nn.Module):
    def __init__(self, config: DecoderConfig, encoder_resolutions):
        super().__init__()
        self.stages = torch.nn.ModuleList()
        for i, (dim, (h, w)) in enumerate(reversed(encoder_resolutions)):
            # Upsample
            if i != 0:
                self.stages.append(PatchExpand(dim, dim_scale=2))
            
            stage = torch.nn.Sequential(
                *[SwinTransformerBlock(
                    config,
                    dim=dim,
                    input_resolution=(h*(2**i), w*(2**i)),
                    num_heads=config.num_heads[i]
                ) for _ in range(config.depths[i])]
            )
            self.stages.append(stage)

        self.final = torch.nn.Conv2d(config.hidden_size, 1, kernel_size=1)
    
    def forward(self, x, skips):
        for i, stage in enumerate(self.stages):
            if isinstance(stage, PatchExpand):
                x = stage(x)
                
                x = torch.cat([x, skips[i]], dim=-1)
            else:
                x = stage(x)
        B, L, C = x.shape
        H = W = int(L**0.5)
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)
        return torch.sigmoid(self.final(x))

In [66]:
class SwinUNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
        self.decoder = TransformerUNetDecoder(decoder_config, decoder_config.encoder_resolutions)

        for name, param in self.encoder.named_parameters():
            if "layers.3" not in name:
                param.requires_grad = False
    
    def forward(self, x):
        encoder_outputs = self.encoder(x, output_hidden_states=True).hidden_states
        skips = list(reversed(encoder_outputs[:-1]))

        return self.decoder(encoder_outputs[-1], skips)

In [67]:
model = SwinUNet()
dummy_input = torch.randn(2, 3, 224, 224)
output = model(dummy_input)
print(output.shape)  # 应输出 torch.Size([2, 1, 224, 224])

# 梯度检查
loss = torch.nn.BCELoss()(output, torch.rand(2,1,224,224))
loss.backward()
print([p.grad is not None for p in model.parameters()])  # 应显示部分梯度存在

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
SIZE = 224
LEARNING_RATE = 1e-4
NUM_EPOCHS = 20

In [None]:
# decoder = SwinUNetDecoder()

# dummy_features = [
#     # torch.randn(1, 768, 7, 7),    # Stage4
#     torch.randn(1, 384, 14, 14),  # Stage3
#     torch.randn(1, 192, 28, 28),  # Stage2
#     torch.randn(1, 96, 56, 56)    # Stage1
# ]

# output = decoder(dummy_features)
# print(f"Final output shape: {output.shape}")  # 应输出 [1,1,224,224]

## Data Preparation

In [None]:
class ISICDataset(Dataset):
    def __init__ (self, image_dir, mask_dir, size, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.ids = [image_dir[:-4] for image_dir in os.listdir(image_dir) if image_dir.endswith('.jpg')]
        self.size = size
    
    def __len__ (self):
        return len(self.ids)
    
    def __getitem__ (self, idx):
        image_path = os.path.join(self.image_dir, self.ids[idx] + ".jpg")
        mask_path = os.path.join(self.mask_dir, self.ids[idx] + "_segmentation.png")

        img = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        if self.transform is not None:
            seed = torch.random.seed()
            torch.random.manual_seed(seed)
            img = self.transform(img)
            torch.random.manual_seed(seed)
            mask = self.transform(mask)
        
        mask = np.array(mask)
        mask = np.expand_dims(mask, axis=0)
        mask = torch.from_numpy(mask).float()
        mask = (mask > 127).float() # 二值化处理

        img = transforms.functional.normalize(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        return img, mask

#### Loading and Preparing

In [None]:
transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.RandomRotation(30),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor()
])

dataset = ISICDataset(image_dir=TRAIN_PATH, mask_dir=MASK_PATH, size=SIZE, transform=transform)
train_size = int(0.8 * len(dataset))
train_dataset = Subset(dataset, range(train_size))
val_dataset = Subset(dataset, range(train_size, len(dataset)))

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)


## Build the training process