# Group 4 Project Version 2 Submission


## Paper Information and Our Information
### **Paper Title:** SeD Semantic-Aware Discriminator for Image Super-Resolution
### **Paper Authors:** Bingchen Li, Xin Li, Hanxin Zhu, Yeying Jin, Ruoyu Feng, Zhizheng Zhang, Zhibo Chen
### **Paper Description:** 
### Github Repository: [Link](https://github.com/YigitEkin/sed)

In this work, researchers highlight the use of Generative Adversarial Networks (GANs) for image super-resolution tasks, particularly focusing on texture recovery. They note a limitation in existing methods where a single discriminator is employed to teach the super-resolution network the distribution of high-quality real-world images, leading to coarse learning and unexpected output. To address this, they introduce a Semantic-aware Discriminator (SeD), which incorporates image semantics to guide the network in learning fine-grained image distributions.

The SeD leverages image semantics extracted from a trained semantic model, allowing the discriminator to discern real and fake images based on different semantic conditions. By integrating semantic features into the discriminator using spatial cross-attention modules, they aim to enhance the SR network's ability to generate more realistic and visually appealing images. The approach capitalizes on pretrained vision models and extensive datasets to enrich the understanding of image semantics and improve the fidelity of super-resolved images.


Authors suggest that Vanilla Discriminators ignore the important semantics of the inputs, hence giving  semantic features of an image ( extracted via a pretrained network ), enables a better discriminator and hence a better feedback for the generator. The situation is demonstrated better on Figure 1, giving semantic features as condition enables the discriminator to specialize by finding boundaries within classes.



<img src="img/fig1.png" style="width:700px; height:auto; display: flex; justify-content: center"/> <br/> <br/>



A classical setup of Super Resolution GAN Network is as below. The generator network takes the low resolution image as input and produces a high resolution image. The discriminator takes the generated and the ground truth high resolution images and classifies as real or fake. In our setup, our generator takes 64x64 low resolution images and generate 256x256 high resolution images.

<img src="img/sr_gan_setup.png" style="width:700px; height:auto; display: flex; justify-content: center"/> <br/> <br/>


The proposed setup of the paper is as below, now the semantic feature maps of the ground truth high resolution images is also given as input to the discriminator. 


<img src="img/sed_gan_setup.png" style="width:700px; height:auto; display: flex; justify-content: center"/> <br/> <br/>


Authors employ two discriminator types, a patch-based discriminator and a pixel-wise discriminator. 
The proposed architecture of the Patch-wise Semantic Aware Discriminator is shown below. 


<img src="img/patchwise_sed.png" style="width:700px; height:auto; display: flex; justify-content: center"/> <br/> <br/>


<details>
  <summary>Patchwise SED</summary>

  ```python
class DownSampler(nn.Module):
    # Downsamples 4 times in a conv, bn, leaky relu fashion that halves the spatial dimensions in each step and doubles the number of filters
    def __init__(self, input_channels, num_filters=64):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_filters, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(num_filters)
        self.leaky_relu = nn.LeakyReLU(0.2)
        
        self.conv2 = nn.Conv2d(num_filters, num_filters * 2, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(num_filters * 2)
        
        self.conv3 = nn.Conv2d(num_filters * 2, num_filters * 4, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(num_filters * 4)
        
        self.conv4 = nn.Conv2d(num_filters * 4, num_filters * 8, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(num_filters * 8)
        
    def forward(self, x):
        x = self.leaky_relu(self.bn1(self.conv1(x)))
        x = self.leaky_relu(self.bn2(self.conv2(x)))
        x = self.leaky_relu(self.bn3(self.conv3(x)))
        x = self.bn4(self.conv4(x))
        return x
    
    
class PatchDiscriminatorWithSeD(nn.Module):
    # PatchGAN discriminator with semantic-aware fusion blocks 
    def __init__(self, input_channels, num_filters=64):
        super().__init__()
        #First downsample the input size from 256x256 to 16x16 to match the semantic feature map size
        self.downsampler = DownSampler(input_channels, num_filters)
        #Use 3 semantic-aware fusion blocks to fuse the semantic feature maps with the downsampled input
        self.semantic_aware_fusion_block1 = SemanticAwareFusionBlock()
        self.semantic_aware_fusion_block2 = SemanticAwareFusionBlock(channel_size_changer_input_nc=1024)
        self.semantic_aware_fusion_block3 = SemanticAwareFusionBlock(channel_size_changer_input_nc=1024)
        #Final convolution to get the output
        self.final_conv = nn.Conv2d(num_filters * 16, 1, kernel_size=4, stride=1, padding=1)
        
    def forward(self, semantic_feature_maps, fs):
        x = self.downsampler(fs)
        x = self.semantic_aware_fusion_block1(semantic_feature_maps, x)
        x = self.semantic_aware_fusion_block2(semantic_feature_maps, x)
        x = self.semantic_aware_fusion_block3(semantic_feature_maps, x)
        x = self.final_conv(x)
        return x
  ```
</details>


The discriminator has a specialized block called Semantic Aware Fusion Block. Semantic Aware Fusion Block takes the ground truth semantic features extracted by CLIP Feature Extractor, and applies cross attention between either ground truth or generated high resolution images as shown below. First the generated (or ground truth) feature maps are processed through normalization and self-attention mechanism , and then cross attention is applied. 


<img src="img/semantic_aware_fb.png" style="width:700px; height:auto; display: flex; justify-content: center"/> <br/> <br/>



