# Imports


In [1]:
import numpy as np
import matplotlib.pyplot as plt
from preprocess_images import data_from_folder
from tqdm import tqdm
from math import log

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as T
from torchsummary import summary
import cv2 
import wandb
from config import LMDB_USE_COMPRESSION

import lmdb
import os
import msgpack
import lz4.frame

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [2]:
def imname_to_target(name:str) -> tuple[float]:
    """Parses image names of format x{x_value}_y{y_value}.jpg"""
    name = name.split('.jpg')[0]
    x, y = name.split("_")
    x = float(x[1:])
    y = float(y[1:5])
    return x, y

def save_model(model:torch.nn.Module, fname="best_model.pth", path="./saved_models/real"):
    torch.save(model.state_dict(), os.path.join(path,fname))

def load_model(model:torch.nn.Module, fname="best_model.pth", path="./saved_models/real"):
    model.load_state_dict(torch.load(os.path.join(path,fname), weights_only=False))
    return model

# Config


In [3]:
config = {
    "experiment_name": "004step_ConfigCNN_DarkOnly512_lmdb_800bs_0001lr_aug+",
    "batch_size": 800,
    "lr": 0.001,
    "lr_scheduler_loop": 7,
    "epochs": 28,
    "use_amp": False,

    "data_folder": "/mnt/h/real_512_0_001step.lmdb",
    # "data_folder": "/mnt/e/color.lmdb",
    "dataset_type": "LMDBImageDataset",
    "dataset_config_flatten": False,
    "dataset_train_keys_fname": "004_dark_train.txt",
    "dataset_val_keys_fname": "004_dark_val.txt",
    "dataset_offload_count": 0,

    "use_noise_transform": True,
    "noise_level": 0.1,
    "use_jitter_transform": True,
    "jitter_brightness": 0.4, 
    "jitter_contrast": 0.1, 
    "jitter_saturation": 0.1, 
    "jitter_hue": 0.2,

    "use_grayscale_transform": False,
    "use_clahegrad_transform": False,
    "clahe_clip_limit": 0.001,
    "clahe_gaussian_size": 15,
    "clahe_gaussian_sigma": 5,

    "use_high_pass_transform": False,
    "high_pass_transform_t": 0.35,

    "data_collection_step": 0.001,
    "starting_checkpoint_fname": None,
    "checkpoint_folder": "./saved_models/real",

    "gradient_layer_kernel_size": 15,
    "gradient_layer_sigma": 5,

    "use_weight_initialization": False,
    "init_red_filter": False
}

# Data

In [None]:


class FilesImageDataset(Dataset):
    def __init__(self, data_dir, filenames):
        self.data_dir = data_dir
        self.filenames = filenames
        self.targets = [imname_to_target(s) for s in filenames]

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        fname = self.filenames[index]
        image = cv2.imread(os.path.join(self.data_dir, fname), cv2.IMREAD_GRAYSCALE)
        label = self.targets[index]  # Get corresponding tuple
        
        # Convert image to Tensor, flatten, and normalize [0,1]
        image = torch.from_numpy(image).flatten()
        image = image.float() / 255.0
        
        # Convert label tuple to Tensor
        label = torch.tensor(label, dtype=torch.float32)

        return image, label

