In [1]:
import warnings

warnings.filterwarnings("ignore")

In [2]:
%matplotlib inline

In [3]:
import os
import sys

In [4]:
module_path = os.path.abspath(os.path.join("..", "src"))

if module_path not in sys.path:
    sys.path.append(module_path)

In [5]:
import matplotlib.pyplot as plt

from PIL import Image
from natsort import natsorted

import scipy

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader
from torch.nn.init import kaiming_normal_ as HeNormal
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

import torchvision
from torchvision import transforms


In [6]:
print(torch.__version__)

2.7.0+cu128


In [7]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print(f'Using GPU: {device}')
else:   
    device = torch.device("cpu")
    print("Using CPU")  

Using GPU: cuda:0


In [8]:
class KodakDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.images = natsorted(os.listdir(root))
        self.images = [os.path.join(root, img) for img in self.images if img.endswith('.png') or img.endswith('.jpg')]

    def __len__(self):        
        return len(self.images)

    def __getitem__(self, idx):
        IMG_PATH = self.images[idx]
        image = Image.open(IMG_PATH)
        if self.transform:
            image = self.transform(image)
        return image

In [9]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size=(1,1), padding=1)
        self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=(3,3), padding=1)
        self.conv2 = nn.Conv2d(out_channels, in_channels, kernel_size=(1,1), padding=1)

    def forward(self, x):
        residual = x

        x = F.relu(self.conv0(x))
        x = F.relu(self.conv1(x))
        x = self.conv2(x)

        out = torch.add(residual, x)

        return F.relu(out)

In [None]:
class MaskedConvolution(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MaskedConvolution, self).__init__()

In [10]:
class AttentionModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionModule, self).__init__()
        self.trunc0 = ResidualBlock(in_channels, out_channels)
        self.trunc1 = ResidualBlock(in_channels, out_channels)
        self.trunc2 = ResidualBlock(in_channels, out_channels)

        self.atten0 = ResidualBlock(in_channels, out_channels)
        self.atten1 = ResidualBlock(in_channels, out_channels)
        self.atten2 = ResidualBlock(in_channels, out_channels)

        self.conv0 = nn.Conv2d(in_channels, in_channels, kernel_size=(1,1), padding=1)

    def forward(self, x):
        residual = x

        trunc = self.trunc0(residual)
        trunc = self.trunc1(trunc)
        trunc = self.trunc2(trunc)

        atten = self.atten0(residual)
        atten = self.atten1(atten)
        atten = self.atten2(atten)
        atten = self.conv0(atten)
        atten = F.sigmoid(atten)

        out = torch.mul(trunc, atten)
        out = torch.add(residual, out)

        return out

In [None]:
class GDN(nn.Module):
    def __init__(self, in_channels):
        super(GDN, self).__init__()

In [None]:
class AnalysisTransform(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AnalysisTransform, self).__init__()

In [None]:
class HyperAnalysisTransform(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(HyperAnalysisTransform, self).__init__()

In [None]:
class EntorpyParameters(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EntorpyParameters, self).__init__()

In [None]:
class SynthesisTransform(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SynthesisTransform, self).__init__()

In [None]:
class HyperSynthesisTransform(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(HyperSynthesisTransform, self).__init__()

In [None]:
class Compressor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Compressor, self).__init__()

In [None]:
class Decompressor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decompressor, self).__init__()