<details>
  <summary>Semantic Aware Fusion Block</summary>

  ```python
class SemanticAwareFusionBlock(nn.Module):
    def __init__(self, channel_size_changer_input_nc=512):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, 1024) 

        self.channel_size_changer1 = nn.Conv2d(in_channels=channel_size_changer_input_nc, out_channels=128, kernel_size=1)
        self.reduce_channels2 = nn.Conv2d(in_channels=1024, out_channels=128, kernel_size=1)

        self.layer_norm_1 = nn.LayerNorm(128)
        self.layer_norm_2 = nn.LayerNorm(128)
        self.layer_norm_3 = nn.LayerNorm(128)

        self.self_attention = SelfAttention(128, num_heads=1, dimensionality=128)
        self.cross_attention = CrossAttention(128, heads=1, dim_head=128)

        self.GeLU = nn.GELU()

        #define 1x1 convolutions
        self.increase_channels1 = nn.Conv2d(256, 1024, 1)

    def forward(self, semantic_feature_maps, fs):
        # fs ( or sh for generated) have shape batch, 3 x 16 x 16
        #semantic feature maps  have shape batch x 1024 x 16 x 16
        final_permute_height = semantic_feature_maps.shape[2]
        final_permute_width = semantic_feature_maps.shape[3]
        
        #first handle S_h
        semantic_feature_maps = self.group_norm(semantic_feature_maps)

        #reduce the channel dimensions for the feature maps from 1024 to 128 for computation
        semantic_feature_maps = self.reduce_channels2(semantic_feature_maps)


        # Permute dimensions to rearrange the tensor
        semantic_feature_maps = semantic_feature_maps.permute(0, 2, 3, 1).contiguous().view(semantic_feature_maps.size(0), -1, semantic_feature_maps.size(1))

        #apply layer normalization
        semantic_feature_maps = self.layer_norm_1(semantic_feature_maps)

        #apply self attention
        semantic_feature_maps = self.self_attention(semantic_feature_maps) #returned has shape 1,196,128 for now
        #apply layer normalization
        query = self.layer_norm_2(semantic_feature_maps)

        #now handle fs or  sh
        #reduce the channel dimensions for the sh

        #make number of channels = 128 to be compatible with the semantic feature maps
        fs = self.channel_size_changer1(fs)

        #to use fs as residual, obtain a clone, 
        #note that gradient still accumulates in the original fs, so no problem
        fs_residual = fs.clone()

        #permute the dimensions
        fs = fs.permute(0, 2, 3, 1).contiguous().view(fs.size(0), -1, fs.size(1))

        #apply cross attention, query is the semantic feature maps and fs is the key and value
        out = self.cross_attention(query, fs)

        #apply layer normalization
        out = self.layer_norm_3(out)

        #apply GeLU
        out = self.GeLU(out)

        #permute the dimensions
        out = out.permute(0,2,1).contiguous().view(out.size(0), -1, final_permute_height, final_permute_width)

        #add the residual
        output = torch.cat((out,fs_residual), dim=1)

        #increase the channels back to 1024
        output = self.increase_channels1(output)
    
        return output
```
</details>



The architecture of the CLIP Feature Extractor is demonstrated below. The CLIP Feature Extractor has normally 4 layers, but authors suggest using the outputs of third layer as going further causes loss of spatial information which is problematic while restoring a high resolution image. The architecture is shown below.


<img src="img/clip_feature_extractor.png" style="width:700px; height:auto; display: flex; justify-content: center"/> <br/> <br/>

<details>
  <summary>CLIP Feature Extractor</summary>

  ```python

class CLIPRN50(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """

    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
        super().__init__()
        self.output_dim = output_dim
        self.input_resolution = input_resolution

        # the 3-layer stem
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.avgpool = nn.AvgPool2d(2)
        self.relu = nn.ReLU(inplace=True)

        # residual layers
        self._inplanes = width  
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)

        embed_dim = width * 32  # the ResNet feature dimension
        self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)

        #add the openai-provided normalization
        #https://github.com/jianjieluo/OpenAI-CLIP-Feature/blob/01269a8fceb540d3b6477b43177ea33845c9514c/clip/clip.py#L82C9-L82C92
        self.preprocess = transforms.Compose([
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

        #load
        self.ckpt_path = "RN50"
        self.load_ckpt(self.ckpt_path)
        self.freeze()

    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        def stem(x):
            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
                x = self.relu(bn(conv(x)))
            x = self.avgpool(x)
            return x

        x = x.type(self.conv1.weight.dtype)
        x = self.preprocess(x)
        x = stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        return x
```
</details>


The pixel-wise discriminator has U-Net architecture, which also employs the Semantic-Aware Fusion Block. The authors use both patch-based and pixel-wise based discriminators to demonstrate effectiveness of the Semantic-Aware Fusion Block. The architecture of the Pixel-wise Semantic Aware Discriminator is shown below.  

<img src="img/pixelwise_sed.png" style="width:700px; height:auto; display: flex; justify-content: center"/> <br/> <br/>

