# Imports


In [1]:
import numpy as np
import matplotlib.pyplot as plt
from preprocess_images import data_from_folder
from tqdm import tqdm
from math import log

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as T
from torchsummary import summary
import cv2 
import wandb
from config import LMDB_USE_COMPRESSION

import lmdb
import os
import msgpack
import lz4.frame

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

torch.manual_seed(0)

cuda


<torch._C.Generator at 0x7f8e1d54ec90>

In [2]:
def imname_to_target(name:str) -> tuple[float]:
    """Parses image names of format x{x_value}_y{y_value}.jpg"""
    name = name.split('.jpg')[0]
    x, y = name.split("_")
    x = float(x[1:])
    y = float(y[1:5])
    return x, y

def save_model(model:torch.nn.Module, fname="best_model.pth", path="./saved_models/real"):
    torch.save(model.state_dict(), os.path.join(path,fname))

def load_model(model:torch.nn.Module, fname="best_model.pth", path="./saved_models/real"):
    model.load_state_dict(torch.load(os.path.join(path,fname), weights_only=False))
    return model

# Config


In [3]:
config = {
    "experiment_name": "004step_ConfigCNN_DarkOnly512_lmdb_400bs_0001lr_aug+",
    "batch_size": 400,
    "lr": 0.001,
    "lr_scheduler_loop": 7,
    "epochs": 28,
    "use_amp": False,

    "data_folder": "/mnt/h/real_512_0_001step.lmdb",
    # "data_folder": "/mnt/e/color.lmdb",
    "dataset_type": "LMDBImageDataset",
    "dataset_config_flatten": False,
    "dataset_train_keys_fname": "004_dark_train.txt",
    "dataset_val_keys_fname": "004_dark_val.txt",
    "dataset_offload_count": 0,

    "use_noise_transform": True,
    "noise_level": 0.1,
    "use_jitter_transform": True,
    "jitter_brightness": 0.4, 
    "jitter_contrast": 0.1, 
    "jitter_saturation": 0.1, 
    "jitter_hue": 0.2,

    "use_grayscale_transform": False,
    "use_clahegrad_transform": False,
    "clahe_clip_limit": 0.001,
    "clahe_gaussian_size": 15,
    "clahe_gaussian_sigma": 5,

    "use_high_pass_transform": False,
    "high_pass_transform_t": 0.35,

    "data_collection_step": 0.001,
    "starting_checkpoint_fname": None,
    "checkpoint_folder": "./saved_models/real",

    "gradient_layer_kernel_size": 15,
    "gradient_layer_sigma": 5,

    "use_weight_initialization": True,
    "init_red_filter": False
}

# Data

In [4]:


class FilesImageDataset(Dataset):
    def __init__(self, data_dir, filenames):
        self.data_dir = data_dir
        self.filenames = filenames
        self.targets = [imname_to_target(s) for s in filenames]

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        fname = self.filenames[index]
        image = cv2.imread(os.path.join(self.data_dir, fname), cv2.IMREAD_GRAYSCALE)
        label = self.targets[index]  # Get corresponding tuple
        
        # Convert image to Tensor, flatten, and normalize [0,1]
        image = torch.from_numpy(image).flatten()
        image = image.float() / 255.0
        
        # Convert label tuple to Tensor
        label = torch.tensor(label, dtype=torch.float32)

        return image, label

