In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import matplotlib.pyplot as plt

In [2]:
class ComplexNetwork(nn.Module):
    def __init__(self, n_neurons_l1, n_neurons_l2):
        super(ComplexNetwork, self).__init__()
        self.conv1 = nn.Conv2d(3, n_neurons_l1, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(n_neurons_l1, n_neurons_l2, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(n_neurons_l2, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.tanh(x)
        return x

class Network(nn.Module):
    def __init__(self, n_neurons_l1):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(3, n_neurons_l1, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(n_neurons_l1, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = torch.tanh(x)
        return x
    

class ImageDataset(Dataset):
    def __init__(self, Xdata_folder, Ydata_folder, transform=None, device='cuda'):
        self.Xdata_folder = Xdata_folder
        self.Ydata_folder = Ydata_folder
        self.transform = transform
        self.image_list = os.listdir(Xdata_folder)
        self.device = device
        self.resize = transforms.Resize((300,512))

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_name = self.image_list[idx]
        Ximage_path = os.path.join(self.Xdata_folder, image_name)
        Yimage_path = os.path.join(self.Ydata_folder, image_name)
        imageX = Image.open(Ximage_path)
        imageY = Image.open(Yimage_path)


        if self.transform:
            imageX = self.transform(imageX)
            imageY = self.transform(imageY)

        # Move images to the desired device
        imageX = imageX.to(self.device)
        imageY = imageY.to(self.device)
        # Resize images to the target size
        imageX = self.resize(imageX)
        imageY = self.resize(imageY)
            

        return imageX, imageY
    
folder_path = "./Screen_data"
# Step 3: Load the data using DataLoader
X_data_folder = os.path.join(folder_path, "X_train")
Y_data_folder = os.path.join(folder_path, "Y_train")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose([
    #transforms.Resize((512, 300)),   # Resize the image to the desired size
    transforms.ToTensor()            # Convert image to PyTorch tensor
])

batch_size = 16
train_dataset = ImageDataset(X_data_folder, Y_data_folder, transform=data_transform, device=device)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


Xv_data_folder = os.path.join(folder_path, "X_val")
Yv_data_folder = os.path.join(folder_path, "Y_val")
val_dataset = ImageDataset(Xv_data_folder, Yv_data_folder, transform=data_transform, device=device)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

def get_vgg_features(images, vgg_model):
    to_pil = transforms.ToPILImage()  # Transform to convert tensor to PIL Image
    
    image_list = []
    for image in images:
        image_pil = to_pil(image.cpu())  # Convert tensor to PIL Image
        preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        image_tensor = preprocess(image_pil).unsqueeze(0)
        image_list.append(image_tensor)

    images_preprocessed = torch.cat(image_list, dim=0).to(images.device)
    
    # Extract features from VGG model
    features = vgg_model(images_preprocessed)
    
    return features

def perceptual_loss(input, target, vgg_model):
    input_features = get_vgg_features(input, vgg_model)
    target_features = get_vgg_features(target, vgg_model)
    
    loss = F.mse_loss(input_features, target_features)
    
    return loss

def tv_loss(adv_patch):
    # Calculate the total variation of the image
    # The input image should be a PyTorch tensor of shape (C, H, W)
    tv_h = torch.sum(torch.abs(adv_patch[:, :, 1:] - adv_patch[:, :, :-1]))
    tv_w = torch.sum(torch.abs(adv_patch[:, 1:, :] - adv_patch[:, :-1, :]))
    tv = (tv_h + tv_w) / torch.numel(adv_patch)
    return tv


def train(model, loss_fn, optimizer, train_loader, val_loader, n_epochs, p_loss_weight, tv_loss_weight, fine_tune=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    train_losses = []
    val_losses = []
    
    vgg_model = models.vgg19(pretrained=True).features[:9]  # Use a subset of VGG layers
    
    for param in vgg_model.parameters():
        param.requires_grad = False  # Freeze VGG layers

    if fine_tune:
        # Freeze specific layers (Example: Freeze conv1 fine-tune conv2)
        for name, param in model.named_parameters():
            if name in ['conv1.weight', 'conv1.bias']:
                param.requires_grad = False

    vgg_model.to(device)
    vgg_model.eval()

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0.0
        for i, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            output = model(x)
            
            mse_loss = loss_fn(output, y[:, :3, ...])
            p_loss = perceptual_loss(output, y[:, :3, ...], vgg_model)
            tv_pred = tv_loss(output)
            tv_y = tv_loss(y[:, :3, ...])
            tv_diff_loss = torch.abs(tv_pred - tv_y)
            
            loss = mse_loss + p_loss_weight *p_loss + tv_loss_weight * tv_diff_loss
            """print('MSE: ',mse_loss)
            print('PRC: ',p_loss)
            print('TV: ',tv_diff_loss)"""
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        total_loss /= len(train_loader)
        train_losses.append(total_loss)


        # Calculate validation loss and append to list
        val_loss = 0.0
        model.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            for i, (x, y) in enumerate(val_loader):
                x = x.to(device)
                y = y.to(device)
                output = model(x)
            
                mse_loss = loss_fn(output, y[:, :3, ...])
                p_loss = perceptual_loss(output, y[:, :3, ...], vgg_model)
                tv_pred = tv_loss(output)
                tv_y = tv_loss(y[:, :3, ...])
                tv_diff_loss = torch.abs(tv_pred - tv_y)
                
                v_loss = mse_loss + p_loss_weight *p_loss + tv_loss_weight * tv_diff_loss

                val_loss += v_loss.item()
            val_loss /= len(val_loader)
            val_losses.append(val_loss)
            
            

        print("Epoch %d, Loss: %.4f" % (epoch + 1, total_loss / len(train_loader)))
    return train_losses, val_losses


In [None]:
network_type = 'simple' #'complex' # 
n_epochs = 50
if network_type == 'complex':
    n_neurons_l1 = 128
    n_neurons_l2 = 64
    model = ComplexNetwork(n_neurons_l1, n_neurons_l2).to(device)
elif network_type == 'simple':
    n_neurons_l1 = 64
    model = Network(n_neurons_l1).to(device)

loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
tv_loss_weight = 0.01 # Adjust this weight according to your preferences
p_loss_weight = 0.02
# Assuming a DataLoader named 'train_loader'
train_losses, val_losses = train(model, loss_fn, optimizer, train_loader, val_loader, n_epochs, p_loss_weight, tv_loss_weight)



In [None]:
# Plot training and validation losses
plt.figure(figsize=(10, 6))
plt.plot(range(1, n_epochs+1), train_losses, label='Train Loss', linewidth=3)
plt.plot(range(1, n_epochs+1), val_losses, label='Validation Loss', linewidth=3)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Loss', fontsize=20)
#plt.xticks(range(1, n_epochs+1))
#plt.title('Model Loss over Epochs', fontsize=14)
plt.legend(prop={'size': 20})
plt.tick_params(axis='both', which='major', labelsize=18)
# Setting x and y limits
plt.xlim(0, 51)
plt.ylim(0, 0.08)
plt.grid()
# To save as a PDF (vectorized format)
plt.savefig(os.path.join(folder_path, 'SIT-Net_train_Plot.pdf'), format='pdf', dpi=300)

plt.show()

In [4]:
Xtest_data_folder = os.path.join(folder_path, "X_test")
Ytest_data_folder = os.path.join(folder_path, "Y_test")
test_dataset = ImageDataset(Xtest_data_folder, Ytest_data_folder, transform=data_transform, device=device)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
from torchvision.transforms import ToPILImage



# Combine X and Y images using make_grid

batch_X, batch_Y = next(iter(test_loader))
brightness = torch.FloatTensor(batch_size, 1, 1, 1).uniform_(-0.1, 0.1).to('cuda')
contrast = torch.cuda.FloatTensor(batch_size, 1, 1, 1).uniform_(0.8,1.2)
noise = torch.cuda.FloatTensor(batch_X.size()).uniform_(-1, 1) * 0.1
batch_R = model(batch_X)

num_images = 10#batch_X.shape[0]

plt.figure(figsize=(5*int(num_images), 11))
plt.subplots_adjust(hspace=0.4, wspace=0.4)
for i in range(num_images):
    # Plot X image
    plt.subplot(3, num_images, i + 1)
    img_pil_X = ToPILImage()(batch_X[i])
    plt.imshow(img_pil_X)
    plt.axis('off')
    if i == 0:
        plt.title("X Images", fontsize=20, loc='left')
    
    # Plot Y image
    plt.subplot(3, num_images, num_images + i + 1)
    img_pil_Y = ToPILImage()(batch_Y[i])
    plt.imshow(img_pil_Y)
    plt.axis('off')
    if i == 0:
        plt.title("Y Images", fontsize=20, loc='left')

    # Plot res image
    plt.subplot(3, num_images, 2*num_images + i + 1)
    img_pil_R = ToPILImage()(batch_R[i])
    plt.imshow(img_pil_R)
    plt.axis('off')
    if i == 0:
        plt.title("Prediction", fontsize=20, loc='left')
plt.suptitle('Model Predictions vs Actual Data', fontsize=30, y=0.98)
plt.tight_layout()
plt.savefig(os.path.join(folder_path, 'Output.pdf'), format='pdf', bbox_inches='tight', dpi=300)
plt.show()

In [6]:
# Save the trained model as a .pt file
model_path = os.path.join(folder_path, f"screen_{network_type}_model_SD_closet.pt")
torch.save(model, model_path)