In [1]:
from C_0E_run import *

In [2]:
ts = "0526172101-1"
train_name = "C_0E"
model_save_dir = os.path.join(model_save_, f"{train_name}-{ts}")

In [3]:
hyper_dir = model_save_dir
model_type = "mtl"
condition = "u"

In [4]:
model_save_dir = os.path.join(hyper_dir, model_type, condition)
mk(model_save_dir)

# Loss Recording
train_losses = ListRecorder(os.path.join(model_save_dir, "train.loss"))
train_recon_losses = ListRecorder(os.path.join(model_save_dir, "train.recon.loss"))
train_embedding_losses = ListRecorder(os.path.join(model_save_dir, "train.embedding.loss"))
train_commitment_losses = ListRecorder(os.path.join(model_save_dir, "train.commitment.loss"))

valid_losses = ListRecorder(os.path.join(model_save_dir, "valid.loss"))
valid_recon_losses = ListRecorder(os.path.join(model_save_dir, "valid.recon.loss"))
valid_embedding_losses = ListRecorder(os.path.join(model_save_dir, "valid.embedding.loss"))
valid_commitment_losses = ListRecorder(os.path.join(model_save_dir, "valid.commitment.loss"))

# In C we take onlyST to record the phenomenon-target dataset
onlyST_valid_losses = ListRecorder(os.path.join(model_save_dir, "valid_onlyST.loss"))
onlyST_valid_recon_losses = ListRecorder(os.path.join(model_save_dir, "valid_onlyST.recon.loss"))
onlyST_valid_embedding_losses = ListRecorder(os.path.join(model_save_dir, "valid_onlyST.embedding.loss"))
onlyST_valid_commitment_losses = ListRecorder(os.path.join(model_save_dir, "valid_onlyST.commitment.loss"))

text_hist = HistRecorder(os.path.join(model_save_dir, "trainhist.txt"))

# Recording Directory
phone_rec_dir = train_cut_phone_
word_rec_dir = train_cut_word_
train_guide_path = os.path.join(src_, "guide_train.csv")
valid_guide_path = os.path.join(src_, "guide_validation.csv")

# Load TokenMap to map the phoneme to the index
with open(os.path.join(src_, "no-stress-seg.dict"), "rb") as file:
    # Load the object from the file
    mylist = pickle.load(file)
    mylist = ["BLANK"] + mylist
    mylist = mylist + ["SIL"]   # this is to fit STV vs #TV

# Now you can use the loaded object
mymap = TokenMap(mylist)
class_dim = mymap.token_num()
ctc_size_list = {'hid': INTER_DIM_2, 'class': class_dim}

# Initialize Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
masked_loss = MaskedLoss(loss_fn=nn.MSELoss(reduction="none"))
ctc_loss = nn.CTCLoss(blank=mymap.encode("BLANK"))
model_loss = AlphaCombineLoss(masked_loss, ctc_loss, alpha=0.2)
if model_type == "mtl":
    model = AEPPV1(enc_size_list=ENC_SIZE_LIST, 
                dec_size_list=DEC_SIZE_LIST, 
                ctc_decoder_size_list=ctc_size_list,
                num_layers=NUM_LAYERS, dropout=DROPOUT)
else: 
    raise Exception("Model type not supported! ")

model.to(device)
initialize_model(model)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
model_str = str(model)
model_txt_path = os.path.join(model_save_dir, "model.txt")
with open(model_txt_path, "w") as f:
    f.write(model_str)
    f.write("\n")

In [5]:
# Load Data
guide_path = os.path.join(hyper_dir, "guides")
# train_loader = load_data_general(TrainDataset, 
#                                     word_rec_dir, train_guide_path, load="train", select=0.3, sampled=False)
valid_loader = load_data_general(TrainDataset, 
                                    word_rec_dir, valid_guide_path, load="valid", select=0.3, sampled=False)
onlyST_valid_loader = load_data_phenomenon(TestDataset, 
                                            phone_rec_dir, guide_path, load="valid", select="both", sampled=True, word_guide_=valid_guide_path)

num_epochs = 100
l_w_embedding = 1
l_w_commitment = 0.25

In [6]:
class AlphaCombineLoss:
    def __init__(self, recon_loss, pred_loss, alpha=0.1):
        self.recon_loss = recon_loss
        self.pred_loss = pred_loss
        self.alpha = alpha
    
    def get_loss(self, y_hat_recon, y_recon, y_hat_pred, y_pred, x_lens, y_pred_lens, mask): 
        reconstruction_loss = self.recon_loss.get_loss(y_hat_recon, y_recon, mask)
        prediction_loss = self.pred_loss(y_hat_pred, y_pred, x_lens, y_pred_lens)

        # Compute the regularization term based on the number of boundaries
        prediction_loss_term = self.alpha * prediction_loss

        return reconstruction_loss + prediction_loss_term, (reconstruction_loss, prediction_loss)

In [7]:
# for epoch in range(num_epochs):
#     text_hist.print("Epoch {}".format(epoch))
#     model.train()
#     train_loss = 0.
#     train_cumulative_l_reconstruct = 0.
#     train_cumulative_l_embedding = 0.
#     train_cumulative_l_commitment = 0.
#     train_num = len(train_loader.dataset)    # train_loader
#     for idx, (x, x_lens, y_preds, y_preds_lens) in enumerate(train_loader):
#         current_batch_size = x.shape[0]
#         # y_lens should be the same as x_lens
#         optimizer.zero_grad()
#         x_mask = generate_mask_from_lengths_mat(x_lens, device=device)
#         y_recon = x
#         x = x.to(device)
#         y_recon = y_recon.to(device)
#         y_preds = y_preds.to(device)
#         y_preds = y_preds.long()