class InMemoryLMDBImageDataset(Dataset):
    def __init__(self, data_folder_path, transforms=None, keys_fname="keys.txt", flatten_data=True, turn_to_grayscale=True):
        self.keys = None

        # Data augmentation
        self.transforms = transforms

        # Read text keys from file
        with open(os.path.join(data_folder_path, keys_fname)) as f:
            self.keys = f.readlines()
            if self.keys[-1] == '':
                self.keys = self.keys[:-1]
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].replace("\n", "")

        # Get labels from text keys
        self.labels = []
        for i, key in enumerate(self.keys):
            try:
                label = imname_to_target(key)

                # Convert label tuple to Tensor
                x, y = label
                x = (x + 2) / 5.7
                y = (y + 2) / 4
                label = (x, y)
                label = torch.tensor(label, dtype=torch.float32)
                self.labels.append(label)
            except Exception as e:
                print("i:", i)
                print("name:", key)
                raise e
            
        # Encode keys
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].encode()

        # Load images
        self.env = lmdb.open(data_folder_path, readonly=True, create=False, lock=False, readahead=False, meminit=False)
        self.txn = self.env.begin()

        self.images = [None]*len(self.keys)
        self.loaded_indexes = set()
        self.flatten_data = flatten_data
        self.turn_to_grayscale = turn_to_grayscale

    def __len__(self):
        return len(self.keys)

    def get_index(self, key):
        for i, k in enumerate(self.keys):
            if k == key:
                return i
        
        return None
    
    def __getitem__(self, index):
        label = self.labels[index]

        if index in self.loaded_indexes:
            img = self.images[index]     
        else:
            key = self.keys[index]
            img_bytes = self.txn.get(key)
        
            if img_bytes is None:
                raise KeyError(f"Image {key} not found in LMDB!")

            img = np.array(msgpack.unpackb(img_bytes, raw=False), dtype=np.uint8)
            if self.turn_to_grayscale:
                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

            img = np.array(img, dtype=np.float32)
            img = torch.from_numpy(img)
            img = img / 255.0
            if not self.turn_to_grayscale:
                img = img.permute(2, 0, 1)
            else:
                img = torch.unsqueeze(img, 0)
            self.images[index] = img
            self.loaded_indexes.add(index)

        # DEBUG
        # print(img.shape)

        # Augmenation
        if self.transforms is not None:
            img = self.transforms(img)
        if self.flatten_data:
            img = img.flatten().float()
            self.debug_msg = f"image shape {img.shape}"
        elif isinstance(img, np.ndarray):
            img = torch.unsqueeze(torch.from_numpy(img), 0)

        return img, label

class LMDBImageDataset(Dataset):
    def __init__(self, lmdb_path, transforms=None, keys_fname="keys.txt", flatten_data=True):
        self.keys = None

        # Data augmentation
        self.transforms = transforms

        # Read text keys from file
        with open(os.path.join(lmdb_path, keys_fname)) as f:
            self.keys = f.readlines()
            if self.keys[-1] == '':
                self.keys = self.keys[:-1]
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].replace("\n", "")

        # Get labels from text keys
        self.labels = []
        # self.labels = [imname_to_target(key) for key in self.keys]
        for i, key in enumerate(self.keys):
            try:
                self.labels.append(imname_to_target(key))
            except Exception as e:
                print("i:", i)
                print("name:", key)
                raise e

        # Encode keys
        for i in range(len(self.keys)):
            self.keys[i] = self.keys[i].encode()

        self.lmdb_path = lmdb_path
        self.flatten_data = flatten_data

    def open_lmdb(self):
        self.env = lmdb.open(self.lmdb_path, readonly=True, create=False, lock=False, readahead=False, meminit=False)
        self.txn = self.env.begin()

    def close(self):
        self.env.close()

    def __len__(self):
        return len(self.keys)

    def get_index(self, key):
        for i, k in enumerate(self.keys):
            if k == key:
                return i
        
        return None
    
    def __getitem__(self, index):
        if not hasattr(self, 'txn'):
            print("Opening lmdb txn")
            self.open_lmdb()
        key = self.keys[index]  # Get corresponding tuple
        label = self.labels[index]
        
        img_bytes = self.txn.get(key)
        
        if img_bytes is None:
            raise KeyError(f"Image {key} not found in LMDB!")

        if LMDB_USE_COMPRESSION:
            img_bytes = lz4.frame.decompress(img_bytes)

        image = np.array(msgpack.unpackb(img_bytes, raw=False), dtype=np.uint8)
        

        # Convert image to Tensor, and normalize [0,1]
        
        image = torch.from_numpy(image).float()

        if len(image.shape)==2:
            image = image.unsqueeze(0)
        elif len(image.shape)==3:
            image = image.permute(2, 0, 1)
        image = image / 255.0

        # Augmenation
        if self.transforms is not None:
            image = self.transforms(image)
            # print(f"image shaep after transforms: {image.shape}")
        if self.flatten_data:
            image = image.flatten().float()
            self.debug_msg = f"image shape {image.shape}"
        elif isinstance(image, np.ndarray):
            image = torch.unsqueeze(torch.from_numpy(image), 0)
            # print(image.shape)

        # Convert label tuple to Tensor
        x, y = label
        x = (x + 2) / 5.7
        y = (y + 2) / 4
        label = (x, y)
        label = torch.tensor(label, dtype=torch.float32)

        return image, label

In [5]:
tarr = []

