In [2]:
import os
import sys

current_dir = os.getcwd()

parent_dir = os.path.abspath(os.path.join(current_dir, ".."))

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

In [41]:
import torch
from torch import nn
import torch.nn.functional as F

class SelfAttentionUNetDecoder(nn.Module):
    def __init__(
        self,
        n_classes=1,
        depth=5,  # Ensure depth is sufficient for the desired output size
        wf=6,
        padding=True,
        batch_norm=False,
        up_mode="upconv",
        kernel_size=3,
    ):
        super(SelfAttentionUNetDecoder, self).__init__()
        assert up_mode in ("upconv", "upsample")
        self.padding = padding
        self.depth = depth

        prev_channels = 2 ** (wf + depth - 1)
        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        # Final layer to convert to desired output channels (n_classes)
        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x, encoder_features):
        for i, up in enumerate(self.up_path):
            x = up(x, encoder_features[-i - 2])
        return self.last(x)

class SelfAttentionBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, kernel_size=3):
        super(SelfAttentionBlock, self).__init__()

        padding_ = int(padding) * (kernel_size - 1) // 2
        self.conv = nn.Conv2d(
            in_size,
            out_size,
            kernel_size=kernel_size,
            padding=padding_,
            padding_mode="reflect",
        )
        self.attention = nn.Conv2d(
            in_size,
            out_size,
            kernel_size=kernel_size,
            padding=padding_,
            padding_mode="reflect",
            bias=False,
        )
        with torch.no_grad():
            self.attention.weight.copy_(torch.zeros_like(self.attention.weight))

    def forward(self, x):
        features = self.conv(x)
        attention_mask = torch.sigmoid(self.attention(x))
        return features * attention_mask

class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm, kernel_size=3):
        super(UNetConvBlock, self).__init__()

        self.self_attention1 = SelfAttentionBlock(
            in_size, out_size, padding, kernel_size
        )
        self.self_attention2 = SelfAttentionBlock(
            out_size, out_size, padding, kernel_size
        )
        self.batch_norm = batch_norm
        if batch_norm:
            self.batch_norm1 = nn.BatchNorm2d(out_size)
            self.batch_norm2 = nn.BatchNorm2d(out_size)

    def forward(self, x):
        x = F.relu(self.self_attention1(x))
        if self.batch_norm:
            x = self.batch_norm1(x)

        x = F.relu(self.self_attention2(x))
        if self.batch_norm:
            x = self.batch_norm2(x)
        return x

