Connected to base (Python 3.12.2)

In [1]:
"""
## Imports
"""

import torch
from torch import nn
from transformers import AutoImageProcessor, MobileNetV1Model

from PIL import Image
import os
import numpy as np
import torchvision.transforms as transforms
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
# import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import wandb
import torch

SEED = 0

# set seed for all possible random functions to ensure reproducibility
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic=True

DEVICE = "cuda:0"
print(f"Using {DEVICE} DEVICE",flush=True)

  from .autonotebook import tqdm as notebook_tqdm


Using cuda:0 DEVICE


In [2]:
""" ## Imports
"""

import torch
from torch import nn
from transformers import AutoImageProcessor, MobileNetV1Model

from PIL import Image
import os
import numpy as np
import torchvision.transforms as transforms
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
# import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import wandb
import torch

SEED = 0

# set seed for all possible random functions to ensure reproducibility
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic=True

DEVICE = "cuda:0"
print(f"Using {DEVICE} DEVICE",flush=True)

Using cuda:0 DEVICE


In [3]:
""" ## Model Definition
"""

def get_encoder_layers():
    """
    Retrieves the layers of the MobileNetV1 model for encoding images.

    Returns:
        mobilenet_seq_blocks (list): List containing the layers of the MobileNetV1 model
            divided into blocks.
        conv_stem (torch.nn.Module): The stem convolutional layer of the MobileNetV1 model.
        image_processor (transformers.AutoImageProcessor): Pretrained image processor for
            MobileNetV1.
    """
    # download the model
    image_processor = AutoImageProcessor.from_pretrained(
        "google/mobilenet_v1_1.0_224")
    model = MobileNetV1Model.from_pretrained("google/mobilenet_v1_1.0_224")

    mobilenet_seq_blocks = []
    # block 1 will contain 4 layer of model.layer
    block = nn.Sequential(*list(model.layer)[:2])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[2:4])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[4:8])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[8:12])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[12:])
    mobilenet_seq_blocks.append(block)
    
    return mobilenet_seq_blocks, model.conv_stem, image_processor

print("Getting encoder layers...",flush=True)
print(get_encoder_layers())
print("Encoder layers retrieved successfully.",flush=True)

