In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

sample_rate = 0.2
n_threads = 2 # 4 #number of threads for data loader to use
batchSize = 64 # 32 # 128
nEpochs = 400 # 800
learning_rate = 0.002 #0.0002
beta1 = 0.5 # Adam momentum term
nef = 64 # number of encoder filters in first conv layer
imageSize = 32 # 64
modelE_name = "model_best.pth"
B_residual_block = 8 #8 
train_size_p = 0.9
val_size_p = 1-train_size_p

In [None]:
import torch
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets, transforms

# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load CIFAR-10 dataset
download = False
train_set = datasets.CIFAR10(root='./data', train=True, download=download, transform=transform)
test_set = datasets.CIFAR10(root='./data', train=False, download=download, transform=transform)

# Combine train and test sets
full_dataset = torch.utils.data.ConcatDataset([train_set, test_set])
subset = Subset(full_dataset, range(10000))

# Split the dataset (90% training, 10% validation)
train_size = int(train_size_p * len(subset))
val_size = len(subset) - train_size
train_dataset, val_dataset = random_split(subset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batchSize, shuffle=True, num_workers=n_threads)
val_loader = DataLoader(val_dataset, batch_size=batchSize, shuffle=False, num_workers=n_threads)

print("CIFAR-10 dataset loaded successfully!")
print(f"Number of training images: {len(train_dataset)}")
print(f"Number of validation images: {len(val_dataset)}")

# Example: Access a batch
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}")

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# Get a batch of images once (instead of inside the loop)
batch = next(iter(train_loader))
imgs, _ = batch  # Unpack images and labels

plt.figure(figsize=(12, 12))
for i in range(9):
    plt.subplot(3, 3, i+1)
    plt.imshow((imgs[i].permute(1,2,0) + 1.0) / 2.0)
    plt.axis("off")  # Hide axis for better visualization
plt.tight_layout()
plt.show()


In [None]:
import torch.nn as nn
#import torch.nn.functional as F
import torch.nn.init as init
import torch


class AttrProxy(object):
    """Translates index lookups into attribute lookups."""
    def __init__(self, module, prefix):
        self.module = module
        self.prefix = prefix

    def __getitem__(self, i):
        return getattr(self.module, self.prefix + str(i))


