In [None]:
import torch
import os, time, numbers, shutil
from torch import nn
import torch.optim as optim
import torch_optimizer as optim_ex
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torchvision import transforms, utils
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Function
import random
import math
import warnings

import seaborn as sns
import pandas as pd
import numpy as np
import scipy.signal
from sklearn.metrics import f1_score, confusion_matrix
import matplotlib.pyplot as plt


%matplotlib inline

In [None]:
%%capture
from tqdm.notebook import tqdm

In [None]:
para_info = torch.__config__.parallel_info()
print(para_info)

In [None]:
igpu = 0
batch_size = 512
lr = 0.01
max_epoch = 2000

n_coord_num = 3
n_subclasses = 3
nstyle = 2

sch_factor = 0.25
sch_patience = 300
alpha_flat_step = 250
alpha_limit = 1.0

lr_ratio_Reconn = 2.0
lr_ratio_Mutual = 3.0
lr_ratio_Smooth = 0.1
lr_ratio_Supervise = 2.0
lr_ratio_Style = 0.5
lr_ratio_CR = 0.5
lr_ratio_domain = 1.0
match_target_dist = True
include_transfer_learning = True

train_ratio = 0.7
validation_ratio = 0.15
test_ratio = 0.15
sampling_exponent = 0.6

variable_list_before_papermill_injection = set(locals().keys())

In [None]:
variable_list_after_papermill_injection = set(locals().keys())
excess_variables = variable_list_after_papermill_injection - variable_list_before_papermill_injection
for v in list(excess_variables):
    if v.startswith("_"):
        excess_variables.remove(v)
excess_variables.remove("variable_list_before_papermill_injection")
assert len(excess_variables) == 0, f"unexpected parameters:{excess_variables}"   
print(f"This notebook will use GPU:{igpu} if available")

In [None]:
feff_cn_spec_df = pd.read_csv("ti_feff_cn_spec.csv", index_col=[0,1])
n_feff_training_samples = int(len(feff_cn_spec_df) * train_ratio)
n_feff_validation_samples = int(len(feff_cn_spec_df) * validation_ratio)
n_feff_test_samples = len(feff_cn_spec_df) - n_feff_training_samples - n_feff_validation_samples
assert math.fabs(n_feff_test_samples - int(len(feff_cn_spec_df) * test_ratio)) < 3
cn_feff_nums =feff_cn_spec_df.to_numpy()[:n_feff_training_samples, :n_coord_num].sum(axis=0)
print("Number of samples", cn_feff_nums)
cn_sampling_weights =  1.0 / feff_cn_spec_df.to_numpy()[:n_feff_training_samples, :n_coord_num].mean(axis=0)
print("Raw sampling weights", cn_sampling_weights)
cn_sampling_weights **= sampling_exponent
cn_sampling_weights /= cn_sampling_weights.sum()
print("Sampling weights", cn_sampling_weights)
print("Sample numbers per epoch", (cn_sampling_weights*cn_feff_nums * (cn_feff_nums.sum()/(cn_sampling_weights*cn_feff_nums).sum())).astype('int'))
feff_cn_spec_df[:3]

In [None]:
if include_transfer_learning:
    xspectra_cn_spec_df = pd.read_csv("ti_xspectra_cn_spec.csv", index_col=[0,1])
    n_xspectra_training_samples = int(len(xspectra_cn_spec_df) * train_ratio)

    n_xspectra_test_samples = len(xspectra_cn_spec_df) - n_xspectra_training_samples
    assert math.fabs(n_xspectra_test_samples - int(len(xspectra_cn_spec_df) * (test_ratio+validation_ratio))) < 3
    cn_xspectra_nums = xspectra_cn_spec_df.to_numpy()[:n_xspectra_training_samples, :n_coord_num].sum(axis=0)

    print("XSpectra coord number counts", cn_xspectra_nums)
    print("FEFF coord number counts", cn_feff_nums)
    if match_target_dist:
        domain_sampling_weights = cn_xspectra_nums / cn_feff_nums
    else:
        domain_sampling_weights = np.ones(n_coord_num)
    domain_sampling_weights /= domain_sampling_weights.sum()
    print("Domain sampling weights", domain_sampling_weights)

In [None]:
class CoordNumSpectraDataset(Dataset):
    def __init__(self, df, n_coord_num=3, transform=None):
        assert df.columns.to_list()[:n_coord_num+1] == ['CN_4', 'CN_5', 'CN_6', 'ENE_4965.000']
        data = df.to_numpy()
        self.cn = data[:, :n_coord_num]
        self.spec = data[:, n_coord_num:]
        self.transform = transform
        
    def __len__(self):
        return self.cn.shape[0]
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = self.spec[idx], self.cn[idx]
        if self.transform is not None:
            sample = [self.transform(x) for x in sample]
        return sample
    
class ToTensor(object):
    def __call__(self, sample):
        return torch.Tensor(sample)

transform_list = transforms.Compose([ToTensor()])
dataset_feff_train = CoordNumSpectraDataset(feff_cn_spec_df[:n_feff_training_samples], 
                                            n_coord_num=n_coord_num, transform=transform_list)
dataset_feff_val = CoordNumSpectraDataset(feff_cn_spec_df[n_feff_training_samples:n_feff_training_samples+n_feff_validation_samples], 
                                          n_coord_num=n_coord_num, transform=transform_list)
dataset_feff_test = CoordNumSpectraDataset(feff_cn_spec_df[-n_feff_test_samples:], 
                                           n_coord_num=n_coord_num, transform=transform_list)
if include_transfer_learning:
    dataset_xspectra_train = CoordNumSpectraDataset(xspectra_cn_spec_df[:n_xspectra_training_samples], 
                                                    n_coord_num=n_coord_num, transform=transform_list)
    dataset_xspectra_test = CoordNumSpectraDataset(xspectra_cn_spec_df[-n_xspectra_test_samples:], 
                                                   n_coord_num=n_coord_num, transform=transform_list)
    print(len(dataset_xspectra_train), len(dataset_xspectra_test))

len(dataset_feff_train), len(dataset_feff_val), len(dataset_feff_test)