if config["use_jitter_transform"]:
    tarr.append(
        T.ColorJitter(
            config["jitter_brightness"],
            config["jitter_contrast"],
            # config["jitter_saturation"],
            # config["jitter_hue"]
        )
    )

if config["use_noise_transform"]:
    tarr.append(
        T.GaussianNoise(sigma=config["noise_level"]),
    )

varr = []

if config["use_grayscale_transform"]:
    tarr.append(T.Grayscale())
    varr.append(T.Grayscale())

train_transforms = T.Compose(tarr)
val_transforms = T.Compose(varr) if len(varr)>0 else None


In [6]:
train_transforms

Compose(
      ColorJitter(brightness=(0.6, 1.4), contrast=(0.9, 1.1))
      GaussianNoise(mean=0.0, sigma=0.1, clip=True)
)

In [7]:
match config["dataset_type"]:
    case "LMDBImageDataset":
        train_dataset = LMDBImageDataset(config["data_folder"], transforms=train_transforms, flatten_data=config["dataset_config_flatten"], keys_fname=config["dataset_train_keys_fname"])
        val_dataset = LMDBImageDataset(config["data_folder"], transforms=val_transforms, flatten_data=config["dataset_config_flatten"], keys_fname=config["dataset_val_keys_fname"])

    case "InMemoryImageDataset":
        train_dataset = InMemoryLMDBImageDataset(config["data_folder"], transforms=train_transforms, flatten_data=config["dataset_config_flatten"], keys_fname=config["dataset_train_keys_fname"])
        val_dataset = InMemoryLMDBImageDataset(config["data_folder"], transforms=val_transforms, flatten_data=config["dataset_config_flatten"], keys_fname=config["dataset_val_keys_fname"])
    case _ :
        raise("Wrong dataset type")
train_data_loader = DataLoader(train_dataset, 
                         batch_size=config["batch_size"], 
                         shuffle=True, 
                         num_workers=8, 
                         pin_memory=True, 
                         prefetch_factor=4, 
                         persistent_workers=True
                        )
val_data_loader = DataLoader(val_dataset,
                             batch_size=config["batch_size"],
                             shuffle=False,
                             num_workers=4,
                             persistent_workers=True,
                             pin_memory=True
                            )

In [8]:
print(train_dataset[0][0].shape)
print(len(train_dataset))
print(train_dataset.transforms)

Opening lmdb txn
torch.Size([1, 512, 512])
11440
Compose(
      ColorJitter(brightness=(0.6, 1.4), contrast=(0.9, 1.1))
      GaussianNoise(mean=0.0, sigma=0.1, clip=True)
)


In [9]:
train_dataset.keys[0]

b'x0.40_y-1.96.jpg'

In [10]:
# sample, _ = train_dataset[train_dataset.get_index(b'x0.00_y0.00.jpg')]
# sample = sample.squeeze().cpu().numpy().reshape((512, 512))

# print(sample.shape)
# plt.imsave("x0.00_y0.00.jpg", sample, cmap="gray")

In [11]:
# fig, axes = plt.subplots(3,1)
# print("Original")
# axes[0].imshow(train_dataset[0][0].permute(1,2,0).numpy())
# axes[0].set_title(train_dataset.keys[0])
# axes[1].imshow(train_dataset[1][0].permute(1,2,0).numpy())
# axes[1].set_title(train_dataset.keys[1])
# axes[2].imshow(train_dataset[10][0].permute(1,2,0).numpy())
# axes[2].set_title(train_dataset.keys[10])
# plt.show()

# Model

In [25]:
def size_after_conv(input_size, kernel_size, stride, padding):
    return (input_size - kernel_size + 2 * padding) // stride + 1
conv_config = [
    {'out_channels':3, 'kernel_size':150, 'stride':5},
    {'out_channels':4, 'kernel_size':75, 'stride':3},
    {'out_channels':8, 'kernel_size':30, 'stride':3},
    {'out_channels':16, 'kernel_size':6, 'stride':2},
    {'out_channels':32, 'kernel_size':3, 'stride':2},
]

for l in conv_config:
    l['padding'] = l['kernel_size'] // 2


s = 512
print(s*s)
for l in conv_config:
    if l is not None:
        s = size_after_conv(s, l['kernel_size'], l['stride'], l['padding'])
        print(s*s*l['out_channels'])

262144
31827
4900
1152
784
512


In [26]:
config['conv_config'] = conv_config