class NetE(nn.Module):
    def __init__(self, nef):
        super(NetE, self).__init__()
        # state size: (nc) x 64 x 64
        self.conv1 = nn.Conv2d(3, nef, (4, 4), (2, 2), (1, 1), bias=False)
        self.conv1_bn = nn.BatchNorm2d(nef)
        self.conv1_relu = nn.LeakyReLU(0.2, inplace=False)
        # state size: (nef) x 32 x 32
        self.conv2 = nn.Conv2d(nef, nef*2, (4, 4), (2, 2), (1, 1), bias=False)
        self.conv2_bn = nn.BatchNorm2d(nef*2)
        self.conv2_relu = nn.LeakyReLU(0.2, inplace=False)
        # state size: (nef*2) x 16 x 16
        self.conv3 = nn.Conv2d(nef*2, nef*4, (4, 4), (2, 2), (1, 1), bias=False)
        self.conv3_bn = nn.BatchNorm2d(nef*4)
        self.conv3_relu = nn.LeakyReLU(0.2, inplace=False)
        # state size: (nef*4) x 8 x 8
        self.conv4 = nn.Conv2d(nef*4, nef*8, (4, 4), (2, 2), (1, 1), bias=False)
        self.conv4_bn = nn.BatchNorm2d(nef*8)
        self.conv4_relu = nn.LeakyReLU(0.2, inplace=False)
        # state size: (nef*8) x 4 x 4

        # channel-wise fully connected layer
        self.channel_wise_layers = []
        fla = int(imageSize**2/256)
        for i in range(0, 512):
            self.add_module('channel_wise_layers_' + str(i), nn.Linear(fla, fla))

        self.channel_wise_layers = AttrProxy(self, 'channel_wise_layers_')

        # state size: (nef*8) x 4 x 4
        self.dconv1 = nn.ConvTranspose2d(nef*8, nef*4, (4, 4), (2, 2), (1, 1), bias=False)
        self.dconv1_bn = nn.BatchNorm2d(nef*4)
        self.dconv1_relu = nn.ReLU(inplace=True)
        # state size: (nef*4) x 8 x 8
        self.dconv2 = nn.ConvTranspose2d(nef*4, nef*2, (4, 4), (2, 2), (1, 1), bias=False)
        self.dconv2_bn = nn.BatchNorm2d(nef*2)
        self.dconv2_relu = nn.ReLU(inplace=True)
        # state size: (nef*2) x 16 x 16
        self.dconv3 = nn.ConvTranspose2d(nef*2, nef, (4, 4), (2, 2), (1, 1), bias=False)
        self.dconv3_bn = nn.BatchNorm2d(nef)
        self.dconv3_relu = nn.ReLU(inplace=True)
        # state size: (nef) x 32 x 32
        self.dconv4 = nn.ConvTranspose2d(nef, 3, (4, 4), (2, 2), (1, 1), bias=False)
        self.dconv4_tanh = nn.Tanh()
        # self.dconv1_bn = nn.BatchNorm2d(3)
        # state size: (nc) x 64 x 64

        self._initialize_weights()

    def forward(self, x):
        x = self.conv1_relu(self.conv1_bn(self.conv1(x)))
        x = self.conv2_relu(self.conv2_bn(self.conv2(x)))
        x = self.conv3_relu(self.conv3_bn(self.conv3(x)))
        x = self.conv4_relu(self.conv4_bn(self.conv4(x)))

        for i in range(0, 512):
            slice_cur = x[:,[i],:,:]
            slice_cur_size = slice_cur.size()
            slice_cur = slice_cur.view(slice_cur_size[0], slice_cur_size[2]*slice_cur_size[3])
            slice_cur = self.channel_wise_layers[i](slice_cur)
            x[:,[i],:,:] = slice_cur.view(slice_cur_size[0], slice_cur_size[1], slice_cur_size[2], slice_cur_size[3])

        x = self.dconv1_relu(self.dconv1_bn(self.dconv1(x)))
        x = self.dconv2_relu(self.dconv2_bn(self.dconv2(x)))
        x = self.dconv3_relu(self.dconv3_bn(self.dconv3(x)))
        x = self.dconv4_tanh(self.dconv4(x))
        return x

    def _initialize_weights(self):

        init.normal_(self.conv1_bn.weight,  1.0, 0.02)
        init.normal_(self.conv2_bn.weight,  1.0, 0.02)
        init.normal_(self.conv3_bn.weight,  1.0, 0.02)
        init.normal_(self.conv4_bn.weight,  1.0, 0.02)
        init.normal_(self.dconv1_bn.weight, 1.0, 0.02)
        init.normal_(self.dconv2_bn.weight, 1.0, 0.02)
        init.normal_(self.dconv3_bn.weight, 1.0, 0.02)

        init.constant_(self.conv1_bn.bias,    0.0)
        init.constant_(self.conv2_bn.bias,    0.0)
        init.constant_(self.conv3_bn.bias,    0.0)
        init.constant_(self.conv4_bn.bias,    0.0)
        init.constant_(self.dconv1_bn.bias,   0.0)
        init.constant_(self.dconv2_bn.bias,   0.0)
        init.constant_(self.dconv3_bn.bias,   0.0)

        init.normal_(self.conv1.weight,  0.0, 0.02)
        init.normal_(self.conv2.weight,  0.0, 0.02)
        init.normal_(self.conv3.weight,  0.0, 0.02)
        init.normal_(self.conv4.weight,  0.0, 0.02)
        init.normal_(self.dconv1.weight, 0.0, 0.02)
        init.normal_(self.dconv2.weight, 0.0, 0.02)
        init.normal_(self.dconv3.weight, 0.0, 0.02)
        init.normal_(self.dconv4.weight, 0.0, 0.02)

In [None]:
!pip install tensorboard_logger

In [None]:
from tensorboard_logger import configure
configure("tensorBoardRuns/on-demand-learn-p-02-zero-corrupt-0-conv-bias-0-cwfc-epoch-800")

## Training

In [None]:
import torch.optim as optim
from math import log10
#from tensorboard_logger import log_value
import os
from tqdm import tqdm
import numpy as np

model = NetE(nef=64)
criterion = nn.MSELoss()

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

print('===> Total Model NetE Parameters:', sum(param.numel() for param in model.parameters()))

print('===> Initialize Optimizer...')
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, 0.999))

if not os.path.exists("epochs_NetE"):
        os.makedirs("epochs_NetE")

if not os.path.exists("tensorBoardRuns"):
        os.makedirs("tensorBoardRuns")

train_loss = []
train_psnr = []
val_loss = []
val_psnr = []
def train(epoch):
    epoch_loss = 0
    epoch_psnr = 0

    #   Step up learning rate decay
    lr = learning_rate
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(beta1, 0.999))


    for iteration, batch in tqdm(enumerate(train_loader, 1)):
        target, _ = batch
        batch_size = target.size(0)
        image = target.clone()
        #   Corrupt the target image
        for i in range(0, batch_size):
            corrupt_mask = np.random.binomial(1, (1 - sample_rate), (imageSize, imageSize))
            corrupt_mask.astype(np.uint8)
            #corrupt_mask = torch.ByteTensor(corrupt_mask)
            corrupt_mask = torch.tensor(corrupt_mask, dtype=torch.bool)

            image[i,0,:,:].masked_fill_(corrupt_mask, (0.0))
            image[i,1,:,:].masked_fill_(corrupt_mask, (0.0))
            image[i,2,:,:].masked_fill_(corrupt_mask, (0.0))

        if torch.cuda.is_available():
            image = image.cuda()
            target = target.cuda()

        optimizer.zero_grad()
        loss = criterion((model(image)+1.0)/2.0, (target+1.0)/2.0)
        psnr = 10 * log10(1 / loss.item())
        epoch_loss += loss.item()
        epoch_psnr += psnr
        loss.backward()
        optimizer.step()

    print("===> Epoch {} Complete: lr: {}, Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB".format(epoch, lr, epoch_loss / len(train_loader), epoch_psnr / len(train_loader)))


    #log_value('train_loss', epoch_loss / len(train_loader), epoch)
    #log_value('train_psnr', epoch_psnr / len(train_loader), epoch)

    train_loss.append(epoch_loss / len(train_loader))
    train_psnr.append(epoch_psnr / len(train_loader))