class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == "upconv":
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == "upsample":
            self.up = nn.Sequential(
                nn.Upsample(mode="bilinear", scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def forward(self, x, bridge):
        up = self.up(x)
        # Handle padding to ensure dimensions match correctly
        diffY = bridge.size()[2] - up.size()[2]
        diffX = bridge.size()[3] - up.size()[3]
        up = F.pad(up, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
        out = torch.cat([up, bridge], 1)
        out = self.conv_block(out)

        return out

# Define the parameters
batch_size = 16
in_channels = 512
output_size = 256
feature_map_size = 8
n_classes = 1
depth = 4  # Ensure depth is sufficient for upsampling to 256x256
wf = 6
padding = True
batch_norm = False
up_mode = "upconv"
kernel_size = 3

# Create a random tensor simulating the feature maps from the encoder with correct input channels
x = torch.randn(batch_size, 2 ** (wf + depth - 1), feature_map_size, feature_map_size)
#x = torch.randn(batch_size, 512, 8, 8)

# Adjust the simulated encoder features to have correct shapes
encoder_features = [
    torch.randn(batch_size, 2 ** (wf + i), feature_map_size * (2 ** (depth - i - 1)), feature_map_size * (2 ** (depth - i - 1)))
    for i in range(depth)
]

# Initialize the decoder
decoder = SelfAttentionUNetDecoder(
    n_classes=n_classes,
    depth=depth,
    wf=wf,
    padding=padding,
    batch_norm=batch_norm,
    up_mode=up_mode,
    kernel_size=kernel_size
)

print(f"Input: Shape: {x.shape}")  # Debug print

for i, feature in enumerate(encoder_features):
    print(f"Encoder Feature {i}: Shape: {feature.shape}")  # Debug print

# Run the decoder
output = decoder(x, encoder_features)

# Check the shape of the output
print("Output shape:", output.shape)

# Expected output shape: (16, 1, 256, 256)


Input: Shape: torch.Size([16, 512, 8, 8])
Encoder Feature 0: Shape: torch.Size([16, 64, 64, 64])
Encoder Feature 1: Shape: torch.Size([16, 128, 32, 32])
Encoder Feature 2: Shape: torch.Size([16, 256, 16, 16])
Encoder Feature 3: Shape: torch.Size([16, 512, 8, 8])
Output shape: torch.Size([16, 1, 64, 64])


In [43]:

import torch
from huggingface_hub import hf_hub_download
from treemort.modeling.network.unet_mtd import smp_unet_mtd

class PretrainedUNetModel:
    def __init__(self, repo_id, filename, architecture="unet", encoder="resnet34", n_channels=4, n_classes=15, use_metadata=False):
        self.repo_id = repo_id
        self.filename = filename
        self.architecture = architecture
        self.encoder = encoder
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.use_metadata = use_metadata

        self.checkpoint_path = hf_hub_download(repo_id=self.repo_id, filename=self.filename)

        self.model = self._initialize_model()

        self._load_pretrained_weights()

    def _initialize_model(self):
        model = smp_unet_mtd(
            architecture=self.architecture,
            encoder=self.encoder,
            n_channels=self.n_channels,
            n_classes=self.n_classes,
            use_metadata=self.use_metadata
        )
        return model

    def _load_pretrained_weights(self):
        self.model.load_state_dict(torch.load(self.checkpoint_path), strict=False)

    def get_model(self):
        return self.model

class FeatureExtractor(nn.Module):
    def __init__(self, model, use_metadata=False):
        super(FeatureExtractor, self).__init__()
        self.model = model
        self.use_metadata = use_metadata
        self.features = []

        # Hook the layers of the encoder
        layers = [
            self.model.seg_model.encoder.layer1[-1],
            self.model.seg_model.encoder.layer2[-1],
            self.model.seg_model.encoder.layer3[-1],
            self.model.seg_model.encoder.layer4[-1],
        ]
        for layer in layers:
            layer.register_forward_hook(self.hook)

    def hook(self, module, input, output):
        self.features.append(output)

    def forward(self, x, met=None):
        self.features = []
        if self.use_metadata:
            self.model(x, met)
        else:
            self.model(x)
        return self.features
    
class CombinedModel(nn.Module):
    def __init__(self, pretrained_model, n_classes=3):
        super(CombinedModel, self).__init__()
        self.feature_extractor = FeatureExtractor(pretrained_model)
        self.decoder = SelfAttentionUNetDecoder(
            n_classes=n_classes,
            depth=4,
            wf=6,
            padding=True,
            batch_norm=False,
            up_mode="upconv",
            kernel_size=3
        )

    def forward(self, x):
        print(f"Input: Shape: {x.shape}")  # Debug print
        encoder_features = self.feature_extractor(x)
        for i, feature in enumerate(encoder_features):
            print(f"Encoder Feature {i}: Shape: {feature.shape}")  # Debug print
        output = self.decoder(encoder_features[-1], encoder_features)
        return output



repo_id = "IGNF/FLAIR-INC_rgbi_15cl_resnet34-unet"
filename = "FLAIR-INC_rgbi_15cl_resnet34-unet_weights.pth"

# Initialize the model
pretrained_model = PretrainedUNetModel(
    repo_id=repo_id,
    filename=filename,
    architecture="unet",
    encoder="resnet34",
    n_channels=4,
    n_classes=15,
    use_metadata=False
)

# Retrieve the model
pretrained_model = pretrained_model.get_model()

model = CombinedModel(pretrained_model=pretrained_model, n_classes=1)

x = torch.randn(16, 4, 256, 256)

outputs = model(x)

outputs.shape

Input: Shape: torch.Size([16, 4, 256, 256])
Encoder Feature 0: Shape: torch.Size([16, 64, 64, 64])
Encoder Feature 1: Shape: torch.Size([16, 128, 32, 32])
Encoder Feature 2: Shape: torch.Size([16, 256, 16, 16])
Encoder Feature 3: Shape: torch.Size([16, 512, 8, 8])


torch.Size([16, 1, 64, 64])

In [None]:
from treemort.main import run
from treemort.utils.config import setup

config_file_path = "../configs/flair_unet_bs8_cs256.txt"

conf = setup(config_file_path)

# Modified Config Variables for Local Execution
conf.data_folder = "/Users/anisr/Documents/AerialImageModel_ITD"
conf.output_dir = os.path.join("..", conf.output_dir)

print(conf)

eval_only = False

run(conf, eval_only)

In [20]:
import torch

from treemort.modeling.network.sa_unet import SelfAttentionUNet

model = SelfAttentionUNet(
    in_channels=4,
    n_classes=1,
    depth=4,
    wf=6,
    batch_norm=True,
)

dummy_input = torch.randn(8, 4, 256, 256)

class FeatureExtractor:
    def __init__(self):
        self.features = None
    
    def hook(self, module, input, output):
        self.features = output

model.encoder.layer4[-1].register_forward_hook(feature_extractor.hook)

feature_extractor = FeatureExtractor()

output = feature_extractor.features

print("Output shape:", output.shape)

AttributeError: 'SelfAttentionUNet' object has no attribute 'encoder'

In [27]:




import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import segmentation_models_pytorch as smp

# Define a hook function to capture the output
def hook_fn(module, input, output):
    global feature_maps
    feature_maps = output

# Load the pretrained model from Hugging Face
checkpoint_path = hf_hub_download(repo_id="IGNF/FLAIR-INC_rgbi_15cl_resnet34-unet", filename="FLAIR-INC_rgbi_15cl_resnet34-unet_weights.pth")

# Assuming smp.Unet to be the model class for the pretrained model
pretrained_model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=4,
    classes=15
)

pretrained_model.load_state_dict(torch.load(checkpoint_path), strict=False)

# Register the hook to a specific layer, e.g., the last layer of layer4 in the encoder
target_layer = pretrained_model.encoder.layer2[-1]
target_layer.register_forward_hook(hook_fn)

# Create a dummy input tensor with shape (batch_size, num_channels, height, width)
dummy_input = torch.randn(16, 4, 256, 256)  # Adjusted for your input shape

# Pass the dummy input through the encoder to capture the feature maps
with torch.no_grad():
    _ = pretrained_model.encoder(dummy_input)  # Pass the input through the encoder part

# Print the shape of the feature maps from the hooked layer
print(f'Hooked layer output shape: {feature_maps.shape}')



Hooked layer output shape: torch.Size([16, 128, 32, 32])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttentionBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, kernel_size=3):
        super(SelfAttentionBlock, self).__init__()

        padding_ = int(padding) * (kernel_size - 1) // 2
        self.conv = nn.Conv2d(
            in_size,
            out_size,
            kernel_size=kernel_size,
            padding=padding_,
            padding_mode="reflect",
        )
        self.attention = nn.Conv2d(
            in_size,
            out_size,
            kernel_size=kernel_size,
            padding=padding_,
            padding_mode="reflect",
            bias=False,
        )
        with torch.no_grad():
            self.attention.weight.copy_(torch.zeros_like(self.attention.weight))

    def forward(self, x):
        features = self.conv(x)
        attention_mask = torch.sigmoid(self.attention(x))
        return features * attention_mask


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm, kernel_size=3):
        super(UNetConvBlock, self).__init__()

        self.self_attention1 = SelfAttentionBlock(
            in_size, out_size, padding, kernel_size
        )
        self.self_attention2 = SelfAttentionBlock(
            out_size, out_size, padding, kernel_size
        )
        self.batch_norm = batch_norm
        if batch_norm:
            self.batch_norm1 = nn.BatchNorm2d(out_size)
            self.batch_norm2 = nn.BatchNorm2d(out_size)

    def forward(self, x):
        x = F.relu(self.self_attention1(x))
        if self.batch_norm:
            x = self.batch_norm1(x)

        x = F.relu(self.self_attention2(x))
        if self.batch_norm:
            x = self.batch_norm2(x)
        return x


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == "upconv":
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == "upsample":
            self.up = nn.Sequential(
                nn.Upsample(mode="bilinear", scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out


class SelfAttentionUNet(nn.Module):
    def __init__(
        self,
        in_channels=15,  # Set this to 15 for your case
        n_classes=3,
        depth=3,
        wf=6,
        padding=True,
        batch_norm=False,
        up_mode="upconv",
        kernel_size=3,
    ):
        super(SelfAttentionUNet, self).__init__()
        assert up_mode in ("upconv", "upsample")
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(
                    prev_channels, 2 ** (wf + i), padding, batch_norm, kernel_size
                )
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        return self.last(x)


In [None]:
class FeatureExtractor:
    def __init__(self):
        self.features = None

    def hook(self, module, input, output):
        self.features = output

    def get_features(self):
        return self.features


In [None]:
import torch
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download

# Download the pretrained weights
checkpoint_path = hf_hub_download(
    repo_id="IGNF/FLAIR-INC_rgbi_15cl_resnet34-unet",
    filename="FLAIR-INC_rgbi_15cl_resnet34-unet_weights.pth"
)

# Define and load the pretrained model
pretrained_model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=4,
    classes=15
)

pretrained_model.load_state_dict(torch.load(checkpoint_path), strict=False)

# Initialize the feature extractor
feature_extractor = FeatureExtractor()

# Register the hook with the specific layer
# Assuming you want the features from the last block of layer4 in ResNet34
pretrained_model.encoder.layer4[-1].register_forward_hook(feature_extractor.hook)


In [None]:
class EncoderDecoderModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoderModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        # Forward pass through the encoder
        _ = self.encoder(x)
        # Extract features from the hook
        features = feature_extractor.get_features()
        # Pass features to the decoder
        output = self.decoder(features)
        return output

# Instantiate the self-attention U-Net decoder
decoder = SelfAttentionUNet(in_channels=256, n_classes=3)  # Adjust `in_channels` as needed

# Create the combined model
model = EncoderDecoderModel(pretrained_model, decoder)


In [None]:
# Create a dummy input batch of size (8, 4, 256, 256)
dummy_input = torch.randn(8, 4, 256, 256)

# Run a forward pass through the model
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    output = model(dummy_input)

# Print the output shape
print("Output shape:", output.shape)

In [None]:
'''
import torch

from huggingface_hub import hf_hub_download

from treemort.modeling.network.unet_mtd import smp_unet_mtd

def create_feature_extractor(model_name, model_type, model_filename):

    checkpoint_path = hf_hub_download(repo_id=model_name, filename=model_filename)

    feature_extractor = smp_unet_mtd(
        architecture="unet",
        encoder="resnet34",
        n_channels=4,
        n_classes=2,
        use_metadata=False,
    )

    feature_extractor.load_state_dict(torch.load(checkpoint_path), strict=False)

    return feature_extractor


feature_extractor = create_feature_extractor(
      "IGNF/FLAIR-INC_rgbi_15cl_resnet34-unet",
      model_type="resnet34_unet",
      model_filename="FLAIR-INC_rgbi_15cl_resnet34-unet_weights.pth",
)

print("[INFO] feature extractor created.")
'''

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
from huggingface_hub import hf_hub_download

# Define the Self-Attention U-Net components as provided
class SelfAttentionBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, kernel_size=3):
        super(SelfAttentionBlock, self).__init__()
        padding_ = int(padding) * (kernel_size - 1) // 2
        self.conv = nn.Conv2d(
            in_size,
            out_size,
            kernel_size=kernel_size,
            padding=padding_,
            padding_mode="reflect",
        )
        self.attention = nn.Conv2d(
            in_size,
            out_size,
            kernel_size=kernel_size,
            padding=padding_,
            padding_mode="reflect",
            bias=False,
        )
        with torch.no_grad():
            self.attention.weight.copy_(torch.zeros_like(self.attention.weight))

    def forward(self, x):
        features = self.conv(x)
        attention_mask = torch.sigmoid(self.attention(x))
        return features * attention_mask

