In [3]:
import os
import torch
import numpy as np
from matplotlib import pyplot as plt

In [12]:
from torch.utils.data import Dataset
import torchvision.transforms as T
import PIL.Image as Image
import numpy as np
import torch

class CustomDataset(Dataset):
    def __init__(self, img_list_file='/mnt/data/ILSVRC/Data/train_img_paths.txt', transforms=None, std=None, crop_size=224, device='cpu') -> None:
        super().__init__()
        with open(img_list_file, 'r') as f:
            self.img_paths = f.readlines()
        self.crop_size = crop_size
        self.transform = transforms or  T.Compose([
                T.ToTensor(),
                T.CenterCrop(size=224)
            ])
        self.std = std
        self.device = device
    
    def __add_noise__(self, img):
        std = self.std or np.random.uniform(0, 1)
        img = np.array(img)

        noise = np.random.normal(0, std, img.shape)
        noisy_img = np.clip(img + noise, 0, 255)
        return noisy_img
    
    def __get_img__(self, idx):
        img_path = self.img_paths[idx].strip()
        img = Image.open(img_path).convert("L")
        noisy_img = self.__add_noise__(img)

        img = self.transform(img)
        noisy_img = self.transform(noisy_img)

        return img, noisy_img
    
    def __len__(self):
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_1, noisy_1 = self.__get_img__(idx)

        # get a random image
        target_idx = np.random.randint(0, len(self.img_paths))
        img_2, noisy_2 = self.__get_img__(target_idx)

        noisy = torch.cat((noisy_1, noisy_2)).to(torch.float)
        clean = torch.cat((img_1, img_2), dim=0).to(torch.float)
        return noisy.to(device=self.device), clean.to(device=self.device)

imagenet_path = '/mnt/data/ILSVRC/Data'
train_path = os.path.join(imagenet_path, "train")
image_list = os.path.join(imagenet_path, "train_img_paths.txt")
dataset = CustomDataset(image_list, device='cuda')
data = next(iter(dataset))


In [None]:
import torch
from torch.optim import Adam
from torch.nn import MSELoss
from torch.utils.data import DataLoader
from tqdm import tqdm

# from dataset import CustomDataset
from smaller_model import UNet

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32
dataset = CustomDataset(device=device)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


model = UNet().to(device=device)

lr = 1e-3
optim = Adam(model.parameters(), lr=lr)
loss_fn = MSELoss()
log_every = 10

n_epochs = 1

for epoch in range(n_epochs):
    for noisy, clean in tqdm(dataloader):
        optim.zero_grad()
        pred = model(noisy.to(torch.float))

        loss = loss_fn(pred, clean)
        loss.backward()
        optim.step()

        if epoch > 0 and epoch % log_every == 0:
            print(f'Loss: {loss.item()}')