In [1]:
import numpy as np
import matplotlib.pyplot as plt
from preprocess_images import data_from_folder
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
import cv2 
import wandb
from config import LMDB_USE_COMPRESSION

import lmdb
import os
import msgpack
import lz4.frame

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [None]:
# data = data_from_folder("H:/latest_real_data/real_data/real", grayscale=True, target_size=(125, 125))
# data = data_from_folder("./data/laser_x4_y6")
# data = data_from_folder("./data/test")

In [None]:
# for k in data:
#     print(k)
#     plt.imshow(data[k], vmin=0, vmax=255)
#     # plt.show()
#     break

In [2]:
def imname_to_target(name:str) -> tuple[float]:
    """Parses image names of format x{x_value}_y{y_value}.jpg"""
    name = name.split('.jpg')[0]
    x, y = name.split("_")
    x = float(x[1:])
    y = float(y[1:])
    return x, y

In [3]:
config = {
    "experiment_name": "tuning_512_real_lmdb_512batch_002srate_0001lr15eploop_lmdb",
    "batch_size": 512,
    "lr": 0.001,
    "lr_scheduler_loop": 15,
    "epochs": 30,
    "use_amp": False,
    "data_folder": "/mnt/h/real_512_0_002step.lmdb",
    # "dataset_type": "FilesImageDataset",
    "dataset_type": "LMDBImageDataset",
    # "data_folder": "H:/real_512_0_001step.lmdb",
    "data_collection_step": 0.002,
    "starting_checkpoint_fname": "512_real_lmdb_512batch_002lr_best_model.pth",
    "checkpoint_folder": "./saved_models/real"
}

In [4]:
# Prepare image filenames for FilesImageDataset
FilesImageDataset_fnames = []
if config["dataset_type"] == "FilesImageDataset":
    for file in os.listdir(config["data_folder"]):
            if not file.endswith((".png", ".jpg", ".jpeg")):
                  continue
            if config["data_collection_step"] == 0.002:
                # TODO: Attention! Mimmicking 0.002 step
                x, y = imname_to_target(file)
                if int(x*100)%2 != 0:
                    continue
                if int(y*100)%2 != 0:
                    continue
            FilesImageDataset_fnames.append(file)

In [5]:
class ImageDataset(Dataset):
    def __init__(self, images, targets, transform=None):
        self.images = images
        self.targets = targets
        self.transform = transform

        self.in_channels = 3
        if len(self.images[0].shape) == 2:
            self.in_channels = 1
        else:
            self.in_channels = self.images[0].shape[2] # TODO check
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image = self.images[index]  # Get image (NumPy array)
        label = self.targets[index]  # Get corresponding tuple

        # Convert image to Tensor and normalize [0,1]
        image = torch.from_numpy(image)
        if len(image.shape)==3:
            image = image.permute(2, 0, 1)  # HWC -> CHW
        else:
            image = image.unsqueeze(0)
        image = image.float() / 255.0

        # Convert label tuple to Tensor
        label = torch.tensor(label, dtype=torch.float32)

        # Apply transforms if specified
        if self.transform:
            image = self.transform(image)

        return image, label


class FlatGrayImageDataset(Dataset):
    def __init__(self, images, targets, exclude=True):
        self.images = images
        self.targets = targets

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        image = self.images[index]  # Get image (NumPy array)
        label = self.targets[index]  # Get corresponding tuple
        
        # Convert image to Tensor, flatten, and normalize [0,1]
        image = torch.from_numpy(image).flatten()
        image = image.float() / 255.0
        
        # Convert label tuple to Tensor
        label = torch.tensor(label, dtype=torch.float32)

        return image, label
    

class FilesImageDataset(Dataset):
    def __init__(self, data_dir, filenames):
        self.data_dir = data_dir
        self.filenames = filenames
        self.targets = [imname_to_target(s) for s in filenames]

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        fname = self.filenames[index]
        image = cv2.imread(os.path.join(self.data_dir, fname), cv2.IMREAD_GRAYSCALE)
        label = self.targets[index]  # Get corresponding tuple
        
        # Convert image to Tensor, flatten, and normalize [0,1]
        image = torch.from_numpy(image).flatten()
        image = image.float() / 255.0
        
        # Convert label tuple to Tensor
        label = torch.tensor(label, dtype=torch.float32)

        return image, label
    

