#### Imports

In [3]:
!pip install pytorch-msssim -q


[0m

In [4]:
!pip install --upgrade wandb -q


[0m

In [5]:
#from kaggle_secrets import UserSecretsClient
import wandb
import numpy as np
import os
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
from torch.nn import functional as F
from pytorch_lightning import LightningModule, Trainer
from pytorch_msssim import ssim
from torchvision.transforms.functional import to_pil_image

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


#### Setups

In [6]:
# Dataset paths
#train
train_hr_path = '/kaggle/input/fairfaceupsample/train/256_256'
train_lr_path = '/kaggle/input/fairfaceupsample/train/32_32'

#val
val_hr_path = '/kaggle/input/fairface-lq-10/fairface_lq-lite_v2/validation/256_256'
val_lr_path = '/kaggle/input/fairface-lq-10/fairface_lq-lite_v2/validation/32_32'

#test
test_hr_path = '/kaggle/input/fairface-lq-10/fairface_lq-lite_v2/test/256_256'
test_lr_path = '/kaggle/input/fairface-lq-10/fairface_lq-lite_v2/test/32_32'


BATCH_SIZE = 32

#### Data modules

In [7]:
class UpscalingDataset(Dataset):
    def __init__(self, lr_folder, hr_folder):
        self.hr_folder = hr_folder
        self.lr_folder = lr_folder
        self.hr_images = sorted(os.listdir(hr_folder))
        self.lr_images = sorted(os.listdir(lr_folder))

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, index):
        hr_img_name = self.hr_images[index]
        lr_img_name = self.lr_images[index]
        hr_img_path = os.path.join(self.hr_folder, hr_img_name)
        lr_img_path = os.path.join(self.lr_folder, lr_img_name)

        hr_img = Image.open(hr_img_path).convert('RGB')
        lr_img = Image.open(lr_img_path).convert('RGB')

        lr_img = np.array(lr_img, dtype=np.float32)
        hr_img = np.array(hr_img, dtype=np.float32)
        
        lr_img /= 255.
        hr_img /= 255.
        lr_img = lr_img.transpose([2, 0, 1])
        hr_img = hr_img.transpose([2, 0, 1])

        return torch.tensor(lr_img, dtype=torch.float), torch.tensor(hr_img, dtype=torch.float)


class UpscalingDataModule(pl.LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            self.upscaling_train = UpscalingDataset(train_lr_path, train_hr_path)
            self.upscaling_val = UpscalingDataset(val_lr_path, val_hr_path)
        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.upscaling_test = UpscalingDataset(test_lr_path, test_hr_path)

    def train_dataloader(self):
        return DataLoader(self.upscaling_train, batch_size=self.batch_size,shuffle = True)

    def val_dataloader(self):
        return DataLoader(self.upscaling_val, batch_size=self.batch_size,shuffle = True)

    def test_dataloader(self):
        return DataLoader(self.upscaling_test, batch_size=self.batch_size,shuffle = True)


In [8]:
class CNA(nn.Module):
    def __init__(self, in_nc, out_nc, stride=1):
        super().__init__()
        
        self.conv = nn.Conv2d(in_nc, out_nc, 3, stride=stride, padding=1, bias=False)
        self.norm = nn.BatchNorm2d(out_nc)
        self.act = nn.GELU()
    
    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out)
        out = self.act(out)
        
        return out

class UpscaleBlock(nn.Module):
    def __init__(self, in_nc, out_nc):
        super().__init__()
        self.upsample = nn.ConvTranspose2d(in_nc, out_nc, kernel_size=4, stride=2, padding=1, bias=False)
        self.norm = nn.BatchNorm2d(out_nc)
        self.act = nn.GELU()
    
    def forward(self, x):
        out = self.upsample(x)
        out = self.norm(out)
        out = self.act(out)
        
        return out


In [9]:
class DeconvNet(LightningModule):
    def __init__(self, in_nc=3, nc=32, out_nc=3):
        super().__init__()
        self.learning_rate = 1e-3
        # Upscaling Blocks
        self.upscale_input = nn.Sequential(
            UpscaleBlock(in_nc, nc),
            CNA(nc, nc)  # Additional convolution layer for accuracy
        )
        self.upscale_blocks = nn.ModuleList()
        for i in range(2):
            self.upscale_blocks.append(UpscaleBlock(nc, nc))
            self.upscale_blocks.append(CNA(nc, nc))  # Additional convolution layer for accuracy
        self.cna_last = CNA(nc, nc)
        self.conv_last = nn.Conv2d(nc, out_nc, 3, padding=1)

    def forward(self, x):
        out  = self.upscale_input(x)
        for upscale_block in self.upscale_blocks:
            out = upscale_block(out)
        out = self.cna_last(out)
        out = self.conv_last(out)
        out = (torch.tanh(out)+1)/2
        return out

    def training_step(self, batch, batch_idx):
        lr_img, hr_img = batch
        out = self(lr_img)
        metrics = self.calculate_metrics(hr_img, out)
        self.log_dict({f'train_{k}': v for k, v in metrics.items()}, on_step=True, on_epoch=True)
        return metrics['mse']
    
    def validation_step(self, batch, batch_idx):
        lr_img, hr_img = batch
        out = self(lr_img)
        metrics = self.calculate_metrics(hr_img, out)
        self.log_dict({f'val_{k}': v for k, v in metrics.items()}, on_step=True, on_epoch=True)
        if self.current_epoch % 2 == 0:  # Log images every 5 epochs
            # Convert tensors to PIL Images
            hr_img = to_pil_image(hr_img[0])
            lr_img = to_pil_image(lr_img[0])
            out = to_pil_image(out[0])
            # Log images to wandb
            self.logger.experiment.log({
                "hr_images": wandb.Image(hr_img),
                "lr_images": wandb.Image(lr_img),
                "out_images": wandb.Image(out),
            })
            
#     def validation_step(self, batch, batch_idx):
#         lr_img, hr_img = batch
#         out = self(lr_img)
#         metrics = self.calculate_metrics(hr_img, out)
#         self.log_dict({f'val_{k}': v for k, v in metrics.items()}, on_step=True, on_epoch=True)

    def calculate_metrics(self, high_res, low_res):
        mse_loss = F.mse_loss(high_res, low_res)
        psnr = 10 * torch.log10(1 / mse_loss)
        ssim_val = ssim(high_res, low_res, data_range=1.0)
        return {'mse': mse_loss, 'psnr': psnr, 'ssim': ssim_val}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    
model = DeconvNet()
input_tensor = torch.rand(1, 3, 32, 32)
output_tensor = model(input_tensor)
print("Output shape:", output_tensor.shape)


Output shape: torch.Size([1, 3, 256, 256])


In [10]:
data_module = UpscalingDataModule(BATCH_SIZE)
data_module.setup('fit')
data_load = data_module.train_dataloader()
imagelr, imagehr = next(iter(data_load))
print(imagehr)

tensor([[[[0.2275, 0.2314, 0.2353,  ..., 0.6784, 0.6745, 0.6706],
          [0.2314, 0.2353, 0.2353,  ..., 0.6784, 0.6745, 0.6706],
          [0.2157, 0.2157, 0.2157,  ..., 0.6784, 0.6745, 0.6706],
          ...,
          [0.6000, 0.6000, 0.6000,  ..., 0.0314, 0.0275, 0.0235],
          [0.5961, 0.5961, 0.6039,  ..., 0.0275, 0.0275, 0.0235],
          [0.6000, 0.6000, 0.6078,  ..., 0.0275, 0.0275, 0.0275]],

         [[0.2392, 0.2431, 0.2471,  ..., 0.7412, 0.7373, 0.7333],
          [0.2431, 0.2471, 0.2471,  ..., 0.7412, 0.7373, 0.7333],
          [0.2392, 0.2392, 0.2392,  ..., 0.7412, 0.7373, 0.7333],
          ...,
          [0.3569, 0.3569, 0.3569,  ..., 0.0314, 0.0275, 0.0235],
          [0.3647, 0.3647, 0.3608,  ..., 0.0275, 0.0275, 0.0235],
          [0.3686, 0.3686, 0.3647,  ..., 0.0275, 0.0275, 0.0275]],

         [[0.3059, 0.3098, 0.3137,  ..., 0.8000, 0.7961, 0.7922],
          [0.3098, 0.3059, 0.3137,  ..., 0.8000, 0.7961, 0.7922],
          [0.2941, 0.2863, 0.2941,  ..., 0

#### Custom callbacks

In [12]:
from pytorch_lightning.callbacks import ModelCheckpoint

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_mse',  # Specify the validation loss to monitor
    dirpath='/kaggle/working/models/',  # Directory where the models will be saved
    filename='upscaling-{epoch:02d}-{val_mse:.3f}-{val_psnr:.3f}',  # Template for the saved model's name
    save_top_k=1,  # Save only the best model
    mode='min',  # Minimize validation loss
)


#### Train

In [13]:

wandb_logger = WandbLogger(entity='upscale-dudes', project="csc-hackathon-2023")
# Wandb
wandb.init(entity='upscale-dudes', project='csc-hackathon-2023')
# Prepare data
data_module = UpscalingDataModule(BATCH_SIZE)

# Define your model
model = DeconvNet()

# Fit the model
trainer = pl.Trainer(max_epochs=70, callbacks=[checkpoint_callback], logger=wandb_logger)  # use wandb_logger for Weights & Biases logging
trainer.fit(model, data_module)

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

#### Train template

In [14]:
import shutil
shutil.make_archive('model', 'zip', '/kaggle/working/')
%cd /kaggle/working
from IPython.display import FileLink
FileLink(r'model.zip')

/kaggle/working
