In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import numpy as np

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [30]:
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1):
        super(UNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        return x

class DenoisingUNet(nn.Module):
    def __init__(self, in_channels, out_channels, init_features=64, kernel_size=3, padding=1):
        super(DenoisingUNet, self).__init__()
        
        features = init_features
        self.enc1 = UNetBlock(in_channels, features, kernel_size, padding)
        self.enc2 = UNetBlock(features, features * 2, kernel_size, padding)
        self.enc3 = UNetBlock(features * 2, features * 4, kernel_size, padding)
        
        self.bottleneck = UNetBlock(features * 4, features * 8, kernel_size, padding)
        
        self.dec3 = UNetBlock(features * 8 + features * 4, features * 4, kernel_size, padding)
        self.dec2 = UNetBlock(features * 4 + features * 2, features * 2, kernel_size, padding)
        self.dec1 = UNetBlock(features * 2 + features, features, kernel_size, padding)
        
        self.final_conv = nn.Conv2d(features, out_channels, kernel_size=1)
        
    def forward(self, x):
        # Encoder
        enc1_out = self.enc1(x)
        enc2_out = self.enc2(F.max_pool2d(enc1_out, 2))
        enc3_out = self.enc3(F.max_pool2d(enc2_out, 2))
        
        # Bottleneck
        bottleneck_out = self.bottleneck(F.max_pool2d(enc3_out, 2))
        
        # Decoder
        dec3_out = self.dec3(torch.cat([F.interpolate(bottleneck_out, scale_factor=2, mode='bilinear', align_corners=True), enc3_out], dim=1))
        dec2_out = self.dec2(torch.cat([F.interpolate(dec3_out, scale_factor=2, mode='bilinear', align_corners=True), enc2_out], dim=1))
        dec1_out = self.dec1(torch.cat([F.interpolate(dec2_out, scale_factor=2, mode='bilinear', align_corners=True), enc1_out], dim=1))
        
        # Output layer
        output = self.final_conv(dec1_out)
        return output


In [31]:
# Example usage with different internal dimensions, kernel size, and padding
model = DenoisingUNet(in_channels=1, out_channels=1, init_features=32, kernel_size=3, padding=1)

input = torch.randn(1, 1, 128, 128)  # Example input tensor
output = model(input)
print(output.shape)  # Should output torch.Size([1, 1, 128, 128])


torch.Size([1, 1, 128, 128])


In [34]:
# A dataset for the latent space and the noised latent representations.
# Haven't created dataset yet, code below this point is a draft

class DenoiseDataset(Dataset):
    def __init__(self, loud_latents, latents):
        self.loud_latents = loud_latents
        self.latents = latents

    def __len__(self):
        return len(self.loud_latents)

    def __getitem__(self, idx):
        noisy = self.loud_latents[idx]
        clean = self.latents[idx]
        return noisy, clean

# 'loud_latents' and 'latents' are tensors of noisy and clean latent spaces
dataset = DenoiseDataset(loud_latents, latents)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [36]:
def train_Unet(model, epochs):
    
    # define optimizer and loss metric
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
    
    for epoch in range(epochs):
    
        model.train()
        running_loss = 0.0

        for loud_latents, latents in dataloader:
            
            # Forward pass
            outputs = model(loud_latents)
            loss = criterion(outputs, latents)

            # Optimizer Step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * loud_latents.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss:.4f}")

    # Save the model every few epochs
    if epoch%10 == 0: 
        torch.save(model.state_dict(), 'denoising_unet.pth')
    