In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from tqdm.auto import tqdm

In [None]:
def calculate_delta_E(Lab1, Lab2):
    return np.sqrt(np.sum((Lab2 - Lab1) ** 2))

In [None]:
class LabDataset(Dataset):
    def __init__(self, npy_file):
        self.data = np.load(npy_file, allow_pickle=True)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        mp = self.data[idx, :3]  
        info1 = self.data[idx, 3:4]
        info2 = self.data[idx, 4:5]
        sp_half = self.data[idx, 6:49]
        
        inverse_input_sp = np.concatenate([self.data[idx, 6:49], self.data[idx, 49:]], axis=0)
        forword_input_mp = np.concatenate([mp, info1, info2], axis=0)
        forword_target_color = self.data[idx, 49:]

        return torch.tensor(inverse_input_sp, dtype=torch.float32), \
               torch.tensor(forword_input_mp, dtype=torch.float32), \
               torch.tensor(forword_target_color, dtype=torch.float32)

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(0)  

In [None]:
class InverseDesignNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(3, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 5),
        )
        self._initialize_weights()

    def forward(self, x):
        return self.layers(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

class ForwardModelingNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(5, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 3)
        )
        self._initialize_weights()

    def forward(self, x):
        return self.layers(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


class TandemNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.inverse_design = InverseDesignNetwork()
        self.forward_modeling = ForwardModelingNetwork()

    def forward(self, x):
        inverse_output_mp = self.inverse_design(x)
        forward_output_color = self.forward_modeling(inverse_output_mp)
        
        return inverse_output_mp, forward_output_color

In [None]:
def delta_e_loss(pred, target):
    return torch.sqrt(torch.sum((pred - target) ** 2, dim=-1))

In [None]:
all_results = []

for run_time in tqdm(range(5)):
    set_seed(run_time)  

    epochs = 4000

    train_dataset = LabDataset(f'dataset/npy/raw_random_train_sp_color.npy') #1013_random_train_sp_color
    val_dataset = LabDataset(f'dataset/npy/raw_random_val_sp_color.npy')

    train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False)

    model = TandemNetwork()

    # Load pre-trained weights for the forward modeling network
    model.forward_modeling.load_state_dict(torch.load('best_forward_color_.pth')) # best_forward_color_delta_e_3.499 #best_forward_color_delta_mp12
    model = model.cuda()
    
    # Freeze the forward modeling network
    for param in model.forward_modeling.parameters():
        param.requires_grad = False

    loss_fn = nn.MSELoss().cuda()
    optimizer = optim.AdamW(model.parameters(), lr=5e-3, weight_decay=1e-3)
    lr_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs*0.7, eta_min=5e-4)

    best_val_loss = float('inf')

    for epoch in tqdm(range(epochs)): 
        model.train()  
        for inverse_input_sp, forword_input_mp, forword_target_color in train_dataloader:
            inverse_input_sp = inverse_input_sp.cuda()
            forword_input_mp = forword_input_mp.cuda()
            forword_target_color = forword_target_color.cuda()
            
            # Forward pass
            inverse_output_mp, forward_output_color = model(forword_target_color)         
            # loss = loss_fn(forward_output_color, forword_target_color) 
            loss = delta_e_loss(forward_output_color, forword_target_color).mean()

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()  # Set the model to evaluation mode
        with torch.no_grad(): 
            val_losses = []
            for inverse_input_sp, forword_input_mp, forword_target_color in val_dataloader:
                inverse_input_sp = inverse_input_sp.cuda()
                forword_input_mp = forword_input_mp.cuda()
                forword_target_color = forword_target_color.cuda()
            
                # Forward pass
                inverse_output_mp, forward_output_color = model(forword_target_color)         
                # loss = loss_fn(forward_output_color, forword_target_color) 
                loss = delta_e_loss(forward_output_color, forword_target_color).mean()
                
                val_losses.append(loss.item())

        avg_val_loss = sum(val_losses) / len(val_losses)
        lr_schedule.step()

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), f'best_tandem_{run_time}.pth')  # save model weights

        if epoch % 100 == 0:
            print(f'Epoch {epoch}, Training Loss: {loss.item()}, Validation Loss: {avg_val_loss}, Best Loss: {best_val_loss}')

    print(f'Best Loss: {best_val_loss}')

    test_data = np.load(f'dataset/npy/raw_random_test_sp_color.npy')

    mp = test_data[:, :3]  
    info1 = test_data[:, 3:4]
    info2 = test_data[:, 4:5]

    inverse_input_sp = torch.tensor(test_data[:, 6:][:, ::2], dtype=torch.float32)
    inverse_input_sp = torch.tensor(np.concatenate([test_data[:, 6:49], test_data[:, 49:]], axis=1), dtype=torch.float32)

    forword_input_mp = torch.tensor(np.concatenate([mp, info1, info2], axis=1), dtype=torch.float32)
    forword_target_color = torch.tensor(test_data[:, 49:], dtype=torch.float32)

    test_dataset = TensorDataset(inverse_input_sp, forword_input_mp, forword_target_color)
    test_dataloader = DataLoader(test_dataset, batch_size=1)

    model = TandemNetwork()
    model.load_state_dict(torch.load(f'best_tandem_{run_time}.pth'))
    model.forward_modeling.load_state_dict(torch.load('best_forward_color.pth'))
    model.eval()
    

    all_delta_e = 0
    all_mae = 0

    for inverse_input_sp, forword_input_mp, forword_target_color in test_dataloader:
        with torch.no_grad():
            inverse_output_mp, forward_output_color = model(forword_target_color) 

            predicted_color = forward_output_color[0].numpy()
            target_color = forword_target_color[0].numpy()

            delta_e = calculate_delta_E(predicted_color, target_color)

            all_delta_e += delta_e

    print(f'Test Delta E {run_time}: {np.round(all_delta_e / len(test_dataloader), 2)}')
    
    all_results.append(np.round(all_delta_e / len(test_dataloader), 2))
    print('-----------------------------------------------')
    
print(f'All: {all_results}')
print(f'Average: {np.round(np.mean(all_results), 2)}')
print(f'STD: {np.round(np.std(all_results), 2)}')
print(f'Min: {np.min(all_results)}')
print(f'Max: {np.max(all_results)}')