<details>
  <summary>Pixelwise SED</summary>

  ```python

class DownSamplerPx(nn.Module):
    #downsamples 4 times in a conv, bn, leaky relu fashion that halves the spatial dimensions in each step and doubles the number of filters
    def __init__(self, input_channels, num_filters=64):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_filters, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(num_filters)
        self.leaky_relu = nn.LeakyReLU(0.2)
        
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=4, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(num_filters)
        
        self.conv3 = nn.Conv2d(num_filters, num_filters, kernel_size=4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(num_filters)
        
        self.conv4 = nn.Conv2d(num_filters, num_filters, kernel_size=4, stride=2, padding=1)
        self.bn4 = nn.BatchNorm2d(num_filters)
        
    def forward(self, x):
        x = self.leaky_relu(self.bn1(self.conv1(x)))
        x = self.leaky_relu(self.bn2(self.conv2(x)))
        x = self.leaky_relu(self.bn3(self.conv3(x)))
        x = self.bn4(self.conv4(x))
        return x
    
class UNetPixelDiscriminatorwithSed(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, num_filters=64):
        super(UNetPixelDiscriminatorwithSed, self).__init__()

        #downsampler takes 256x256 images and downsamples to the 16x16
        #to make dimensionality compatible with semantic feature maps
        self.downsampler = DownSamplerPx(in_channels, num_filters)
        
        # Semantic Aware Fusion Blocks
        self.semantic_aware_fusion_block1 = SemanticAwareFusionBlock(channel_size_changer_input_nc=64)
        self.semantic_aware_fusion_block2 = SemanticAwareFusionBlock(channel_size_changer_input_nc=1024)
        self.semantic_aware_fusion_block3 = SemanticAwareFusionBlock(channel_size_changer_input_nc=1024)
        
        self.upconv1 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1)
        self.upconv2 = nn.Conv2d(1024, 64, kernel_size=1, stride=1)
        self.upconv3 = nn.Conv2d(64, 3, kernel_size=1, stride=1)


    def forward(self,semantic_feature_maps, fs):
        x = self.downsampler(fs)
        enc1 = self.semantic_aware_fusion_block1(semantic_feature_maps, x)
        enc2 = self.semantic_aware_fusion_block2(semantic_feature_maps, enc1)
        enc3 = self.semantic_aware_fusion_block3(semantic_feature_maps, enc2)
        
        dec = self.upconv1(enc3 + enc2)
        dec = self.upconv2(dec + enc1)
        dec = self.upconv3(dec + x)
        
        return dec
```
</details>




Throughout the experiments, we use the RRDB Generator proposed in ESRGAN paper, whose building blocks are shown in below Figure.

<img src="img/rrdb_generator.png" style="width:700px; height:auto; display: flex; justify-content: center"/> <br/> <br/>

<details>
  <summary>RRDB Generator</summary>

  ```python
class DenseBlock(nn.Module):
  '''
  Dense Block structure from https://arxiv.org/pdf/1809.00219 Fig4 : Left
  '''
    def __init__(self, in_channels, out_channels, num_blocks=5, is_upsample=False):
        super().__init__()
        self.blocks = make_blocks(in_channels, out_channels, num_blocks, is_upsample)

    def forward(self, x):
        prev_features = x
        for block in self.blocks:
            current_output = block(prev_features)
            prev_features = torch.cat([prev_features, current_output], dim=1)
        return x + current_output * 0.2

class Residual_in_ResidualBlock(nn.Module):
  '''
  RRDB  structure from https://arxiv.org/pdf/1809.00219 Fig4 : Right
  consists of 3 Dense Blocks
  '''
    def __init__(self, in_channels, num_blocks=3, is_upsample=False):
        super().__init__()
        self.rrdb1 = DenseBlock(in_channels, in_channels, num_blocks, is_upsample)
        self.rrdb2 = DenseBlock(in_channels, in_channels, num_blocks, is_upsample)
        self.rrdb3 = DenseBlock(in_channels, in_channels, num_blocks, is_upsample)
        
    def forward(self, x):
        out1 = self.rrdb1(x)
        out2 = self.rrdb2(out1)
        out3 = self.rrdb3(out2)
        return x + out3 * 0.2

class RRDBNet(nn.Module):
    '''ESRGAN Generator, which consists of 23 Residual in Residual Dense Blocks
    paper : https://arxiv.org/pdf/1809.00219
    '''
    def __init__(self, in_channels=3, num_channels=64, num_blocks=23, clip_output=False):
        super().__init__()
        self.conv1 = get_layer(in_channels, num_channels)
        self.conv2 = get_layer(num_channels, num_channels)
        self.conv3 = get_layer(num_channels, num_channels)
        self.act = nn.LeakyReLU(0.2, inplace=True)
        self.output = get_layer(num_channels, in_channels)
        self.first_ups = get_layer(num_channels, num_channels, is_upsample=True)
        self.second_ups = get_layer(num_channels, num_channels, is_upsample=True)
        self.rrdb = nn.Sequential(*[Residual_in_ResidualBlock(num_channels) for _ in range(num_blocks)])
        self.clip_output = clip_output

    def forward(self, x):
        res = self.conv1(x)
        x = self.rrdb(res)
        x = self.conv2(x)
        x = x + res
        x = self.first_ups(x)
        x = self.second_ups(x)
        x = self.act(self.conv3(x))
        if self.clip_output:
            x = self.output(x).clip(-1, 1)
        else:
            x = self.output(x)
        return x
```
</details>


         
To see the effect of Semantic Aware Fusion Block , we also implemented Vanilla Patch-wise Discriminator and Vanilla Pixel-wise Discriminator.

<details>
  <summary>Patch-wise Discriminator</summary>

  ```python
#Vanilla patchgan discriminator
class PatchDiscriminator(nn.Module):
    def __init__(self, input_channels, num_filters=64):
        super().__init__()
        #Downsample the input size from 256x256 to 16x16
        self.downsampler = DownSampler(input_channels, num_filters)
        self.final_conv = nn.Conv2d(num_filters * 8, 1, kernel_size=4, stride=1, padding=1)
        
    def forward(self, fs):
        fs = self.downsampler(fs)
        fs = self.final_conv(fs)
        return fs
```
</details>



