In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import yaml

In [8]:
from ITSRN.code import models

In [9]:
global config
with open("ITSRN\code\configs\\train\\train_itnsr.yaml", 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    print('config loaded.')

config loaded.


In [10]:
def make_coord(shape, ranges=None, flatten=True):
    """
    Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()  # n coordinates between v0 and v1
        coord_seqs.append(seq)

    # Make coordinate grid
    ret = torch.stack(torch.meshgrid(*coord_seqs, indexing='ij'), dim=-1)
    if flatten:
        ret = ret.reshape(-1, ret.shape[-1])
    return ret

def get_coord(img, scale):
    """
    Get coordinates for the given image and target scale.

    Args:
        img: Input image tensor of shape (B, C, H, W)
        scale: Target upsampling scale

    Returns:
        coord: Coordinate tensor
    """
    B, C, H, W = img.shape

    # Calculate target size
    H_up = int(H * scale)
    W_up = int(W * scale)

    # Generate normalized coordinates for target size
    coord = make_coord((H_up, W_up))  # (H_up*W_up, 2)

    # Add batch dimension if needed
    if B > 1:
        coord = coord.unsqueeze(0).repeat(B, 1, 1)

    return coord

In [11]:
def make_dense(nChannels, growthRate):
    return nn.Sequential(
        nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False),
        nn.ReLU(inplace=True)
    )

class RDB(nn.Module):
    def __init__(self, nChannels, nDenselayer, growthRate):
        super(RDB, self).__init__()
        nChannels_ = nChannels
        modules = []
        for i in range(nDenselayer):
            modules.append(make_dense(nChannels_, growthRate))
            nChannels_ += growthRate
        self.dense_layers = nn.ModuleList(modules)
        self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False)

    def forward(self, x):
        features = [x]
        for layer in self.dense_layers:
            out = layer(torch.cat(features, 1))
            features.append(out)
        out = self.conv_1x1(torch.cat(features, 1))
        return out + x

class EnhancedFeatureExtraction(nn.Module):
    def __init__(self, input_channels=3, base_channels=64, num_rdb=2, num_dense_layers=6, growth_rate=32):
        super(EnhancedFeatureExtraction, self).__init__()

        # Shallow feature extraction
        self.sfe1 = nn.Conv2d(input_channels, base_channels, kernel_size=3, padding=1, bias=False)
        self.sfe2 = nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1, bias=False)

        # RDB blocks for local feature extraction
        self.rdbs = nn.ModuleList([
            RDB(base_channels, num_dense_layers, growth_rate)
            for _ in range(num_rdb)
        ])

        # Enhanced Feature Fusion
        total_rdb_channels = base_channels * (num_rdb + 1)  # +1 for initial features

        # First fusion layer - reduces to half of total dimensions
        self.fusion1 = nn.Sequential(
            nn.Conv2d(total_rdb_channels, total_rdb_channels // 2, kernel_size=1, bias=False),
            nn.ReLU(inplace=True)
        )

        # Fusion-enhanced layer - further reduces dimension to base_channels
        self.fusion2 = nn.Sequential(
            nn.Conv2d(total_rdb_channels // 2, base_channels, kernel_size=1, bias=False),
            nn.ReLU(inplace=True)
        )

        # Context Enhancement Block
        self.context_enhancement = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=2, dilation=2, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=4, dilation=4, bias=False),
            nn.ReLU(inplace=True)
        )

        # Final conv
        self.conv_out = nn.Conv2d(base_channels, base_channels, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        # Shallow feature extraction
        sfe1 = self.sfe1(x)
        sfe2 = self.sfe2(sfe1)

        # Local feature extraction with RDBs
        rdb_in = sfe2
        local_features = [sfe2]  # Include shallow features

        for rdb in self.rdbs:
            rdb_out = rdb(rdb_in)
            local_features.append(rdb_out)
            rdb_in = rdb_out

        # Enhanced Feature Fusion
        # 1. Concatenate all local features
        concat_features = torch.cat(local_features, 1)
        # 2. First fusion - reduce to half dimension
        fused1 = self.fusion1(concat_features)
        # 3. Enhanced fusion - further reduce dimension
        fused2 = self.fusion2(fused1)

        # Context enhancement
        enhanced = self.context_enhancement(fused2)

        # Final output
        out = self.conv_out(enhanced)

        return out

In [12]:
class FFESR(nn.Module):
    def __init__(self, input_channels=3, base_channels=64, num_rdb=2, num_dense_layers=6, growth_rate=32):
        super(FFESR, self).__init__()
        self.enhanced_feature_extraction = EnhancedFeatureExtraction(input_channels, base_channels, num_rdb, num_dense_layers, growth_rate)
        self.output_conv = nn.Sequential(
            nn.Conv2d(base_channels, 32, kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 16, kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 3, kernel_size=3, padding=1, bias=True)
        )
        self.itsrn = models.make(config['model'])

    def forward(self, x):
        x = self.enhanced_feature_extraction(x)
        y = self.itsrn(x, get_coord(x, 2), 2)
        z = self.output_conv(x)
        return (y, z)

In [13]:
model = FFESR(
        input_channels=3,
        base_channels=3,
        num_rdb=3,
        num_dense_layers=6,
        growth_rate=32
    )

# Create sample input
sample_input = torch.randn(1, 3, 6, 6)

# Forward pass
output = model(sample_input)
print(f"Input shape: {sample_input.shape}")
print(f"Output shape: {output[0].shape}, {output[1].shape}")

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


AssertionError: Torch not compiled with CUDA enabled