In [41]:
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms


In [42]:
IMG_HEIGHT, IMG_WIDTH = 128,128
CHANNELS = 3
BATCH_SIZE = 32
EPOCHS = 50
DATA_DIR1 = "train/real"
DATA_DIR2 = "train/comic"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [43]:
class PairedImageDataset(Dataset):
    def __init__(self, dir1, dir2):
        self.dir1 = dir1
        self.dir2 = dir2
        self.images = os.listdir(dir1)

        self.transform = transforms.Compose([
            transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
            transforms.ToTensor()   # [0,1]
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        name = self.images[idx]

        img1 = Image.open(os.path.join(self.dir1, name)).convert("RGB")
        img2 = Image.open(os.path.join(self.dir2, name)).convert("RGB")

        return self.transform(img1), self.transform(img2)


In [44]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU()
        )

        # Decoder
        self.up1 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, padding=1),
            nn.ReLU()
        )

        self.up2 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, padding=1),
            nn.ReLU()
        )

        self.out = nn.Sequential(
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.enc1(x)
        x = self.enc2(x)
        x = self.bottleneck(x)
        x = self.up1(x)
        x = self.dec1(x)
        x = self.up2(x)
        x = self.dec2(x)
        return self.out(x)


In [45]:
dataset = PairedImageDataset(DATA_DIR1, DATA_DIR2)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


In [46]:
model = UNet().to(DEVICE)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [47]:
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)

        pred = model(x)
        loss = criterion(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {epoch_loss/len(loader):.4f}")


Epoch 1/50 | Loss: 0.0611
Epoch 2/50 | Loss: 0.0578
Epoch 3/50 | Loss: 0.0512
Epoch 4/50 | Loss: 0.0509
Epoch 5/50 | Loss: 0.0341
Epoch 6/50 | Loss: 0.0247
Epoch 7/50 | Loss: 0.0259
Epoch 8/50 | Loss: 0.0173
Epoch 9/50 | Loss: 0.0233
Epoch 10/50 | Loss: 0.0172
Epoch 11/50 | Loss: 0.0178
Epoch 12/50 | Loss: 0.0202
Epoch 13/50 | Loss: 0.0169
Epoch 14/50 | Loss: 0.0145
Epoch 15/50 | Loss: 0.0154
Epoch 16/50 | Loss: 0.0161
Epoch 17/50 | Loss: 0.0146
Epoch 18/50 | Loss: 0.0126
Epoch 19/50 | Loss: 0.0125
Epoch 20/50 | Loss: 0.0129
Epoch 21/50 | Loss: 0.0115
Epoch 22/50 | Loss: 0.0111
Epoch 23/50 | Loss: 0.0117
Epoch 24/50 | Loss: 0.0110
Epoch 25/50 | Loss: 0.0107
Epoch 26/50 | Loss: 0.0114
Epoch 27/50 | Loss: 0.0109
Epoch 28/50 | Loss: 0.0104
Epoch 29/50 | Loss: 0.0106
Epoch 30/50 | Loss: 0.0101
Epoch 31/50 | Loss: 0.0099
Epoch 32/50 | Loss: 0.0099
Epoch 33/50 | Loss: 0.0095
Epoch 34/50 | Loss: 0.0094
Epoch 35/50 | Loss: 0.0092
Epoch 36/50 | Loss: 0.0088
Epoch 37/50 | Loss: 0.0087
Epoch 38/5

In [48]:
torch.save(model.state_dict(), "reddot_unet.pth")
print("Model saved as reddot_unet.pth")


Model saved as reddot_unet.pth


In [49]:
model = UNet().to(DEVICE)
model.load_state_dict(torch.load("reddot_unet.pth", map_location=DEVICE))
model.eval()


UNet(
  (enc1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (enc2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (bottleneck): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (up1): Upsample(scale_factor=2.0, mode='nearest')
  (dec1): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (up2): Upsample(scale_factor=2.0, mode='nearest')
  (dec2): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (out): Sequential(
    (0): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Sigm

In [50]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os

IMG_HEIGHT, IMG_WIDTH = 128,128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [51]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU()
        )

        self.up1 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, padding=1),
            nn.ReLU()
        )

        self.up2 = nn.Upsample(scale_factor=2, mode="nearest")
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, padding=1),
            nn.ReLU()
        )

        self.out = nn.Sequential(
            nn.Conv2d(64, 3, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.enc1(x)
        x = self.enc2(x)
        x = self.bottleneck(x)
        x = self.up1(x)
        x = self.dec1(x)
        x = self.up2(x)
        x = self.dec2(x)
        return self.out(x)


In [52]:
model = UNet().to(DEVICE)
model.load_state_dict(torch.load("reddot_unet.pth", map_location=DEVICE))
model.eval()
print("Model loaded")


Model loaded


In [53]:
transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor()  # [0,1]
])


In [54]:
@torch.no_grad()
def infer_image(input_path, output_path):
    img = Image.open(input_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(DEVICE)

    output = model(x)

    output_img = transforms.ToPILImage()(output.squeeze().cpu())
    output_img.save(output_path)

    print("Saved:", output_path)


In [55]:
infer_image(
    input_path="face.6.jpg",
    output_path="face2.png"
)


Saved: face2.png