<details>
  <summary>Pixel-wise Discriminator </summary>

  ```python
class UNetPixelDiscriminator(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, num_filters=64):
        super(UNetPixelDiscriminator, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            self._conv_block(in_channels, num_filters),
            self._conv_block(num_filters, num_filters),
            self._conv_block(num_filters, num_filters * 2),
            self._conv_block(num_filters * 2, num_filters * 4),
        )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(num_filters * 4, num_filters * 4, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Decoder
        self.decoder = nn.Sequential(
            self._upconv_block(num_filters * 4, num_filters * 4),
            self._upconv_block(num_filters * 4, num_filters * 2),
            self._upconv_block(num_filters * 2, num_filters),
            self._upconv_block(num_filters, num_filters),
            nn.Conv2d(num_filters, out_channels, kernel_size=1, stride=1, padding=0),
            nn.Sigmoid()
        )

    def _conv_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(out_channels),
        )

    def _upconv_block(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, fs):
        # Encoder
        # fs = 64
        enc1 = self.encoder[0](fs) # 32x32x64
        enc2 = self.encoder[1](enc1) # 16x16x64
        enc3 = self.encoder[2](enc2) # 8x8x128
        enc4 = self.encoder[3](enc3) # 4x4x256

        #Bottleneck
        bottleneck = self.bottleneck(enc4) # 2x2x256

        #Decoder with skip connections using addition
        dec = self.decoder[0](bottleneck) # 4x4x256
        dec = self.decoder[1](dec + enc4) # 8x8x128
        dec = self.decoder[2](dec + enc3) # 16x16x64
        dec = self.decoder[3](dec + enc2) # 32x32x64
        dec = self.decoder[4](dec + enc1) # 64x64x1

        return dec
```
</details>


### **Authors:**  Yigit Ekin and Mustafa Utku Aydogdu
### **Mail:** e270207@metu.edu.tr e270206@metu.edu.tr

### **Our Assumptions:**
Implementing a super resolution model based solely on a paper, without access to the accompanying code, was challenging due to the complexities of understanding and implementing the loss function, architecture, and performance metrics described in the paper. Dealing with dimensionality inconsistencies in paper. Some are listed below.

* we assumed that the group normalization has 32 groups (not stated in the paper)
* we assumed that the conv block in patchwise discriminator is a  convolution block that doubles the channel size and with kernel_size of 4, stride=2 and padding=1 followed by a batch normalization block followed by a leaky relu block (not included in the last convolution block) which is not stated in the paper.
* They did not specified the adverserial loss function details. As a result, we have decided to go with wassertein loss with gradient penalty to achieve a more stable training.
* They did not specify how they have preprocessed the dataset. As a result, due to small number of images in the dataset, we have decided to conduct a literature survey on how different models have overcome this issue and found that ESRGAN does combine 2 datasets and crops random patches from each image to increase the number of images.
* We have decided to move with crop size of 400 for hr images and 100 for lr images. This means that during training our model inputs 100x100 crops and tries to generate 400x400 hr version of it.
* For cross attention, we have decided to use single head attention rather than multi head attention
* CLIP preprocessor normally downscales the image to 224x224 before extracting embeddings. We believed that this can downgrade the performance w.r.t hr images as a result, we did not use this preprocessor.
* To obtain same spatial dimensionality with the clip embeddings (for concatenation specified in the image below in part d), we added extra convolution layer that did not change the channel size but decreases the spatial dimensions.
* The authors did not describe the weight (lambda) values of the loss functions as a result, we have decided to go with 1 for mse and 10 for gradient penalty in wasserstein loss
* The authors did not specify whether they have used multi-head attention or single head attention. As a result, we have decided to go with single head attention because we thought it should be sufficient enough.
* The authors did not specify the dimensionlity of attention head. So, we have decided to go with 128 as this will result in 8 times less memory usage. 
* For the coefficients of the losses (i.e VGG, adverserial, MSE), we have conducted several experiments and the current setup in the config files are the ones that have achieved the best scores. One thing that we have tried to keep constant is the ratio between the losses. For example, if the VGG loss is 0.1 times the adverserial loss, we have tried to keep this ratio constant in all experiments by changing the coefficients.

## Hyper-parameters of your model

We aim to compare the effect of SeD discriminator with vanilla discriminator. As a result, we have two different training setups. Before reading the hyperparameters, please note that the hyperparameters are the same for both models except for the discriminator part. In addition, the losses used for the model can be seen from the image below where L_s is VGG  perceptual loss, L_p is the pixelwise MSE loss and L_adv is the adverserial loss.


<img src="img/losses.png"> <br/> <br/>
The hyperparameters of the models are as follows:

### Vanilla Discriminator
- **Accelerator**: 'gpu'
- **Device**: 'cuda'
- **PL Trainer**:
  - `max_epochs`: 1000
  - `accelerator`: 'gpu'
  - `log_every_n_steps`: 50
  - `strategy`: DDPStrategy(find_unused_parameters=True)
  - `devices`: Number of available CUDA devices (determined by `torch.cuda.device_count()`)
  - `sync_batchnorm`: True
- **Train Batch Size**: 16
- **Validation Batch Size**: 8
- **Test Batch Size**: 8
- **Image Size**: 256
- **Dataset Module**:
  - `num_workers`: 4
  - `train_batch_size`: 16
  - `val_batch_size`: 8
  - `test_batch_size`: 8
  - **Train Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/dataset_cropped/hr"
    - `image_dir_lr`: "data/dataset_cropped/lr"
    - `downsample_factor`: 4 (downsampling factor for low-resolution images)
    - `mirror_augment_prob`: 0.5 (probability of applying mirroring w.r.t. y axis as a data augmentation)
  - **Validation Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/evaluation/hr/manga109"
    - `image_dir_lr`: "data/evaluation/lr/manga109"
  - **Test Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/evaluation/hr/manga109"
    - `image_dir_lr`: "data/evaluation/lr/manga109"
- **Losses**:
  - **VGG**:
    - `weight`: 5e-5 
    - `model_config`:
      - `path`: "pretrained_models/vgg16.pth"
      - `output_layer_idx`: 23 (index of the layer to extract features from)
      - `resize_input`: False
  - **Adversarial_G**:
    - `weight`: 1.0
  - **MSE**:
    - `weight`: 1.0
  - **Adversarial_D**:
    - `r1_gamma`: 10.0 (constant for wasserstein GP)
    - `r2_gamma`: 0.0 (constant for wasserstein GP)