In [None]:
feff_train_max_cn_list = [np.argmax(dataset_feff_train[i][1].numpy()) for i in range(len(dataset_feff_train))]
feff_val_max_cn_list = [np.argmax(dataset_feff_val[i][1].numpy()) for i in range(len(dataset_feff_val))]
cn_train_sample_weights = [cn_sampling_weights[cn] for cn in feff_train_max_cn_list]
if include_transfer_learning:
    domain_train_sample_weights = [domain_sampling_weights[cn] for cn in feff_train_max_cn_list]
cn_val_sample_weights = [cn_sampling_weights[cn] for cn in feff_val_max_cn_list]

cn_train_sampler = WeightedRandomSampler(cn_train_sample_weights, replacement=True,
                                         num_samples=math.ceil(len(dataset_feff_train)/batch_size)*batch_size)
if include_transfer_learning:
    domain_train_sampler = WeightedRandomSampler(domain_train_sample_weights, replacement=True,
                                                 num_samples=math.ceil(len(dataset_feff_train)/batch_size)*batch_size)

In [None]:
cn_train_feff_loader = DataLoader(dataset_feff_train, 
                                  batch_size=batch_size,
                                  sampler=cn_train_sampler,
                                  num_workers=0, pin_memory=False)
if include_transfer_learning:
    domain_train_feff_loader = DataLoader(dataset_feff_train, 
                                          batch_size=batch_size,
                                          sampler=domain_train_sampler,
                                          num_workers=0, pin_memory=False)

    domain_train_xspectra_loader = DataLoader(dataset_xspectra_train, 
                                              batch_size=math.ceil(len(dataset_xspectra_train)/len(domain_train_feff_loader)),
                                              shuffle=True, num_workers=0, pin_memory=False)

cn_val_loader = DataLoader(dataset_feff_val, batch_size=batch_size,
                           num_workers=0, pin_memory=False)

In [None]:
class EncodingBlock(nn.Module):

    def __init__(self, in_channels, out_channels, in_len, out_len, kernel_size=7, stride=2, excitation=4, dropout_rate=0.2):
        super(EncodingBlock, self).__init__()
        if in_channels > 1:
            self.bn1 = nn.BatchNorm1d(in_channels, affine=False)    
        else:
            self.bn1 = None
        self.relu1 = nn.PReLU(num_parameters=out_channels, init=0.01)
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size-1)//2, padding_mode='replicate')
        self.bn2 = nn.BatchNorm1d(out_channels, affine=False)
        self.relu2 = nn.PReLU(num_parameters=out_channels, init=0.01)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size-1)//2, stride=stride)
        
        if in_len > 10:
            self.dropout_1 = nn.Dropout(p=dropout_rate)
        else:
            self.dropout_1 = None
        self.fc1 = nn.Linear(in_len, excitation)
        self.relu_excit_1 = nn.PReLU(num_parameters=in_channels, init=0.01)
        self.fc2 = nn.Linear(excitation, out_len)
        self.relu_excit_2 = nn.PReLU(num_parameters=in_channels, init=0.01)
        if in_channels != out_channels:
            self.bn_excit = nn.BatchNorm1d(in_channels, affine=False)
            self.relu_excit_3 = nn.PReLU(num_parameters=out_channels, init=0.01)
            self.conv_excit = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, groups=math.gcd(in_channels, out_channels))
        else:
            self.bn_excit = None
            self.relu_excit_3 = None
            self.conv_excit = None        
        
        if stride > 1 or (in_channels != out_channels):
            self.conv_short = nn.Conv1d(in_channels, out_channels, kernel_size=stride, stride=stride, groups=math.gcd(in_channels, out_channels))
            self.relu_short = nn.PReLU(num_parameters=out_channels, init=0.01)
        else:
            self.conv_short = None
            
        self.fc_domain = nn.Linear(1, excitation*in_channels)

        
    def forward(self, x, domain=0):
        
        if self.bn1 is not None:
            out = self.bn1(x)  
        else:
            out = x
        residual = out
        out = self.conv1(out)
        out = self.relu1(out)
        
        out = self.bn2(out)
        out = self.conv2(out)
        out = self.relu2(out)
        
        if self.conv_short is not None:
            res = self.conv_short(residual)
            res = self.relu_short(res)
        else:
            res = residual
        
        if self.dropout_1 is not None:
            excit = self.dropout_1(residual)
        else:
            excit = residual
        excit = self.fc1(excit)
        excit = self.relu_excit_1(excit)
        
        domain = torch.full([excit.size()[0], 1], fill_value=domain, requires_grad=False, dtype=excit.dtype, device=excit.device)
        domain_corr = self.fc_domain(domain).reshape(*excit.size())
        excit += domain_corr
        
        excit = self.fc2(excit)
        excit = self.relu_excit_2(excit)
        if self.conv_excit is not None:
            excit = self.bn_excit(excit)
            excit = self.conv_excit(excit)
            excit = self.relu_excit_3(excit)
            
        out = out + res + excit
        return out

t = torch.ones((32, 1, 256))
eb = EncodingBlock(1, 2, 256, 128, kernel_size=11, stride=2)
eb(t).shape

In [None]:
class DecodingBlock(nn.Module):

    def __init__(self, in_channels, out_channels, in_len, excitation=4, dropout_rate=0.2):
        super(DecodingBlock, self).__init__()
        out_len = in_len * 4
        self.bn1 = nn.BatchNorm1d(in_channels, affine=False)
        self.relu1 = nn.PReLU(num_parameters=out_channels, init=0.01)
        self.conv1 = nn.ConvTranspose1d(in_channels, out_channels, kernel_size=2, stride=2)
        self.bn2 = nn.BatchNorm1d(out_channels, affine=False)
        self.relu2 = nn.PReLU(num_parameters=out_channels, init=0.01)
        self.conv2 = nn.ConvTranspose1d(out_channels, out_channels, kernel_size=2, stride=2)
        
        if in_len > 10:
            self.dropout_1 = nn.Dropout(p=dropout_rate)
        else:
            self.dropout_1 = None
        self.fc1 = nn.Linear(in_len, excitation)
        self.relu_excit_1 = nn.PReLU(num_parameters=in_channels, init=0.01)
        self.fc2 = nn.Linear(excitation, out_len)
        self.relu_excit_2 = nn.PReLU(num_parameters=in_channels, init=0.01)
        if in_channels != out_channels:
            self.bn_excit = nn.BatchNorm1d(in_channels, affine=False)
            self.relu_excit_3 = nn.PReLU(num_parameters=out_channels, init=0.01)
            self.conv_excit = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, groups=math.gcd(in_channels, out_channels))
        else:
            self.bn_excit = None
            self.relu_excit_3 = None
            self.conv_excit = None
               
        self.conv_short = nn.ConvTranspose1d(in_channels, out_channels, kernel_size=4, stride=4, groups=math.gcd(in_channels, out_channels))
        self.relu_short = nn.PReLU(num_parameters=out_channels, init=0.01)

        
    def forward(self, x):
        out = self.bn1(x)
        residual = out
        out = self.conv1(out)
        out = self.relu1(out)
        
        out = self.bn2(out)
        out = self.conv2(out)
        out = self.relu2(out)

        res = self.conv_short(residual)
        res = self.relu_short(res)
        
        if self.dropout_1 is not None:
            excit = self.dropout_1(residual)
        else:
            excit = residual
        excit = self.fc1(excit)
        excit = self.relu_excit_1(excit)
        excit = self.fc2(excit)
        excit = self.relu_excit_2(excit)
        if self.conv_excit is not None:
            excit = self.bn_excit(excit)
            excit = self.conv_excit(excit)
            excit = self.relu_excit_3(excit)
            
        out = out + res + excit
        return out

