In [1]:
!pip install pytorch_msssim

Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch_msssim
Successfully installed pytorch_msssim-1.0.0
[0m

In [2]:
#from kaggle_secrets import UserSecretsClient

import numpy as np
import os
import wandb
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch.nn import functional as F
from pytorch_lightning import LightningModule, Trainer
from pytorch_msssim import ssim
from torchvision.transforms.functional import to_pil_image

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


In [7]:
# user_secrets = UserSecretsClient()
# wandb_api = user_secrets.get_secret("wandb_api")

# # Wandb
# wandb.init(entity='upscale-dudes', project='csc-hackathon-2023')
# wandb_logger = WandbLogger(entity='upscale-dudes', project="csc-hackathon-2023")

# Dataset paths
# Train
#train
#train
train_hr_path = '/kaggle/input/fairfaceupsample/train/256_256'
train_lr_path = '/kaggle/input/fairfaceupsample/train/32_32'

#val
val_hr_path = '/kaggle/input/fairface-lq-10/fairface_lq-lite_v2/validation/256_256'
val_lr_path = '/kaggle/input/fairface-lq-10/fairface_lq-lite_v2/validation/32_32'

#test
test_hr_path = '/kaggle/input/fairface-lq-10/fairface_lq-lite_v2/test/256_256'
test_lr_path = '/kaggle/input/fairface-lq-10/fairface_lq-lite_v2/test/32_32'
BATCH_SIZE = 16

In [8]:
class UpscalingDataset(Dataset):
    def __init__(self, lr_folder, hr_folder):
        self.hr_folder = hr_folder
        self.lr_folder = lr_folder
        self.hr_images = sorted(os.listdir(hr_folder))
        self.lr_images = sorted(os.listdir(lr_folder))

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, index):
        hr_img_name = self.hr_images[index]
        lr_img_name = self.lr_images[index]
        hr_img_path = os.path.join(self.hr_folder, hr_img_name)
        lr_img_path = os.path.join(self.lr_folder, lr_img_name)

        hr_img = Image.open(hr_img_path).convert('RGB')
        lr_img = Image.open(lr_img_path).convert('RGB')

        lr_img = np.array(lr_img, dtype=np.float32)
        hr_img = np.array(hr_img, dtype=np.float32)
        
        lr_img /= 255.
        hr_img /= 255.
        lr_img = lr_img.transpose([2, 0, 1])
        hr_img = hr_img.transpose([2, 0, 1])

        return torch.tensor(lr_img, dtype=torch.float), torch.tensor(hr_img, dtype=torch.float)


class UpscalingDataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self.upscaling_train = UpscalingDataset(train_lr_path, train_hr_path)
            self.upscaling_val = UpscalingDataset(val_lr_path, val_hr_path)
        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.upscaling_test = UpscalingDataset(test_lr_path, test_hr_path)

    def train_dataloader(self):

        return DataLoader (self.upscaling_train, batch_size=self.batch_size,shuffle = True)

    def val_dataloader(self):
        return DataLoader(self.upscaling_val, batch_size=self.batch_size,shuffle = True)

    def test_dataloader(self):
        return DataLoader(self.upscaling_test, batch_size=self.batch_size,shuffle = True)


UDLRN(Shuffle)

In [12]:
class ConvReluBlock(nn.Module):
    def __init__(self, channelsin, channelsout):
        super(ConvReluBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(channelsin, channelsout, kernel_size = 3, stride=1, padding = 1, dilation = 1),
            nn.ReLU(),
            nn.Conv2d(channelsout, channelsout, kernel_size = 3, stride=1, padding = 1, dilation = 1),
            nn.ReLU()
        )

    def forward(self, x):
        res = x
        x = self.conv_block(x)
        out = x + res
        return out