- **Super Resolution Module Configuration**:
  - `generator_learning_rate`: 1e-4
  - `discriminator_learning_rate`: 1e-5
  - `generator_decay_steps`: [50_000, 100_000, 150_000, 200_000, 250_000]
  - `discriminator_decay_steps`: [50_000, 100_000, 150_000, 200_000, 250_000]
  - `generator_decay_gamma`: 0.5
  - `discriminator_decay_gamma`: 0.5
  - `clip_generator_outputs`: False (whether to clip generator outputs to valid pixel range [-1,1])
  - `use_sed_discriminator`: False (whether to use SeD discriminator)

### SeD Discriminator
- **Accelerator**: 'gpu'
- **Device**: 'cuda'
- **PL Trainer**:
  - `max_epochs`: 1000
  - `accelerator`: 'gpu'
  - `log_every_n_steps`: 50
  - `strategy`: DDPStrategy(find_unused_parameters=True)
  - `devices`: Number of available CUDA devices (determined by `torch.cuda.device_count()`)
  - `sync_batchnorm`: True
- **Train Batch Size**: 16
- **Validation Batch Size**: 8
- **Test Batch Size**: 8
- **Image Size**: 256
- **Dataset Module**:
  - `num_workers`: 4
  - `train_batch_size`: 16
  - `val_batch_size`: 8
  - `test_batch_size`: 8
  - **Train Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/dataset_cropped/hr"
    - `image_dir_lr`: "data/dataset_cropped/lr"
    - `downsample_factor`: 4
    - `mirror_augment_prob`: 0.5 (probability of applying mirroring w.r.t. y axis as a data augmentation)
  - **Validation and Test Dataset Configuration**:
    - `image_size`: 256
    - `image_dir_hr`: "data/evaluation/hr/manga109"
    - `image_dir_lr`: "data/evaluation/lr/manga109"
- **Losses**:
  - **VGG**:
    - `weight`: 5e-5
    - `model_config`:
      - `path`: "pretrained_models/vgg16.pth"
      - `output_layer_idx`: 23 (index of the layer to extract features from)
      - `resize_input`: False
  - **Adversarial_G**:
    - `weight`: 1.0
  - **MSE**:
    - `weight`: 1.0
  - **Adversarial_D**:
    - `r1_gamma`: 10.0 (constant for wasserstein GP)
    - `r2_gamma`: 0.0 (constant for wasserstein GP)
- **Super Resolution Module Configuration**:
  - `generator_learning_rate`: 1e-4
  - `discriminator_learning_rate`: 1e-5
  - `generator_decay_steps`: [50_000, 100_000, 150_000, 200_000, 250_000]
  - `discriminator_decay_steps`: [50_000, 100_000, 150_000, 200_000, 250_000]
  - `generator_decay_gamma`: 0.5
  - `discriminator_decay_gamma`: 0.5
  - `clip_generator_outputs`: False (whether to clip generator outputs to valid pixel range [-1,1])
  - `use_sed_discriminator`: True (whether to use SeD discriminator)

## Training and saving of the model.

### Training with SeD

#### **IMPORTANT NOTE:** the training of the model is done on a remote server where we have not used jupyter notebook. Normally, scripts in the first 3 cells are used to train the model. However, in order to not overly crowd the jupyter notebook for the reviewers, we have included the code that is responsible for training but the training logs will be displayed in the last cell of this section named as training loop which abstracts all this logic

PLEASE DO NOT CHANGE THE FILE STRUCTURE THAT THE SUBMISSION HAS PROVIDED. THIS CAN CAUSE ERRORS IN THE TRAINING OF THE MODEL.

### TRAINING CONFIG

```python
import torch
from pytorch_lightning.strategies import DDPStrategy

accelerator = 'gpu'
device = torch.device("cuda") if accelerator=="gpu" else torch.device("cpu")
if accelerator == 'cpu':
    pl_trainer = dict(max_epochs=1000, accelerator=accelerator, log_every_n_steps=50, strategy=DDPStrategy(find_unused_parameters=True), devices=1, sync_batchnorm=True) # CHECK sync_batchnorm in this and below part !!!
else:
    pl_trainer = dict(max_epochs=1000, accelerator=accelerator, log_every_n_steps=50, strategy=DDPStrategy(find_unused_parameters=True), devices=torch.cuda.device_count(), sync_batchnorm=True)  # CHECK strategy and find_unused_parameters!!!

train_batch_size = 16
val_batch_size = 8
test_batch_size = 8

image_size = 256


###########################
##### Dataset Configs #####
###########################

dataset_module = dict(
    num_workers=4,
    train_batch_size=train_batch_size,
    val_batch_size=val_batch_size,
    test_batch_size=test_batch_size,
    train_dataset_config=dict(image_size=256, image_dir_hr="data/dataset_cropped/hr", image_dir_lr="data/dataset_cropped/lr", downsample_factor=4,mirror_augment_prob=0.5),
    val_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
    test_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
)

##################
##### Losses #####
##################
vgg_ckpt_path="pretrained_models/vgg16.pth"
loss_dict = dict(
    VGG=dict(weight=5e-5, model_config=dict(path=vgg_ckpt_path, output_layer_idx=23, resize_input=False)),
    Adversarial_G=dict(weight=1.0),
    MSE=dict(weight=1.0),
    Adversarial_D=dict(r1_gamma=10.0, r2_gamma=0.0)
)

#########################
##### Model Configs #####
#########################

super_resolution_module_config = dict(loss_dict=loss_dict, 
    generator_learning_rate=1e-4, discriminator_learning_rate=1e-5, 
    generator_decay_steps=[50_000, 100_000, 150_000, 200_000, 250_000], 
    discriminator_decay_steps=[50_000, 100_000, 150_000, 200_000, 250_000], 
    generator_decay_gamma=0.5, discriminator_decay_gamma=0.5,
    clip_generator_outputs=False,
    use_sed_discriminator=True)

#######################
###### Callbacks ######
#######################

ckpt_callback = dict(every_n_train_steps=4000, save_top_k=1, save_last=True, monitor='fid_test', mode='min')
synthesize_callback_train = dict(num_samples=12, eval_every=2000) # TODO: 4000
synthesize_callback_test = dict(num_samples=6, eval_every=2000)
fid_callback = dict(eval_every=4000)
```