t = torch.ones((32, 14, 1))
eb = DecodingBlock(14, 8, 1, excitation=1)
eb(t).shape

In [None]:
class GaussianSmoothing(nn.Module):
    def __init__(self, channels, kernel_size, sigma, dim=2, device='cpu'):
        super(GaussianSmoothing, self).__init__()
        if isinstance(kernel_size, numbers.Number):
            kernel_size = [kernel_size] * dim
        if isinstance(sigma, numbers.Number):
            sigma = [sigma] * dim

        # The gaussian kernel is the product of the
        # gaussian function of each dimension.
        kernel = 1
        meshgrids = torch.meshgrid(
            [
                torch.arange(size, dtype=torch.float32)
                for size in kernel_size
            ]
        )      
        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
            mean = (size - 1) / 2
            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \
                      torch.exp(-((mgrid - mean) / std) ** 2 / 2)

        # Make sure sum of values in gaussian kernel equals 1.
        kernel = kernel / torch.sum(kernel)

        # Reshape to depthwise convolutional weight
        kernel = kernel.view(1, 1, *kernel.size())
        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))

        self.register_buffer('weight', kernel.to(device))
        self.groups = channels 
        

    def forward(self, input):
        """
        Apply gaussian filter to input.
        Arguments:
            input (torch.Tensor): Input to apply gaussian filter on.
        Returns:
            filtered (torch.Tensor): Filtered output.
        """
        
        if len(input.size()) - 2 == 1:
            conv = nn.functional.conv1d
        elif len(input.size()) - 2 == 2:
            conv = nn.functional.conv2d
        elif len(input.size()) - 2 == 3:
            conv = nn.functional.conv3d
        else:
            raise RuntimeError(
                'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim)
            )
        return conv(input, weight=self.weight, groups=self.groups)

t = torch.rand((3, 256))
sns.set_palette("husl", t.size()[0]*2)
for spec in t:
    plt.plot(spec, lw=1.0)
    
sm = GaussianSmoothing(channels=1, kernel_size=17, sigma=3.0, dim=1)
t = t.unsqueeze(dim=1)
t = nn.functional.pad(t, (8, 8), mode='replicate')
t = sm(t).squeeze(dim=1)
plt.figure()
for spec in t:
    plt.plot(spec, lw=1.0)

In [None]:
class mySequential(nn.Sequential):
    def forward(self, spec, domain):
        output = spec
        for module in self._modules.values():
            output = module(output, domain)
        return output
    
class Q_net(nn.Module):
    ''' front end part of discriminator and Q'''

    def __init__(self, dropout_rate=0.2, nclasses=12, nstyle=2):
        super(Q_net, self).__init__()

        self.main = mySequential(
            EncodingBlock(in_channels=1, out_channels=4, in_len=256, out_len=128, kernel_size=11, stride=2, excitation=4, dropout_rate=dropout_rate),
            EncodingBlock(in_channels=4, out_channels=4, in_len=128, out_len=64, kernel_size=11, stride=2, excitation=4, dropout_rate=dropout_rate),
            EncodingBlock(in_channels=4, out_channels=4, in_len=64, out_len=32, kernel_size=7, stride=2, excitation=2, dropout_rate=dropout_rate),
            EncodingBlock(in_channels=4, out_channels=4, in_len=32, out_len=16, kernel_size=7, stride=2, excitation=2, dropout_rate=dropout_rate),
            EncodingBlock(in_channels=4, out_channels=4, in_len=16, out_len=8, kernel_size=5, stride=2, excitation=1, dropout_rate=dropout_rate) 
        )
        self.lin1 = nn.Linear(32, nclasses)
        self.lin3 = nn.Linear(32, nstyle)

    def forward(self, spec, domain=0):
        batch_size = spec.size()[0]
        output = spec.unsqueeze(dim=1)
        output = self.main(output, domain)
        output = output.reshape(batch_size, 32)
        
        z_gauss = self.lin3(output)
        y = nn.functional.softmax(self.lin1(output), dim=1)
        
        return z_gauss, y
    
t = torch.ones((64, 256))
eb = Q_net()
[x.size() for x in eb(t)]