class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm, kernel_size=3):
        super(UNetConvBlock, self).__init__()
        self.self_attention1 = SelfAttentionBlock(
            in_size, out_size, padding, kernel_size
        )
        self.self_attention2 = SelfAttentionBlock(
            out_size, out_size, padding, kernel_size
        )
        self.batch_norm = batch_norm
        if batch_norm:
            self.batch_norm1 = nn.BatchNorm2d(out_size)
            self.batch_norm2 = nn.BatchNorm2d(out_size)

    def forward(self, x):
        x = F.relu(self.self_attention1(x))
        if self.batch_norm:
            x = self.batch_norm1(x)
        x = F.relu(self.self_attention2(x))
        if self.batch_norm:
            x = self.batch_norm2(x)
        return x

class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == "upconv":
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == "upsample":
            self.up = nn.Sequential(
                nn.Upsample(mode="bilinear", scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )
        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)
        return out

class SelfAttentionUNet(nn.Module):
    def __init__(
        self,
        in_channels=15,
        n_classes=3,
        depth=3,
        wf=6,
        padding=True,
        batch_norm=False,
        up_mode="upconv",
        kernel_size=3,
    ):
        super(SelfAttentionUNet, self).__init__()
        assert up_mode in ("upconv", "upsample")
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(
                    prev_channels, 2 ** (wf + i), padding, batch_norm, kernel_size
                )
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        return self.last(x)