### PYTORCH LIGHTNING ALLOWS US TO USE CALLBACK FUNCTIONS DURING TRAINING. HENCE, WE HAVE USED CALLBACKS TO SAVE WEIGHTS OF THE MODEL. THE CALLBACK IS THE FOLLOWING:

```python
from pytorch_lightning.callbacks import ModelCheckpoint
```

### Training loop for the following Models:
- Patch-wise Discriminator with SeD
- Vanilla Patch-wise Discriminator
- Pixel-wise Discriminator with SeD
- Vanilla Pixel-wise Discriminator


In [6]:
CFG="configs/patchgan_sed.py" #Training of patchgan discriminator with SeD

!python train.py --config_file=$CFG  # --resume_from logs/sed

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                      | Params
------------------------------------------------------------
0 | generator     | RRDBNet                   | 15.4 M
1 | discriminator | PatchDiscriminatorWithSeD | 4.7 M 
2 | clip          | CLIPRN50                  | 23.4 M
------------------------------------------------------------
20.1 M    Trainable params
23.4 M    Non-trainable params
43.5 M    Total params
173.867   Total estimated model para

In [7]:
CFG="configs/patchgan.py" #Training of patchgan discriminator without SeD

!python train.py --config_file=$CFG #--debug # --resume_from logs/sed

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | generator     | RRDBNet            | 15.4 M
1 | discriminator | PatchDiscriminator | 2.8 M 
2 | clip          | CLIPRN50           | 23.4 M
-----------------------------------------------------
18.2 M    Trainable params
23.4 M    Non-trainable params
41.5 M    Total params
166.180   Total estimated model params size (MB)
SLURM auto-requeueing enabled

In [8]:
CFG="configs/pixelwise_sed.py" #Training of pixelwise discriminator with SeD

!python train.py --config_file=$CFG #--debug # --resume_from logs/sed

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                          | Params
----------------------------------------------------------------
0 | generator     | RRDBNet                       | 15.4 M
1 | discriminator | UNetPixelDiscriminatorwithSed | 3.2 M 
2 | clip          | CLIPRN50                      | 23.4 M
----------------------------------------------------------------
18.6 M    Trainable params
23.4 M    Non-trainable params
42.0 M    Total params
167.802   To

In [None]:
CFG="configs/pixelwise.py" #Training of pixelwise discriminator without SeD

!python train.py --config_file=$CFG #--debug # --resume_from logs/sed

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                   | Params
---------------------------------------------------------
0 | generator     | RRDBNet                | 15.4 M
1 | discriminator | UNetPixelDiscriminator | 3.5 M 
2 | clip          | CLIPRN50               | 23.4 M
---------------------------------------------------------
19.0 M    Trainable params
23.4 M    Non-trainable params
42.3 M    Total params
169.295   Total estimated model params size (MB)
SLURM

<details>
  <summary><h3>Training Details:</h3></summary>

  <details>
      <summary><h4>Patchgan SeD:</h4></summary>
      <div style="display:flex; justify-content:center; align-items:center;">
      <img src="img/adv_l_patchgan_sed.png" style="width: 600px; height:auto;"/>
      <img src="img/adv_g_patchgan_sed.png" style="width: 600px; height:auto;"/>
      </div>
      <div style="display:flex; justify-content:center; align-items:center; margin-top:50px;">
      <img src="img/mse_l_patchgan_sed.png" style="width: 600px; height:auto;"/>
      <img src="img/vgg_l_patchgan_sed.png" style="width: 600px; height:auto;"/>
      </div>
      <div style="display:flex; justify-content:center; align-items:center; margin-top:50px;">
      <img src="img/lpips_patchgan_sed.png" style="width: 800px; height:auto;"/>
      </div>
  </details>

  <details>
      <summary><h4>Vanilla Patchgan:</h4></summary>
      <div style="display:flex; justify-content:center; align-items:center;">
      <img src="img/adv_l_patchgan.png" style="width: 600px; height:auto;"/>
      <img src="img/adv_g_patchgan.png" style="width: 600px; height:auto;"/>
      </div>
      <div style="display:flex; justify-content:center; align-items:center; margin-top:50px;">
      <img src="img/mse_l_patchgan.png" style="width: 600px; height:auto;"/>
      <img src="img/vgg_l_patchgan.png" style="width: 600px; height:auto;"/>
      </div>
      <div style="display:flex; justify-content:center; align-items:center; margin-top:50px;">
      <img src="img/lpips_patchgan.png" style="width: 800px; height:auto;"/>
      </div>
  </details>
  <details>
      <summary><h4>Pixelwise SeD:</h4></summary>
      <div style="display:flex; justify-content:center; align-items:center;">
      <img src="img/adv_l_px_sed.png" style="width: 600px; height:auto;"/>
      <img src="img/adv_g_px_sed.png" style="width: 600px; height:auto;"/>
      </div>
      <div style="display:flex; justify-content:center; align-items:center; margin-top:50px;">
      <img src="img/mse_l_px_sed.png" style="width: 600px; height:auto;"/>
      <img src="img/vgg_l_px_sed.png" style="width: 600px; height:auto;"/>
      </div>
      <div style="display:flex; justify-content:center; align-items:center; margin-top:50px;">
      <img src="img/lpips_px_sed.png" style="width: 800px; height:auto;"/>
      </div>
  </details>
  <details>
      <summary><h4>Vanilla Pixelwise Discriminator:</h4></summary>
      <div style="display:flex; justify-content:center; align-items:center;">
      <img src="img/adv_l_px.png" style="width: 600px; height:auto;"/>
      <img src="img/adv_g_px.png" style="width: 600px; height:auto;"/>
      </div>
      <div style="display:flex; justify-content:center; align-items:center; margin-top:50px;">
      <img src="img/mse_l_px.png" style="width: 600px; height:auto;"/>
      <img src="img/vgg_l_px.png" style="width: 600px; height:auto;"/>
      </div>
      <div style="display:flex; justify-content:center; align-items:center; margin-top:50px;">
      <img src="img/lpips_px.png" style="width: 800px; height:auto;"/>
      </div>
  </details>