In [27]:
class ConfigCNN(nn.Module):
    def __init__(self, output_size = 2, input_size=(1, 250, 250)):
        super(ConfigCNN, self).__init__()
        c, h, w = input_size
        layers = []
        prev_channels = c
        size = c * h * w
        for layer_config in conv_config:
            layers.append(
                nn.Conv2d(prev_channels, 
                          layer_config['out_channels'], 
                          layer_config['kernel_size'], 
                          layer_config['stride'], 
                          padding=layer_config['padding']
                          )
            )
            prev_channels = layer_config['out_channels']
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm2d(layer_config['out_channels']))

            h = size_after_conv(h, layer_config['kernel_size'], layer_config['stride'], layer_config['padding'])
            w = h
            c = layer_config['out_channels']
            size = c * h * w

        self.sec1 = nn.Sequential(
            *layers
        )

        self.sec2 = nn.Sequential(
            # nn.Linear(size, size//4),
            # nn.ReLU(),
            # nn.BatchNorm1d(size//4),
            nn.Linear(size, output_size),
        )

    def forward(self, x):
        x = self.sec1(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.sec2(x)

        return x


class SimpleFC(nn.Module):
    def __init__(self, in_features, out_features):
        super(SimpleFC, self).__init__()
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(
            nn.Linear(in_features, 1024), # 262,144 -> 1024
            nn.BatchNorm1d(1024),
            self.relu,
            nn.Linear(1024, 256),
            nn.BatchNorm1d(256),
            self.relu,
            nn.Linear(256, 32),
            nn.BatchNorm1d(32),
            self.relu,
            nn.Linear(32, out_features),
        )
    def forward(self, x):
        return self.layers.forward(x)
    
    

# model = SimpleFC(512*512, 2).to(DEVICE)
model = ConfigCNN(2, input_size=(1, 512, 512)).to(DEVICE)

if config["use_weight_initialization"]:
    for m in model.modules():
        if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d,
                          nn.Linear)):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)


summary(model, (1,512,512), config["batch_size"])
# summary(model, (512*512,), config["batch_size"])
        

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [400, 3, 103, 103]          67,503
              ReLU-2         [400, 3, 103, 103]               0
       BatchNorm2d-3         [400, 3, 103, 103]               6
            Conv2d-4           [400, 4, 35, 35]          67,504
              ReLU-5           [400, 4, 35, 35]               0
       BatchNorm2d-6           [400, 4, 35, 35]               8
            Conv2d-7           [400, 8, 12, 12]          28,808
              ReLU-8           [400, 8, 12, 12]               0
       BatchNorm2d-9           [400, 8, 12, 12]              16
           Conv2d-10            [400, 16, 7, 7]           4,624
             ReLU-11            [400, 16, 7, 7]               0
      BatchNorm2d-12            [400, 16, 7, 7]              32
           Conv2d-13            [400, 32, 4, 4]           4,640
             ReLU-14            [400, 3

- ansemble: 
- transfer learning

# Train


In [28]:
optimizer = optim.AdamW(model.parameters(), config["lr"], weight_decay=0.001)
criterion = nn.MSELoss()
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, config["lr_scheduler_loop"], eta_min=0.00001)
# scheduler = optim.lr_scheduler.ConstantLR(optimizer, 1, 0, )
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)
# scaler = torch.cuda.amp.GradScaler("cuda", enabled=config["use_amp"])

In [29]:
wandb.login(key="a41d74c58ab2f0d2c2bbdb317450ab14a8ad9d4e")
wandb.init(
    project="multireflection",
    name=config["experiment_name"],
    config=config,
    resume="allow",
)
wandb.watch(model, log='all', log_freq=100)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/evv/.netrc


In [30]:
from torch.amp import GradScaler, autocast