class InMemoryLMDBImageDataset(Dataset):
    def __init__(self, data_folder_path, transforms=None, keys_fname="keys.txt", flatten_data=True, turn_to_grayscale=True):
        self.keys = None

        # Data augmentation
        self.transforms = transforms

        # Read text keys from file
        with open(os.path.join(data_folder_path, keys_fname)) as f:
            self.keys = f.readlines()
            if self.keys[-1] == '':
                self.keys = self.keys[:-1]
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].replace("\n", "")

        # Get labels from text keys
        self.labels = []
        for i, key in enumerate(self.keys):
            try:
                label = imname_to_target(key)

                # Convert label tuple to Tensor
                x, y = label
                x = (x + 2) / 5.7
                y = (y + 2) / 4
                label = (x, y)
                label = torch.tensor(label, dtype=torch.float32)
                self.labels.append(label)
            except Exception as e:
                print("i:", i)
                print("name:", key)
                raise e
            
        # Encode keys
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].encode()

        # Load images
        self.env = lmdb.open(data_folder_path, readonly=True, create=False, lock=False, readahead=False, meminit=False)
        self.txn = self.env.begin()

        self.images = [None]*len(self.keys)
        self.loaded_indexes = set()
        self.flatten_data = flatten_data
        self.turn_to_grayscale = turn_to_grayscale

    def __len__(self):
        return len(self.keys)

    def get_index(self, key):
        for i, k in enumerate(self.keys):
            if k == key:
                return i
        
        return None
    
    def __getitem__(self, index):
        label = self.labels[index]

        if index in self.loaded_indexes:
            img = self.images[index]     
        else:
            key = self.keys[index]
            img_bytes = self.txn.get(key)
        
            if img_bytes is None:
                raise KeyError(f"Image {key} not found in LMDB!")

            img = np.array(msgpack.unpackb(img_bytes, raw=False), dtype=np.uint8)
            if self.turn_to_grayscale:
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

            img = np.array(img, dtype=np.float32)
            img = torch.from_numpy(img)
            img = img / 255.0
            if not self.turn_to_grayscale:
                img = img.permute(2, 0, 1)
            else:
                img = torch.unsqueeze(img, 0)
            self.images[index] = img
            self.loaded_indexes.add(index)

        # DEBUG
        # print(img.shape)

        # Augmenation
        if self.transforms is not None:
            img = self.transforms(img)
        if self.flatten_data:
            img = img.flatten().float()
            self.debug_msg = f"image shape {img.shape}"
        elif isinstance(img, np.ndarray):
            img = torch.unsqueeze(torch.from_numpy(img), 0)

        return img, label

class LMDBImageDataset(Dataset):
    def __init__(self, lmdb_path, transforms=None, keys_fname="keys.txt", flatten_data=True):
        self.keys = None

        # Data augmentation
        self.transforms = transforms

        # Read text keys from file
        with open(os.path.join(lmdb_path, keys_fname)) as f:
            self.keys = f.readlines()
            if self.keys[-1] == '':
                self.keys = self.keys[:-1]
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].replace("\n", "")

        # Get labels from text keys
        self.labels = []
        # self.labels = [imname_to_target(key) for key in self.keys]
        for i, key in enumerate(self.keys):
            try:
                self.labels.append(imname_to_target(key))
            except Exception as e:
                print("i:", i)
                print("name:", key)
                raise e

        # Encode keys
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].encode()

        self.lmdb_path = lmdb_path
        self.flatten_data = flatten_data

    def open_lmdb(self):
        self.env = lmdb.open(self.lmdb_path, readonly=True, create=False, lock=False, readahead=False, meminit=False)
        self.txn = self.env.begin()

    def close(self):
        self.env.close()

    def __len__(self):
        return len(self.keys)

    def get_index(self, key):
        for i, k in enumerate(self.keys):
            if k == key:
                return i
        
        return None
    
    def __getitem__(self, index):
        if not hasattr(self, 'txn'):
            print("Opening lmdb txn")
            self.open_lmdb()
        key = self.keys[index]  # Get corresponding tuple
        label = self.labels[index]
        
        img_bytes = self.txn.get(key)
        
        if img_bytes is None:
            raise KeyError(f"Image {key} not found in LMDB!")

        if LMDB_USE_COMPRESSION:
            img_bytes = lz4.frame.decompress(img_bytes)

        image = np.array(msgpack.unpackb(img_bytes, raw=False), dtype=np.uint8)
        

        # Convert image to Tensor, and normalize [0,1]
        
        image = torch.from_numpy(image).float()

        if len(image.shape)==2:
            image = image.unsqueeze(0)
        elif len(image.shape)==3:
            image = image.permute(2, 0, 1)
        image = image / 255.0

        # Augmenation
        if self.transforms is not None:
            image = self.transforms(image)
            # print(f"image shaep after transforms: {image.shape}")
        if self.flatten_data:
            image = image.flatten().float()
            self.debug_msg = f"image shape {image.shape}"
        elif isinstance(image, np.ndarray):
            image = torch.unsqueeze(torch.from_numpy(image), 0)
            # print(image.shape)

        # Convert label tuple to Tensor
        x, y = label
        x = (x + 2) / 5.7
        y = (y + 2) / 4
        label = (x, y)
        label = torch.tensor(label, dtype=torch.float32)

        return image, label

In [7]:
tarr = []

