In [1]:
import torch
import torchvision
from torch import nn
from torch import optim
from torchvision import datasets
from torchvision.transforms import v2 as T
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
def add_noise(x, steps: int = 50):
    blur = T.GaussianNoise()
    for _ in range(steps):
        x = blur(x)
    return x

In [3]:
class Denoise(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        # self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        # self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        # self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        # x = self.relu(self.conv2(x))
        # x = self.relu(self.conv3(x))
        # x = self.relu(self.conv4(x))
        x = self.conv5(x)
        return x

In [4]:
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
transform = T.Compose([
    T.ToImage(),
    T.Resize(size=(32, 32), antialias=True),
    T.ToDtype(torch.float32, scale=True),
    # T.Normalize(mean=mean, std=std),
])

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform):
        super().__init__()
        self.root_dir = root_dir
        self.transform = transform

train_data = datasets.CIFAR10(
    root="data",
    download=True,
    transform=transform
)

train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True)

In [6]:
def plot_image(tensor):
    plt.figure(figsize=(1.25, 1.25))
    plt.axis('off')
    plt.imshow(tensor.permute(1, 2, 0).numpy())

In [7]:
# TODO fix after implementing norm if needed

def rev_transform(image, mean=mean, std=std):
    mean = mean.view(1, 3, 1, 1)
    std = std.view(1, 3, 1, 1)

    fixed_image = (image * std + mean) * 255
    return fixed_image

In [8]:
# plot_image(rev_transform([train_data[7][0]]))

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Denoise()
model.to(device)
criterion = nn.MSELoss()
optimizer = optim.AdamW(params=model.parameters(), lr=1e-4, weight_decay=1e-3)

In [None]:
num_epochs = 10
for _ in tqdm(range(num_epochs)):

    model.train()
    for batch in train_loader:
        batch = batch[0]
        batch.to(device)
        noisy_batch = add_noise(batch)
            
        output = model(noisy_batch)
        loss = criterion(output, batch)
        print(loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


In [None]:
plot_image(model(torch.randn((1, 3, 32, 32))))