PSNR_best = 0

def val(epoch):
    avg_psnr = 0
    avg_mse = 0
    for batch in val_loader:
        target, _ = batch
        batch_size = target.size(0)
        image = target.clone()
        #   Corrupt the target image
        for i in range(0, batch_size):
            corrupt_mask = np.random.binomial(1, (1 - sample_rate), (imageSize, imageSize))
            corrupt_mask.astype(np.uint8)
            corrupt_mask = torch.tensor(corrupt_mask, dtype=torch.bool)

            image[i,0,:,:].masked_fill_(corrupt_mask, (0.0))
            image[i,1,:,:].masked_fill_(corrupt_mask, (0.0))
            image[i,2,:,:].masked_fill_(corrupt_mask, (0.0))

        if torch.cuda.is_available():
            image = image.cuda()
            target = target.cuda()

        prediction = model(image)
        mse = criterion((prediction+1.0)/2.0, (target+1.0)/2.0)
        psnr = 10 * log10(1 / mse.item())
        avg_psnr += psnr
        avg_mse  += mse.item()
    print("===> Epoch {} Validation: Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB".format(epoch, avg_mse / len(val_loader), avg_psnr / len(val_loader)))

    #log_value('val_loss', avg_mse / len(val_loader), epoch)
    #log_value('val_psnr', avg_psnr / len(val_loader), epoch)

    val_loss.append(avg_mse / len(val_loader))
    val_psnr.append(avg_psnr / len(val_loader))

    global PSNR_best
    if avg_psnr > PSNR_best:
        PSNR_best = avg_psnr
        model_out_path = "epochs_NetE/" + "model_best.pth".format(epoch)
        torch.save(model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

def checkpoint(epoch):
    if epoch%100 == 0:
        model_out_path = "epochs_NetE/" + "model_epoch_{}.pth".format(epoch)
        torch.save(model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

val(0)
checkpoint(0)
for epoch in range(1, nEpochs + 1):
    train(epoch)
    val(epoch)
    checkpoint(epoch)

## Trainign NetE

In [None]:
import os
if os.path.exists("epochs_NetE/" + "model_best.pth"):
    model = torch.load("epochs_NetE/" + "model_best.pth")
    print("Model loaded")

In [None]:
import matplotlib.pyplot as plt

def plot_loss(train_losses, val_losses,titre):

    epochs = range(0, len(train_losses))

    plt.figure(figsize=(8, 6))
    plt.plot(epochs, train_losses, label="Train Loss")
    plt.plot(epochs, val_losses, label="Val Loss")

    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Évolution de la "+titre+" au fil des Epochs")
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
plot_loss(train_loss, val_loss[1:],"loss")
plot_loss(train_psnr, val_psnr[1:],"psnr")

# Test NetE on test data

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFile

# Load the image
'''image_path = "Eolienne.jpg"

# Open and preprocess the image
original_image = Image.open(image_path).convert("RGB")
original_image = original_image.resize((imageSize, imageSize), Image.Resampling.LANCZOS)

# Convert to numpy and normalize
original_np = np.asarray(original_image, dtype=np.float32) / 255.0'''
tensor,_ = val_dataset[np.random.randint(0,len(val_dataset)-1)]
original_np = tensor.permute(1, 2, 0).numpy()

# Convert to torch tensor and rearrange dimensions to (C, H, W)
original_tensor = torch.tensor(original_np).permute(2, 0, 1).unsqueeze(0)  # Add batch dimension

# Apply corruption mask
corrupted_np = original_np.copy()
corrupt_mask = np.random.binomial(1, (1 - sample_rate), (imageSize, imageSize)).astype(bool)

# Apply mask to each channel
for i in range(3):  # RGB channels
    corrupted_np[:, :, i][corrupt_mask] = 0.0

# Convert corrupted image to tensor
corrupted_tensor = torch.tensor(corrupted_np).permute(2, 0, 1).unsqueeze(0)

# Move tensors to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
corrupted_tensor = corrupted_tensor.to(device)

# Perform inference
model.eval()
with torch.no_grad():
    predicted_tensor = model(corrupted_tensor)
    predicted_tensor = predicted_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert back to (H, W, C)

# Display results
fig, ax = plt.subplots(1, 3, figsize=(12, 4))

ax[0].imshow((original_np+1.0)/2.0)
ax[0].set_title("Original Image")
ax[0].axis("off")

ax[1].imshow((corrupted_np+1.0)/2.0)
ax[1].set_title("Corrupted Image")
ax[1].axis("off")

ax[2].imshow((predicted_tensor+1.0)/2.0)  # Clip values between 0-1
ax[2].set_title("Predicted Image")
ax[2].axis("off")

plt.show()


# NetM

In [None]:
import torch.nn as nn
#import torch.nn.functional as F
import torch.nn.init as init
import torch
import math
class Mean_Shift(nn.Module):
    def __init__(self, sample_rate=0.2):
        super(Mean_Shift, self).__init__()
        self.sample_rate = sample_rate
        self.sample_rate = torch.autograd.Variable(torch.tensor(sample_rate), requires_grad=False)
        if torch.cuda.is_available(): self.sample_rate = self.sample_rate.cuda()

    def forward(self, x):
        x_size = x.size()

        x_mean = torch.mean(x, 2, True)
        x_mean = torch.mean(x_mean, 3, True)
        x_mean = x_mean.expand(x_size[0], x_size[1], x_size[2], x_size[3])

        x_out = x / x_mean * self.sample_rate

        return x_out


class _Residual_Block(nn.Module):
    def __init__(self):
        super(_Residual_Block, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.in1 = nn.InstanceNorm2d(64, affine=True)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.in2 = nn.InstanceNorm2d(64, affine=True)

    def forward(self, x):
        identity_data = x
        output = self.relu(self.in1(self.conv1(x)))
        output = self.in2(self.conv2(output))
        output = torch.add(output,identity_data)
        return output


class NetM(nn.Module):
    def __init__(self, nef, sample_rate):
        super(NetM, self).__init__()

        self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

        self.residual = self.make_layer(_Residual_Block, B_residual_block)

        self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn_mid = nn.InstanceNorm2d(64, affine=True)

        self.conv_output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=9, stride=1, padding=4, bias=False)
        self.conv_output_bn = nn.BatchNorm2d(1)
        self.conv_output_sig = nn.Sigmoid()

        self.mean_shift = Mean_Shift(sample_rate=sample_rate)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()

    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.conv_input(x))
        residual = out
        out = self.residual(out)
        out = self.bn_mid(self.conv_mid(out))
        out = torch.add(out, residual)

        out = self.conv_output_sig(self.conv_output_bn(self.conv_output(out)))
        out = self.mean_shift(out)

	# iterative mean-clamp, to keep the sample rate precise
        for i in range(0, 25):
            #   Clip to [0, 1]
            out = torch.clamp(out, min=0.0, max=1.0)
            out = self.mean_shift(out)
        out = torch.clamp(out, min=0.0, max=1.0)
        return out

class NetME(nn.Module):
    def __init__(self, nef, NetE_name, sample_rate):
        super(NetME, self).__init__()
        self.netM  = NetM(nef = 64, sample_rate = sample_rate)
        self.netE = NetE(nef = 64)
        self.netE = torch.load(NetE_name, weights_only=False)

    def forward(self, x):
        x_clone = x.clone()
        mask = self.netM(x)
        mask_4d = mask.expand(mask.shape[0], 3, mask.shape[2], mask.shape[3])

        mask_x = mask_4d * x_clone
        x_recon = self.netE(mask_x)

        return mask, x_recon


train NETM


In [None]:
#from tensorboard_logger import configure
#print('===> Initialize Logger...')
#configure("tensorBoardRuns/mask-train-conti-on-demand-learn-p-02-zero-corrupt-zero-conv-bias-conti-ber-train-v4-cwfc-one-net-eval-h5-val-sig-M-res_net-clip-mean-iter-switch-epoch-800")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from math import log10
#from tensorboard_logger import configure, log_value, log_images
import os
from tqdm import tqdm
import numpy as np
from math import ceil
import gc

torch.cuda.empty_cache()

print('===> Building ME model...')
modelME = NetME(nef = nef, NetE_name = 'epochs_NetE/' + modelE_name, sample_rate = sample_rate)

if torch.cuda.is_available():
    modelME = modelME.cuda()
modelME.netM.train()
modelME.netE.eval()


criterion = nn.MSELoss()

if torch.cuda.is_available():
    criterion = criterion.cuda()

print('===> Total Model NetME Parameters:', sum(param.numel() for param in modelME.parameters()))

print('===> Initialize Optimizer...')
optimizer = optim.Adam([{'params': modelME.netM.parameters(), 'lr': learning_rate},
                        {'params': modelME.netE.parameters(), 'lr': 0.0}
                        ], lr=learning_rate)

if not os.path.exists("epochs_NetME"):
        os.makedirs("epochs_NetME")

if not os.path.exists("tensorBoardRuns"):
        os.makedirs("tensorBoardRuns")

val_loss = []
val_psnr = []
val_sparsity = []
train_loss = []
train_psnr = []
train_sparsity = []

def train(epoch):
    epoch_loss = 0
    epoch_psnr = 0
    epoch_sparsity = 0

    #	train/eval modes make difference on batch normalization layer
    modelME.netM.train()
    modelME.netE.eval()

    #   Step up learning rate decay
    #   No learning rate decay here
    #   Learning rate of NetE is fixed to be 0
    #optimizer = optim.Adam([{'params': modelME.netM.parameters(), 'lr': learning_rate},
                            #{'params': modelME.netE.parameters(), 'lr': 0.0}
                            #], lr=learning_rate)

    for iteration, batch in tqdm(enumerate(train_loader, 1)):
        target, _ = batch
        image = target.clone()

        #	mean_image and std_image are used to compute loss
        mean_image = torch.zeros(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
        mean_image[:,0,:,:] = 0.5
        mean_image[:,1,:,:] = 0.5
        mean_image[:,2,:,:] = 0.5

        std_image = torch.zeros(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
        std_image[:,0,:,:] = 0.5
        std_image[:,1,:,:] = 0.5
        std_image[:,2,:,:] = 0.5

        if torch.cuda.is_available():
            image = image.cuda()
            target = target.cuda()
            mean_image = mean_image.cuda()
            std_image = std_image.cuda()

        optimizer.zero_grad()

        #   Generate the corruption mask and reconstructed image
        corrupt_mask_conti, image_recon = modelME(image)

        mask_sparsity = corrupt_mask_conti.sum() / (corrupt_mask_conti.shape[0] * corrupt_mask_conti.shape[1] * corrupt_mask_conti.shape[2] * corrupt_mask_conti.shape[3])

        loss = criterion((image_recon*std_image)+mean_image, (target*std_image)+mean_image)
        psnr = 10 * log10(1 / loss.item())
        epoch_loss += loss.item()
        epoch_psnr += psnr
        epoch_sparsity += mask_sparsity
        loss.backward()
        optimizer.step()
        
     # 🚀 Explicitly delete unused tensors
    del image, target, mean_image, std_image, corrupt_mask_conti, image_recon, loss
    torch.cuda.empty_cache()

    train_loss.append(epoch_loss / len(train_loader))
    train_psnr.append(epoch_psnr / len(train_loader))
    train_sparsity.append(epoch_sparsity.item() / len(train_loader))

    print("===> Epoch {} Complete: lr: {}, Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB, Mask Sparsity: {:.4f}".format(epoch, learning_rate, epoch_loss / len(train_loader), epoch_psnr / len(train_loader), epoch_sparsity / len(train_loader)))

    #log_value('train_loss', epoch_loss / len(train_loader), epoch)
    #log_value('train_psnr', epoch_psnr / len(train_loader), epoch)
    #log_value('train_sparsity', epoch_sparsity / len(train_loader), epoch)

PSNR_best = 0

def reshape_4D_array(array_4D, width_num):
    num, cha, height, width = array_4D.shape
    height_num = ceil(float(num) / width_num)
    total_width = width * width_num
    total_height = height * height_num
    target_array_4D = np.zeros((1, cha, total_height, total_width))
    for index in range(0, num):
        height_start = index//width_num
        width_start = index%width_num
        target_array_4D[:,:,height_start*height:height_start*height+height,width_start*width:width_start*width+width] = array_4D[index,:,:,:]
    return target_array_4D

def val(epoch):
    avg_psnr = 0
    avg_mse = 0
    avg_sparsity = 0

    modelME.eval()
    modelME.netM.eval()
    modelME.netE.eval()

    with torch.no_grad():
        for batch in val_loader:
            target, _ = batch
            image = target.clone()
            image_clone = image.clone()

            mean_image = torch.zeros(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
            mean_image[:,0,:,:] = 0.5
            mean_image[:,1,:,:] = 0.5
            mean_image[:,2,:,:] = 0.5

            std_image = torch.zeros(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
            std_image[:,0,:,:] = 0.5
            std_image[:,1,:,:] = 0.5
            std_image[:,2,:,:] = 0.5

            if torch.cuda.is_available():
                image = image.cuda()
                image_clone = image_clone.cuda()
                target = target.cuda()
                mean_image = mean_image.cuda()
                std_image = std_image.cuda()

            #   Generate the corruption mask and reconstructed image
            corrupt_mask_conti, _ = modelME(image)
            
            corrupt_mask = corrupt_mask_conti.bernoulli()   # Binarize the corruption mask using Bernoulli distribution, then feed into modelE
            mask_sparsity = corrupt_mask.sum() / (corrupt_mask.shape[0] * corrupt_mask.shape[1] * corrupt_mask.shape[2] * corrupt_mask.shape[3])
            corrupt_mask = corrupt_mask.expand(corrupt_mask.shape[0], 3, corrupt_mask.shape[2], corrupt_mask.shape[3])

            #   Generate the corrupted image
            mask_image = corrupt_mask * image_clone

            restored_image = modelME.netE(mask_image)

            mse = criterion((restored_image*std_image)+mean_image, (target*std_image)+mean_image)
            psnr = 10 * log10(1 / mse.item())
            avg_psnr += psnr
            avg_mse  += mse.item()
            avg_sparsity += mask_sparsity
        # 🚀 Free up memory after validation
        del image, target, mean_image, std_image, corrupt_mask_conti, mask_image, restored_image, mse
        torch.cuda.empty_cache()

    val_loss.append(avg_mse / len(val_loader))
    val_psnr.append(avg_psnr / len(val_loader))
    val_sparsity.append(avg_sparsity.item() / len(val_loader))


    print("===> Epoch {} Validation: Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB, Mask Sparsity: {:.4f}".format(epoch, avg_mse / len(val_loader), avg_psnr / len(val_loader), avg_sparsity / len(val_loader)))

    #log_value('val_loss', avg_mse / len(val_loader), epoch)
    #log_value('val_psnr', avg_psnr / len(val_loader), epoch)
    #log_value('val_sparsity', avg_sparsity / len(val_loader), epoch)

    #corrupt_mask_conti = corrupt_mask_conti.expand(corrupt_mask_conti.shape[0], 3, corrupt_mask_conti.shape[2], corrupt_mask_conti.shape[3])

    # scipy.misc is DEPRECATED and has no attribute toimage
    #log_images('original_image', reshape_4D_array((image*std_image+mean_image).cpu().numpy(), 10), step=1)
    #log_images('conti_mask', reshape_4D_array(corrupt_mask_conti.data.cpu().numpy(), 10), step=1)
    #log_images('binar_mask', reshape_4D_array(corrupt_mask.data.cpu().numpy(), 10), step=1)
    #log_images('restored_image', reshape_4D_array((restored_image*std_image+mean_image).data.cpu().numpy(), 10), step=1)


    global PSNR_best
    if avg_psnr > PSNR_best:
        PSNR_best = avg_psnr
        model_out_path = "epochs_NetME/" + "model_best.pth"
        torch.save(modelME.state_dict(), model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

def val_rand(epoch):
    avg_psnr = 0
    avg_mse = 0
    avg_sparsity = 0

    modelME.eval()
    modelME.netM.eval()
    modelME.netE.eval()

    with torch.no_grad():
        for batch in val_loader:
            target, _ = batch
            image = target.clone()

            mean_image = torch.zeros(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
            mean_image[:,0,:,:] = 0.5
            mean_image[:,1,:,:] = 0.5
            mean_image[:,2,:,:] = 0.5

            std_image = torch.zeros(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
            std_image[:,0,:,:] = 0.5
            std_image[:,1,:,:] = 0.5
            std_image[:,2,:,:] = 0.5

            #   Generate the random corruption mask
            corrupt_mask = torch.ones(image.shape[0], 1, image.shape[2], image.shape[3])
            corrupt_mask = corrupt_mask * sample_rate
            mask_sparsity = corrupt_mask.sum() / (corrupt_mask.shape[0] * corrupt_mask.shape[1] * corrupt_mask.shape[2] * corrupt_mask.shape[3])
            
            corrupt_mask = corrupt_mask.bernoulli()
            corrupt_mask = corrupt_mask.expand(corrupt_mask.shape[0], 3, corrupt_mask.shape[2], corrupt_mask.shape[3])

            if torch.cuda.is_available():
                image = image.cuda()
                target = target.cuda()
                mean_image = mean_image.cuda()
                std_image = std_image.cuda()
                corrupt_mask = corrupt_mask.cuda()

            #   Generate the corrupted image
            mask_image = corrupt_mask * image

            mse = criterion(( modelME.netE(mask_image)*std_image)+mean_image, (target*std_image)+mean_image)
            psnr = 10 * log10(1 / mse.item())
            avg_psnr += psnr
            avg_mse  += mse.item()
            avg_sparsity += mask_sparsity

    print("===> Epoch {} Random Validation: Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB, Mask Sparsity: {:.4f}".format(epoch, avg_mse / len(val_loader), avg_psnr / len(val_loader), avg_sparsity / len(val_loader)))

    #log_value('val_loss_rand', avg_mse / len(val_loader), epoch)
    #log_value('val_psnr_rand', avg_psnr / len(val_loader), epoch)
    #log_value('val_sparsity_rand', avg_sparsity / len(val_loader), epoch)

def checkpoint(epoch):
    if epoch%100 == 0:
        model_out_path = "epochs_NetME/" + "model_epoch_{}.pth".format(epoch)
        torch.save(modelME.state_dict(), model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

def cleanup_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()




val(0)
val_rand(0)
checkpoint(0)
for epoch in range(1, nEpochs + 1):
    train(epoch)
    val(epoch)
    val_rand(epoch)
    checkpoint(epoch)

    # 🚀 Call this after every epoch
    cleanup_memory()

In [None]:
import torch
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("CUDNN version:", torch.backends.cudnn.version())
print("GPU:", torch.cuda.get_device_name(0))



In [None]:
plot_loss(train_loss, val_loss[1:],"loss")
plot_loss(train_psnr, val_psnr[1:],"psnr")
plot_loss(train_sparsity, val_sparsity[1:],"sparsity")

In [None]:
import os

if os.path.exists("epochs_NetME/" + "model_best.pth"):
    modelME = NetME(nef = nef, NetE_name = 'epochs_NetE/' + modelE_name, sample_rate = sample_rate)
    modelME.load_state_dict(torch.load("epochs_NetME/" + "model_best.pth"))
    print("Model loaded")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFile

# Load the image
tensor,_ = val_dataset[np.random.randint(0,len(val_dataset)-1)]
original_np = tensor.permute(1, 2, 0).numpy()

# Convert to torch tensor and rearrange dimensions to (C, H, W)
image = torch.tensor(original_np).permute(2, 0, 1).unsqueeze(0)  # Add batch dimension

# convert to cuda
if torch.cuda.is_available():
    image = image.cuda()
    modelME = modelME.cuda()


# Generate the corruption mask and reconstructed image
modelME.eval()
modelME.netM.eval()
modelME.netE.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

corrupt_mask_conti, _ = modelME(image)

corrupt_mask = corrupt_mask_conti.bernoulli()   # Binarize the corruption mask using Bernoulli distribution, then feed into modelE

corrupt_mask = corrupt_mask.expand(corrupt_mask.shape[0], 3, corrupt_mask.shape[2], corrupt_mask.shape[3])

#   Generate the corrupted image
mask_image = corrupt_mask * image

restored_image = modelME.netE(mask_image)

# Convert corrupted image to tensor
corrupted_tensor = torch.tensor(mask_image).squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert back to (H, W, C)

restored_image = restored_image.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()  # Convert back to (H, W, C)

# Display results
fig, ax = plt.subplots(1, 3, figsize=(12, 4))

ax[0].imshow((original_np+1.0)/2.0)
ax[0].set_title("Original Image")
ax[0].axis("off")

ax[1].imshow((corrupted_tensor+1.0)/2.0)
ax[1].set_title("Genrated mask Image")
ax[1].axis("off")

ax[2].imshow((restored_image+1.0)/2.0)  # Clip values between 0-1
ax[2].set_title("Predicted Image")
ax[2].axis("off")

plt.show()


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFile

# Load the image
tensor,_ = val_dataset[np.random.randint(0,len(val_dataset)-1)]
original_np = tensor.permute(1, 2, 0).numpy()

# Convert to torch tensor and rearrange dimensions to (C, H, W)
image = torch.tensor(original_np).permute(2, 0, 1).unsqueeze(0)  # Add batch dimension

# convert to cuda
if torch.cuda.is_available():
    image = image.cuda()
    modelME = modelME.cuda()


# Generate the corruption mask and reconstructed image
modelME.eval()
modelME.netM.eval()
modelME.netE.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

corrupt_mask_conti, _ = modelME(image)

corrupt_mask = corrupt_mask_conti.bernoulli()   # Binarize the corruption mask using Bernoulli distribution, then feed into modelE
corrupt_mask = corrupt_mask.expand(corrupt_mask.shape[0], 3, corrupt_mask.shape[2], corrupt_mask.shape[3])

#   Generate the corrupted image
mask_image = corrupt_mask * image

restored_image = modelME.netE(mask_image)

# Convert corrupted image to tensor
corrupted_tensor = torch.tensor(mask_image).permute(2, 0, 1).unsqueeze(0)

restored_image = restored_image.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert back to (H, W, C)

# Display results
fig, ax = plt.subplots(1, 3, figsize=(12, 4))

ax[0].imshow((original_np+1.0)/2.0)
ax[0].set_title("Original Image")
ax[0].axis("off")

ax[1].imshow((corrupted_tensor+1.0)/2.0)
ax[1].set_title("Corrupted Image")
ax[1].axis("off")

ax[2].imshow((restored_image+1.0)/2.0)  # Clip values between 0-1
ax[2].set_title("Predicted Image")
ax[2].axis("off")

plt.show()


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageFile

# Load the image
tensor,_ = val_dataset[np.random.randint(0,len(val_dataset)-1)]
original_np = tensor.permute(1, 2, 0).numpy()

# Convert to torch tensor and rearrange dimensions to (C, H, W)
image = torch.tensor(original_np).permute(2, 0, 1).unsqueeze(0)  # Add batch dimension

# convert to cuda
if torch.cuda.is_available():
    image = image.cuda()
    modelME = modelME.cuda()


# Generate the corruption mask and reconstructed image
modelME.eval()
modelME.netM.eval()
modelME.netE.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

corrupt_mask_conti, _ = modelME(image)

corrupt_mask = corrupt_mask_conti.bernoulli()   # Binarize the corruption mask using Bernoulli distribution, then feed into modelE
corrupt_mask = corrupt_mask.expand(corrupt_mask.shape[0], 3, corrupt_mask.shape[2], corrupt_mask.shape[3])

#   Generate the corrupted image
mask_image = corrupt_mask * image

restored_image = modelME.netE(mask_image)

# Convert corrupted image to tensor
corrupted_tensor = torch.tensor(mask_image).permute(2, 0, 1).unsqueeze(0)

restored_image = restored_image.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert back to (H, W, C)

# Display results
fig, ax = plt.subplots(1, 3, figsize=(12, 4))

ax[0].imshow((original_np+1.0)/2.0)
ax[0].set_title("Original Image")
ax[0].axis("off")

ax[1].imshow((corrupted_tensor+1.0)/2.0)
ax[1].set_title("Corrupted Image")
ax[1].axis("off")

ax[2].imshow((restored_image+1.0)/2.0)  # Clip values between 0-1
ax[2].set_title("Predicted Image")
ax[2].axis("off")

plt.show()


In [None]:
print(mask_image.shape)

In [None]:
def reshape_4D_array(array_4D, width_num):
    num, cha, height, width = array_4D.shape
    height_num = num // width_num
    total_width = width * width_num
    total_height = height * height_num
    target_array_4D = np.zeros((1, cha, total_height, total_width))
    for index in range(0, num):
        height_start = index // width_num
        width_start = index % width_num
        target_array_4D[:, :, height_start * height:height_start * height + height, width_start * width:width_start * width + width] = array_4D[index, :, :, :]
    return target_array_4D

def val(epoch):
    avg_psnr = 0
    avg_mse = 0
    avg_sparsity = 0

    modelME.eval()
    modelME.netM.eval()
    modelME.netE.eval()

    for batch in val_loader:
        target, _ = batch
        image = target.clone()
        image_clone = image.clone()

        mean_image = torch.zeros(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
        mean_image[:,0,:,:] = 0.5
        mean_image[:,1,:,:] = 0.5
        mean_image[:,2,:,:] = 0.5

        std_image = torch.zeros(image.shape[0], image.shape[1], image.shape[2], image.shape[3])
        std_image[:,0,:,:] = 0.5
        std_image[:,1,:,:] = 0.5
        std_image[:,2,:,:] = 0.5

        if torch.cuda.is_available():
            image = image.cuda()
            image_clone = image_clone.cuda()
            target = target.cuda()
            mean_image = mean_image.cuda()
            std_image = std_image.cuda()

        # Generate the corruption mask and reconstructed image
        corrupt_mask_conti, _ = modelME(image)

        corrupt_mask = corrupt_mask_conti.bernoulli()   # Binarize the corruption mask using Bernoulli distribution, then feed into modelE
        mask_sparsity = corrupt_mask.sum() / (corrupt_mask.shape[0] * corrupt_mask.shape[1] * corrupt_mask.shape[2] * corrupt_mask.shape[3])
        corrupt_mask = corrupt_mask.expand(corrupt_mask.shape[0], 3, corrupt_mask.shape[2], corrupt_mask.shape[3])

        # Generate the corrupted image
        mask_image = corrupt_mask * image_clone

        restored_image = modelME.netE(mask_image)

        mse = criterion((restored_image*std_image)+mean_image, (target*std_image)+mean_image)
        psnr = 10 * log10(1 / mse.item())
        avg_psnr += psnr
        avg_mse  += mse.item()
        avg_sparsity += mask_sparsity

    print("===> Epoch {} Validation: Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB, Mask Sparsity: {:.4f}".format(epoch, avg_mse / len(val_loader), avg_psnr / len(val_loader), avg_sparsity / len(val_loader)))

    log_value('val_loss', avg_mse / len(val_loader), epoch)
    log_value('val_psnr', avg_psnr / len(val_loader), epoch)
    log_value('val_sparsity', avg_sparsity / len(val_loader), epoch)

    corrupt_mask_conti = corrupt_mask_conti.expand(corrupt_mask_conti.shape[0], 3, corrupt_mask_conti.shape[2], corrupt_mask_conti.shape[3])

    log_images('original_image', reshape_4D_array((image*std_image+mean_image).cpu().numpy(), 10), step=1)
    log_images('conti_mask', reshape_4D_array(corrupt_mask_conti.data.cpu().numpy(), 10), step=1)
    log_images('binar_mask', reshape_4D_array(corrupt_mask.data.cpu().numpy(), 10), step=1)
    log_images('restored_image', reshape_4D_array((restored_image*std_image+mean_image).data.cpu().numpy(), 10), step=1)

    global PSNR_best
    if avg_psnr > PSNR_best:
        PSNR_best = avg_psnr
        model_out_path = "epochs_NetME/" + "model_best.pth"
        torch.save(modelME.state_dict(), model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

# Initialize global PSNR_best
PSNR_best = 0

# Run the validation function for one epoch with random data
val(0)

In [None]:
import torch
import numpy as np
from math import ceil
def reshape_4D_array(array_4D, width_num):
    num, cha, height, width = array_4D.shape
    height_num = ceil(num / width_num)
    total_width = width * width_num
    total_height = height * height_num
    target_array_4D = np.zeros((1, cha, total_height, total_width))
    for index in range(0, num):
        height_start = index // width_num
        width_start = index % width_num
        target_array_4D[:, :, height_start * height:height_start * height + height, width_start * width:width_start * width + width] = array_4D[index, :, :, :]
    return target_array_4D



test = torch.zeros(3,3,64,64)
reshape_4D_array(test,10).shape