</details>

## Loading a pre-trained model and computing qualitative samples/outputs from that model.

### **IMPORTANT! loss curves and outputs of the training process is logged under the directory /logs during training loop in order to display the results, one can execute the command below at the terminal in order to display the results**

```bash
tensorboard --logdir=logs/<experiment_name_under_the_directory>
```

### To make reviewer's job easier, we have provided the code needed to load a pretrained model and compute qualitative samples from the model but added the tensorboard logs from the training loop that had been executed from the training cells from the previous section.

In [2]:
from models.super_resolution_module import SuperResolutionModule
from datasets.dataset_module import DatasetModule
from tqdm import tqdm
from PIL import Image
import torch 
import numpy as np
import os

def postprocess_image(image, min_val=-1.0, max_val=1.0):
    image = image.astype(np.float64)
    image = np.clip(image, -1, 1)
    image = (image - min_val) * 255 / (max_val - min_val)
    image = image.astype(np.uint8)
    # image = np.clip(image + 0.5, 0, 255).astype(np.uint8)
    image = image.transpose(1, 2, 0)
    return image

def generate_results(ckpt, image_dir_hr, image_dir_lr, save_path):
    model = SuperResolutionModule.load_from_checkpoint(ckpt) 
    train_batch_size = 1  # given so that each image is processed by itself
    val_batch_size = 1 # given so that each image is processed by itself
    test_batch_size = 1 # given so that each image is processed by itself

    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset_module = dict(
        num_workers=4,
        train_batch_size=train_batch_size,
        val_batch_size=val_batch_size,
        test_batch_size=test_batch_size,
        train_dataset_config=dict(image_size=256, image_dir_hr=image_dir_hr, image_dir_lr=image_dir_lr, downsample_factor=4),
        val_dataset_config=dict(image_size=256, image_dir_hr=image_dir_hr, image_dir_lr=image_dir_lr),
        test_dataset_config=dict(image_size=256, image_dir_hr=image_dir_hr, image_dir_lr=image_dir_lr),
    )


    data_module_gt = DatasetModule(**dataset_module)
    data_module_gt.setup('test')
    dataloader = data_module_gt.test_dataloader()

    os.makedirs("results", exist_ok=True)
    os.makedirs(f"results/{save_path}", exist_ok=True)
    cnt = 0
    for batch in tqdm(dataloader, desc=f"Calculating FID on SR images", total=len(dataloader)):
        sr_images = model.make_high_resolution(batch)
        sr_images = sr_images ['generated_super_resolution_image'].to(device)
        #save the sr images to the "sr_pngs" folder
        for i in range(len(sr_images)):
            img = sr_images[i]
            
            img = postprocess_image(img.detach().cpu().numpy())
            img = Image.fromarray(img)
            img.save(f"results/{save_path}/{cnt}.png")
            cnt += 1
    
torch.manual_seed(1256)
np.random.seed(1256)
ckpt="logs/2024-05-28_16-19-08_patchgan_sed/checkpoint/last.ckpt"
ckpt2="logs/2024-05-28_16-10-08_pixelwise_sed/checkpoint/last.ckpt"
ckpt3="logs/2024-05-28_16-12-06_pixelwise/checkpoint/last.ckpt"
ckpt4="logs/2024-05-28_16-17-05_patchgan/checkpoint/last.ckpt"

generate_results_dict = {
    "patchgan": ckpt4,
    "patchgan_sed": ckpt,
    "pixelwise": ckpt3,
    "pixelwise_sed": ckpt2,
} 

image_path_hr = "/kuacc/users/hpc-yekin/hpc_run/sed/test_images_fig3/downscaled"
image_path_lr = "/kuacc/users/hpc-yekin/hpc_run/sed/test_images_fig3/downscaled"

for key, item in generate_results_dict.items():
    generate_results(item, image_path_hr, image_path_lr, key)


Calculating FID on SR images: 100%|██████████| 6/6 [00:04<00:00,  1.44it/s]
Calculating FID on SR images: 100%|██████████| 6/6 [00:00<00:00,  7.53it/s]
Calculating FID on SR images: 100%|██████████| 6/6 [00:00<00:00,  7.95it/s]
Calculating FID on SR images: 100%|██████████| 6/6 [00:00<00:00,  7.20it/s]


### Results and logs of training SeD patchwise discriminator

### Results and logs of training patchwise discriminator


## Reproducing results

as explained in the goals.txt an updated version will be provided with the trained model weights as soon as possible. The incident that happened is known by gökberk hoca.