class LMDBImageDataset(Dataset):
    def __init__(self, lmdb_path, keys_fname="keys.txt"):
        self.keys = None

        # Read text keys from file
        with open(os.path.join(lmdb_path, keys_fname)) as f:
            self.keys = f.readlines()
            if self.keys[-1] == '':
                self.keys = self.keys[:-1]
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].replace("\n", "")

        # Get labels from text keys
        self.labels = [imname_to_target(key) for key in self.keys]

        # Encode keys
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].encode()

        self.lmdb_path = lmdb_path

    def open_lmdb(self):
        self.env = lmdb.open(self.lmdb_path, readonly=True, create=False, lock=False, readahead=False, meminit=False)
        self.txn = self.env.begin()

    def close(self):
        self.env.close()

    def __len__(self):
        return len(self.keys)
    
    def __getitem__(self, index):
        if not hasattr(self, 'txn'):
            print("Opening lmdb txn")
            self.open_lmdb()
        key = self.keys[index]  # Get corresponding tuple
        label = self.labels[index]
        
        img_bytes = self.txn.get(key)
        
        if img_bytes is None:
            raise KeyError(f"Image {key} not found in LMDB!")

        if LMDB_USE_COMPRESSION:
            img_bytes = lz4.frame.decompress(img_bytes)

        image = np.array(msgpack.unpackb(img_bytes, raw=False), dtype=np.uint8)
        
        # Convert image to Tensor, flatten, and normalize [0,1]
        image = torch.from_numpy(image).flatten()
        image = image.float() / 255.0
        
        # Convert label tuple to Tensor
        label = torch.tensor(label, dtype=torch.float32)

        return image, label

In [6]:
# Prepare dataset
# images = []
# keys = []
# exclude_x_min = -0.4
# exclude_x_max = -0.04
# for name, image in data.items():
#     x, y = imname_to_target(name)
#     if x >= exclude_x_min and x <= exclude_x_max:
#         continue
#     keys.append((x, y))
#     images.append(image)
# dataset = ImageDataset(images, targets)
# dataset = FlatGrayImageDataset(images, keys)

match config["dataset_type"]:
    case "LMDBImageDataset":
        dataset = LMDBImageDataset(config["data_folder"])
    case "FilesImageDataset":
        dataset = FilesImageDataset(config["data_folder"], FilesImageDataset_fnames)
    case _ :
        raise("Wrong dataset type")
data_loader = DataLoader(dataset, 
                         batch_size=config["batch_size"], 
                         shuffle=True, 
                         num_workers=2, 
                         pin_memory=True, 
                         prefetch_factor=2, 
                         persistent_workers=True
                        )

In [7]:
# print(dataset[0][0].shape)
print(len(dataset))

58580


In [None]:
# for x in data_loader:
#     print(x)
#     break

In [8]:
class SimpleCNN(nn.Module):
    def __init__(self, output_size, in_channels):
        super(SimpleCNN, self).__init__()
        self.sec1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 5, 2), # 3, 250 -> 32, 125
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 32, 125 -> 32, 62

            nn.Conv2d(32, 64, 3, 2), # 32, 62 -> 64, 31
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 64, 31 -> 64, 15

            nn.Conv2d(64, 128, 3, 2), # 64, 15 -> 128, 7
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # 128, 7 -> 128, 3
        )

        self.sec2 = nn.Sequential(
            nn.Linear(128*3*3, 256),
            nn.ReLU(),
            nn.Linear(256, output_size),
        )

    def forward(self, x):
        x = self.sec1(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.sec2(x)

        return x
    
class SimpleFC(nn.Module):
    def __init__(self, in_features, out_features):
        super(SimpleFC, self).__init__()
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            nn.Linear(in_features, 1024), # 262,144 -> 1024
            nn.BatchNorm1d(1024),
            self.relu,
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            self.relu,
            nn.Linear(256, 32),
            nn.BatchNorm1d(32),
            self.relu,
            nn.Linear(32, out_features),
        )
    def forward(self, x):
        return self.layers.forward(x)
    
class WideFC(nn.Module):
    def __init__(self, in_features, out_features):
        super(SimpleFC, self).__init__()
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            nn.Linear(in_features, 2048), # 262,144 -> 1024
            nn.BatchNorm1d(2048),
            self.relu,
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            self.relu,
            nn.Linear(1024, 64),
            nn.BatchNorm1d(64),
            self.relu,
            nn.Linear(64, out_features),
        )
    def forward(self, x):
        return self.layers.forward(x)

# model = SimpleCNN(output_size=2, in_channels=data_loader.in_channels).to(DEVICE)
# summary(model, (data_loader.in_channels, 250, 250))


model = SimpleFC(512*512, 2).to(DEVICE)
summary(model, (512*512,), config["batch_size"])
        

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [512, 1024]     268,436,480
       BatchNorm1d-2                [512, 1024]           2,048
              ReLU-3                [512, 1024]               0
              ReLU-4                [512, 1024]               0
            Linear-5                 [512, 256]         262,400
       BatchNorm1d-6                 [512, 256]             512
              ReLU-7                 [512, 256]               0
              ReLU-8                 [512, 256]               0
            Linear-9                  [512, 32]           8,224
      BatchNorm1d-10                  [512, 32]              64
             ReLU-11                  [512, 32]               0
             ReLU-12                  [512, 32]               0
           Linear-13                   [512, 2]              66