class FeatureExtractor:
    def __init__(self):
        self.features = None

    def hook(self, module, input, output):
        self.features = output

    def get_features(self):
        return self.features

# Download the pretrained weights
checkpoint_path = hf_hub_download(
    repo_id="IGNF/FLAIR-INC_rgbi_15cl_resnet34-unet",
    filename="FLAIR-INC_rgbi_15cl_resnet34-unet_weights.pth"
)

# Define and load the pretrained model
pretrained_model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=4,
    classes=15
)

pretrained_model.load_state_dict(torch.load(checkpoint_path), strict=False)

# Initialize the feature extractor
feature_extractor = FeatureExtractor()

# Register the hook with the specific layer
pretrained_model.encoder.layer4[-1].register_forward_hook(feature_extractor.hook)

# Define the Self-Attention U-Net decoder
decoder = SelfAttentionUNet(in_channels=256, n_classes=3)  # Adjust `in_channels` as needed

# Create the combined model
class EncoderDecoderModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoderModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        _ = self.encoder(x)
        features = feature_extractor.get_features()
        output = self.decoder(features)
        return output

model = EncoderDecoderModel(pretrained_model, decoder)

# Create a dummy input batch of size (8, 4, 256, 256)
dummy_input = torch.randn(8, 4, 256, 256)