if config["use_jitter_transform"]:
    tarr.append(
        T.ColorJitter(
            config["jitter_brightness"],
            config["jitter_contrast"],
            # config["jitter_saturation"],
            # config["jitter_hue"]
        )
    )

if config["use_noise_transform"]:
    tarr.append(
        T.GaussianNoise(sigma=config["noise_level"]),
    )

varr = []
"""if config["use_clahegrad_transform"]:
    tarr.append(CLAHEGradTransform())
    varr.append(CLAHEGradTransform())
"""
if config["use_grayscale_transform"]:
    tarr.append(T.Grayscale())
    varr.append(T.Grayscale())
"""
if config["use_high_pass_transform"]:
    tarr.append(HighPassTransform(config["high_pass_transform_t"]))
    varr.append(HighPassTransform(config["high_pass_transform_t"]))"""

train_transforms = T.Compose(tarr)
val_transforms = T.Compose(varr) if len(varr)>0 else None


In [8]:
train_transforms

Compose(
      ColorJitter(brightness=(0.6, 1.4), contrast=(0.9, 1.1))
      GaussianNoise(mean=0.0, sigma=0.1, clip=True)
)

In [9]:
match config["dataset_type"]:
    case "LMDBImageDataset":
        train_dataset = LMDBImageDataset(config["data_folder"], transforms=train_transforms, flatten_data=config["dataset_config_flatten"], keys_fname=config["dataset_train_keys_fname"])
        val_dataset = LMDBImageDataset(config["data_folder"], transforms=val_transforms, flatten_data=config["dataset_config_flatten"], keys_fname=config["dataset_val_keys_fname"])
    # case "FilesImageDataset":
    #     train_dataset = FilesImageDataset(config["data_folder"], FilesImageDataset_fnames)
    #     val_dataset = FilesImageDataset(config["data_folder"], FilesImageDataset_fnames)
    case "InMemoryImageDataset":
        train_dataset = InMemoryLMDBImageDataset(config["data_folder"], transforms=train_transforms, flatten_data=config["dataset_config_flatten"], keys_fname=config["dataset_train_keys_fname"])
        val_dataset = InMemoryLMDBImageDataset(config["data_folder"], transforms=val_transforms, flatten_data=config["dataset_config_flatten"], keys_fname=config["dataset_val_keys_fname"])
    case _ :
        raise("Wrong dataset type")
train_data_loader = DataLoader(train_dataset, 
                         batch_size=config["batch_size"], 
                         shuffle=True, 
                         num_workers=8, 
                         pin_memory=True, 
                         prefetch_factor=4, 
                         persistent_workers=True
                        )
val_data_loader = DataLoader(val_dataset,
                             batch_size=config["batch_size"],
                             shuffle=False,
                             num_workers=4,
                             persistent_workers=True,
                             pin_memory=True
                            )

In [10]:
print(train_dataset[0][0].shape)
print(len(train_dataset))
print(train_dataset.transforms)

Opening lmdb txn
torch.Size([1, 512, 512])
11440
Compose(
      ColorJitter(brightness=(0.6, 1.4), contrast=(0.9, 1.1))
      GaussianNoise(mean=0.0, sigma=0.1, clip=True)
)


In [11]:
train_dataset.keys[0]

b'x0.40_y-1.96.jpg'

In [12]:
# sample, _ = train_dataset[train_dataset.get_index(b'x0.00_y0.00.jpg')]
# sample = sample.squeeze().cpu().numpy().reshape((512, 512))

# print(sample.shape)
# plt.imsave("x0.00_y0.00.jpg", sample, cmap="gray")

In [13]:
# fig, axes = plt.subplots(3,1)
# print("Original")
# axes[0].imshow(train_dataset[0][0].permute(1,2,0).numpy())
# axes[0].set_title(train_dataset.keys[0])
# axes[1].imshow(train_dataset[1][0].permute(1,2,0).numpy())
# axes[1].set_title(train_dataset.keys[1])
# axes[2].imshow(train_dataset[10][0].permute(1,2,0).numpy())
# axes[2].set_title(train_dataset.keys[10])
# plt.show()

# Model

