In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

# ------------------------------
# 1. Simple U-Net-like Model (AICNet-inspired)
# ------------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNet, self).__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(128, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.final = nn.Conv2d(64, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        b = self.bottleneck(self.pool2(e2))
        d2 = self.up2(b)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        return self.final(d1)

# ------------------------------
# 2. Dataset with a Single Image Pair
# ------------------------------
class SingleImageDataset(Dataset):
    def __init__(self, input_path, target_path, transform=None):
        self.input_img = Image.open(input_path).convert("RGB")
        self.target_img = Image.open(target_path).convert("RGB")
        self.transform = transform

    def __len__(self):
        return 1  # only one sample

    def __getitem__(self, idx):
        if self.transform:
            input_tensor = self.transform(self.input_img)
            target_tensor = self.transform(self.target_img)
        else:
            to_tensor = transforms.ToTensor()
            input_tensor = to_tensor(self.input_img)
            target_tensor = to_tensor(self.target_img)
        return input_tensor, target_tensor

# ------------------------------
# 3. Training Script
# ------------------------------
def train_single_image(input_path, target_path, save_path="model.pth", epochs=10):
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    dataset = SingleImageDataset(input_path, target_path, transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = UNet(in_channels=3, out_channels=3).to(device)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

    torch.save(model.state_dict(), save_path)
    print(f"Model saved to {save_path}")

# ------------------------------
# 4. Run Training
# ------------------------------
if __name__ == "__main__":
    # Replace these with your image paths
    input_image_path = "images.jpeg"
    target_image_path = "images.jpeg"

    if not os.path.exists(input_image_path) or not os.path.exists(target_image_path):
        print("❌ Please put your input.jpg and target.jpg in this folder first!")
    else:
        train_single_image(input_image_path, target_image_path, save_path="unet_model.pth", epochs=20)


Epoch [1/20], Loss: 0.2476
Epoch [2/20], Loss: 0.2107
Epoch [3/20], Loss: 0.1323
Epoch [4/20], Loss: 0.0204
Epoch [5/20], Loss: 0.1381
Epoch [6/20], Loss: 0.0140
Epoch [7/20], Loss: 0.0381
Epoch [8/20], Loss: 0.0512
Epoch [9/20], Loss: 0.0488
Epoch [10/20], Loss: 0.0362
Epoch [11/20], Loss: 0.0206
Epoch [12/20], Loss: 0.0117
Epoch [13/20], Loss: 0.0169
Epoch [14/20], Loss: 0.0281
Epoch [15/20], Loss: 0.0276
Epoch [16/20], Loss: 0.0177
Epoch [17/20], Loss: 0.0114
Epoch [18/20], Loss: 0.0124
Epoch [19/20], Loss: 0.0161
Epoch [20/20], Loss: 0.0184
Model saved to unet_model.pth


In [3]:
import torch

# import the UNet class from your training script
from train_unet import UNet   # 👈 change if your file name is different

# initialize the model
model = UNet(in_channels=3, out_channels=3)

# load trained weights
model.load_state_dict(torch.load("unet_model.pth", map_location="cpu"))
model.eval()


UNet(
  (enc1): DoubleConv(
    (block): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
    )
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (enc2): DoubleConv(
    (block): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
    )
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (bottleneck): DoubleConv(
    (block): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
    )
 

In [4]:
dummy_input = torch.randn(1, 3, 128, 128)  # (batch=1, channels=3, H=128, W=128)


In [6]:
!pip install onnx onnxruntime

torch.onnx.export(
    model,                         # model
    dummy_input,                   # dummy input
    "unet_model.onnx",             # output file name
    export_params=True,            # store trained weights inside ONNX file
    opset_version=11,              # ONNX opset version (11 is widely supported)
    do_constant_folding=True,      # optimize constant folding
    input_names=['input'],         # input tensor name
    output_names=['output'],       # output tensor name
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
print("✅ Exported to unet_model.onnx")


Collecting onnx
  Downloading onnx-1.19.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.0 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.22.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.19.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m61.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.22.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m54.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 

  torch.onnx.export(


✅ Exported to unet_model.onnx