# Run a forward pass through the model
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    output = model(dummy_input)

# Print the output shape
print("Output shape:", output.shape)


In [51]:
model

EncoderDecoderModel(
  (encoder): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          

Output shape: torch.Size([16, 1, 256, 256])


In [69]:
# Example tensors with different feature map sizes
x = torch.randn(16, 512, 8, 8)  # Example tensor from encoder
bridge = torch.randn(16, 128, 16, 16)  # Example tensor from previous layer

# Create an upsampling block
up_block = UNetUpBlock(512, 256, "upconv", 3, batch_norm=True)

# Perform upsampling
up = up_block.up(x)
print(f"Up sample output shape: {up.shape}")  # Expected: [16, 512, 16, 16]

# Perform center crop on the bridge tensor to match the size of up
crop1 = center_crop(bridge, up.shape[2:])
print(f"Crop1 shape: {crop1.shape}")  # Should match up's height and width

# Concatenate along channel dimension
out = torch.cat([up, crop1], 1)
print(f"Concatenated shape: {out.shape}")  # Channels should be in_size + out_size


Up sample output shape: torch.Size([16, 512, 16, 16])
Crop1 shape: torch.Size([16, 128, 16, 16])
Concatenated shape: torch.Size([16, 640, 16, 16])


In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F


from huggingface_hub import hf_hub_download
import segmentation_models_pytorch as smp


class SelfAttentionUNet(nn.Module):
    def __init__(self, in_channels=512, n_classes=1, depth=3, wf=6, batch_norm=False, up_mode="upconv", kernel_size=3):
        super(SelfAttentionUNet, self).__init__()
        assert up_mode in ("upconv", "upsample")
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), kernel_size, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, kernel_size, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        return self.last(x)