In [14]:
class GradientMagnitude(nn.Module):
    def __init__(self, kernel_size=config["gradient_layer_kernel_size"], sigma=config["gradient_layer_sigma"]):
        super().__init__()
        # Sobel filters
        sobel_x = torch.tensor([[-1., 0., 1.],
                                [-2., 0., 2.],
                                [-1., 0., 1.]]).view(1, 1, 3, 3)
        sobel_y = torch.tensor([[-1., -2., -1.],
                                [ 0.,  0.,  0.],
                                [ 1.,  2.,  1.]]).view(1, 1, 3, 3)

        self.register_buffer('weight_x', sobel_x)
        self.register_buffer('weight_y', sobel_y)

        self.register_buffer('gaussian_kernel', self._create_gaussian_kernel(kernel_size, sigma))

    def _create_gaussian_kernel(self, kernel_size, sigma):
        ax = torch.arange(kernel_size) - kernel_size // 2
        xx, yy = torch.meshgrid(ax, ax, indexing='ij')
        kernel = torch.exp(-(xx**2 + yy**2) / (2 * sigma**2))
        kernel = kernel / kernel.sum()
        return kernel.view(1, 1, kernel_size, kernel_size)

    def forward(self, x):
        # Apply Gaussian blur
        x_blurred = torch.nn.functional.conv2d(x, self.gaussian_kernel, padding=self.gaussian_kernel.shape[-1] // 2)

        # Apply Sobel filtering
        grad_x = torch.nn.functional.conv2d(x_blurred, self.weight_x, padding=1)
        grad_y = torch.nn.functional.conv2d(x_blurred, self.weight_y, padding=1)

        # Gradient magnitude
        grad_mag = torch.sqrt(grad_x ** 2 + grad_y ** 2 + 1e-6)

        # Normalize to [0, 1] per image
        B = grad_mag.shape[0]
        grad_mag_flat = grad_mag.view(B, -1)
        min_vals = grad_mag_flat.min(dim=1)[0].view(B, 1, 1, 1)
        max_vals = grad_mag_flat.max(dim=1)[0].view(B, 1, 1, 1)
        grad_mag = (grad_mag - min_vals) / (max_vals - min_vals + 1e-6)

        return grad_mag


In [15]:
def size_after_conv(input_size, kernel_size, stride, padding):
    return (input_size - kernel_size + 2 * padding) // stride + 1
conv_config = [
    {'out_channels':2, 'kernel_size':50, 'stride':10},
    {'out_channels':3, 'kernel_size':25, 'stride':5},
    {'out_channels':4, 'kernel_size':12, 'stride':3},
    {'out_channels':5, 'kernel_size':6, 'stride':2},
]

for l in conv_config:
    l['padding'] = l['kernel_size'] // 2


s = 512
print(s*s)
for l in conv_config:
    if l is not None:
        s = size_after_conv(s, l['kernel_size'], l['stride'], l['padding'])
        print(s*s*l['out_channels'])

262144
5408
363
64
45


In [16]:
config['conv_config'] = conv_config

In [17]:
class ConfigCNN(nn.Module):
    def __init__(self, output_size = 2, input_size=(1, 250, 250)):
        super(ConfigCNN, self).__init__()
        c, h, w = input_size
        layers = []
        prev_channels = c
        size = c * h * w
        for layer_config in conv_config:
            layers.append(
                nn.Conv2d(prev_channels, 
                          layer_config['out_channels'], 
                          layer_config['kernel_size'], 
                          layer_config['stride'], 
                          padding=layer_config['padding']
                          )
            )
            prev_channels = layer_config['out_channels']
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm2d(layer_config['out_channels']))

            h = size_after_conv(h, layer_config['kernel_size'], layer_config['stride'], layer_config['padding'])
            w = h
            c = layer_config['out_channels']
            size = c * h * w

        self.sec1 = nn.Sequential(
            *layers
        )

        self.sec2 = nn.Sequential(
            # nn.Linear(size, size//4),
            # nn.ReLU(),
            # nn.BatchNorm1d(size//4),
            nn.Linear(size, output_size),
        )

    def forward(self, x):
        x = self.sec1(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.sec2(x)

        return x

class SimpleCNN(nn.Module):
    def __init__(self, output_size, in_channels):
        super(SimpleCNN, self).__init__()
        self.sec1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 5, 2), # 3, 250 -> 32, 125
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 32, 125 -> 32, 62

            nn.Conv2d(32, 64, 3, 2), # 32, 62 -> 64, 31
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 64, 31 -> 64, 15

            nn.Conv2d(64, 128, 3, 2), # 64, 15 -> 128, 7
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 128, 7 -> 128, 3
        )

        self.sec2 = nn.Sequential(
            nn.Linear(128*3*3, 256),
            nn.ReLU(),
            nn.Linear(256, output_size),
        )

    def forward(self, x):
        x = self.sec1(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.sec2(x)

        return x

class SpotLocalizer(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 7, padding=3),  # keep resolution
            nn.LeakyReLU(),
            nn.Conv2d(8, 8, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(8, 1, 1),  # output heatmap
        )
        self.mlp = nn.Sequential(
            nn.Linear(512*512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(),
            nn.Linear(256, 64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(),
            nn.Linear(64, 16),
            nn.BatchNorm1d(16),
            nn.LeakyReLU(),
            nn.Linear(16, 2)
        )

    def forward(self, x):
        heatmap = self.conv(x)  # (B, 1, H, W)
        # print(heatmap.shape)
        heatmap = torch.flatten(heatmap, start_dim=1)
        # print(heatmap.shape)

        # prob = torch.nn.functional.softmax(heatmap, dim=1)  # [B, H*W]

        coords = self.mlp(heatmap)               # [B, 2] in [0, 1] or raw

        return coords

class SimpleFC(nn.Module):
    def __init__(self, in_features, out_features):
        super(SimpleFC, self).__init__()
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            nn.Linear(in_features, 1024), # 262,144 -> 1024
            nn.BatchNorm1d(1024),
            self.relu,
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            self.relu,
            nn.Linear(256, 32),
            nn.BatchNorm1d(32),
            self.relu,
            nn.Linear(32, out_features),
        )
    def forward(self, x):
        return self.layers.forward(x)
    
class GradientSimpleFC(nn.Module):
    def __init__(self, in_features, out_features):
        super(GradientSimpleFC, self).__init__()
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            GradientMagnitude(),
            nn.Flatten(),
            nn.Linear(in_features, 1024), # 262,144 -> 1024
            nn.BatchNorm1d(1024),
            self.relu,
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            self.relu,
            nn.Linear(256, 32),
            nn.BatchNorm1d(32),
            self.relu,
            nn.Linear(32, out_features),
        )
    def forward(self, x):
        return self.layers.forward(x)
    
class GradientLargerFC(nn.Module):
    def __init__(self, in_features, out_features):
        super(GradientLargerFC, self).__init__()
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            GradientMagnitude(),
            nn.Flatten(),
            nn.Linear(in_features, 1024), # 262,144 -> 1024
            nn.BatchNorm1d(1024),
            self.relu,
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            self.relu,
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            self.relu,
            nn.Linear(256, 32),
            nn.BatchNorm1d(32),
            self.relu,
            nn.Linear(32, out_features),
        )
    def forward(self, x):
        return self.layers.forward(x)
    