def train(model, train_loader, val_loader, optimizer: optim.Optimizer, criterion, scheduler: optim.lr_scheduler.CosineAnnealingWarmRestarts, best_loss=None):
    scaler = GradScaler(DEVICE)
    if best_loss is None:
        best_loss = 1000000000
    best_model = None
    for epoch in range(config['epochs']):
        model.train()
        running_loss = 0.0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            optimizer.zero_grad(set_to_none=True)

            with autocast("cuda", dtype=torch.float16, enabled=False):
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item()

        last_lr = scheduler.get_last_lr()[0]
        avg_train_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.inference_mode():
            for images, labels in tqdm(val_loader):
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                with autocast("cuda", dtype=torch.float16, enabled=False):
                    out = model(images)
                    loss = criterion(out, labels)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        scheduler.step()

        if avg_val_loss < best_loss:
            best_model = model
            best_loss = avg_val_loss
            save_model(model, fname=config["experiment_name"] + "_best_model.pth")

        print(f"Epoch {epoch + 1}/{config['epochs']}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        # ✅ Log Training Loss
        log_train_loss = log(avg_train_loss)
        log_val_loss = log(avg_val_loss)
        avg_total_loss = avg_train_loss * 0.8 + avg_val_loss * 0.2
        log_total_loss = log(avg_total_loss)
        wandb.log({
            "Train Loss": avg_train_loss,
            "Val Loss": avg_val_loss,
            "LR": last_lr,
            "best_loss": best_loss,
            "log_train_loss": log_train_loss,
            "log_val_loss": log_val_loss,
            "avg_total_loss": avg_total_loss,
            "log_total_loss": log_total_loss,
            # "ds_train_loaded": len(train_dataset.loaded_indexes),
            # "ds_val_loaded": len(val_dataset.loaded_indexes),
        })

    print("Best loss:", best_loss)
    return model, best_model, best_loss


In [31]:
if config["starting_checkpoint_fname"] is not None:
    model = load_model(model, fname=config["starting_checkpoint_fname"], path=config["checkpoint_folder"])

In [32]:
best_loss = None
# best_loss = 0.0115

In [33]:
model, best_model, best_loss = train(model, train_data_loader, val_data_loader, optimizer, criterion, scheduler, best_loss)

100%|██████████| 29/29 [01:45<00:00,  3.64s/it]
100%|██████████| 8/8 [00:13<00:00,  1.68s/it]


Epoch 1/28, Train Loss: 0.4513, Val Loss: 0.1645


100%|██████████| 29/29 [01:50<00:00,  3.80s/it]
100%|██████████| 8/8 [00:13<00:00,  1.65s/it]


Epoch 2/28, Train Loss: 0.0505, Val Loss: 0.0475


100%|██████████| 29/29 [01:52<00:00,  3.88s/it]
100%|██████████| 8/8 [00:15<00:00,  1.89s/it]


Epoch 3/28, Train Loss: 0.0340, Val Loss: 0.0570


100%|██████████| 29/29 [01:49<00:00,  3.79s/it]
100%|██████████| 8/8 [00:12<00:00,  1.61s/it]


Epoch 4/28, Train Loss: 0.0264, Val Loss: 0.0542


100%|██████████| 29/29 [01:47<00:00,  3.70s/it]
100%|██████████| 8/8 [00:14<00:00,  1.78s/it]


Epoch 5/28, Train Loss: 0.0234, Val Loss: 0.0509


100%|██████████| 29/29 [01:47<00:00,  3.71s/it]
100%|██████████| 8/8 [00:13<00:00,  1.73s/it]


Epoch 6/28, Train Loss: 0.0224, Val Loss: 0.0511


100%|██████████| 29/29 [01:55<00:00,  4.00s/it]
100%|██████████| 8/8 [00:11<00:00,  1.44s/it]


Epoch 7/28, Train Loss: 0.0201, Val Loss: 0.0502


100%|██████████| 29/29 [01:50<00:00,  3.83s/it]
100%|██████████| 8/8 [00:15<00:00,  1.88s/it]


Epoch 8/28, Train Loss: 0.0194, Val Loss: 0.0468


100%|██████████| 29/29 [01:50<00:00,  3.82s/it]
100%|██████████| 8/8 [00:22<00:00,  2.77s/it]


Epoch 9/28, Train Loss: 0.0195, Val Loss: 0.0471


100%|██████████| 29/29 [01:59<00:00,  4.11s/it]
100%|██████████| 8/8 [00:17<00:00,  2.19s/it]


Epoch 10/28, Train Loss: 0.0166, Val Loss: 0.0488


100%|██████████| 29/29 [02:00<00:00,  4.16s/it]
100%|██████████| 8/8 [00:16<00:00,  2.08s/it]


Epoch 11/28, Train Loss: 0.0146, Val Loss: 0.0378


100%|██████████| 29/29 [01:50<00:00,  3.80s/it]
100%|██████████| 8/8 [00:13<00:00,  1.64s/it]


Epoch 12/28, Train Loss: 0.0123, Val Loss: 0.0423


100%|██████████| 29/29 [01:49<00:00,  3.79s/it]
100%|██████████| 8/8 [00:17<00:00,  2.14s/it]


Epoch 13/28, Train Loss: 0.0111, Val Loss: 0.0393


100%|██████████| 29/29 [01:51<00:00,  3.86s/it]
100%|██████████| 8/8 [00:17<00:00,  2.23s/it]


Epoch 14/28, Train Loss: 0.0108, Val Loss: 0.0405


100%|██████████| 29/29 [01:49<00:00,  3.77s/it]
100%|██████████| 8/8 [00:12<00:00,  1.62s/it]


Epoch 15/28, Train Loss: 0.0123, Val Loss: 0.0391


100%|██████████| 29/29 [01:51<00:00,  3.84s/it]
100%|██████████| 8/8 [00:11<00:00,  1.38s/it]


Epoch 16/28, Train Loss: 0.0106, Val Loss: 0.0381


100%|██████████| 29/29 [01:51<00:00,  3.83s/it]
100%|██████████| 8/8 [00:17<00:00,  2.24s/it]


Epoch 17/28, Train Loss: 0.0112, Val Loss: 0.0372


100%|██████████| 29/29 [01:51<00:00,  3.85s/it]
100%|██████████| 8/8 [00:17<00:00,  2.20s/it]


Epoch 18/28, Train Loss: 0.0096, Val Loss: 0.0374


100%|██████████| 29/29 [01:55<00:00,  3.97s/it]
100%|██████████| 8/8 [00:12<00:00,  1.51s/it]


Epoch 19/28, Train Loss: 0.0086, Val Loss: 0.0369


100%|██████████| 29/29 [01:52<00:00,  3.89s/it]
100%|██████████| 8/8 [00:14<00:00,  1.83s/it]


Epoch 20/28, Train Loss: 0.0073, Val Loss: 0.0355


100%|██████████| 29/29 [01:51<00:00,  3.86s/it]
100%|██████████| 8/8 [00:15<00:00,  1.88s/it]


Epoch 21/28, Train Loss: 0.0073, Val Loss: 0.0357


100%|██████████| 29/29 [01:49<00:00,  3.76s/it]
100%|██████████| 8/8 [00:15<00:00,  1.89s/it]


Epoch 22/28, Train Loss: 0.0097, Val Loss: 0.0451


100%|██████████| 29/29 [01:50<00:00,  3.81s/it]
100%|██████████| 8/8 [00:13<00:00,  1.64s/it]


Epoch 23/28, Train Loss: 0.0093, Val Loss: 0.0337


100%|██████████| 29/29 [01:50<00:00,  3.82s/it]
100%|██████████| 8/8 [00:13<00:00,  1.72s/it]


Epoch 24/28, Train Loss: 0.0071, Val Loss: 0.0394


100%|██████████| 29/29 [02:02<00:00,  4.21s/it]
100%|██████████| 8/8 [00:15<00:00,  1.99s/it]


Epoch 25/28, Train Loss: 0.0078, Val Loss: 0.0340


100%|██████████| 29/29 [01:48<00:00,  3.73s/it]
100%|██████████| 8/8 [00:11<00:00,  1.43s/it]


Epoch 26/28, Train Loss: 0.0062, Val Loss: 0.0301


100%|██████████| 29/29 [03:06<00:00,  6.44s/it]
100%|██████████| 8/8 [00:23<00:00,  2.91s/it]


Epoch 27/28, Train Loss: 0.0056, Val Loss: 0.0316


100%|██████████| 29/29 [02:20<00:00,  4.86s/it]
100%|██████████| 8/8 [00:13<00:00,  1.63s/it]


Epoch 28/28, Train Loss: 0.0055, Val Loss: 0.0313
Best loss: 0.030137872556224465


In [34]:
# train_dataset.debug_msg

In [35]:
wandb.finish()

0,1
LR,██▇▅▃▂▁██▇▅▃▂▁██▇▅▃▂▁██▇▅▃▂▁
Train Loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Val Loss,█▂▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁
avg_total_loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
best_loss,█▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
log_total_loss,█▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▂▁▁▁▁▁
log_train_loss,█▅▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▂▂▁▂▁▁▁
log_val_loss,█▃▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▃▁▂▂▁▁▁

0,1
LR,6e-05
Train Loss,0.00554
Val Loss,0.0313
avg_total_loss,0.01069
best_loss,0.03014
log_total_loss,-4.53856
log_train_loss,-5.19626
log_val_loss,-3.46429