Getting encoder layers...
([Sequential(
  (0): MobileNetV1ConvLayer(
    (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
    (normalization): BatchNorm2d(32, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
    (activation): ReLU6()
  )
  (1): MobileNetV1ConvLayer(
    (convolution): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (normalization): BatchNorm2d(64, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
    (activation): ReLU6()
  )
), Sequential(
  (0): MobileNetV1ConvLayer(
    (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), groups=64, bias=False)
    (normalization): BatchNorm2d(64, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
    (activation): ReLU6()
  )
  (1): MobileNetV1ConvLayer(
    (convolution): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (normalization): BatchNorm2d(128, eps=0.001, momentum=0.9997, affine=True, t

In [4]:
""" ## Model Definition
"""

def get_encoder_layers():
    """
    Retrieves the layers of the MobileNetV1 model for encoding images.

    Returns:
        mobilenet_seq_blocks (list): List containing the layers of the MobileNetV1 model
            divided into blocks.
        conv_stem (torch.nn.Module): The stem convolutional layer of the MobileNetV1 model.
        image_processor (transformers.AutoImageProcessor): Pretrained image processor for
            MobileNetV1.
    """
    # download the model
    image_processor = AutoImageProcessor.from_pretrained(
        "google/mobilenet_v1_1.0_224")
    model = MobileNetV1Model.from_pretrained("google/mobilenet_v1_1.0_224")

    mobilenet_seq_blocks = []
    # block 1 will contain 4 layer of model.layer
    block = nn.Sequential(*list(model.layer)[:2])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[2:4])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[4:8])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[8:12])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[12:])
    mobilenet_seq_blocks.append(block)
    
    return mobilenet_seq_blocks, model.conv_stem, image_processor

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=3, do_up_sampling=True):
        # TODO: adding skip connections
        """
        Decoder block module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            expansion (int, optional): Expansion factor. Default is 3.
        """
        super(DecoderBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)

        self.cnn1 = nn.Conv2d(in_channels, in_channels *
                              expansion, kernel_size=1, stride=1)
        self.bnn1 = nn.BatchNorm2d(in_channels*expansion)

        # nearest neighbor x2
        self.do_up_sampling = do_up_sampling
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        # DW conv/ c_in*exp x 5 x 5 x c_in*exp
        self.cnn2 = nn.Conv2d(in_channels*expansion, in_channels *
                              expansion, kernel_size=5, padding=2, stride=1)
        self.bnn2 = nn.BatchNorm2d(in_channels*expansion)

        self.cnn3 = nn.Conv2d(in_channels*expansion,
                              out_channels, kernel_size=1, stride=1)
        self.bnn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        """
        Forward pass through the decoder block.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        x = self.cnn1(x)
        x = self.bnn1(x)
        x = self.relu(x)

        if self.do_up_sampling:
            x = self.upsample(x)

        x = self.cnn2(x)
        x = self.bnn2(x)
        x = self.relu(x)

        x = self.cnn3(x)
        x = self.bnn3(x)

        return x

def get_decoder_layers(out_sizes=[512, 256, 128, 64, 32]):
    decoder_blocks = nn.ModuleList()
    for i, out_size in enumerate(out_sizes):
        if i == 0:
            decoder_blocks.append(DecoderBlock(out_size*2, out_size))
        elif i == len(out_sizes)-1:
            decoder_blocks.append(DecoderBlock(
                out_size*4, out_size, do_up_sampling=False))
        else:
            decoder_blocks.append(DecoderBlock(out_size*4, out_size))
    return decoder_blocks

print(get_decoder_layers())

ModuleList(
  (0): DecoderBlock(
    (relu): ReLU(inplace=True)
    (cnn1): Conv2d(1024, 3072, kernel_size=(1, 1), stride=(1, 1))
    (bnn1): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (upsample): Upsample(scale_factor=2.0, mode='nearest')
    (cnn2): Conv2d(3072, 3072, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (bnn2): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (cnn3): Conv2d(3072, 512, kernel_size=(1, 1), stride=(1, 1))
    (bnn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): DecoderBlock(
    (relu): ReLU(inplace=True)
    (cnn1): Conv2d(1024, 3072, kernel_size=(1, 1), stride=(1, 1))
    (bnn1): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (upsample): Upsample(scale_factor=2.0, mode='nearest')
    (cnn2): Conv2d(3072, 3072, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (bnn2): BatchNorm2d(30

In [5]:
""" ## Model Definition
"""

def get_encoder_layers():
    """
    Retrieves the layers of the MobileNetV1 model for encoding images.

    Returns:
        mobilenet_seq_blocks (list): List containing the layers of the MobileNetV1 model
            divided into blocks.
        conv_stem (torch.nn.Module): The stem convolutional layer of the MobileNetV1 model.
        image_processor (transformers.AutoImageProcessor): Pretrained image processor for
            MobileNetV1.
    """
    # download the model
    image_processor = AutoImageProcessor.from_pretrained(
        "google/mobilenet_v1_1.0_224")
    model = MobileNetV1Model.from_pretrained("google/mobilenet_v1_1.0_224")

    mobilenet_seq_blocks = nn.ModuleList()
    # block 1 will contain 4 layer of model.layer
    block = nn.Sequential(*list(model.layer)[:2])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[2:4])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[4:8])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[8:12])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[12:])
    mobilenet_seq_blocks.append(block)
    
    return mobilenet_seq_blocks, model.conv_stem, image_processor

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=3, do_up_sampling=True):
        # TODO: adding skip connections
        """
        Decoder block module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            expansion (int, optional): Expansion factor. Default is 3.
        """
        super(DecoderBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)

        self.cnn1 = nn.Conv2d(in_channels, in_channels *
                              expansion, kernel_size=1, stride=1)
        self.bnn1 = nn.BatchNorm2d(in_channels*expansion)

        # nearest neighbor x2
        self.do_up_sampling = do_up_sampling
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        # DW conv/ c_in*exp x 5 x 5 x c_in*exp
        self.cnn2 = nn.Conv2d(in_channels*expansion, in_channels *
                              expansion, kernel_size=5, padding=2, stride=1)
        self.bnn2 = nn.BatchNorm2d(in_channels*expansion)

        self.cnn3 = nn.Conv2d(in_channels*expansion,
                              out_channels, kernel_size=1, stride=1)
        self.bnn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        """
        Forward pass through the decoder block.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        x = self.cnn1(x)
        x = self.bnn1(x)
        x = self.relu(x)

        if self.do_up_sampling:
            x = self.upsample(x)

        x = self.cnn2(x)
        x = self.bnn2(x)
        x = self.relu(x)

        x = self.cnn3(x)
        x = self.bnn3(x)

        return x

def get_decoder_layers(out_sizes=[512, 256, 128, 64, 32]):
    decoder_blocks = nn.ModuleList()
    for i, out_size in enumerate(out_sizes):
        if i == 0:
            decoder_blocks.append(DecoderBlock(out_size*2, out_size))
        elif i == len(out_sizes)-1:
            decoder_blocks.append(DecoderBlock(
                out_size*4, out_size, do_up_sampling=False))
        else:
            decoder_blocks.append(DecoderBlock(out_size*4, out_size))
    return decoder_blocks

print(get_decoder_layers())

ModuleList(
  (0): DecoderBlock(
    (relu): ReLU(inplace=True)
    (cnn1): Conv2d(1024, 3072, kernel_size=(1, 1), stride=(1, 1))
    (bnn1): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (upsample): Upsample(scale_factor=2.0, mode='nearest')
    (cnn2): Conv2d(3072, 3072, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (bnn2): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (cnn3): Conv2d(3072, 512, kernel_size=(1, 1), stride=(1, 1))
    (bnn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): DecoderBlock(
    (relu): ReLU(inplace=True)
    (cnn1): Conv2d(1024, 3072, kernel_size=(1, 1), stride=(1, 1))
    (bnn1): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (upsample): Upsample(scale_factor=2.0, mode='nearest')
    (cnn2): Conv2d(3072, 3072, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (bnn2): BatchNorm2d(30

In [6]:
""" ## Model Definition
"""

def get_encoder_layers():
    """
    Retrieves the layers of the MobileNetV1 model for encoding images.

    Returns:
        mobilenet_seq_blocks (list): List containing the layers of the MobileNetV1 model
            divided into blocks.
        conv_stem (torch.nn.Module): The stem convolutional layer of the MobileNetV1 model.
        image_processor (transformers.AutoImageProcessor): Pretrained image processor for
            MobileNetV1.
    """
    # download the model
    image_processor = AutoImageProcessor.from_pretrained(
        "google/mobilenet_v1_1.0_224")
    model = MobileNetV1Model.from_pretrained("google/mobilenet_v1_1.0_224")

    mobilenet_seq_blocks = nn.ModuleList()
    # block 1 will contain 4 layer of model.layer
    block = nn.Sequential(*list(model.layer)[:2])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[2:4])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[4:8])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[8:12])
    mobilenet_seq_blocks.append(block)

    block = nn.Sequential(*list(model.layer)[12:])
    mobilenet_seq_blocks.append(block)
    
    return mobilenet_seq_blocks, model.conv_stem, image_processor

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=3, do_up_sampling=True):
        # TODO: adding skip connections
        """
        Decoder block module.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            expansion (int, optional): Expansion factor. Default is 3.
        """
        super(DecoderBlock, self).__init__()
        self.relu = nn.ReLU(inplace=True)

        self.cnn1 = nn.Conv2d(in_channels, in_channels *
                              expansion, kernel_size=1, stride=1)
        self.bnn1 = nn.BatchNorm2d(in_channels*expansion)

        # nearest neighbor x2
        self.do_up_sampling = do_up_sampling
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        # DW conv/ c_in*exp x 5 x 5 x c_in*exp
        self.cnn2 = nn.Conv2d(in_channels*expansion, in_channels *
                              expansion, kernel_size=5, padding=2, stride=1)
        self.bnn2 = nn.BatchNorm2d(in_channels*expansion)

        self.cnn3 = nn.Conv2d(in_channels*expansion,
                              out_channels, kernel_size=1, stride=1)
        self.bnn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        """
        Forward pass through the decoder block.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        x = self.cnn1(x)
        x = self.bnn1(x)
        x = self.relu(x)

        if self.do_up_sampling:
            x = self.upsample(x)

        x = self.cnn2(x)
        x = self.bnn2(x)
        x = self.relu(x)

        x = self.cnn3(x)
        x = self.bnn3(x)

        return x

def get_decoder_layers(out_sizes=[512, 256, 128, 64, 32]):
    decoder_blocks = nn.ModuleList()
    for i, out_size in enumerate(out_sizes):
        if i == 0:
            decoder_blocks.append(DecoderBlock(out_size*2, out_size))
        elif i == len(out_sizes)-1:
            decoder_blocks.append(DecoderBlock(
                out_size*4, out_size, do_up_sampling=False))
        else:
            decoder_blocks.append(DecoderBlock(out_size*4, out_size))
    return decoder_blocks

print(get_decoder_layers())
print(get_encoder_layers())

ModuleList(
  (0): DecoderBlock(
    (relu): ReLU(inplace=True)
    (cnn1): Conv2d(1024, 3072, kernel_size=(1, 1), stride=(1, 1))
    (bnn1): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (upsample): Upsample(scale_factor=2.0, mode='nearest')
    (cnn2): Conv2d(3072, 3072, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (bnn2): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (cnn3): Conv2d(3072, 512, kernel_size=(1, 1), stride=(1, 1))
    (bnn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): DecoderBlock(
    (relu): ReLU(inplace=True)
    (cnn1): Conv2d(1024, 3072, kernel_size=(1, 1), stride=(1, 1))
    (bnn1): BatchNorm2d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (upsample): Upsample(scale_factor=2.0, mode='nearest')
    (cnn2): Conv2d(3072, 3072, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (bnn2): BatchNorm2d(30

In [7]:
""" ## Model Definition """

class UnetWithoutAT(nn.Module):
    def __init__(self, loss_weights=(10,5), lr=0.5):
        super(UnetWithoutAT, self).__init__()
        encoder_blocks, image_stem_layer, image_processor = get_encoder_layers()
        decoder_blocks = get_decoder_layers()

        self.encoder_blocks = encoder_blocks
        self.decoder_blocks = decoder_blocks

        self.image_processor = image_processor
        self.image_stem_layer = image_stem_layer
        
        self.out_image_stem_layer = nn.Sequential(
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2,
                               bias=False, padding=1, output_padding=1),
            nn.BatchNorm2d(3, eps=0.001, momentum=0.9997,
                           affine=True, track_running_stats=True),
            nn.ReLU()
        )

        self.loss_weights = loss_weights
        self.loss1 = nn.L1Loss()
        # self.loss2 = nn.
        
        self.metric = SSIM()
        

    def forward(self, x, process_image=False):
        """
        Performs forward pass through the U-Net model.

        Args:
            x (torch.Tensor): Input image tensor.
            process_image (bool): Whether to preprocess input image.

        Returns:
            torch.Tensor: Output image tensor.
        """
        if process_image:
            new_x = [self.image_processor(img)['pixel_values'][0] for img in x]
            x = torch.stack(new_x).permute(0, 3, 1, 2)
        assert x.shape[1] == 3, "Input image should have 3 channels(nx3x224x224)"

        temp_in_x = x
        x = self.image_stem_layer(x)
        enc_outputs = []
        for indx, enc_block in enumerate(self.encoder_blocks):
            x = enc_block(x)
            print(f"Encoder block {indx} output shape: {x.shape}")
            enc_outputs.append(x)
        for indx, dec_block in enumerate(self.decoder_blocks):
            if indx == 0:
                x = dec_block(x)
            else:
                x = dec_block(
                    torch.cat([x, enc_outputs[len(self.decoder_blocks) - indx - 1]], dim=1))
            print(f"Decoder block {indx} output shape: {x.shape}")
        return 0.5 * self.out_image_stem_layer(x) + temp_in_x
    
    def loss_layer(self,y_pred,y):
        '''
        Considering wieghted loss
        '''
        return self.loss_weights[0]*self.loss1(y_pred, y)
    
    def loss(self,x,y):
        with torch.no_grad():
            pred = self(x)
            return self.loss_layer(pred,y).item()

    def accuracy(self,x,y):
        with torch.no_grad():
            pred = self(x)
            return self.metric(pred,y)

    def fit(self, x_train, y_train, batch_size=32, n_epochs=3, validation=False, x_val=None, y_val=None):
        for epoch in tqdm(range(n_epochs)):
            #mini-batch gradient descent
            for i in range((x_train.shape[0] - 1) // batch_size + 1):
                start_i = i * batch_size
                end_i = start_i + batch_size
                xb = x_train[start_i:end_i]
                yb = y_train[start_i:end_i]
                pred = self(xb)

                self.optimizer.zero_grad()
                loss = self.loss_layer(pred,yb)
               
                metrices = {"train/loss": loss.item(),
                           "epoch": epoch+1, 
                }
                if validation:
                    metrices['val/loss'] = self.loss(x_val,y_val)
                    metrices['val/accuracy'] = self.accuracy(x_val,y_val)
                    
                wandb.log(metrices)
                
                loss.backward()
                self.optimizer.step()
        
        wandb.finish()
        return self.loss(x_train,y_train)

In [8]:
""" ## Model Definition """

class UnetWithoutAT(nn.Module):
    def __init__(self, loss_weights=(10,5), lr=0.5):
        super(UnetWithoutAT, self).__init__()
        encoder_blocks, image_stem_layer, image_processor = get_encoder_layers()
        decoder_blocks = get_decoder_layers()

        self.encoder_blocks = encoder_blocks
        self.decoder_blocks = decoder_blocks

        self.image_processor = image_processor
        self.image_stem_layer = image_stem_layer
        
        self.out_image_stem_layer = nn.Sequential(
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2,
                               bias=False, padding=1, output_padding=1),
            nn.BatchNorm2d(3, eps=0.001, momentum=0.9997,
                           affine=True, track_running_stats=True),
            nn.ReLU()
        )

        self.loss_weights = loss_weights
        self.loss1 = nn.L1Loss()
        # self.loss2 = nn.
        
        self.metric = SSIM()
        

    def forward(self, x, process_image=False):
        """
        Performs forward pass through the U-Net model.

        Args:
            x (torch.Tensor): Input image tensor.
            process_image (bool): Whether to preprocess input image.

        Returns:
            torch.Tensor: Output image tensor.
        """
        if process_image:
            new_x = [self.image_processor(img)['pixel_values'][0] for img in x]
            x = torch.stack(new_x).permute(0, 3, 1, 2)
        assert x.shape[1] == 3, "Input image should have 3 channels(nx3x224x224)"

        temp_in_x = x
        x = self.image_stem_layer(x)
        enc_outputs = []
        for indx, enc_block in enumerate(self.encoder_blocks):
            x = enc_block(x)
            print(f"Encoder block {indx} output shape: {x.shape}")
            enc_outputs.append(x)
        for indx, dec_block in enumerate(self.decoder_blocks):
            if indx == 0:
                x = dec_block(x)
            else:
                x = dec_block(
                    torch.cat([x, enc_outputs[len(self.decoder_blocks) - indx - 1]], dim=1))
            print(f"Decoder block {indx} output shape: {x.shape}")
        return 0.5 * self.out_image_stem_layer(x) + temp_in_x
    
    def loss_layer(self,y_pred,y):
        '''
        Considering wieghted loss
        '''
        return self.loss_weights[0]*self.loss1(y_pred, y)
    
    def loss(self,x,y):
        with torch.no_grad():
            pred = self(x)
            return self.loss_layer(pred,y).item()

    def accuracy(self,x,y):
        with torch.no_grad():
            pred = self(x)
            return self.metric(pred,y)

    def fit(self, train_loader, n_epochs=3, validation=False, val_loader=None, print_every_epoch=True):
        
        if print_every_epoch: print("Training started")
        
        for epoch in (range(n_epochs)):
            if print_every_epoch: print(f"Epoch {epoch+1}/{n_epochs}")
            
            for x_train, y_train in tqdm(train_loader):
                y_pred = self(x_train)
                loss = self.loss_layer(y_pred,y_train)
                loss.backward()
                self.optimizer.zero_grad()
                loss = self.loss_layer(y_pred,y_train)
                
                metrices = {"train/loss": loss.item(),
                           "epoch": epoch+1, 
                }
                if validation:
                    metrices['val/loss'] = 0
                    metrices['val/accuracy'] = 0
                    for x_val, y_val in val_loader:
                        metrices['val/loss'] += self.loss(x_val,y_val)
                        metrices['val/accuracy'] += self.accuracy(x_val,y_val)
                    
                    metrices['val/loss'] /= len(val_loader)
                    metrices['val/accuracy'] /= len(val_loader)
                    
                    if print_every_epoch: print(f"train_loss: {metrices['train/loss']}, val_loss: {metrices['val/loss']}, val_accuracy: {metrices['val/accuracy']}")

                wandb.log(metrices)                
                loss.backward()
                self.optimizer.step()
                
        if print_every_epoch: print("Training finished")
        
        wandb.finish()
        return self.loss(x_train,y_train)