In [None]:
class P_net(nn.Module):

    def __init__(self, dropout_rate=0.2,nclasses=12, nstyle=2, gau_kernel_size=11, sigma=2.0, device='cpu'):
        super(P_net, self).__init__()

        self.main = nn.Sequential(
            DecodingBlock(in_channels=nclasses+nstyle, out_channels=8, in_len=1, excitation=1, dropout_rate=dropout_rate), 
            DecodingBlock(in_channels=8, out_channels=4, in_len=4, excitation=2, dropout_rate=dropout_rate), 
            DecodingBlock(in_channels=4, out_channels=4, in_len=16, excitation=2, dropout_rate=dropout_rate), 
            DecodingBlock(in_channels=4, out_channels=4, in_len=64, excitation=4, dropout_rate=dropout_rate), 
            EncodingBlock(in_channels=4, out_channels=4, in_len=256, out_len=256, kernel_size=11, stride=1, excitation=2, dropout_rate=dropout_rate),
            EncodingBlock(in_channels=4, out_channels=4, in_len=256, out_len=256, kernel_size=11, stride=1, excitation=2, dropout_rate=dropout_rate),
            EncodingBlock(in_channels=4, out_channels=2, in_len=256, out_len=256, kernel_size=11, stride=1, excitation=2, dropout_rate=dropout_rate),
            EncodingBlock(in_channels=2, out_channels=2, in_len=256, out_len=256, kernel_size=11, stride=1, excitation=2, dropout_rate=dropout_rate),
            EncodingBlock(in_channels=2, out_channels=2, in_len=256, out_len=256, kernel_size=11, stride=1, excitation=2, dropout_rate=dropout_rate),
            nn.BatchNorm1d(2, affine=False),
            nn.Conv1d(2, 1, kernel_size=1, stride=1),
            nn.Softplus(beta=2)   #,
            #nn.ReplicationPad1d(padding=(gau_kernel_size-1)//2),
            #GaussianSmoothing(channels=1, kernel_size=gau_kernel_size, sigma=sigma, dim=1, device=device)
        )
        
        self.nclasses = nclasses
        self.nstyle = nstyle

    def forward(self, z_gauss, y):
        assert z_gauss.size()[1] == self.nstyle
        assert y.size()[1] == self.nclasses
        x = torch.cat([z_gauss, y], dim=1)
        x = x.unsqueeze(dim=2)
        spec = self.main(x)
        spec = spec.squeeze(dim=1)
        return spec
    
    
tz, ty = torch.ones(32, 2), torch.ones(32, 12)
eb = P_net()
eb(tz, ty).size()

In [None]:
class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        #ctx.save_for_backward(x)
        ctx.alpha = alpha
        return x

    @staticmethod
    def backward(ctx, grad_output):  
        grad_input = grad_output.clone()
        grad_input = grad_output.neg() * ctx.alpha
        return grad_input, None
    
    
if include_transfer_learning:    
    class Domain_classifier_net(nn.Module):
        def __init__(self, encoder, hiden_size=50, dropout_rate=0.2, nclasses=12, nstyle=2):
            super(Domain_classifier_net, self).__init__()
            self.latent_size = nclasses + nstyle
            self.main = nn.Sequential(
                nn.Linear(self.latent_size, hiden_size),
                nn.PReLU(num_parameters=hiden_size, init=0.01),

                nn.BatchNorm1d(hiden_size, affine=False),
                nn.Dropout(p=dropout_rate),
                nn.Linear(hiden_size, hiden_size),
                nn.PReLU(num_parameters=hiden_size, init=0.01),

                nn.BatchNorm1d(hiden_size, affine=False),
                nn.Dropout(p=dropout_rate),
                nn.Linear(hiden_size, hiden_size),
                nn.PReLU(num_parameters=hiden_size, init=0.01),

                nn.BatchNorm1d(hiden_size, affine=False),
                nn.Dropout(p=dropout_rate),
                nn.Linear(hiden_size, hiden_size),
                nn.PReLU(num_parameters=hiden_size, init=0.01),

                nn.BatchNorm1d(hiden_size, affine=False),
                nn.Dropout(p=dropout_rate),
                nn.Linear(hiden_size, 2),
                nn.LogSoftmax(dim=1)
            )
            self.encoder = encoder


        def forward(self, x, alpha, domain):
            batch_size = x.size()[0]
            feature = self.encoder(x, domain)
            feature = torch.cat(feature, dim=1)
            reverse_feature = ReverseLayerF.apply(feature, alpha)
            reverse_feature = reverse_feature.reshape(batch_size, self.latent_size)
            domain_output = self.main(reverse_feature)
            return domain_output

    t = torch.ones((64, 256))
    eb = Domain_classifier_net(Q_net())
    print(eb(t, 0.3, 0).size())

In [None]:
class D_net_gauss(nn.Module):
    def __init__(self, hiden_size=50, dropout_rate=0.2, nstyle=2):
        super(D_net_gauss, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(nstyle, hiden_size),
            nn.PReLU(num_parameters=hiden_size, init=0.01),
            
            nn.BatchNorm1d(hiden_size, affine=False),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hiden_size, hiden_size),
            nn.PReLU(num_parameters=hiden_size, init=0.01),
            
            nn.BatchNorm1d(hiden_size, affine=False),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hiden_size, hiden_size),
            nn.PReLU(num_parameters=hiden_size, init=0.01),
            
            nn.BatchNorm1d(hiden_size, affine=False),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hiden_size, hiden_size),
            nn.PReLU(num_parameters=hiden_size, init=0.01),
            
            nn.BatchNorm1d(hiden_size, affine=False),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hiden_size, 2),
            nn.LogSoftmax(dim=1)
        )
        self.nstyle = nstyle


    def forward(self, x, alpha):
        reverse_feature = ReverseLayerF.apply(x, alpha)
        out = self.main(reverse_feature)
        return out

t = torch.ones((64, 2))
eb = D_net_gauss()
eb(t, 0.3).size()

In [None]:
class FakeDualAAE(nn.Module):
    def __init__(self):
        super(FakeDualAAE, self).__init__()
        self.q = Q_net()
        self.p = P_net()
        self.d = D_net_gauss()

    def forward(self, x):
        z, y = self.q(x)
        x2 = self.p(z, y)
        is_gau = self.d(z, 0.3)
        return x2, is_gau
    
t = torch.ones((64, 256))
eb = FakeDualAAE()
eb(t)[0].size()

In [None]:
warnings.filterwarnings("ignore")