def center_crop(layer, target_size):
    _, _, layer_height, layer_width = layer.size()
    target_height, target_width = target_size

    # Calculate cropping offsets
    diff_y = (layer_height - target_height) // 2
    diff_x = (layer_width - target_width) // 2

    # Ensure the calculated indices are within bounds
    diff_y = max(0, diff_y)
    diff_x = max(0, diff_x)

    return layer[:, :, diff_y:diff_y + target_height, diff_x:diff_x + target_width]

class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, kernel_size, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == "upconv":
            self.up = nn.ConvTranspose2d(in_size, in_size, kernel_size=2, stride=2)
        elif up_mode == "upsample":
            self.up = nn.Sequential(
                nn.Upsample(mode="bilinear", scale_factor=2, align_corners=True),
                nn.Conv2d(in_size, in_size, kernel_size=1),
            )

        # Modify in_size to match the number of channels after concatenation
        self.conv_block = UNetConvBlock(in_size + out_size, out_size, kernel_size, batch_norm)

    def forward(self, x, bridge):
        up = self.up(x)
        print(f"Up sample output shape: {up.shape}")  # Expected: [16, in_size, H', W']
        
        # Use the updated center_crop function
        crop1 = center_crop(bridge, up.shape[2:])
        print(f"Crop1 shape: {crop1.shape}")  # Should match up's height and width
        
        out = torch.cat([up, crop1], 1)  # Concatenate along channel dimension
        print(f"Concatenated shape: {out.shape}")  # Channels should be in_size + out_size
        
        out = self.conv_block(out)
        return out





class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, batch_norm=False):
        super(UNetConvBlock, self).__init__()
        self.self_attention1 = SelfAttentionBlock(in_size, out_size, kernel_size, batch_norm)
        self.self_attention2 = SelfAttentionBlock(out_size, out_size, kernel_size, batch_norm)

    def forward(self, x):
        x = F.relu(self.self_attention1(x))
        x = F.relu(self.self_attention2(x))
        return x




class SelfAttentionBlock(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=3, batch_norm=False):
        super(SelfAttentionBlock, self).__init__()

        padding_ = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(in_size, out_size, kernel_size=kernel_size, padding=padding_, padding_mode="reflect")
        self.attention = nn.Conv2d(in_size, out_size, kernel_size=kernel_size, padding=padding_, padding_mode="reflect", bias=False)
        with torch.no_grad():
            self.attention.weight.copy_(torch.zeros_like(self.attention.weight))

        self.batch_norm = batch_norm
        if batch_norm:
            self.bn = nn.BatchNorm2d(out_size)

    def forward(self, x):
        features = self.conv(x)
        attention_mask = torch.sigmoid(self.attention(x))
        if self.batch_norm:
            features = self.bn(features)
        return features * attention_mask