In [None]:
"""
StarGAN v2
Copyright (c) 2020-present NAVER Corp.
This work is licensed under the Creative Commons Attribution-NonCommercial
4.0 International License. To view a copy of this license, visit
http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""

#WE HAVE IMPLEMENTED THIS CODE BLOCK BY USING THE REFERENCE AT THE TOP AS A GUIDANCE

import torch
import numpy as np
from tqdm import tqdm
from datasets.dataset_module import DatasetModule
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
import torch
import numpy as np
from tqdm import tqdm
from losses.lpips.lpips import LPIPS
import torch.nn.functional as F

def print_metrics_given_path(path):
    print("calculating metrics for " + path)
    train_batch_size = 2  # given as temporary data
    val_batch_size = 2 # given as temporary data
    test_batch_size = 2 # given as temporary data
    
    
    ################ lpips
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lpips_model = LPIPS(net_type='alex', device=device).to('cpu')
    lpips_model.eval()
    image_size = 256
    dataset_module_gt = dict(
        num_workers=4,
        train_batch_size=train_batch_size,
        val_batch_size=val_batch_size,
        test_batch_size=test_batch_size,
        train_dataset_config=dict(image_size=256, image_dir_hr="data/dataset_cropped/hr", image_dir_lr="data/dataset_cropped/lr", downsample_factor=4,mirror_augment_prob=0),
        val_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
        test_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
    )
    
    dataset_module_gt = DatasetModule(**dataset_module_gt)
    dataset_module_gt.setup('test')
    first_dataloader = dataset_module_gt.test_dataloader()
    
    
    dataset_module_sr = dict( #UPDATE DIRS
        num_workers=4,
        train_batch_size=train_batch_size,
        val_batch_size=val_batch_size,
        test_batch_size=test_batch_size,
        train_dataset_config=dict(image_size=256, image_dir_hr="data/dataset_cropped/hr", image_dir_lr="data/dataset_cropped/lr", downsample_factor=4,mirror_augment_prob=0),
        val_dataset_config=dict(image_size=256, image_dir_hr="data/evaluation/hr/manga109", image_dir_lr="data/evaluation/lr/manga109"),
        test_dataset_config=dict(image_size=256, image_dir_hr=path, image_dir_lr="data/evaluation/lr/manga109/"),
    )
    
    data_module_sr = DatasetModule(**dataset_module_sr)
    data_module_sr.setup('test')
    second_dataloader = data_module_sr.test_dataloader()
    
    def get_lpips_mean(dataloader1,dataloader2,lpips_model,device,dataset_type):
        lpips_model.to(device)
        lpips_list = []
        with torch.no_grad():
            for batch1,batch2 in tqdm(zip(dataloader1,dataloader2), desc=f"Calculating {dataset_type} LPIPS on sr images", total=len(dataloader1)):
                gt_images = batch1["image_hr"].to(device) * 0.5 + 0.5
                sr_images = batch2["image_hr"].to(device) * 0.5 + 0.5
                lpips = lpips_model(gt_images, sr_images, return_similarity=True)
                lpips_list.append(lpips.cpu())
        lpips_list = torch.cat(lpips_list).numpy()
        lpips_mean = np.nanmean(lpips_list)
        lpips_model.to('cpu')
        return lpips_mean
    
    
    
    lpips_mean = get_lpips_mean(first_dataloader,second_dataloader,lpips_model,device,"lpips")
    
    print("lpips: ",lpips_mean)
    
    #### SSIM
    
    def ssim(img1, img2):
        # Calculate SSIM (Structural Similarity Index)
        ssim_val = torch.mean((2 * img1 * img2 + 1e-8) * (2 * torch.abs(img1 - img2) + 1e-8) / (img1**2 + img2**2 + 1e-8), dim=(1, 2, 3))
        return ssim_val
    
    def get_ssim_mean(dataloader1,dataloader2,ssim,device,dataset_type):
        ssim_list = []
        with torch.no_grad():
            for batch1,batch2 in tqdm(zip(dataloader1,dataloader2), desc=f"Calculating {dataset_type} SSIM on sr images", total=len(dataloader1)):
                gt_images = batch1["image_hr"].to(device) * 0.5 + 0.5
                sr_images = batch2["image_hr"].to(device) * 0.5 + 0.5
                ssim_val = ssim(sr_images, gt_images)
                ssim_list.append(ssim_val.cpu())
        ssim_list = torch.cat(ssim_list).numpy()
        ssim_mean = np.nanmean(ssim_list)
        return ssim_mean
    
    ssim_mean = get_ssim_mean(first_dataloader,second_dataloader,ssim,device,"ssim")
    print("ssim: ",ssim_mean)
    
    #### PSNR
    
    def psnr(img1, img2, max_val=1.0):
        # Convert images to float tensors
        img1 = img1.float()
        img2 = img2.float()
        
        max_val = img1.max()
        # Calculate MSE (Mean Squared Error)
        mse = F.mse_loss(img1, img2)
        
        # Calculate PSNR (Peak Signal-to-Noise Ratio)
        psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
        
        return psnr.item()
    
    def get_psnr_mean(dataloader1,dataloader2,device,dataset_type):
        psnr_list = []
        with torch.no_grad():
            for batch1,batch2 in tqdm(zip(dataloader1,dataloader2), desc=f"Calculating {dataset_type} PSNR on sr images", total=len(dataloader1)):
                gt_images = batch1["image_hr"].to(device) * 0.5 + 0.5
                sr_images = batch2["image_hr"].to(device) * 0.5 + 0.5
                psnr_val = psnr(sr_images, gt_images)
                psnr_list.append(psnr_val)
        psnr_mean = np.nanmean(psnr_list)
        return psnr_mean
    
    
    psnr_mean = get_psnr_mean(first_dataloader,second_dataloader,device,"psnr")
    print("psnr: ",psnr_mean)

print_metrics_given_path("patchgan/")
print_metrics_given_path("patchgansed/")
print_metrics_given_path("pixelwise/")
print_metrics_given_path("pixelwise_sed/")


calculating metrics for patchgan/


Calculating lpips LPIPS on sr images: 100%|██████████| 55/55 [00:01<00:00, 41.71it/s]


lpips:  0.6864215


Calculating ssim SSIM on sr images: 100%|██████████| 55/55 [00:01<00:00, 42.31it/s]

ssim:  0.37364906



Calculating psnr PSNR on sr images: 100%|██████████| 55/55 [00:01<00:00, 41.70it/s]


psnr:  7.709990440715443
calculating metrics for patchgansed/


## Challenges we have encountered when implementing the paper