class SmallFC(nn.Module):
    def __init__(self, in_features, out_features):
        super(SmallFC, self).__init__()
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            nn.Linear(in_features, 128), # 512*512 -> 128
            nn.BatchNorm1d(128),
            self.relu,
            nn.Linear(128, 32),
            nn.BatchNorm1d(32),
            self.relu,
            nn.Linear(32, out_features),
        )
    def forward(self, x):
        return self.layers.forward(x)
    
class SmallExt(nn.Module):
    def __init__(self, in_features, out_features):
        super(SmallExt, self).__init__()
        k_size = config["conv_kernel"]
        step = config["conv_stride"]
        pad = k_size//2
        depth = config["conv_depth"]
        flat = int((in_features-k_size + 2*pad)//step + 1)
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            nn.Conv2d(1, depth, k_size, step, padding=pad),
            self.relu,
            nn.Flatten(),
            nn.Linear(flat*flat * depth, 128), # 512*512 -> 128
            nn.BatchNorm1d(128),
            self.relu,
            nn.Linear(128, 32),
            nn.BatchNorm1d(32),
            self.relu,
            nn.Linear(32, out_features),
        )
    def forward(self, x):
        return self.layers.forward(x)
    
class CnnExtractor(nn.Module):
    """Input.shape = (3, 256, 256)"""
    def __init__(self, output_size):
        super(CnnExtractor, self).__init__()
        ksize = config["conv_kernel"]
        pad = ksize // 2
        self.sec1 = nn.Sequential(
            nn.Conv2d(3, 1, ksize, 1, pad), # (3, 256, 256) -> (1, 256, 256)
            nn.BatchNorm2d(1),
            nn.LeakyReLU(),
        )

        self.sec2 = nn.Sequential(
            nn.Linear(256*256, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(),
            nn.Linear(256, 64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.LeakyReLU(),
            nn.Linear(32, output_size),
        )

    def forward(self, x:torch.Tensor):
        x = self.sec1(x)
        x = torch.squeeze(x)
        x = torch.flatten(x, 1)
        x = self.sec2(x)

        return x
    
class LearnableNormalizer(nn.Module):
    def __init__(self, kernel_size=15, sigma=3.5, a=3, d=1):
        super().__init__()
        self.conv = nn.Conv2d(3, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
        self.init_weights(kernel_size, sigma, a, d)
        self.activaltion = nn.Sigmoid()
    def gaussian_kernel(self, kernel_size, sigma):
        """Create 2D Gaussian kernel centered in the middle."""
        ax = torch.linspace(-(kernel_size - 1) / 2., (kernel_size - 1) / 2., kernel_size)
        xx, yy = torch.meshgrid(ax, ax, indexing='ij')
        kernel = torch.exp(-(xx**2 + yy**2) / (2. * sigma**2))
        return kernel

    def init_weights(self, kernel_size, sigma, a, d):
        with torch.no_grad():
            self.conv.weight.zero_()
            kernel = self.gaussian_kernel(kernel_size, sigma)
            weighted_kernel = kernel * a - d
            for c in range(3):
                self.conv.weight[0, c] = weighted_kernel

    def forward(self, x):
        return self.activaltion(self.conv(x))

class CnnExtractorDeepColor(nn.Module):
    """Input.shape = (3, 256, 256)"""
    def __init__(self, output_size):
        super(CnnExtractorDeepColor, self).__init__()
        ksize0 = config["conv_kernel"]
        ksize1 = config["conv_kernel_1"]
        pad0 = ksize0 // 2
        pad1 = ksize1 // 2
        self.sec1 = nn.Sequential(
            nn.Conv2d(3, 6, ksize0, 1, pad0), # (3, 256, 256) -> (6, 256, 256)
            nn.BatchNorm2d(6),
            nn.Tanh(),
            nn.MaxPool2d(3, 1, 1),
            nn.Conv2d(6, 1, ksize1, 1, pad1), # (6, 256, 256) -> (1, 256, 256)
            nn.BatchNorm2d(1),
            nn.Tanh(),
            nn.MaxPool2d(3, 1, 1),
        )

        self.sec2 = nn.Sequential(
            nn.Linear(256*256, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(),
            nn.Linear(256, 64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.LeakyReLU(),
            nn.Linear(32, output_size),
        )

    def forward(self, x:torch.Tensor):
        x = self.sec1(x)
        x = torch.squeeze(x)
        x = torch.flatten(x, 1)
        x = self.sec2(x)

        return x

class CnnShallowColor(nn.Module):
    """Input.shape = (3, 256, 256)"""
    def __init__(self, output_size):
        super(CnnShallowColor, self).__init__()
        ksize0 = config["conv_kernel"]
        self.sec1 = LearnableNormalizer(ksize0)
        for param in self.sec1.parameters():
            param.requires_grad = False

        self.sec2 = nn.Sequential(
            nn.Linear(256*256, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(),
            nn.Linear(256, 64),
            nn.BatchNorm1d(64),
            nn.LeakyReLU(),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.LeakyReLU(),
            nn.Linear(32, output_size),
        )

    def forward(self, x:torch.Tensor):
        x = self.sec1(x)
        x = torch.squeeze(x)
        x = torch.flatten(x, 1)
        x = self.sec2(x)

        return x
 
    

# model = CoordWideConv().to(DEVICE)
# model = SimpleFC(512*512, 2).to(DEVICE)
# model = SmallFC(512*512, 2).to(DEVICE)
# model = SmallExt(512, 2).to(DEVICE)
# model = GradientLargerFC(512*512, 2).to(DEVICE)
# model = SpotLocalizer().to(DEVICE)
# model = CnnExtractor(2).to(DEVICE)

# model = CnnExtractorDeepColor(2).to(DEVICE)
# model = CnnShallowColor(2).to(DEVICE)

model = ConfigCNN(2, input_size=(1, 512, 512)).to(DEVICE)

if config["use_weight_initialization"]:
    for m in model.modules():
        if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d,
                          nn.Linear)):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

if config["init_red_filter"]:
    def gaussian_kernel(kernel_size, sigma):
        """Create 2D Gaussian kernel centered in the middle."""
        ax = np.linspace(-(kernel_size - 1) / 2., (kernel_size - 1) / 2., kernel_size)
        xx, yy = np.meshgrid(ax, ax)
        kernel = np.exp(-(xx**2 + yy**2) / (2. * sigma**2))
        return torch.tensor(kernel, dtype=torch.float32)

    gauss = gaussian_kernel(config["conv_kernel"], 3)
    with torch.no_grad():
        model.sec1[0].weight[0, 0] = gauss

summary(model, (1,512,512), config["batch_size"])
# summary(model, (512*512,), config["batch_size"])
        

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [800, 2, 52, 52]           5,002
              ReLU-2           [800, 2, 52, 52]               0
       BatchNorm2d-3           [800, 2, 52, 52]               4
            Conv2d-4           [800, 3, 11, 11]           3,753
              ReLU-5           [800, 3, 11, 11]               0
       BatchNorm2d-6           [800, 3, 11, 11]               6
            Conv2d-7             [800, 4, 4, 4]           1,732
              ReLU-8             [800, 4, 4, 4]               0
       BatchNorm2d-9             [800, 4, 4, 4]               8
           Conv2d-10             [800, 5, 3, 3]             725
             ReLU-11             [800, 5, 3, 3]               0
      BatchNorm2d-12             [800, 5, 3, 3]              10
           Linear-13                   [800, 2]              92
Total params: 11,332
Trainable params: 

- ansemble: 
- transfer learning

# Train


In [18]:
optimizer = optim.AdamW(model.parameters(), config["lr"], weight_decay=0.001)
criterion = nn.MSELoss()
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, config["lr_scheduler_loop"], eta_min=0.00001)
# scheduler = optim.lr_scheduler.ConstantLR(optimizer, 1, 0, )
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)
# scaler = torch.cuda.amp.GradScaler("cuda", enabled=config["use_amp"])

In [19]:
wandb.login(key="a41d74c58ab2f0d2c2bbdb317450ab14a8ad9d4e")
wandb.init(
    project="multireflection",
    name=config["experiment_name"],
    config=config,
    resume="allow",
)
wandb.watch(model, log='all', log_freq=100)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/evv/.netrc
[34m[1mwandb[0m: Currently logged in as: [33me-venediktov[0m ([33me-venediktov-university-of-pittsburgh[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [20]:
from torch.amp import GradScaler, autocast

def train(model, train_loader, val_loader, optimizer: optim.Optimizer, criterion, scheduler: optim.lr_scheduler.CosineAnnealingWarmRestarts, best_loss=None):
    scaler = GradScaler(DEVICE)
    if best_loss is None:
        best_loss = 1000000000
    best_model = None
    for epoch in range(config['epochs']):
        model.train()
        running_loss = 0.0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad(set_to_none=True)

            with autocast("cuda", dtype=torch.float16, enabled=False):
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        last_lr = scheduler.get_last_lr()[0]
        avg_train_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.inference_mode():
            for images, labels in tqdm(val_loader):
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                with autocast("cuda", dtype=torch.float16, enabled=False):
                    out = model(images)
                    loss = criterion(out, labels)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        scheduler.step()

        if avg_val_loss < best_loss:
            best_model = model
            best_loss = avg_val_loss
            save_model(model, fname=config["experiment_name"] + "_best_model.pth")

        print(f"Epoch {epoch + 1}/{config['epochs']}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        # ✅ Log Training Loss
        log_train_loss = log(avg_train_loss)
        log_val_loss = log(avg_val_loss)
        avg_total_loss = avg_train_loss * 0.8 + avg_val_loss * 0.2
        log_total_loss = log(avg_total_loss)
        wandb.log({
            "Train Loss": avg_train_loss,
            "Val Loss": avg_val_loss,
            "LR": last_lr,
            "best_loss": best_loss,
            "log_train_loss": log_train_loss,
            "log_val_loss": log_val_loss,
            "avg_total_loss": avg_total_loss,
            "log_total_loss": log_total_loss,
            # "ds_train_loaded": len(train_dataset.loaded_indexes),
            # "ds_val_loaded": len(val_dataset.loaded_indexes),
        })

    print("Best loss:", best_loss)
    return model, best_model, best_loss


In [21]:
if config["starting_checkpoint_fname"] is not None:
    model = load_model(model, fname=config["starting_checkpoint_fname"], path=config["checkpoint_folder"])

In [22]:
best_loss = None
# best_loss = 0.0115

In [None]:
model, best_model, best_loss = train(model, train_data_loader, val_data_loader, optimizer, criterion, scheduler, best_loss)

100%|██████████| 15/15 [01:32<00:00,  6.14s/it]
  0%|          | 0/4 [00:00<?, ?it/s]

Opening lmdb txnOpening lmdb txnOpening lmdb txn


Opening lmdb txn


100%|██████████| 4/4 [00:28<00:00,  7.12s/it]


Epoch 1/28, Train Loss: 0.1774, Val Loss: 0.2770


100%|██████████| 15/15 [01:18<00:00,  5.22s/it]
100%|██████████| 4/4 [00:40<00:00, 10.12s/it]


Epoch 2/28, Train Loss: 0.0285, Val Loss: 0.0669


100%|██████████| 15/15 [01:52<00:00,  7.49s/it]
100%|██████████| 4/4 [00:38<00:00,  9.74s/it]


Epoch 3/28, Train Loss: 0.0198, Val Loss: 0.0350


  0%|          | 0/15 [00:00<?, ?it/s]

In [None]:
# train_dataset.debug_msg

In [None]:
wandb.finish()

# Test


In [None]:
# Prepare test images
keys = [
    b"x0.00_y0.00.jpg",
    b"x1.00_y0.00.jpg",
]
ds_imgs = []
for key in keys:
    img, _ = train_dataset[train_dataset.get_index(key)]
    ds_imgs.append(img)
ds_test_input = torch.stack(ds_imgs).to(DEVICE)

data_folder = "/mnt/h/dark512"
file_imgs = []
for key in keys:
    path = os.path.join(data_folder, key.decode())
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    img = torch.from_numpy(img).float().flatten()/255
    file_imgs.append(img)
file_test_input = torch.stack(file_imgs).to(DEVICE)

In [None]:
# Try best model
model.eval()

ds_predictions:torch.Tensor = model.forward(ds_test_input)
ds_predictions = ds_predictions.detach().cpu().numpy()
print(ds_predictions)

file_predictions:torch.Tensor = model.forward(file_test_input)
file_predictions = file_predictions.detach().cpu().numpy()
print(file_predictions)

In [None]:
from math import ceil

In [None]:
# Visualize
def visualize(keys, imgs, predictions):
    fig, axes = plt.subplots(nrows=ceil(len(keys)/3), ncols=2, figsize=(7, 7))
    for i, ax in enumerate(axes.flat):
        if i >= len(keys):
            break
        ax.imshow(imgs[i].reshape((512,512)).detach().cpu().numpy(), cmap='gray')
        ax.set_title(keys[i])
        ax.axis("off")

        # Display prediction
        ax.text(35,15,f"x{predictions[i][0]:.2f}_y{predictions[i][1]:.2f}", color="white")

    plt.show()

In [None]:
visualize(keys, ds_imgs, ds_predictions)

In [None]:
visualize(keys, file_imgs, file_predictions)

In [None]:
# display conv filter
# conv_weight = model.sec1[0].weight.data.clone()  # shape: (1, 3, kH, kW)

# # Normalize weights for visualization
# # conv_weight = (conv_weight - conv_weight.min()) / (conv_weight.max() - conv_weight.min())
# print(conv_weight)

# out_channels, in_channels, kH, kW = conv_weight.shape

# # Plot each input channel of the filter
# fig, axs = plt.subplots(1, in_channels, figsize=(in_channels * 2, 2))

# for j in range(in_channels):
#     ax = axs[j] if in_channels > 1 else axs
#     ax.imshow(conv_weight[0, j].cpu().numpy())
#     ax.axis('off')
#     ax.set_title(f'Channel {j}')

# plt.tight_layout()
# plt.show()

In [None]:
ntargets = np.array(train_dataset.targets)
print(ntargets.shape)

evaluate_on_train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=False)