class Trainer:

    def __init__(self, Q, P, D_gauss, Dann, device, cn_train_feff_loader, cn_val_loader, 
                 domain_train_feff_loader, domain_train_xspectra_loader, 
                 val_cn_weights, base_lr=0.0001, nclasses=12, nstyle=2, batch_size=111, max_epoch=300, 
                 tb_logdir="runs", zero_conc_thresh=0.05):

        self.q = Q.to(device)
        self.p = P.to(device)
        self.d = D_gauss.to(device)
        if include_transfer_learning:
            self.dann = Dann.to(device)
            self.domain_train_feff_loader = domain_train_feff_loader
            self.domain_train_xspectra_loader = domain_train_xspectra_loader
        
        self.nclasses = nclasses
        self.nstyle = nstyle
        
        self.val_cn_weights = torch.tensor(val_cn_weights,  dtype=torch.float, device=device)
        self.batch_size = batch_size
        self.max_epoch = max_epoch
        self.base_lr = base_lr
        self.device = device
        self.cn_train_feff_loader = cn_train_feff_loader
        self.cn_val_loader = cn_val_loader
        
        self.noise_test_range = (-2, 2)
        self.ntest_per_spectra = 10
        self.zero_conc_thresh = zero_conc_thresh
        gau_kernel_size = 17
        self.gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=gau_kernel_size, sigma=3.0, dim=1, device=device).to(device)
        self.padding4smooth = nn.ReplicationPad1d(padding=(gau_kernel_size-1)//2).to(device)
        
        self.tb_writer = SummaryWriter(log_dir=tb_logdir)
        
        example_spec = iter(cn_train_feff_loader).next()[0]
        self.tb_writer.add_graph(FakeDualAAE(), example_spec)

    def sample_categorical(self):
        '''
         Sample from a categorical distribution
         of size batch_size and # of classes n_classes
         return: torch.autograd.Variable with the sample
        '''
        idx = np.random.randint(0, self.nclasses, self.batch_size)
        cat = np.eye(self.nclasses)[idx].astype('float32')
        cat = torch.tensor(cat, requires_grad=False, device=self.device)
        return cat, idx

    def zerograd(self):
        self.q.zero_grad()
        self.p.zero_grad()
        self.d.zero_grad()
        if include_transfer_learning:
            self.dann.zero_grad()
            
    def d_entropy1(self, y):
        y1 = y.mean(dim=0) 
        y2 = torch.sum(-y1*torch.log(y1+1e-5))
        return y2   
        
    def d_entropy2(self, y):
        y1 = -y*torch.log(y+1e-5)
        y2 = torch.sum(y1)/self.batch_size
        return y2      
        
    def get_cluster_idx(self, Y_pred):
        return Y_pred.argmax(dim=1).cpu()
    
    def get_cluster_plot(self, spec_list, nsub=3):
        assert spec_list.shape[0] == self.nclasses
        assert spec_list.shape[1] == self.ntest_per_spectra
        fig, ax_list = plt.subplots(self.nclasses//nsub, nsub, sharex=True, sharey=True, figsize=(9, 12))
        colors = sns.color_palette("husl", self.ntest_per_spectra)
        for i, (sl, ax) in enumerate(zip(spec_list, ax_list.ravel())):
            for spec, color in zip(sl, colors):
                ax.plot(spec, lw=1.5, c=color)
                if i % 3 == 0:
                    ax.set_ylabel(f"{i//n_subclasses + 3} Folds Coordinated")
                if i >= (n_coord_num-1) *n_subclasses:
                    ax.set_xlabel(f"Subclass {i%3 + 1}")    
        title = "All Classes and Styles"
        fig.suptitle(title, y=0.91)
        return fig
    
    def get_style_distribution_plot(self, z):
        fig, ax_list = plt.subplots(self.nstyle, 1, sharex=True, sharey=True, figsize=(9, 12))
        for istyle, ax in zip(range(self.nstyle), ax_list):
            sns.distplot(z[:, istyle], kde=False, color='blue', bins=np.arange(-3.0, 3.01, 0.2),
                         ax=ax)
        return fig
            

    def train(self):
        
        #loss function
        mse_dis = nn.MSELoss().to(self.device)
        criterionQ_dis = nn.NLLLoss().to(self.device)
        bce_loss = nn.BCELoss().to(self.device)
        nll_loss = nn.NLLLoss().to(self.device)
        
        
        solver_lr_ratio = {
            "RE": lr_ratio_Reconn,
            "I": lr_ratio_Mutual,
            "Smooth": lr_ratio_Smooth,
            "Cat": lr_ratio_Supervise,
            "G": lr_ratio_Style,
            "h": lr_ratio_CR,    
        }
        if include_transfer_learning:
            solver_lr_ratio["dann"] = lr_ratio_domain
            
        RE_solver = optim.AdamW([{'params':self.q.parameters()}, {'params':self.p.parameters()}], lr=solver_lr_ratio["RE"]*self.base_lr)
        I_solver = optim.AdamW([{'params':self.q.parameters()}, {'params':self.p.parameters()}], lr=solver_lr_ratio["I"]*self.base_lr)
        Smooth_solver = optim.AdamW([{'params':self.p.parameters()}], lr=solver_lr_ratio["Smooth"]*self.base_lr)
        Cat_solver = optim.AdamW([{'params':self.q.parameters()}], lr=solver_lr_ratio["Cat"]*self.base_lr)
        G_solver = optim.AdamW([{'params':self.d.parameters()}, {'params':self.q.parameters()}], lr=solver_lr_ratio["G"]*self.base_lr)
        h_solver = optim.AdamW([{'params':self.q.parameters()}], lr=solver_lr_ratio["h"]*self.base_lr)
        if include_transfer_learning:
            Dann_solver =  optim.AdamW([{'params':self.dann.parameters()}], lr=solver_lr_ratio["dann"]*self.base_lr)
        
        sol_list = [RE_solver, I_solver, Smooth_solver, Cat_solver, G_solver, h_solver]
        if include_transfer_learning:
            sol_list.append(Dann_solver)
        schedulers = [ReduceLROnPlateau(sol, factor=sch_factor, patience=sch_patience, cooldown=0, threshold=0.01, verbose=True)
                      for sol in sol_list]
        
        
        # fixed random variables, for plot spectra
        Idx = np.arange(self.nclasses).repeat(self.ntest_per_spectra)
        one_hot = np.zeros((self.nclasses * self.ntest_per_spectra, self.nclasses))
        one_hot[range(self.nclasses * self.ntest_per_spectra), Idx] = 1

        c = np.linspace(*self.noise_test_range, self.ntest_per_spectra).reshape(1,-1)
        c = np.repeat(c, self.nclasses, 0).reshape(-1, 1)
        c2 = np.hstack([np.zeros_like(c)] * (self.nstyle-1) + [c])

        dis_c = torch.tensor(one_hot, dtype=torch.float, device=self.device, requires_grad=False)
        con_c = torch.tensor(c2, dtype=torch.float, device=self.device, requires_grad=False)
        
        if include_transfer_learning:
                assert len(self.cn_train_feff_loader) == len(self.domain_train_feff_loader)
                assert len(self.cn_train_feff_loader) == len(self.domain_train_xspectra_loader)
                
        #train network
        last_best = 0.0
        chkpt_dir = "checkpoints"
        if not os.path.exists(chkpt_dir):
            os.makedirs(chkpt_dir, exist_ok=True)
        best_chk = None
        for epoch in tqdm(range(self.max_epoch), desc="Cluster"):            
            # Set the networks in train mode (apply dropout when needed)
            self.q.train()
            self.p.train()
            self.d.train()
            if include_transfer_learning:
                self.dann.train()
            # Loop through the labeled and unlabeled dataset getting one batch of samples from each
            # The batch size has to be a divisor of the size of the dataset or it will return
            # invalid samples
            
            iterator_cn_train = iter(self.cn_train_feff_loader)
            if include_transfer_learning:
                iterator_domain_feff = iter(self.domain_train_feff_loader)
                iterator_damain_xspectra = iter(self.domain_train_xspectra_loader)
            
            for _ in range(len(self.cn_train_feff_loader)):
                spec_in, cn_in = next(iterator_cn_train)
                if include_transfer_learning:
                    spec_src, _ = next(iterator_domain_feff)
                    spec_target, _ = next(iterator_damain_xspectra)
                alpha = (2. / (1. + np.exp(-1.0E4 / alpha_flat_step * epoch/self.max_epoch)) - 1) * alpha_limit
                
                zero_conc_selector = (cn_in < self.zero_conc_thresh)
                zero_conc_selector = zero_conc_selector.unsqueeze(dim=2)
                zero_conc_selector = zero_conc_selector.repeat(1, 1, self.nclasses//cn_in.size()[1])
                zero_conc_selector = zero_conc_selector.resize(cn_in.size()[0], self.nclasses)
                pure_selector = (cn_in.max(dim=-1).values > 1.0 - self.zero_conc_thresh)
        
                spec_in = spec_in.to(self.device)
                if include_transfer_learning:
                    spec_src = spec_src.to(self.device)
                    spec_target = spec_target.to(self.device)
                pure_selector = pure_selector.to(self.device)
                zero_conc_selector = zero_conc_selector.to(self.device)
                bce_eps = torch.full_like(zero_conc_selector[zero_conc_selector], fill_value=1.0E-3, dtype=torch.float, device=self.device)
                
                # Init gradients
                self.zerograd()
        
                #Domain loss
                if include_transfer_learning:
                    comp_domain_label = torch.zeros(spec_src.size()[0], dtype=torch.long, requires_grad=False, device=self.device)
                    comp_domain_pred = self.dann(spec_src, alpha, 0)

                    exp_domain_label = torch.ones(spec_target.size()[0], dtype=torch.long, requires_grad=False, device=self.device)
                    exp_domain_pred = self.dann(spec_target, alpha, 1)
                    domain_loss = nll_loss(comp_domain_pred, comp_domain_label) + nll_loss(exp_domain_pred, exp_domain_label)
                    domain_loss.backward()
                    Dann_solver.step()
        
                # Init gradients
                self.zerograd()
        
                # D loss
                z_real_gauss = torch.randn(self.batch_size, self.nstyle, requires_grad=True, device=self.device)
                z_fake_gauss, _ = self.q(spec_in)
                
                real_gauss_label = torch.ones(self.batch_size, dtype=torch.long, requires_grad=False, device=self.device)
                real_gauss_pred = self.d(z_real_gauss, alpha)
                
                fake_guass_lable = torch.zeros(spec_in.size()[0], dtype=torch.long, requires_grad=False, device=self.device)
                fake_gauss_pred = self.d(z_fake_gauss, alpha)
                       
                G_loss = nll_loss(real_gauss_pred, real_gauss_label) + nll_loss(fake_gauss_pred, fake_guass_lable)
                G_loss.backward()
                G_solver.step()
                
                
                # Init gradients
                self.zerograd()
        
                #H loss
                _, y = self.q(spec_in[pure_selector])
                h1_loss = self.d_entropy2(y)
                h2_loss = -self.d_entropy1(y)
                h_loss = 0.5 * h1_loss + h2_loss
                h_loss.backward()
                h_solver.step()
                
                # Init gradients
                self.zerograd()
                _, y = self.q(spec_in)
                cat_semisupervise_loss = bce_loss(y[zero_conc_selector], bce_eps)
                
                cat_semisupervise_loss.backward()
                Cat_solver.step()
                
                # Init gradients
                self.zerograd()
                
                #recon_x
                z, y = self.q(spec_in)
                spec_re = self.p(z, y)
                
                recon_loss = mse_dis(spec_re, spec_in)
                recon_loss.backward()
                RE_solver.step()
                
                # Init gradients
                self.zerograd()
                
                #recon y and d
                y, idx = self.sample_categorical()
                z = torch.randn(self.batch_size, self.nstyle, requires_grad=False, device=self.device)
                target = torch.tensor(idx, dtype=torch.long, requires_grad=False, device=self.device)
                
                X_sample = self.p(z, y)
                z_recon, y_recon = self.q(X_sample)
                
                I_loss = criterionQ_dis(torch.log(y_recon), target) + mse_dis(z_recon[:,:-1], z[:, :-1])
                
                I_loss.backward()
                I_solver.step()
                
                # Init gradients
                self.zerograd()
                
                X_sample = self.p(z, y)
                # smoothed regulation
                X_sample_padded = self.padding4smooth(X_sample.unsqueeze(dim=1))
                spec_smoothed = self.gaussian_smoothing(X_sample_padded).squeeze(dim=1)
                Smooth_loss = mse_dis(X_sample, spec_smoothed)
                
                Smooth_loss.backward()
                Smooth_solver.step()
                
                # Init gradients
                self.zerograd()
                
                # record losses   
                loss_dict = {
                    'recon_loss': recon_loss.item(),
                    'I_loss': I_loss.item(),   
                    'Smooth': Smooth_loss.item()
                }
                self.tb_writer.add_scalars("Recon/train", loss_dict, global_step=epoch)
                loss_dict = {
                    'h_loss': h_loss.item()
                }
                self.tb_writer.add_scalars("CR/train", loss_dict, global_step=epoch)
                if include_transfer_learning:
                    loss_dict = {
                        'domain_loss': domain_loss.item()
                    }
                    self.tb_writer.add_scalars("Domain/train", loss_dict, global_step=epoch)
                loss_dict = { 
                    'cat_loss': cat_semisupervise_loss.item()
                }
                self.tb_writer.add_scalars("Supervise/train", loss_dict, global_step=epoch)
                loss_dict = {
                    'G_loss': G_loss.item()       
                }
                self.tb_writer.add_scalars("Adversarial/train", loss_dict, global_step=epoch)
            
            
            
            self.q.eval()
            self.p.eval()
            self.d.eval()
            if include_transfer_learning:
                self.dann.eval()
            spec_in, cn_in = [torch.cat(x, dim=0) for x in zip(*list(self.cn_val_loader))]
            spec_in = spec_in.to(self.device)
            cn_in = cn_in.to(self.device)
            z, y = self.q(spec_in)
            spec_re = self.p(z, y)
            tw =  cn_in @ self.val_cn_weights
            tw /= tw.sum()
            spec_diff = ((spec_re - spec_in)**2).mean(dim=1)
            recon_loss = (spec_diff * tw).sum()
            loss_dict = {
                'recon_loss': recon_loss.item()
            }
            self.tb_writer.add_scalars("Recon/val", loss_dict, global_step=epoch) 
            
            class_probs = y.detach().cpu().numpy()
            class_pred = class_probs.argmax(axis=-1)//n_subclasses
            class_true = cn_in.detach().cpu().numpy().argmax(axis=-1)
            cat_accuracy = f1_score(class_true, class_pred, average='weighted')
            
            class_sum_pred = class_probs.reshape(class_probs.shape[0], cn_in.size()[1], n_subclasses).sum(axis=1).argmax(axis=-1)
            cat_sum_accuracy = f1_score(class_true, class_sum_pred, average='weighted')
            
            loss_dict = {
                'Max Divid': cat_accuracy,
                'Group Sum': cat_sum_accuracy
            }
            self.tb_writer.add_scalars("F1 Score/val", loss_dict, global_step=epoch)
            
            pure_selector = (cn_in.max(dim=-1).values > 1.0 - self.zero_conc_thresh)
            pure_selector = pure_selector.to(self.device)
            h1_loss = self.d_entropy2(y[pure_selector])
            h2_loss = -self.d_entropy1(y[pure_selector])
            h_loss = 0.5 * h1_loss + h2_loss   
            loss_dict = {
                    'h_loss': h_loss.item(),       
            }
            self.tb_writer.add_scalars("CR/val", loss_dict, global_step=epoch)
            
            zero_conc_selector = (cn_in < self.zero_conc_thresh)
            zero_conc_selector = zero_conc_selector.unsqueeze(dim=2)
            zero_conc_selector = zero_conc_selector.repeat(1, 1, self.nclasses//cn_in.size()[1])
            zero_conc_selector = zero_conc_selector.resize(cn_in.size()[0], self.nclasses)

            zero_conc_selector = zero_conc_selector.to(self.device)
            bce_eps = torch.full_like(zero_conc_selector[zero_conc_selector], fill_value=1.0E-3, dtype=torch.float, device=self.device)
            zero_conc_selector = zero_conc_selector.to(self.device)
            bce_eps = torch.full_like(zero_conc_selector[zero_conc_selector], fill_value=1.0E-3, dtype=torch.float, device=self.device)
            cat_semisupervise_loss = bce_loss(y[zero_conc_selector], bce_eps)
            
            loss_dict = { 
                'cat_loss': cat_semisupervise_loss.item()
            }
            self.tb_writer.add_scalars("Supervise/val", loss_dict, global_step=epoch)
            
            z_fake_gauss = z
            z_real_gauss = torch.randn_like(z, requires_grad=True, device=self.device)

            real_gauss_label = torch.ones(spec_in.size()[0], dtype=torch.long, requires_grad=False, device=self.device)
            real_gauss_pred = self.d(z_real_gauss, alpha)

            fake_guass_lable = torch.zeros(spec_in.size()[0], dtype=torch.long, requires_grad=False, device=self.device)
            fake_gauss_pred = self.d(z_fake_gauss, alpha)

            G_loss = nll_loss(real_gauss_pred, real_gauss_label) + nll_loss(fake_gauss_pred, fake_guass_lable)

            loss_dict = {
                'G_loss': G_loss.item()       
            }
            self.tb_writer.add_scalars("Adversarial/val", loss_dict, global_step=epoch)
            
            model_dict = {"Encoder": self.q, 
                          "Decoder": self.p, 
                          "Style Discriminator": self.d}
            if include_transfer_learning:
                model_dict["Domain Classifier"] = self.dann
            if cat_accuracy > last_best * 1.01:
                chk_fn = f"{chkpt_dir}/epoch_{epoch:06d}_loss_{cat_accuracy:05.4g}.pt"
                torch.save(model_dict, 
                           chk_fn)
                last_best = cat_accuracy
                best_chk = chk_fn
                
            for sch in schedulers:
                sch.step(torch.tensor(last_best))
                
            # plot images
            if epoch % 25 == 0:
                spec_out = self.p(con_c, dis_c).reshape(self.nclasses, self.ntest_per_spectra, -1).clone().cpu().detach().numpy()
                fig = self.get_cluster_plot(spec_out)
                self.tb_writer.add_figure("Spectra", fig, global_step=epoch) 
                
                spec_in, cn_in = [torch.cat(x, dim=0) for x in zip(*list(self.cn_val_loader))]
                spec_in = spec_in.to(self.device)
                cn_in = cn_in.to(self.device)
                z, _ = self.q(spec_in)
                fig = self.get_style_distribution_plot(z.clone().cpu().detach().numpy())
                self.tb_writer.add_figure("Style Value Distribution", fig, global_step=epoch)
            
        
        #save model
        model_dict = {"Encoder": self.q, 
                      "Decoder": self.p, 
                      "Style Discriminator": self.d}
        if include_transfer_learning:
            model_dict["Domain Classifier"] = self.dann
        torch.save(model_dict, 
                   'final.pt')
        if best_chk is not None:
            shutil.copy2(best_chk, 'best.pt')
        

use_cuda = torch.cuda.is_available()
if use_cuda:
    print("Use GPU")
    for loader in [cn_train_feff_loader, cn_val_loader]:
        loader.pin_memory = False
    if include_transfer_learning:
        domain_train_xspectra_loader.pin_memory = False
else:
    print("Use Slow CPU!")
    
device = torch.device(f"cuda:{igpu}" if use_cuda else "cpu")

Q = Q_net(nclasses=n_coord_num*n_subclasses, nstyle=nstyle)
P = P_net(nclasses=n_coord_num*n_subclasses, nstyle=nstyle, device=device)
D_gauss = D_net_gauss(nstyle=nstyle)
if include_transfer_learning:
    Dann = Domain_classifier_net(Q, nclasses=n_coord_num*n_subclasses, nstyle=nstyle)
else:
    Dann = None

for i in [Q, P, D_gauss, Dann]:
    if i is not None:
        i.to(device)
        #i.apply(weights_init)

if not include_transfer_learning:
    domain_train_feff_loader, domain_train_xspectra_loader = None, None

trainer = Trainer(Q, P, D_gauss, Dann, device, 
                  cn_train_feff_loader, cn_val_loader, domain_train_feff_loader, domain_train_xspectra_loader,
                  val_cn_weights=cn_sampling_weights, nclasses=n_coord_num*n_subclasses, nstyle=nstyle,
                  max_epoch=max_epoch, base_lr=lr)

trainer.train()

In [None]:
para_info = torch.__config__.parallel_info()
print(para_info)

In [None]:
def cluster_grid_plot(decoder):
    decoder.eval()
    for istyle in range(decoder.nstyle):
        nspec_pc = 10
        Idx = np.arange(n_coord_num*n_subclasses).repeat(nspec_pc)
        one_hot = np.zeros((n_coord_num*n_subclasses*nspec_pc,n_coord_num*n_subclasses))
        one_hot[list(range(n_coord_num*n_subclasses*nspec_pc)), Idx] = 1

        c = np.linspace(*[-1, 1], n_coord_num*n_subclasses).reshape(1,-1)
        c = np.repeat(c, nspec_pc, 0).reshape(-1, 1)
        c2 = np.hstack([np.zeros_like(c)] * istyle + [c] + [np.zeros_like(c)]*(decoder.nstyle - istyle - 1))

        dis_c = torch.tensor(one_hot, dtype=torch.float, requires_grad=False)
        con_c = torch.tensor(c2, dtype=torch.float, requires_grad=False)

        spec_out = decoder(con_c, dis_c).reshape(n_coord_num*n_subclasses, nspec_pc, -1).clone().cpu().detach().numpy()
        plt.figure()
        nsub=n_subclasses
        fig, ax_list = plt.subplots(n_coord_num*n_subclasses//nsub, nsub, sharex=True, sharey=True, figsize=(9, 12))
        colors = sns.color_palette("coolwarm", nspec_pc)
        for i, (sl, ax) in enumerate(zip(spec_out, ax_list.ravel())):
            for spec, color in zip(sl, colors):
                ax.plot(spec, lw=1.5, c=color)
                if i % 3 == 0:
                    ax.set_ylabel(f"{i//n_subclasses + 3} Folds Coordinated")
                if i >= (n_coord_num-1) *n_subclasses:
                    ax.set_xlabel(f"Subclass {i%3 + 1}")    
        title = f"All Classes and Styles #{istyle}"
        fig.suptitle(title, y=0.91)
        if not os.path.exists("reports"):
            os.makedirs("reports")
        plt.savefig(f"reports/{title}.pdf", dpi=600)


final_spuncat = torch.load('final.pt', map_location=torch.device('cpu')) 
cluster_grid_plot(final_spuncat["Decoder"])

In [None]:
def plot_cluster_centers(p, nclasses=n_coord_num*n_subclasses, title='final cluster center'):
    p.eval()
    nstyle = p.nstyle
    one_hot = torch.eye(nclasses)
    z = torch.zeros(nclasses, nstyle)
    cluster_specs = p(z, one_hot).cpu().detach().numpy()
    plt.figure()
    sns.set_palette('husl', nclasses)
    for spec in cluster_specs:
        plt.plot(spec)
        
    plt.title(title)
    if not os.path.exists("reports"):
            os.makedirs("reports")
    plt.savefig(f'reports/{title}.pdf', dpi=300)
    
best_spuncat = torch.load('best.pt', map_location=torch.device('cpu'))  
plot_cluster_centers(final_spuncat["Decoder"], title='final cluster center')
plot_cluster_centers(best_spuncat["Decoder"], title='best cluster center')

In [None]:
def compute_confusion_matrix(encoder, ds, title_base):
    encoder.eval()
    spec_in, cn_in = [torch.stack(x, dim=0) for x in zip(*list(ds))]
    _, y = encoder(spec_in)
    class_probs = y.detach().cpu().numpy()
    class_pred = class_probs.argmax(axis=-1)//n_subclasses
    class_true = cn_in.detach().cpu().numpy().argmax(axis=-1)
    cat_accuracy = f1_score(class_true, class_pred, average='weighted')
    cm = confusion_matrix(class_true, class_pred)
    cn_labels = [f'CN{i}' for i in range(4, 4+n_coord_num)]
    plt.figure()
    sns.set(font_scale=1.4)
    sns.heatmap(cm, xticklabels=cn_labels, yticklabels=cn_labels, cmap='Blues', fmt='d', annot=True, lw=1.0)
    title = f'{title_base} with F1 Score at {cat_accuracy:.2%}'
    plt.title(title)
    plt.xlabel("Prediction")
    plt.ylabel("True Value")
    if not os.path.exists("reports"):
            os.makedirs("reports")
    plt.savefig(f'reports/{title_base}.pdf', dpi=300, bbox_inches='tight')
    return cm, cat_accuracy

compute_confusion_matrix(final_spuncat["Encoder"], dataset_feff_train, title_base="Confusion Matrix on FEFF Training Set"); 
compute_confusion_matrix(final_spuncat["Encoder"], dataset_feff_val, title_base="Confusion Matrix on FEFF Validation Set");
compute_confusion_matrix(final_spuncat["Encoder"], dataset_feff_test, title_base="Confusion Matrix on FEFF Test Set");

In [None]:
if include_transfer_learning:
    compute_confusion_matrix(final_spuncat["Encoder"], dataset_xspectra_test, title_base="Confusion Matrix on XSpectra Target Set");