#         (x_hat_recon, y_hat_preds), (attn_w_recon, attn_w_preds), (ze, zq) = model(x, x_lens, x_mask)
#         y_hat_preds = y_hat_preds.permute(1, 0, 2)

        

#         l_alpha, (l_reconstruct, l_prediction) = model_loss.get_loss(x_hat_recon, y_recon, 
#                                                                         y_hat_preds, y_preds, 
#                                                                         x_lens, y_preds_lens, 
#                                                                         x_mask)
#         if model_type == "vqvae":
#             l_embedding = model_loss.get_loss(ze.detach(), zq, x_mask)
#             l_commitment = model_loss.get_loss(ze, zq.detach(), x_mask)
#             loss = l_alpha + \
#                 l_w_embedding * l_embedding + l_w_commitment * l_commitment
#         elif model_type == "mtl":
#             l_embedding = l_prediction
#             l_commitment = l_prediction
#             loss = l_alpha
#         else: 
#             l_embedding = torch.tensor(0)
#             l_commitment = torch.tensor(0)
#             loss = l_alpha

#         loss.backward()
#         optimizer.step()

#         train_loss += loss.item() * current_batch_size
#         train_cumulative_l_reconstruct += l_reconstruct.item() * current_batch_size
#         train_cumulative_l_embedding += l_embedding.item() * current_batch_size
#         train_cumulative_l_commitment += l_commitment.item() * current_batch_size

#         if idx % 100 == 0:
#             text_hist.print(f"""Training step {idx} loss {loss: .3f} \t recon {l_reconstruct: .3f} \t embed {l_embedding: .3f} \t commit {l_commitment: .3f}""")

In [8]:
# # Valid (ST + T)
# model.eval()
# valid_loss = 0.
# valid_cumulative_l_reconstruct = 0.
# valid_cumulative_l_embedding = 0.
# valid_cumulative_l_commitment = 0.
# valid_num = len(valid_loader.dataset)
# for idx, (x, x_lens, y_preds, y_preds_lens) in enumerate(valid_loader):
#     current_batch_size = x.shape[0]
#     x_mask = generate_mask_from_lengths_mat(x_lens, device=device)

#     y_recon = x
#     x = x.to(device)
#     y_recon = y_recon.to(device)
#     y_preds = y_preds.to(device)
#     y_preds = y_preds.long()

#     (x_hat_recon, y_hat_preds), (attn_w_recon, attn_w_preds), (ze, zq) = model(x, x_lens, x_mask)
#     y_hat_preds = y_hat_preds.permute(1, 0, 2)

#     l_alpha, (l_reconstruct, l_prediction) = model_loss.get_loss(x_hat_recon, y_recon, 
#                                                                     y_hat_preds, y_preds, 
#                                                                     x_lens, y_preds_lens, 
#                                                                     x_mask)
#     if model_type == "vqvae":
#         l_embedding = model_loss.get_loss(ze.detach(), zq, x_mask)
#         l_commitment = model_loss.get_loss(ze, zq.detach(), x_mask)
#         loss = l_alpha + \
#             l_w_embedding * l_embedding + l_w_commitment * l_commitment
#     elif model_type == "mtl":
#         l_embedding = l_prediction
#         l_commitment = l_prediction
#         loss = l_alpha
#     else: 
#         l_embedding = torch.tensor(0)
#         l_commitment = torch.tensor(0)
#         loss = l_alpha

#     valid_loss += loss.item() * current_batch_size
#     valid_cumulative_l_reconstruct += l_reconstruct.item() * current_batch_size
#     valid_cumulative_l_embedding += l_embedding.item() * current_batch_size
#     valid_cumulative_l_commitment += l_commitment.item() * current_batch_size

In [10]:
model.eval()
onlyST_valid_loss = 0.
onlyST_valid_cumulative_l_reconstruct = 0.
onlyST_valid_cumulative_l_embedding = 0.
onlyST_valid_cumulative_l_commitment = 0.
onlyST_valid_num = len(onlyST_valid_loader.dataset)
for idx, ((x, y_preds), (x_lens, y_preds_lens), pt, sn) in enumerate(onlyST_valid_loader):
    current_batch_size = x.shape[0]
    x_mask = generate_mask_from_lengths_mat(x_lens, device=device)

    y_recon = x
    x = x.to(device)
    y_recon = y_recon.to(device)
    y_preds = y_preds.to(device)
    y_preds = y_preds.long()

    (x_hat_recon, y_hat_preds), (attn_w_recon, attn_w_preds), (ze, zq) = model(x, x_lens, x_mask)
    y_hat_preds = y_hat_preds.permute(1, 0, 2)

    l_alpha, (l_reconstruct, l_prediction) = model_loss.get_loss(x_hat_recon, y_recon, 
                                                                    y_hat_preds, y_preds, 
                                                                    x_lens, y_preds_lens, 
                                                                    x_mask)
    if model_type == "vqvae":
        l_embedding = model_loss.get_loss(ze.detach(), zq, x_mask)
        l_commitment = model_loss.get_loss(ze, zq.detach(), x_mask)
        loss = l_alpha + \
            l_w_embedding * l_embedding + l_w_commitment * l_commitment
    elif model_type == "mtl":
        l_embedding = l_prediction
        l_commitment = l_prediction
        loss = l_alpha
    else: 
        l_embedding = torch.tensor(0)
        l_commitment = torch.tensor(0)
        loss = l_alpha

    onlyST_valid_loss += loss.item() * current_batch_size
    onlyST_valid_cumulative_l_reconstruct += l_reconstruct.item() * current_batch_size
    onlyST_valid_cumulative_l_embedding += l_embedding.item() * current_batch_size
    onlyST_valid_cumulative_l_commitment += l_commitment.item() * current_batch_size