In [4]:
import sys
sys.path.append('/Image-Dehazing/src')

import h5py
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from transmission_model import TransmissionModel

In [2]:
file = '/home/shkraboom/Рабочий стол/Image Dehazing/train_data.hdf5'

train_dataset = h5py.File(file, 'r')

haze_images = np.array(train_dataset['haze_image'])
clear_images = np.array(train_dataset['clear_image'])
transmission_value = np.array(train_dataset['transmission_value'])

transmission_value = np.expand_dims(transmission_value, axis=(1, 2, 3))

In [3]:
model = TransmissionModel()
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr = 0.001)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer = optimizer, step_size = 10, gamma = 0.1)
num_epochs = 100

In [4]:
def set_seed(seed):
    """
    Function for seed setting
    """

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def train_trans_model(num_epochs, model, loss_fn, optimizer, haze_images, transmission_value, device = 'cuda', seed = 42, lr_scheduler = None):
    """
    Function for training transmission model
    """

    set_seed(seed)

    model.to(device)

    haze_images_tensor = torch.tensor(haze_images.transpose(0, 3, 1, 2)).float().to(device)
    transmission_value_tensor = torch.tensor(transmission_value).float().to(device)

    dataset = TensorDataset(haze_images_tensor, transmission_value_tensor)
    batches = DataLoader(dataset = dataset, batch_size = 30, shuffle = True)

    for epoch in range(num_epochs):
        model.train()

        loss_epoch = 0.0

        for batch_idx, batch in enumerate(batches):
            haze_images_batch, transmission_value_batch = batch

            loss_train = 0.0

            optimizer.zero_grad()

            predict = model(haze_images_batch)

            loss = loss_fn(transmission_value_batch, predict)

            loss.backward()

            optimizer.step()

            loss_epoch += loss.item()
        
        print(f'Epoch: {epoch + 1} Loss epoch: {loss_epoch / len(batches)}')

        if lr_scheduler:
            lr_scheduler.step()


In [None]:
train_trans_model(num_epochs = num_epochs, model = model, loss_fn = loss_fn, optimizer = optimizer, haze_images = haze_images, transmission_value = transmission_value, lr_scheduler = lr_scheduler)

In [142]:
torch.save(model.state_dict(), 'trans_model_weights.pt')