In [25]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np


In [26]:
class MoireDataset(Dataset):
    def __init__(self, moire_folder, transform=None):
        self.moire_folder = moire_folder
        self.transform = transform
        # List all image files
        self.image_files = [f for f in os.listdir(moire_folder) if f.endswith('_moire.jpg')]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        moire_filename = self.image_files[idx]
        gt_filename = moire_filename.replace('_moire.jpg', '_gt.jpg')
        
        moire_path = os.path.join(self.moire_folder, moire_filename)
        gt_path = os.path.join(self.moire_folder, gt_filename)
        
        # Load images
        moire_image = Image.open(moire_path).convert('RGB')
        gt_image = Image.open(gt_path).convert('RGB')
        
        if self.transform:
            moire_image = self.transform(moire_image)
            gt_image = self.transform(gt_image)
        
        return moire_image, gt_image


In [27]:
class MoireCNN(nn.Module):
    def __init__(self):
        super(MoireCNN, self).__init__()
        
        # Encoder part (downsampling)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        
        # Decoder part (upsampling)
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)  # Adjusted kernel size
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)  # Adjusted kernel size
        self.deconv3 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)    # Adjusted kernel size

        # ReLU activation
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        
        # Upsampling layer to scale up the final output to 256x256
        self.upsample = nn.Upsample(size=(256, 256), mode='bilinear', align_corners=False)
    
    def forward(self, x):
        # Encoder path
        x = self.relu(self.conv1(x))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        
        # Decoder path (upsample)
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        x = self.deconv3(x)  # Final output layer to match 3 channels
        
        # Upsample to 256x256
        x = self.upsample(x)
        
        return x


In [28]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])


In [29]:
# Set your dataset directory
moire_folder = 'Dataset/train/train/pair_00/'

# Create the dataset and dataloader
train_dataset = MoireDataset(moire_folder, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)


In [30]:
# Setup device (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the model
model = MoireCNN().to(device)

# Loss function and optimizer
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for moire_img, gt_img in train_loader:
        moire_img, gt_img = moire_img.to(device), gt_img.to(device)
        
        # Forward pass
        output = model(moire_img)
        
        # Print output size to debug
        print(f"Output shape: {output.shape}")
        print(f"Ground truth shape: {gt_img.shape}")
        
        # Calculate loss
        loss = criterion(output, gt_img)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")


Output shape: torch.Size([16, 3, 256, 256])
Ground truth shape: torch.Size([16, 3, 256, 256])
Output shape: torch.Size([16, 3, 256, 256])
Ground truth shape: torch.Size([16, 3, 256, 256])
Output shape: torch.Size([16, 3, 256, 256])
Ground truth shape: torch.Size([16, 3, 256, 256])
Output shape: torch.Size([16, 3, 256, 256])
Ground truth shape: torch.Size([16, 3, 256, 256])
Output shape: torch.Size([16, 3, 256, 256])
Ground truth shape: torch.Size([16, 3, 256, 256])
Output shape: torch.Size([16, 3, 256, 256])
Ground truth shape: torch.Size([16, 3, 256, 256])
Output shape: torch.Size([16, 3, 256, 256])
Ground truth shape: torch.Size([16, 3, 256, 256])
Output shape: torch.Size([16, 3, 256, 256])
Ground truth shape: torch.Size([16, 3, 256, 256])
Output shape: torch.Size([16, 3, 256, 256])
Ground truth shape: torch.Size([16, 3, 256, 256])
Output shape: torch.Size([10, 3, 256, 256])
Ground truth shape: torch.Size([10, 3, 256, 256])
Epoch [1/10], Loss: 0.4110
Output shape: torch.Size([16, 3, 

In [32]:
torch.save(model.state_dict(), 'moire_model.pth')