class AttentionLikeBlock(nn.Module):
    def __init__(self, channel, reduct=4):
        super(AttentionLikeBlock, self).__init__()
        self.globPool = nn.AdaptiveAvgPool2d(1)
        self.convo1 = nn.Sequential(
            nn.Conv2d(channel, channel//reduct, kernel_size = 3, stride=1, padding = 3, dilation=3),
            nn.ReLU()
        )
        self.convo2 = nn.Sequential(
            nn.Conv2d(channel, channel//reduct, kernel_size = 3, stride=1, padding = 5, dilation=5),
            nn.ReLU()
        )
        self.convo3 = nn.Sequential(
            nn.Conv2d(channel, channel//reduct, kernel_size = 3, stride=1, padding = 7, dilation=7),
            nn.ReLU()
        )

        self.convsig = nn.Sequential(
            nn.Conv2d((channel//reduct)*3,channel , kernel_size = 3, stride=1, padding = 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        gd = self.globPool(x)
        r3 = self.convo2(gd)
        r5 = self.convo2(gd)
        r7 = self.convo2(gd)
        gp = torch.cat([r3,r5,r7], 1)
        sigm = self.convsig(gp)
        out = x*sigm
        return out

class DRLM(nn.Module):
    def __init__(self, channelsin, chanelsout, deapth=3):
        super(DRLM, self).__init__()
        self.baseblock = nn.ModuleList([ConvReluBlock(channelsin*(2**i), chanelsout*(2**i)) for i in range(deapth)])
        self.compresion = nn.Sequential(
            nn.Conv2d(channelsin*(2** len(self.baseblock)),chanelsout, kernel_size=1, padding=0, stride=1, dilation=1),
            nn.ReLU()
        )
        self.attention = AttentionLikeBlock(channelsin)

       
        
    def forward(self, x):
        longCon = x
        #print(x.size())
        for i, _ in enumerate(self.baseblock):
            skip = x
            x = self.baseblock[i](x)
            x = torch.cat((x, skip),1)
        x = self.compresion(x)
        x = self.attention(x)
        out  = longCon + x
        return out

class BaseBLock(nn.Module):
    def __init__(self, channels, channelsoutin):
        super(BaseBLock, self).__init__()
        self.drlm = DRLM(channels, channels)
        self.convoDown = nn.Sequential(
            nn.Conv2d(channelsoutin, channels, kernel_size=3, padding=1, stride=1, dilation=1),
            nn.ReLU()
        ) 

    def forward(self, x, cat):
        x = self.drlm(x)
        cat = torch.cat((x,cat),1)
        x = self.convoDown(cat)
        return(x, cat)



class CascadeBlock(nn.Module):
    def __init__(self, channels, deapth=3 ):
        super(CascadeBlock, self).__init__()
        self.baseBlocks = nn.ModuleList([BaseBLock(channels, channels*(i+2)) for i in range(deapth)])
    
    def forward(self, x):
        shortSkip = x
        cat = x
        for  i, _ in enumerate(self.baseBlocks):
            x, cat =  self.baseBlocks[i](x, cat)
        out = x+shortSkip
        return out

class UpsampleBlock(nn.Module):
    def __init__(self, n_channels):#, scale):
        super(UpsampleBlock, self).__init__()
        self.upscale = nn.Sequential(
            nn.Conv2d(n_channels, n_channels*8, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.PixelShuffle(8)
        )

    def forward(self, x):
        #print(x.size())
        x = self.upscale(x)
        return x

class UDLRN(LightningModule):
    def __init__(self, channels, deapth=6, channelsin=3):
        super(UDLRN, self).__init__()
        self.learning_rate = 1e-3
        self.inputConv = nn.Sequential(
            nn.Conv2d(channelsin, channels, kernel_size = 3, stride=1, padding = 1, dilation = 1),
            nn.ReLU()
        )
        self.cascadeBLocks = nn.ModuleList([CascadeBlock(channels) for i in range(deapth)])
        self.upscale  = UpsampleBlock(channels*8)
        #self.upscale = ESPCN(channels,channelsin,channels, 8)
        self.tail = nn.Conv2d(channels, 3, 3, 1, 1)

    def forward(self, x):
        x = self.inputConv(x)
        longSkip = x
        longcat = [x]
        for  i, _ in enumerate(self.cascadeBLocks):
            x =  self.cascadeBLocks[i](x)
            longcat.append(x)
        x = x + longSkip
        longcat.append(x)
        out = torch.cat(longcat, 1)
        out = self.upscale(out)
        out = self.tail(out)
        out = (torch.tanh(out)+1)/2
        return out
    
    def training_step(self, batch, batch_idx):
        lr_img, hr_img = batch
        out = self(lr_img)
        metrics = self.calculate_metrics(hr_img, out)
        self.log_dict({f'train_{k}': v for k, v in metrics.items()}, on_step=True, on_epoch=True)
        return metrics['mse']
    
    def validation_step(self, batch, batch_idx):
        lr_img, hr_img = batch
        out = self(lr_img)
        metrics = self.calculate_metrics(hr_img, out)
        self.log_dict({f'val_{k}': v for k, v in metrics.items()}, on_step=True, on_epoch=True)
        if self.current_epoch % 5 == 0:  # Log images every 5 epochs
            # Convert tensors to PIL Images
            hr_img = to_pil_image(hr_img[0])
            lr_img = to_pil_image(lr_img[0])
            out = to_pil_image(out[0])
            # Log images to wandb
            self.logger.experiment.log({
                "hr_images": wandb.Image(hr_img),
                "lr_images": wandb.Image(lr_img),
                "out_images": wandb.Image(out),
            })
    def test_step(self, batch, batch_idx):
        lr_img, hr_img = batch
        out = self(lr_img)
        metrics =self.calculate_metrics(hr_img, out)
        self.log_dict({f'val_{k}': v for k, v in metrics.items()}, on_step=True, on_epoch=True)
            
#     def validation_step(self, batch, batch_idx):
#         lr_img, hr_img = batch
#         out = self(lr_img)
#         metrics = self.calculate_metrics(hr_img, out)
#         self.log_dict({f'val_{k}': v for k, v in metrics.items()}, on_step=True, on_epoch=True)

    def calculate_metrics(self, high_res, low_res):
        mse_loss = F.mse_loss(high_res, low_res)
        psnr = 10 * torch.log10(1 / mse_loss)
        ssim_val = ssim(high_res, low_res, data_range=1.0)
        return {'mse': mse_loss, 'psnr': psnr, 'ssim': ssim_val}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)


        scheduler = {
            'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min'),
            'monitor': 'val_mse', 
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_mse"}



model = UDLRN(channels = 64)
input_tensor = torch.rand(1, 3, 32, 32)
output_tensor  = model( input_tensor)
print("Output shape:", output_tensor.shape)  # Output shape: torch.Size([1, 3, 256, 256])


       


Output shape: torch.Size([1, 3, 256, 256])


In [13]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_mse',  # Specify the validation loss to monitor
    dirpath='/kaggle/working/models/',  # Directory where the models will be saved
    filename='upscaling-{epoch:02d}-{val_mse:.4f}-{val_psnr:.3f}',  # Template for the saved model's name
    save_top_k=1,  # Save only the best model
    mode='min',  # Minimize validation loss
    
)


In [14]:
data_module = UpscalingDataModule(BATCH_SIZE)
data_module.setup()

In [15]:
# Wandb
wandb.init(entity='upscale-dudes', project='csc-hackathon-2023')
wandb_logger = WandbLogger(entity='upscale-dudes', project="csc-hackathon-2023")

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


  rank_zero_warn(


In [None]:
# Define your model
model = UDLRN(channels = 64)

# Fit the model
trainer = pl.Trainer(max_epochs=100, callbacks=[checkpoint_callback], logger=wandb_logger)  # use wandb_logger for Weights & Biases logging
trainer.fit(model, data_module)
#trainer.test(model, data_module) # with fine-tuning

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
import shutil
shutil.make_archive('model', 'zip', '/kaggle/working/')
%cd /kaggle/working
from IPython.display import FileLink
FileLink(r'model.zip')

In [None]:
'''model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
trainer.test(model, data_module) # with fine-tuning'''