Total params: 268,709,794
Trainable par

In [9]:
optimizer = optim.AdamW(model.parameters(), config["lr"], weight_decay=0)
criterion = nn.MSELoss()
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, config["lr_scheduler_loop"], eta_min=0.00001)
# scaler = torch.cuda.amp.GradScaler("cuda", enabled=config["use_amp"])

In [10]:
def save_model(model:torch.nn.Module, fname="best_model.pth", path="./saved_models/real"):
    torch.save(model.state_dict(), os.path.join(path,fname))

def load_model(model:torch.nn.Module, fname="best_model.pth", path="./saved_models/real"):
    model.load_state_dict(torch.load(os.path.join(path,fname), weights_only=False))
    return model

In [11]:
wandb.login(key="a41d74c58ab2f0d2c2bbdb317450ab14a8ad9d4e")
wandb.init(
    project="multireflection",
    name=config["experiment_name"],
    config=config,
)
wandb.watch(model, log='all', log_freq=100)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/evv/.netrc
[34m[1mwandb[0m: Currently logged in as: [33me-venediktov[0m ([33me-venediktov-university-of-pittsburgh[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
def train(model, data_loader, optimizer:optim.Optimizer, criterion, scheduler:optim.lr_scheduler.CosineAnnealingWarmRestarts, best_loss=None):
    
    if best_loss is None:
        best_loss = 1000000000
    best_model = None

    # Training Loop
    for epoch in range(config['epochs']):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(data_loader):
            # with torch.autocast(device_type=DEVICE, dtype=torch.float16, enabled=config["use_amp"]):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad(set_to_none=True)
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            
            # scaler.scale(loss).backward()
            # scaler.step(optimizer)
            # scaler.update()
            running_loss += loss.item()
            
        last_lr = scheduler.get_last_lr()[0]
        scheduler.step()

        avg_train_loss = running_loss / len(data_loader)
        print(f"Epoch {epoch + 1}/{config['epochs']}, Train Loss: {avg_train_loss:.4f}")

        if avg_train_loss < best_loss:
            best_model = model
            best_loss = avg_train_loss
            save_model(model, fname=config["experiment_name"]+"_best_model.pth")

        # ✅ Log Training Loss
        wandb.log({"Train Loss": avg_train_loss, "LR": last_lr, "best_loss":best_loss})

    print("Best loss:", best_loss)
    return model, best_model, best_loss


In [13]:
if config["starting_checkpoint_fname"] is not None:
    model = load_model(model, fname=config["starting_checkpoint_fname"], path=config["checkpoint_folder"])

In [None]:
best_loss = 0.00259

In [15]:
model, best_model, best_loss = train(model, data_loader, optimizer, criterion, scheduler, best_loss)

  0%|          | 0/115 [00:00<?, ?it/s]

100%|██████████| 115/115 [03:08<00:00,  1.64s/it]


Epoch 1/30, Train Loss: 0.0043


100%|██████████| 115/115 [03:06<00:00,  1.62s/it]


Epoch 2/30, Train Loss: 0.0046


100%|██████████| 115/115 [03:06<00:00,  1.62s/it]


Epoch 3/30, Train Loss: 0.0040


100%|██████████| 115/115 [03:22<00:00,  1.76s/it]


Epoch 4/30, Train Loss: 0.0040


100%|██████████| 115/115 [03:02<00:00,  1.59s/it]


Epoch 5/30, Train Loss: 0.0035


100%|██████████| 115/115 [03:00<00:00,  1.57s/it]


Epoch 6/30, Train Loss: 0.0033


100%|██████████| 115/115 [02:56<00:00,  1.54s/it]


Epoch 7/30, Train Loss: 0.0032


100%|██████████| 115/115 [02:59<00:00,  1.56s/it]


Epoch 8/30, Train Loss: 0.0030


100%|██████████| 115/115 [02:57<00:00,  1.55s/it]


Epoch 9/30, Train Loss: 0.0027


100%|██████████| 115/115 [03:01<00:00,  1.58s/it]


Epoch 10/30, Train Loss: 0.0027


100%|██████████| 115/115 [03:03<00:00,  1.60s/it]


Epoch 11/30, Train Loss: 0.0027


100%|██████████| 115/115 [02:55<00:00,  1.52s/it]


Epoch 12/30, Train Loss: 0.0025


100%|██████████| 115/115 [03:00<00:00,  1.57s/it]


Epoch 13/30, Train Loss: 0.0025


100%|██████████| 115/115 [02:59<00:00,  1.56s/it]


Epoch 14/30, Train Loss: 0.0024


100%|██████████| 115/115 [03:07<00:00,  1.63s/it]


Epoch 15/30, Train Loss: 0.0021


100%|██████████| 115/115 [03:09<00:00,  1.65s/it]


Epoch 16/30, Train Loss: 0.0036


100%|██████████| 115/115 [03:08<00:00,  1.64s/it]


Epoch 17/30, Train Loss: 0.0036


100%|██████████| 115/115 [02:58<00:00,  1.55s/it]


Epoch 18/30, Train Loss: 0.0036


100%|██████████| 115/115 [02:56<00:00,  1.53s/it]


Epoch 19/30, Train Loss: 0.0033


100%|██████████| 115/115 [02:54<00:00,  1.52s/it]


Epoch 20/30, Train Loss: 0.0033


100%|██████████| 115/115 [02:51<00:00,  1.49s/it]


Epoch 21/30, Train Loss: 0.0027


100%|██████████| 115/115 [02:52<00:00,  1.50s/it]


Epoch 22/30, Train Loss: 0.0029


100%|██████████| 115/115 [02:50<00:00,  1.48s/it]


Epoch 23/30, Train Loss: 0.0025


100%|██████████| 115/115 [02:54<00:00,  1.52s/it]


Epoch 24/30, Train Loss: 0.0023


100%|██████████| 115/115 [02:52<00:00,  1.50s/it]


Epoch 25/30, Train Loss: 0.0025


100%|██████████| 115/115 [02:50<00:00,  1.48s/it]


Epoch 26/30, Train Loss: 0.0021


100%|██████████| 115/115 [02:50<00:00,  1.48s/it]


Epoch 27/30, Train Loss: 0.0021


100%|██████████| 115/115 [02:50<00:00,  1.48s/it]


Epoch 28/30, Train Loss: 0.0021


100%|██████████| 115/115 [02:49<00:00,  1.47s/it]


Epoch 29/30, Train Loss: 0.0019


100%|██████████| 115/115 [02:56<00:00,  1.53s/it]


Epoch 30/30, Train Loss: 0.0019
Best loss: 0.0018502317541076437


In [16]:
wandb.finish()

0,1
LR,███▇▆▅▄▃▁▁██▇▇▆▅▄▃▃▁██▇▇▆▅▄▃▃▂█▇▇▆▅▄▃▃▂▁
Train Loss,█▇▆▆▅▅▄▄▄▃▅▄▅▄▄▃▃▃▂▂▅▅▄▄▃▃▃▂▂▂▁▄▄▃▂▂▂▁▁▁
best_loss,████▇▆▆▆▅▅▅▅▅▅▅▄▄▄▃▃▆▅▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁

0,1
LR,2e-05
Train Loss,0.00185
best_loss,0.00185


In [None]:
def prepare_test_input(data_folder:str, fnames:list[str], device):
    tensors = []
    targets = []
    original_images = [] # for visualization
    for fname in fnames:
        img = cv2.imread(data_folder+fname, cv2.IMREAD_GRAYSCALE)
        original_images.append(img)
        img = torch.from_numpy(img).flatten()
        img = img.float() / 255.0
        tensors.append(img)

        targets.append(imname_to_target(fname))
    result_tensor = torch.stack(tensors).to(device)

    return result_tensor, targets, original_images

In [None]:
# Prepare test images
data_folder = "data/125x125_laser_x4_y6/"
fnames = [
    "x-3.00_y-2.10.jpg",
    "x-2.90_y1.40.jpg",
    "x-1.70_y2.00.jpg",

    "x-1.10_y1.40.jpg",
    "x-0.60_y2.10.jpg",
    "x1.30_y-0.90.jpg",

    "x-0.00_y-0.10.jpg",
    "x-0.00_y0.00.jpg",
    "x0.80_y0.10.jpg",
]
test_input, keys, original_images = prepare_test_input(data_folder, fnames, DEVICE)
print(test_input.shape)
print(keys)


In [None]:
# Try best model
best.eval()

predictions:torch.Tensor = best.forward(test_input)
predictions = predictions.detach().cpu().numpy()
print(predictions)

In [None]:
# Visualize
fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(10, 10))
for i, ax in enumerate(axes.flat):
    ax.imshow(original_images[i], cmap='gray')
    ax.set_title(fnames[i])
    ax.axis("off")

    # Display prediction
    ax.text(35,15,f"x{predictions[i][0]:.2f}_y{predictions[i][1]:.2f}", color="white")

plt.show()

In [None]:
ntargets = np.array(dataset.targets)
print(ntargets.shape)

evaluate_on_train_loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False)