class FeatureExtractor:
    def __init__(self):
        self.features = None
    
    def hook(self, module, input, output):
        self.features = output

# Download the pretrained weights
checkpoint_path = hf_hub_download(
    repo_id="IGNF/FLAIR-INC_rgbi_15cl_resnet34-unet",
    filename="FLAIR-INC_rgbi_15cl_resnet34-unet_weights.pth"
)

# Define and load the pretrained model
pretrained_model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=4,
    classes=15
)

pretrained_model.load_state_dict(torch.load(checkpoint_path), strict=False)

feature_extractor = FeatureExtractor()
pretrained_model.encoder.layer2[-1].register_forward_hook(feature_extractor.hook)

decoder = SelfAttentionUNet(in_channels=512, n_classes=1, depth=3)

class EncoderDecoderModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoderModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        _ = self.encoder(x)
        features = feature_extractor.features
        print(features.shape)
        return self.decoder(features)


model = EncoderDecoderModel(pretrained_model, decoder)

dummy_input = torch.randn(16, 4, 256, 256)

model.eval()
with torch.no_grad():
    output = model(dummy_input)

print("Output shape:", output.shape)




torch.Size([16, 128, 32, 32])


RuntimeError: Given groups=1, weight of size [64, 512, 3, 3], expected input[16, 128, 34, 34] to have 512 channels, but got 128 channels instead

In [42]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp

# Define the feature extractor class if needed
class FeatureExtractor:
    def __init__(self):
        self.features = None
    
    def hook(self, module, input, output):
        self.features = output

# Download the pretrained weights
checkpoint_path = hf_hub_download(
    repo_id="IGNF/FLAIR-INC_rgbi_15cl_resnet34-unet",
    filename="FLAIR-INC_rgbi_15cl_resnet34-unet_weights.pth"
)

# Define and load the pretrained model
pretrained_model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=4,
    classes=15
)

pretrained_model.load_state_dict(torch.load(checkpoint_path), strict=False)

# Initialize the feature extractor
feature_extractor = FeatureExtractor()

pretrained_model.encoder.layer4[-1].register_forward_hook(feature_extractor.hook)

# Define the decoder model
decoder = SelfAttentionUNet(in_channels=512, n_classes=1)  # Adjust in_channels as needed

# Create combined model
class EncoderDecoderModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoderModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        _ = self.encoder(x)  # Run encoder forward to trigger the hook
        features = feature_extractor.features
        # Pass the features from encoder to decoder
        return self.decoder(features)

model = EncoderDecoderModel(pretrained_model, decoder)

# Create dummy input
dummy_input = torch.randn(8, 4, 256, 256)

# Run the forward pass
model.eval()
with torch.no_grad():
    output = model(dummy_input)

# Print output shape
print("Output shape:", output.shape)  # Should be (8, 1, 256, 256)


UNetConvBlock input shape: torch.Size([8, 512, 8, 8])
UNetConvBlock output shape: torch.Size([8, 64, 8, 8])
UNetConvBlock input shape: torch.Size([8, 64, 4, 4])
UNetConvBlock output shape: torch.Size([8, 128, 4, 4])
UNetConvBlock input shape: torch.Size([8, 128, 2, 2])
UNetConvBlock output shape: torch.Size([8, 256, 2, 2])
UNetConvBlock input shape: torch.Size([8, 384, 4, 4])


RuntimeError: Given groups=1, weight of size [256, 256, 3, 3], expected input[8, 384, 6, 6] to have 256 channels, but got 384 channels instead