In [39]:
from scipy.io import loadmat
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split, KFold
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import CTCLoss
# from torch.nn import CTCBeamSearchDecoder
from torch.autograd import Variable
# from ctcdecode import CTCBeamDecoder
import multiprocessing as mp
from torchaudio.models.decoder import ctc_decoder
from random import randint, sample
import random
import csv
import numpy as np
import re
import math
import h5py
import nltk
from collections import defaultdict

def split_data_based_on_labels(data, priority="test"):
    ### Split the data based on the labels
    ### priority: "train" or "test"
    ### train means n-1 duplicates in the training set, 1 in the testing set
    
    train_indices = []
    test_indices = []
    seen_labels = set()

    for i in range(data.shape[0]):
        # label = data[i, 0][0]
        label = tuple(data[i, 0][0])
        # prioritize the training data
        if label not in seen_labels:
            train_indices.append(i) if priority == "test" else test_indices.append(i)
            seen_labels.add(label)
        else:
            test_indices.append(i) if priority == "test" else train_indices.append(i)
    print(f"Train indices: {train_indices}, Test indices: {test_indices}")

    return data[train_indices, :], data[test_indices, :]

class SequenceDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Define a function to process the data
def process_data(data, max_lab_length, max_seq_length, trainortest="test", need_pad=True, label_class = 26, augment_times=20):
    processed_data = []
    lexicon = set()
    unique_labels = set()
    for i in range(data.shape[0]):
        label = data[i, 0][0]
        # skip if the label is not numeric
        if any(isinstance(x, str) for x in label):
            continue
        # check if the element of label is all between 1 and 26
        if not all(1 <= x <= label_class for x in label):
            label = [x for x in label if 1 <= x <= label_class]
            print(f"Label {i} is not all between 1 and {label_class}, the label is: {data[i, 1]}, now changed to: {label}")

        sequence = data[i, 2]
        # skip if the sequence is empty
        if sequence.shape[0] == 0:
            continue
        original_label_length = len(label)
        original_length = sequence.shape[0]
        if original_length > max_seq_length:
            print(f"Sequence {i} is longer than max length: {original_length}")
            break
        


        # Pad the sequence and label
        if need_pad:
            if len(sequence.shape) > 2:
                sequence = np.pad(sequence, ((0, max_seq_length - original_length), (0, 0), (0, 0)), 'constant', constant_values=0)
            else:
                sequence = np.pad(sequence, ((0, max_seq_length - original_length), (0, 0)), 'constant', constant_values=0) 
            label = np.pad(label, (0, max_lab_length - len(label)), 'constant', constant_values=0)

        sequence = sequence.astype(np.float32)  # Convert sequence to float32

        # If training data, add augmented sequences
        if trainortest == "train":
            for j in range(3, max(data.shape[1], augment_times)):
                if j >= data.shape[1]:
                    augmented_sequence = sequence
                else:
                    augmented_sequence = data[i, j]
                    if augmented_sequence.shape[0] == 0:
                        # continue
                        augmented_sequence = sequence
                if need_pad:
                    if len(augmented_sequence.shape) > 2:
                        augmented_sequence = np.pad(augmented_sequence, ((0, max_seq_length - augmented_sequence.shape[0]), (0, 0), (0, 0)), 'constant', constant_values=0)
                    else:
                        augmented_sequence = np.pad(augmented_sequence, ((0, max_seq_length - augmented_sequence.shape[0]), (0, 0)), 'constant', constant_values=0)
                augmented_sequence = augmented_sequence.astype(np.float32)
                processed_data.append((label, augmented_sequence, original_label_length, augmented_sequence.shape[0]))
        
        processed_data.append((label, sequence, original_label_length, original_length))

        # Add words to lexicon
        word = data[i, 1][0]
        word = re.sub('[^a-zA-Z]', '', word).lower()
        lexicon.add(word)

        # Add labels to unique_labels
        unique_labels.add(tuple(label))

        
        

    # Write to lexicon.txt if testing data
    if trainortest == "test":
        with open('./lexicon.txt', 'w') as f:
            for word in sorted(lexicon):
                f.write(f"{word} {' '.join(list(word))} |\n")
        print(f"Lexicon is written, its size is: {len(lexicon)}")

    return processed_data, unique_labels

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size].contiguous()

class TemporalBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        # self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
        #                                    stride=stride, padding=padding, dilation=dilation))
        # change the weight norm to layer norm
        self.conv1 = nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                    stride=stride, padding=padding, dilation=dilation)
        self.layernorm1 = nn.LayerNorm(n_outputs, eps = 1e-4)
                                    
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        # self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                        #    stride=stride, padding=padding, dilation=dilation))
        self.conv2 = nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=padding, dilation=dilation)
        self.layernorm2 = nn.LayerNorm(n_outputs, eps = 1e-4)

        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2)
        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.epsilon = 1e-5
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res + self.epsilon)

class TemporalConvNet(nn.Module):
    def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
        super(TemporalConvNet, self).__init__()
        layers = []
        num_levels = len(num_channels)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = num_inputs if i == 0 else num_channels[i-1]
            out_channels = num_channels[i]
            layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
                                     padding=(kernel_size-1) * dilation_size, dropout=dropout)]

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# Modify the TCN model
class TCN(nn.Module):
    def __init__(self, input_size, output_size, num_channels, kernel_size=2, dropout=0.2, lstm_hidden=64):
        super(TCN, self).__init__()
        self.tcn = TemporalConvNet(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.linear = nn.Linear(num_channels[-1]*100, output_size)
        
        # self.lstm = nn.LSTM(num_channels[-1]*200, lstm_hidden, batch_first=True,num_layers = 2)
        # # self.lstm = nn.GRU(num_channels[-1]*200, lstm_hidden, batch_first=True,num_layers = 2)
        # self.linear = nn.Linear(lstm_hidden, output_size)

    def forward(self, inputs):
        """Inputs is a batch of sequences of dimension (batch_size, input_size, seq_len)"""
        # Transpose the tensor to (batch_size, seq_len, input_size)
        # inputs = inputs.transpose(1, 2)

        # Reshape input to collapse the batch size and sequence length dimensions
        # batch_size, seq_length, input_size = inputs.shape
        if len(inputs.shape) == 3:
            batch_size, seq_length, input_size = inputs.shape
            inputs = inputs.view(batch_size * seq_length, 1, -1)
        elif len(inputs.shape) == 4:
            batch_size, seq_length, input_size, input_chl = inputs.shape
            # permute the input to (batch_size * seq_length, input_chl, input_size)
            inputs = inputs.permute(0, 1, 3, 2).contiguous().view(batch_size * seq_length, input_chl, -1)


        # print(inputs.shape)
        
        y = self.tcn(inputs)  # input should have dimension (N, C, L)
        unfolded_y = y.view(batch_size, seq_length, -1)
        # # flatten the output of TCN layer to (batch_size * seq_length, channel_size * features)
        o = self.linear(unfolded_y) #y.view(batch_size * seq_length, -1)
        # # print(o.shape)

        # lstm_out, _ = self.lstm(unfolded_y)
        # # lstm_out = lstm_out[:, -1, :]  # Taking the last output
        # # print(lstm_out.shape)
        # o = self.linear(lstm_out) # output of the linear layer
        # # o = torch.stack([self.linear(lstm_out[i]) for i in range(lstm_out.shape[0])])

        

        return o,y

# Define the training function
def train_model(model, train_loader, test_loader, num_epochs, optimizer, ctc_loss, scheduler=None, test_lexicon=None, output_size=27):
    # define the labels of the dataset of the characters 0-26
    # labels = np.arange(0, 27)
    # define the labels of the alphabet list
    labelchar = np.array(['-', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j','k','l','m','n','o','p','q','r','s','t','u','v','w','x','y', 'z','|'])
    # convert the labels to string
    labelchar = [str(c) for c in labelchar]
    # acquires the cpu number
    num_cpu = mp.cpu_count()
    # define the ctcdecoder
    eps = 1e-6
    # ctc_decoder = CTCBeamDecoder(labels, beam_width=100, log_probs_input=True, num_processes=round(num_cpu*0.8))
    decoder = ctc_decoder(lexicon="./lexicon.txt",tokens = labelchar, beam_size_token=500, blank_token=labelchar[0], sil_token=labelchar[27])
    
    torch.autograd.set_detect_anomaly(False)
    
    if torch.cuda.is_available() and test_lexicon is not None:
        test_lexicon = torch.IntTensor(list(test_lexicon)).to(device)

    model.train()  # Set the model to training mode
    for epoch in range(num_epochs):
        loss_total = 0
        for i, (labels, sequences, original_label_lengths,original_lengths) in enumerate(train_loader):
            # print shape of sequences

            # Transfer data to GPU if available
            if torch.cuda.is_available():
                labels = labels.to(device)
                sequences = sequences.to(device)
                original_lengths = original_lengths.to(device)
                original_label_lengths = original_label_lengths.to(device)
                
                

            # randomly scale every element of the input sequence between (0.95, 1.05) to avoid overfitting
            sequences = sequences * (0.95 + 0.1 * torch.rand(sequences.shape)).to(device)

            # Clear the gradients
            optimizer.zero_grad()

            # Forward pass
            batch_size, seq_length = sequences.shape[:2] # input_size[2], input_chl if len(sequences.shape) == 4
            outputs,tcnout = model(sequences)          
            if outputs.shape[2] != output_size:     
                outputs = outputs.view(batch_size, seq_length, outputs.shape[1])
            outputs = outputs.transpose(0, 1)  # (seq_length, batch_size, output_size)
            
            loss = ctc_loss(outputs.log_softmax(2).clamp(min=-10, max=10), labels, original_lengths, original_label_lengths) #
            
            
            # stop training if the loss is less than 0.1 or nan
            if math.isnan(loss.item()):
                print(loss.item(), "_", epoch, "_", i)
                loss = eps
                # set all nan parameters of model to a random number
                # for param in model.parameters():
                #     torch.manual_seed(i)
                #     param.data = torch.rand_like(param.data)*0.1-5e-2
                #     # param.data = torch.where(torch.isnan(param.data), torch.rand_like(param.data)*0.1-5e-2, param.data)
                # outputs,tcnout = model(sequences)          
                # if outputs.shape[2] != 27:     
                #     outputs = outputs.view(batch_size, seq_length, outputs.shape[1])
                # outputs = outputs.transpose(0, 1)  # (seq_length, batch_size, output_size)
                # loss = ctc_loss(outputs.log_softmax(2).clamp(min=-10, max=10), labels, original_lengths, original_label_lengths)
        
                print(labels)
                # return model
                # break
                
            loss_total += loss.item()

            # Backward pass and optimization
            loss.backward()
            # clip the gradient to avoid exploding gradient
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        # if it is a reduceonplateau scheduler, then step the scheduler
        if scheduler is not None and epoch > 0:
            if scheduler.__class__.__name__ == "ReduceLROnPlateau":
                scheduler.step(loss_total/len(train_loader))
                print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
            else:
                scheduler.step()
        # Print the loss for this batch
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {loss_total/len(train_loader)}")
        if loss_total/len(train_loader) < 0.1:
            print("original break due to loss")
            # break

##################################################################
#####----------- for every 10 epochs, evaluate the model----------
##################################################################

        if (epoch+1) % 1 == 0:
            # Evaluation on the test set
            model.eval()  # Set the model to evaluation mode
            with torch.no_grad():
                total_loss = 0

                rank_of_all = []
                for ii, (labels, sequences,original_label_lengths, original_lengths) in enumerate(test_loader):
                    
                    # Transfer data to GPU if available
                    if torch.cuda.is_available():
                        labels = labels.to(device)
                        sequences = sequences.to(device)
                        original_lengths = original_lengths.to(device)
                        original_label_lengths = original_label_lengths.to(device)

                    # Forward pass
                    batch_size, seq_length = sequences.shape[:2]
                    outputs,_ = model(sequences)
                    if len(outputs.shape) == 2:          
                        outputs = outputs.view(batch_size, seq_length, outputs.shape[1])

                    outputs = outputs.transpose(0, 1)  # to the shape of (seq_length, batch_size, output_size) #.clamp(min=-10, max=10)
                    outputs_log_softmax = outputs.log_softmax(2).clamp(min=-10, max=10) # log softmax

                    # Intuitive Decoder
                    outputs_inte = []
                    for i in range(len(original_lengths)):
                        # cut the matrix to the original length
                        outputs_i = outputs[:original_lengths[i],i, :].transpose(0,1)
                        # return the greatest probability of each character
                        outputs_i = torch.argmax(outputs_i, dim=0)

                        # if there are consecutive same characters, only keep one
                        for j in range(len(outputs_i)-1):
                            if outputs_i[j] == outputs_i[j+1]:
                                outputs_i[j+1] = 0
                        # remove the 0s
                        outputs_i = outputs_i[outputs_i.nonzero()]
                        # convert the tensor to list
                        outputs_inte.append(outputs_i.tolist())
                    # print(outputs_inte)    


                    # CTC decoder to get the final output and calculate the accuracy
                    # The CTC decoder returns a list of lists
                    # The first list contains the decoded outputs
                    # The second list contains the output probabilities
                    
                    # decoded_outputs = decoder(outputs.transpose(0, 1).softmax(2).cpu(), original_lengths.cpu())
                    # decoded_outputs = decoded_outputs[0][0].words
                    # print(decoded_outputs[0])
                    # print(decoded_outputs[0])
                    # return decoded_outputs

                    # treaverse the testlexicon with CTC loss of each sample, calculate the accuracy and the rank of the correct word
                    if test_lexicon is not None:

                        rank_of_batch = []
                        for i in range(batch_size):
                            loss_search = []
                            label_i = labels[i,:original_label_lengths[i]]
                            true_label_index = None
                            outputs_i = outputs_log_softmax[:,i,:]
                            for j, word in enumerate(test_lexicon):
                                # Cut the word before the first zero (padding)
                                word = word[:word.tolist().index(0) if 0 in word.tolist() else len(word)]

                                # Check if the word is equal to the label
                                if torch.equal(word, label_i):
                                    if true_label_index is None:
                                        true_label_index = j
                                    else:
                                        print("Repeat label in the lexicon!")
                                        continue

                                # calculate the CTC loss of each word
                                # try:
                                torch.backends.cudnn.enabled = False if sequences.shape[1] > 1 else True
                                loss_s = ctc_loss(outputs_i, word, original_lengths[i], torch.tensor([len(word)]).to(device)).item()
                                torch.backends.cudnn.enabled = True
                                # except:
                                #     # check if outputs_i is cuda or cpu
                                #     print(outputs_i.device)
                                #     print(word.device)
                                #     print(original_lengths[i].device)
                                #     print(torch.tensor([len(word)]).to(device).device)
                                if math.isnan(loss_s):
                                    loss_s = 1e5
                                loss_search.append(loss_s)
                            # make sure the true label is in the lexicon
                            if true_label_index is None:
                               # raise ValueError("True label not in the lexicon!")
                               print(f"Test label {label_i} not in the lexicon!")
                               continue
                            
                            # Get the rank of the true label
                            sorted_loss = np.argsort(loss_search)
                            rank_of_true_label = list(sorted_loss).index(true_label_index)
                            rank_of_batch.append(rank_of_true_label)
                        rank_of_all.append(rank_of_batch)
                         

                    # Calculate the CTC loss
                    output_lengths = torch.full((batch_size,), outputs.shape[1], dtype=torch.long)
                    try:
                        loss = ctc_loss(outputs_log_softmax, labels, original_lengths, original_label_lengths)
                    except:
                        print("error")
                        print(outputs.shape)
                        print(original_lengths)
                        print(labels.shape)
                        print(original_label_lengths)
                        
                    total_loss += loss.item()

                print(f"Epoch {epoch+1}/{num_epochs}, Test Loss: {total_loss/len(test_loader)}")
                # flatten the rank_of_all
                rank_of_all = [item for sublist in rank_of_all for item in sublist]
                print(np.mean(rank_of_all)) # print the average rank of the correct word
        
        model.train()  # Set the model back to training mode
    return model, rank_of_all


def get_lomo_train_data(word_list, sen_list, file_name, exp_mode, pr=0, ftsession=3, scenario="none"):
    train_data_raw = []
    
    # For "lomoword" mode
    if exp_mode.startswith("lomo"):
        for word in word_list:
            if word != file_name:
                wordmat = loadmat('../UTokyo data/utokyo_word_fea_left_' + word + '.mat')
                worddata = wordmat[list(wordmat.keys())[-1]][0,:]
                # switch the chosen element by the scenario
                if scenario == "none":
                    train_data = np.concatenate(worddata[:5], axis=0)
                elif scenario == "headmotion":
                    train_data = worddata[5]
                elif scenario == "noise":
                    train_data = worddata[4]
                elif scenario == "walk":
                    train_data = worddata[6]
                train_data_raw.append(train_data)
                
    # For "lomoall" mode
    if exp_mode.startswith("lomoall"):
        prefix = file_name[:4]
        for sen in sen_list:
            if not sen.startswith(prefix):
                senmat = loadmat('../UTokyo data/utokyo_sen_fea_left_' + sen + '.mat')
                sendata = senmat[list(senmat.keys())[-1]][0,:]
                data = np.concatenate(sendata[:4], axis=0)
                # # elaminiate the empty sequences
                # data = [ele for ele in data if ele.shape[1] != 0]
                # eleminate the sequences with length less than 10
                # data = np.array([ele[ele[:,2].shape[1]!=0] for ele in data])
                train_data_raw.append(data)
    if exp_mode.startswith("lomo"):
        train_data_raw = np.concatenate(train_data_raw, axis=0)
    
    # Fine-tuning data generation if pr is not 0
    finetune_raw = None
    if pr != 0:
        if exp_mode.endswith("word"):
            wordmat = loadmat('../UTokyo data/utokyo_word_fea_left_' + file_name + '_spl6.mat')
            data = wordmat[list(wordmat.keys())[-1]][0,:]
            if scenario == "none": # test = word4
                ftindex = [0, 1, 2, 4, 5, 6]
            elif scenario == "headmotion": # test = word6
                ftindex = [3, 4, 0, 1, 2, 6]
            elif scenario == "walk": # test = word7
                ftindex = [3, 4, 5, 0, 1, 2]
        elif exp_mode.endswith("sen"):
            senmat = loadmat('../UTokyo data/utokyo_sen_fea_left_' + file_name + '_spl6.mat')
            data = senmat[list(senmat.keys())[-1]][0,:]
            ftindex = [0]

        if ftsession > len(ftindex):
            print(f"fintune session {ftsession} is too large, now set to {len(ftindex)}")
            ftsession = len(ftindex)          
        ft_data_set = data[ftindex[:ftsession]]
        np.random.seed(581)
        ft_data_set = [ft_data_set[i][np.random.choice(ft_data_set[i].shape[0], int(ft_data_set[i].shape[0]*pr), replace=False)] for i in range(len(ft_data_set))]
        finetune_raw = np.concatenate(ft_data_set, axis=0)

    return train_data_raw, finetune_raw

def ctcdecode_byloss(model, test_loader, test_lexicon=None, test_lex_size = None):
    with open("../Oxford 1035 Export pure_ext.csv", newline="") as csvfile:
        data = list(csv.reader(csvfile))
        word_list_ori = [row[0] for row in data]
        random.seed(100)
        word_list = sample(word_list_ori, len(word_list_ori))
        word_list_num = [[ord(c) - 96 for c in re.sub('[^a-zA-Z]', '', word).lower()] for word in word_list]
        # pad the word_list_num to the max length with 0
        word_list_num = [word + [0]*(max_word_length-len(word)) for word in word_list_num]
    with open("../Corpus/MS_phrases2.csv", newline="",) as csvfile:
        data = list(csv.reader(csvfile))
        sentences_ori = [row[0] for row in data[1:]]
        # split the sentences into words, remove the empty words, then convert the words to numbers
        w_in_sentences = [[ord(c) - 96 for c in re.sub('[^a-zA-Z]', '', word).lower()] for sentence in sentences_ori for word in sentence.split()]
        # pad the w_in_sentences to the max length with 0
        w_in_sentences = [word + [0]*(max_word_length-len(word)) for word in w_in_sentences]
    # word_list = word_list + w_in_sentences
    # add word_list to the lexicon until the size of lexicon is equal to test_lex_size or the word_list is empty
    if test_lex_size is not None:
        while len(test_lexicon) < test_lex_size and len(word_list_num) > 0:
            try:
                word = word_list_num.pop()
            except:
                print("word_list is empty!")
                break
            test_lexicon.add(tuple(word))
    print(f"Lexicon size is: {len(test_lexicon)}")
            
    test_lexicon = torch.IntTensor(list(test_lexicon))
        
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():
        total_loss = 0
        rank_of_all = []
        for i, (labels, sequences,original_label_lengths, original_lengths) in enumerate(test_loader):
            
            # Transfer data to GPU if available
            if torch.cuda.is_available():
                labels = labels.to(device)
                sequences = sequences.to(device)
                original_lengths = original_lengths.to(device)
                original_label_lengths = original_label_lengths.to(device)
                test_lexicon = test_lexicon.to(device)

            # Forward pass
            # batch_size, seq_length, input_size = sequences.shape
            batch_size, seq_length = sequences.shape[:2]
            outputs,_ = model(sequences)
            if len(outputs.shape) == 2:          
                outputs = outputs.view(batch_size, seq_length, outputs.shape[1])

            outputs = outputs.transpose(0, 1)  # to the shape of (seq_length, batch_size, output_size) #.clamp(min=-10, max=10)
            outputs_log_softmax = outputs.log_softmax(2).clamp(min=-10, max=10) # log softmax

            # treaverse the testlexicon with CTC loss of each sample, calculate the accuracy and the rank of the correct word
            if test_lexicon is not None:

                rank_of_batch = []
                for i in range(batch_size):
                    loss_search = []
                    label_i = labels[i,:original_label_lengths[i]]
                    true_label_index = None
                    outputs_i = outputs_log_softmax[:,i,:]
                    for j, word in enumerate(test_lexicon):
                        # Cut the word before the first zero (padding)
                        word = word[:word.tolist().index(0) if 0 in word.tolist() else len(word)]

                        # Check if the word is equal to the label
                        if torch.equal(word, label_i):
                            if true_label_index is None:
                                true_label_index = j
                            else:
                                print("Repeat label in the lexicon!")
                                continue

                        # calculate the CTC loss of each word
                        torch.backends.cudnn.enabled = False if sequences.shape[1] > 1 else True
                        loss_s = ctc_loss(outputs_i, word, original_lengths[i], torch.tensor([len(word)]).to(device)).item()
                        torch.backends.cudnn.enabled = True
                        if math.isnan(loss_s):
                            loss_s = 1e5
                        loss_search.append(loss_s)
                    # make sure the true label is in the lexicon
                    if true_label_index is None:
                    # raise ValueError("True label not in the lexicon!")
                        print(f"Test label {label_i} not in the lexicon!")
                        continue                    
                    # Get the rank of the true label
                    sorted_loss = np.argsort(loss_search)
                    rank_of_true_label = list(sorted_loss).index(true_label_index)
                    rank_of_batch.append(rank_of_true_label)
                rank_of_all.append(rank_of_batch)
    # print the statistics of every element in the rank_of_all
    rank_of_all = [item for sublist in rank_of_all for item in sublist]
    topNacc = []
    for i in range(1, 10):
        topNacc.append(np.sum(np.array(rank_of_all) <= i) / len(rank_of_all))
    # print topNacc with 3 decimal places
    print([round(i, 3) for i in topNacc])
    return topNacc #rank_of_all
                                            


In [2]:
print(torch.cuda.is_available())
with open("../Oxford 1035 Export pure_ext.csv", newline="") as csvfile:
    data = list(csv.reader(csvfile))
    word_list_ori = [row[0] for row in data[1:]]
    # find the longest word length
    max_word_length = max([len(word) for word in word_list_ori])
    print(f"Max word length is: {max_word_length}")
with open("../Corpus/MS_phrases2.csv", newline="",) as csvfile:
    data = list(csv.reader(csvfile))
    sentences_ori = [row[0] for row in data[1:]]
    # split the sentences into words, remove the empty words, then convert the words to numbers
    w_in_sentences = [[ord(c) - 96 for c in re.sub('[^a-zA-Z]', '', word).lower()] for sentence in sentences_ori for word in sentence.split()]
    print(max([len(word) for word in w_in_sentences]))

True
Max word length is: 14
13


In [3]:
with open("../Oxford 1035 Export pure_ext.csv", newline="") as csvfile:
    data = list(csv.reader(csvfile))
    # convert the words to visemes, and 

In [4]:
# print_visemes("bat")
# print_visemes("bought")
# print_visemes("boat")
# print_visemes("bait")
# print_visemes("bet")
# print_visemes("but")
# print_visemes("beat")
# print_visemes("bit")
# print_visemes("hang")
# print_visemes("vision")
# print_visemes("shy")
# print_visemes("jive")
# print_visemes("i've")
# "Convert a word to a list of phonemes using CMU Pronouncing Dictionary.".split()

NameError: name 'print_visemes' is not defined

In [7]:


# Download the CMU dictionary if you haven't already
# nltk.download('cmudict')
cmudict = nltk.corpus.cmudict.dict()
cmudict["behaviour"] = cmudict["behavior"]
cmudict["colour"] = cmudict["color"]
cmudict["favourite"] = cmudict["favorite"]
cmudict["covid"] = ["K", "OW", "V", "IH", "D"]
cmudict["racketball"] = cmudict["racquetball"]
cmudict["ebook"] = cmudict["e"] + cmudict["book"]


# Define a phoneme-to-viseme mapping (simplified version)
# phon_to_vise = {
#     'AA': '1', 'AE': '2', 'AH': '1', 'AO': '1', 'AW': '3', 'AY': '3',
#     'B': '4', 'CH': '5', 'D': '6', 'DH': '6', 'EH': '2', 'ER': '1',
#     'EY': '2', 'F': '7', 'G': '8', 'HH': '9', 'IH': '2', 'IY': '2',
#     'JH': '5', 'K': '8', 'L': '10', 'M': '4', 'N': '11', 'NG': '11',
#     'OW': '3', 'OY': '3', 'P': '4', 'R': '12', 'S': '13', 'SH': '5',
#     'T': '6', 'TH': '13', 'UH': '3', 'UW': '3', 'V': '7', 'W': '9',
#     'Y': '9', 'Z': '13', 'ZH': '5'
# }
gpt_vise = {
    'AA': '1', 'AH': '1', 'AO': '1', 'ER': '1', 
    'AE': '2', 'EH': '2', 'EY': '2', 'IH': '2', 'IY': '2',	
    'AW': '3', 'AY': '3', 'OY': '3', 'UH': '3','UW': '3','OW': '3',
    'B': '4', 'M': '4', 'P': '4',
    'CH': '5', 'JH': '5', 'SH': '5', 'ZH': '5',
    'D': '6', 'DH': '6', 'T': '6',
    'F': '7', 'V': '7',
    'G': '8', 'K': '8',
    'HH': '9', 'W': '9', 'Y': '9',
    'L': '10', 
    'N': '11', 'NG': '11',
    'R': '12',
    'S': '13', 'Z': '13', 'TH': '13'
    }
# Define a phoneme-to-viseme mapping (Lee's version) 
## {/A/ /aʊ/ /ai/ /ʌ/}{/e/ /ei/ /æ/} {/i/ /I/} {/O/ /OI/ /@U/} {/U/ /u/}
lee_vise = { ####Lee
    'AA': '1', 'AW': '1', 'AY': '1', 'AH': '1', 
    'AE': '2', 'EH': '2', 'EY': '2', 
    'IH': '3', 'IY': '3',
    'AO': '4', 'OY': '4', 'OW': '4',
    'UH': '5', 'UW': '5',
    'B': '6', 'P': '6', 'M': '6',
    'CH': '7', 'JH': '7', 'SH': '7', 'ZH': '7',
    'D': '8', 'T': '8', 'S': '8', 'Z': '8', 'TH': '8', 'DH': '8',
    'F': '9', 'V': '9',
    'G': '10', 'K': '10', 'N': '10', 'NG': '10', 'L': '10', 'Y': '10', 'HH': '10', 
    'R': '11', 'W': '11',
    'ER': '12'
    }

disney_wood_vise = { ####　Disney Vowel and Woodward Consonant
    'B': '1', 'P': '1', 'M': '1',
    'D': '2', 'T': '2', 'N': '2','L': '2','TH': '2','DH': '2','S': '2','Z': '2','CH': '2','JH': '2','SH': '2','ZH': '2','Y': '2','K': '2','G': '2','HH': '2','NG': '2',
    'F': '3', 'V': '3',
    'R': '4', 'W': '4',
    'IH': '5', 'IY': '5',
    'UH': '6', 'UW': '6',
    'OW': '7',
    'AO': '8', 'OY': '8', 'AW': '8',
    'AA': '9', 'AE': '9','ER': '9',   
    'AH': '10', 'EH': '10', 'EY': '10',	'AY': '10'
    }

disney_wood_9 = { ####　Disney Vowel and Woodward Consonant
    'B': '1', 'P': '1', 'M': '1',
    'D': '2', 'T': '2', 'N': '2','L': '2','TH': '2','DH': '2','S': '2','Z': '2','CH': '2','JH': '2','SH': '2','ZH': '2','Y': '2','K': '2','G': '2','HH': '2','NG': '2',
    'F': '3', 'V': '3',
    'R': '4', 'W': '4',
    'IH': '5', 'IY': '5',
    'UH': '6', 'UW': '6',
    'OW': '7',
    'AO': '8', 'OY': '8', 'AW': '8',
    'AA': '9', 'AE': '9','ER': '9',   
    'AH': '9', 'EH': '9', 'EY': '9','AY': '9'
    }

phon_to_phon = {
    'AA': '1', 'AE': '2', 'AH': '3', 'AO': '4', 'AW': '5', 'AY': '6',
    'B': '7', 'CH': '8', 'D': '9', 'DH': '10', 'EH': '11', 'ER': '12',
    'EY': '13', 'F': '14', 'G': '15', 'HH': '16', 'IH': '17', 'IY': '18',
    'JH': '19', 'K': '20', 'L': '21', 'M': '22', 'N': '23', 'NG': '24',
    'OW': '25', 'OY': '26', 'P': '27', 'R': '28', 'S': '29', 'SH': '30',
    'T': '31', 'TH': '32', 'UH': '33', 'UW': '34', 'V': '35', 'W': '36',
    'Y': '37', 'Z': '38', 'ZH': '39'
}

# phon_to_vise = phon_to_phon
# phon_to_vise = gpt_vise
# phon_to_vise = lee_vise
phon_to_vise = disney_wood_vise

def count_visemes(phone2visemeDict):
    viseme_count = set()
    for phoneme, viseme in phone2visemeDict.items():
        viseme_count.add(viseme)
    return len(viseme_count)

def remove_stress(phoneme):
    return ''.join([char for char in phoneme if not char.isdigit()])

def word_to_phonemes(word):
    """Convert a word to a list of phonemes using CMU Pronouncing Dictionary."""
    word = word.lower()
    if ' ' in word[:-1]:
        wordlist = word.split(' ')
    else:
        wordlist = [word]
    phon_seq = []
    for word_ele in wordlist:
        if word_ele in cmudict:
            # Take the first pronunciation entry
            phon_ele = cmudict[word_ele][0]
            phon_seq.extend(phon_ele)
        else:
            print(f"Word '{word_ele}' in '{word}' not found in CMUdict.")
            return []
    return [remove_stress(phoneme) for phoneme in phon_seq]

def phonemes_to_visemes(phoneme_sequence, viseme_dict = phon_to_vise, ifviseme = True):
    """Convert a list of phonemes to visemes using a predefined mapping."""
    if ifviseme:
        cvt_dict = viseme_dict
    else:
        cvt_dict = phon_to_phon
    for phoneme in phoneme_sequence:
        if phoneme not in cvt_dict:
            print(f"Phoneme '{phoneme}' not found in the viseme dictionary.")
            return []
    visemes = [cvt_dict[phoneme] for phoneme in phoneme_sequence]
    # visemes = [phon_to_vise.get(phoneme, '0') for phoneme in phoneme_sequence]
    return visemes

def print_visemes(word, ifviseme = True):    
    phoneme_sequence = word_to_phonemes(word)
    if phoneme_sequence:
        print(f"Phonemes for '{word}': {phoneme_sequence}")
        if ifviseme:
            viseme_sequence = phonemes_to_visemes(phoneme_sequence)
            print(f"Visemes for '{word}': {viseme_sequence}")



def phoneme_labelling(data_ori, ifviseme = True, viseme_dict = phon_to_vise):
    ## can be used for both csv and readed mat file
    iscsv = isinstance(data_ori[0][0], str)
    data_copy = data_ori.copy()
    if iscsv:
        data_copy = [data_copy]
    elif len(data_ori.shape) > 1: ## if the data is mat file but is session data
        data_copy = [data_copy]
    viseme_lexicon = {}
    word_to_viseme = {}
    conflict_list = set()
    phoneme_lack_list = set()
    for session in data_copy:
        for i in range(len(session)):
            textword = session[i][1][0].lower() if not iscsv else session[i][0].lower()
            if textword in word_to_viseme:
                continue
            phoneme_sequence = word_to_phonemes(textword)
            if phoneme_sequence:
                viseme_sequence = phonemes_to_visemes(phoneme_sequence, viseme_dict = viseme_dict, ifviseme = ifviseme) 
                cvt_seq = [int(x) for x in viseme_sequence]
                # build the word_to_viseme dictionary
                if textword not in word_to_viseme:
                    word_to_viseme[textword] = np.array([cvt_seq])
                # if viseme_lexicon do not have a key of the viseme_sequence, then add the viseme_sequence as the key and the word as the value
                if tuple(cvt_seq) not in viseme_lexicon:
                    viseme_lexicon[tuple(cvt_seq)] = textword
                else:
                    if textword != viseme_lexicon[tuple(cvt_seq)]:
                        if (textword, viseme_lexicon[tuple(cvt_seq)]) not in conflict_list:
                            print(f"Conflict: {textword} and {viseme_lexicon[tuple(cvt_seq)]}")
                            conflict_list.add((textword, viseme_lexicon[tuple(cvt_seq)]))
            else:
                print(f"Error, Word {textword} not in the CMU dictionary!")
                phoneme_lack_list.add(textword)
    if len(phoneme_lack_list) > 0:
        return phoneme_lack_list
    if len(conflict_list) > -1:  ## 
    # if len(conflict_list) == 0: ## if there is no conflict    
        print("No conflict in the lexicon!")
        for session in data_copy:
            for i in range(len(session)):
                textword = session[i][1][0].lower() if not iscsv else session[i][0].lower()
                if textword in word_to_viseme:
                    # if the data is csv, save viseme in 2nd column, otherwise save in 1st column
                    if iscsv:
                        session[i].append(word_to_viseme.get(textword, [[0]]))
                    else:
                        session[i][0] = word_to_viseme.get(textword, [[0]])
                else:
                    print(f"Error, Word {textword} not in the word_to_viseme dictionary!")
        return data_copy
    else:
        print("Conflict in the lexicon!")
        return None
# data_viseme = phoneme_labelling(data)
# csvname = "Oxford 1035 Export pure_ext.csv"
# with open("../Oxford 1035 Export pure_ext.csv", newline="") as csvfile:
#     csvdata = list(csv.reader(csvfile))
#     csv_vise = phoneme_labelling(csvdata)
#     # save the viseme data to a csv file
#     with open("./Oxford 1035 Export pure_ext_vise.csv", "w", newline="") as f:
#         writer = csv.writer(f)
#         writer.writerows(csv_vise)
# count_visemes(phon_to_vise)
# # Example usage:

print_visemes("e-book")
print(cmudict["a"][0])

Word 'e-book' in 'e-book' not found in CMUdict.
['AH0']


In [None]:
####--- Generate viseme for spelled words
# print phoneme of a==>z
for i in range(97, 123):
    print(chr(i), word_to_phonemes(chr(i)))
wordsample = "hello"
# connect the phoneme of all characters in wordsample
listlist_phon = [word_to_phonemes(c) for c in wordsample]
list_phon = [phoneme for list_phon in listlist_phon for phoneme in list_phon] # left for goes ahead
# print(phonemes)

a ['AH']
b ['B', 'IY']
c ['S', 'IY']
d ['D', 'IY']
e ['IY']
f ['EH', 'F']
g ['JH', 'IY']
h ['EY', 'CH']
i ['AY']
j ['JH', 'EY']
k ['K', 'EY']
l ['EH', 'L']
m ['EH', 'M']
n ['EH', 'N']
o ['OW']
p ['P', 'IY']
q ['K', 'Y', 'UW']
r ['AA', 'R']
s ['EH', 'S']
t ['T', 'IY']
u ['Y', 'UW']
v ['V', 'IY']
w ['D', 'AH', 'B', 'AH', 'L', 'Y', 'UW']
x ['EH', 'K', 'S']
y ['W', 'AY']
z ['Z', 'IY']


In [None]:
# wordmat = loadmat('../BC data/'+ word_file +'.mat')
# mat = wordmat
# data = mat[list(mat.keys())[-1]]
# data = data[0,:]
# session = data[0]

# Cell for Training BC Data

In [None]:
phon_to_vise = phon_to_phon
# phon_to_vise = gpt_vise
# phon_to_vise = lee_vise
# phon_to_vise = disney_wood_vise #disney_wood_9

# Define the parameters
num_channels = [16,32]  # number of channels
kernel_size = 11  # kernel size
# input_size = 1 | 2  # number of input channels
batch_size = 8 # Define the batch size
batch_size_test = 128
output_size = 1 + count_visemes(phon_to_vise)  # number of output channels
normal_train_epoch = 30

# change the device to the second GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Define the CTC Loss function
ctc_loss = CTCLoss(blank=0, reduction='mean',zero_infinity=True)

# file_list = ['']

# def train_with_file_list(mode="word"):

mode="phrase"
print(f"{mode} Mode")
word_list = ['_dmu1hzy_normal'] #dxf_44100_right_100phrase_down_spl6_str3 hlx_50word5times_down_spl6_str3
# word_list = glob.glob("../BC data/xlq_44100_50word_5times_down_spl6_*.mat")
# word_list = [word_l.split("\\")[-1].split(".")[0] for word_l in word_list]

word_list_right = []
sen_list = []
file_list = word_list if mode == "word" or "phrase" else sen_list

# choose a sublist of file to train
# subindex = [4]
# file_list = [file_list[i] for i in subindex]
acc_list = []
pr = 0
self_session = 3
losocv = ["leave1","leave2","leave3","none"] # ["leave1","leave2","leave3","none","noise","headmotion","walk"]
lex_size = [100] #[100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
# lex_size = None
for file_name in file_list:
    tcnmodel1 = None
    train_data_raw = None
    for pr in [1]: #[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]:#[0, 0.2, 0.4, 0.6, 0.8]: #
        for self_session in [3]: #np.arange(5): #
            for scenario in ["none","none_mtrain"]: #["readWord"]: #losocv: # 
                print(f"Subject: {file_name}, pr: {pr}, session: {self_session}, scenario: {scenario}")
                # Load the Matlab file
                # wordmat = loadmat('../UTokyo data/utokyo_word_fea_left_'+ file_name +'.mat')
                for sen_file in sen_list:
                    if sen_file.startswith(file_name[:4]):
                        senmat = loadmat('../UTokyo data/utokyo_sen_fea_left_'+ sen_file +'_spl6.mat')
                        break
                for word_file in word_list:
                    if word_file.startswith(file_name[:4]):
                        wordmat = loadmat('../BC data/'+ word_file +'.mat')
                        print(f"File {word_file} loaded!")
                        break
                # for word_file_right in word_list_right:
                #     if word_file_right.startswith(file_name[:4]):
                #         wordmat_right = loadmat('../UTokyo data/utokyo_word_fea_right_'+ word_file_right +'.mat')
                #         break

                if mode == "word" or mode == "phrase":
                    mat = wordmat
                    # data1 = senmat[list(senmat.keys())[-1]][0,:] # data1 is the senmat data
                elif mode == "sen":
                    mat = senmat
                    data1 = wordmat[list(wordmat.keys())[-1]][0,:] # data1 is the wordmat data
                    
                # dataRword = wordmat_right[list(wordmat_right.keys())[-1]][0,:]
                data = mat[list(mat.keys())[-1]]
                data = data[0,:]
                # if scenario == "readWord":

                kfold = KFold(n_splits=5, shuffle=True, random_state=120)        

                if mode == "phrase":
                    data = phoneme_labelling(data, viseme_dict=phon_to_vise)
                    print("Word to viseme done!")
                    
                if mode == "word" or mode == "phrase":
                                    
                    if scenario == "none":
                        test_data_raw = data[3]
                        train_data_raw = np.concatenate((data[[0,1,2]]), axis=0)
                    elif scenario == "none_mtrain":
                        test_data_raw = data[3]
                        train_data_raw = np.concatenate((data[[0,1,2,4]]), axis=0)
                    elif scenario == "none_wtrain":
                        test_data_raw = data[3]
                        train_data_raw = np.concatenate((data[[0,1,2,6]]), axis=0)
                    elif scenario == "leave1":
                        test_data_raw = data[0]
                        train_data_raw = np.concatenate((data[[1,2,3]]), axis=0)
                    elif scenario == "leave2":
                        test_data_raw = data[1]
                        train_data_raw = np.concatenate((data[[0,2,3]]), axis=0)
                    elif scenario == "leave3":
                        test_data_raw = data[2]
                        train_data_raw = np.concatenate((data[[0,1,3]]), axis=0)
                    elif scenario == "music":
                        test_data_raw = data[5]
                        train_data_raw = np.concatenate((data[[0,1,2,3]]), axis=0)
                    elif scenario == "noise":
                        test_data_raw = data[4]
                        train_data_raw = np.concatenate((data[[0,1,2,3]]), axis=0)
                    elif scenario == "walk":
                        test_data_raw = data[6]
                        train_data_raw = np.concatenate((data[[0,1,2,3]]), axis=0)
                        
                    elif scenario == "onesession":
                        all_data_raw = np.concatenate((data[:7]), axis=0) if len(data) > 1 else data[0]
                        # test_data_raw = data[6]
                        # train_data_raw, test_data_raw = split_data_based_on_labels(all_data_raw, priority="train")
                        data1234, data5 = split_data_based_on_labels(all_data_raw, priority="train")
                        data123, data4 = split_data_based_on_labels(data1234, priority="train")
                        data12, data3 = split_data_based_on_labels(data123, priority="train")
                        data_1, data2 = split_data_based_on_labels(data12, priority="train")
                        train_data_raw = np.concatenate((data_1, data5, data4, data2), axis=0)
                        test_data_raw = data3


                    
                    

                if len(test_data_raw[0][2].shape) == 2:
                    input_size = 1
                else:
                    input_size = test_data_raw[0][2].shape[2] # if the input channel number is 2, then the shape would be >3d, and 1 for 2d
                print(f"Input size is: {input_size}")
                # delete the model if it already exists
                if 'model' in locals() and 'optimizer' in locals() and 'scheduler' in locals():
                    del model
                    del optimizer
                    del scheduler
                    del train_loader
                    del test_loader
                    del ctc_loss
                    torch.cuda.empty_cache()
                    # make sure local variable of function is deleted                    
                    print("Model, optimizer and scheduler deleted!")
                    ctc_loss = CTCLoss(blank=0, reduction='mean',zero_infinity=True)
                
                finetune_raw = None

                # # fetch train_data_raw from lomo function
                # if train_data_raw is None:
                #     train_data_raw, finetune_raw = get_lomo_train_data(word_list, sen_list, file_name, exp_mode="lomoword"+mode, pr=pr, ftsession=self_session, scenario=scenario)
                # else:
                #     _, finetune_raw = get_lomo_train_data(word_list, sen_list, file_name, exp_mode="finetune"+mode, pr=pr, ftsession=self_session, scenario=scenario)      


                # # split by vocabulary
                # train_data_raw, test_data_raw = split_data_based_on_labels(data[6])

                # train_data_raw = data[0]
                # load the test data from single file
                # test_mat = loadmat('./fyt_4_1.mat')
                # test_data_raw = test_mat[list(test_mat.keys())[-1]]


                if finetune_raw is None:
                    data = np.concatenate((train_data_raw,test_data_raw), axis=0)
                else:
                    data = np.concatenate((train_data_raw,test_data_raw,finetune_raw), axis=0)

                # Find the maximum sequence length in the data
                # fetch from the third column to last column
                max_seq_length = max([sample.shape[0] for sample in data[:, 2:].reshape(-1)])
                max_lab_length = max([sample[0].shape[0] for sample in data[:, 0] if type(sample[0]) == np.ndarray]) # fetch from the first column
                max_word_length = max_lab_length 
                # max_lab_length = max_word_length # 14 # fetch from the csv

                print(max_seq_length)
                print(max_lab_length)


                # Process the data
                processed_train_data,train_lexicon = process_data(train_data_raw, max_lab_length, max_seq_length ,trainortest="train", label_class=output_size-1) # 
                processed_test_data,test_lexicon = process_data(test_data_raw, max_lab_length, max_seq_length, label_class=output_size-1) #need_pad=False when batch_size=1

                # # Split the data into training and testing datasets
                # train_data, test_data = train_test_split(processed_data, test_size=0.2, random_state=0)

                # Create the training and testing datasets
                train_dataset = SequenceDataset(processed_train_data)
                test_dataset = SequenceDataset(processed_test_data)

                # Define the data loaders
                train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
                test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=True)

                # Check its type
                print(f"Type: {type(data)}")

                # Check its shape
                print(f"Shape: {data.shape}")

                # print size of train and test data
                print(f"Train data size: {len(train_data_raw)}, Test data size: {len(test_data_raw)}")

                # Print the first few items
                print(f"First few items: {data[0, :3]}")

                # Create the model
                torch.manual_seed(581)

                model = TCN(input_size, output_size, num_channels, kernel_size)

                # Define the optimizer
                optimizer = optim.Adam(model.parameters(), lr=0.0002, weight_decay=0.01)
                # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.3, verbose=True)
                scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=13, gamma=0.3, verbose=True)
                # warm up the model
                
                model.to(device)
                torch.backends.cudnn.enabled = True
                # # use all the available GPUs
                # if torch.cuda.device_count() > 1:
                #     print("Using", torch.cuda.device_count(), "GPUs!")
                #     model = nn.DataParallel(model)

            # Train the model
                if finetune_raw is not None:
                    if tcnmodel1 is not None:
                        print("Model already trained, use the previous model to fine tune")
                    else:
                        tcnmodel1, rank_of_all = train_model(model, train_loader, test_loader, 10, optimizer, ctc_loss, scheduler, test_lexicon=test_lexicon, output_size=output_size)
                        
                    # print the name and pr of the file with the "fine tune started"
                    print(f"Fine tune started: {file_name}, Portion is {pr}")
                    processed_finetune_data, _ = process_data(finetune_raw, max_lab_length, max_seq_length, trainortest="train")
                    finetune_dataset = SequenceDataset(processed_finetune_data)
                    finetune_loader = DataLoader(finetune_dataset, batch_size=batch_size, shuffle=True)
                    # Fine-tune the model using finetune_loader
                    optimizer = optim.Adam(tcnmodel1.parameters(), lr=0.0002, weight_decay=0.01)
                    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.3, verbose=True)
                    tcnmodel2 = tcnmodel1
                    tcnmodelft,rank_of_all = train_model(tcnmodel2, finetune_loader, test_loader, 5, optimizer, ctc_loss, scheduler, test_lexicon=test_lexicon, output_size=output_size)
                else:
                    tcnmodel1, rank_of_all = train_model(model, train_loader, test_loader, normal_train_epoch, optimizer, ctc_loss, scheduler, test_lexicon=test_lexicon, output_size=output_size)


            # test different lexicon size
                if mode == "word" or mode == "phrase":
                    if lex_size is None: # original lexicon
                        # print the statistics of every element in the rank_of_all
                        topNacc = []
                        for i in range(1, 10):
                            topNacc.append(np.sum(np.array(rank_of_all) <= i) / len(rank_of_all))
                        # print topNacc with 3 decimal places
                        print([round(i, 3) for i in topNacc])
                        # append topNacc, file_name, pr to acc_list
                        acc_list.append([round(i, 5) for i in topNacc] + [file_name, scenario,pr, self_session, len(test_lexicon)])
                    else: # test different lexicon size
                        if mode == "phrase":
                            if lex_size[0] > len(test_lexicon):
                                lex_size.insert(0, len(test_lexicon))
                        for lex_size_i in lex_size:
                            # test_lexicon_i = set(random.sample(test_lexicon, lex_size_i))
                            print(f"lex_size: {lex_size_i}")
                            topNacc = ctcdecode_byloss(tcnmodel1, test_loader, test_lexicon=test_lexicon, test_lex_size=lex_size_i)
                            # append topNacc, file_name, pr to acc_list
                            acc_list.append([round(i, 5) for i in topNacc] + [file_name, scenario, pr, self_session,  lex_size_i])
                elif mode == "sen":
                    topNacc = []
                # WER per phrase
                    # acc_phrase = []
                    # for phrase in test_data_phrase:
                    #     process_phrase, _ = process_data(phrase, max_lab_length, max_seq_length)
                    #     phrase_dataset = SequenceDataset(process_phrase)
                    #     phrase_loader = DataLoader(phrase_dataset, batch_size=len(phrase_dataset), shuffle=False)
                    #     acc_tmp = ctcdecode_byloss(tcnmodel1, phrase_loader, test_lexicon=test_lexicon)
                    #     acc_phrase.append(acc_tmp)
                    # # calculate the average accuracy of the phrase
                    # topNacc = np.mean(acc_phrase, axis=0)
                    # # append topNacc, file_name, pr, lex_size to acc_list
                    # acc_list.append([round(i, 5) for i in topNacc] + [file_name, pr, self_session, len(test_lexicon)])
                # WER of all phrases
                    for i in range(1, 10):
                        topNacc.append(np.sum(np.array(rank_of_all) <= i) / len(rank_of_all))
                    # print topNacc with 3 decimal places
                    print([round(i, 3) for i in topNacc])
                    # append topNacc, file_name, pr to acc_list
                    acc_list.append([round(i, 5) for i in topNacc] + [file_name, scenario,pr, self_session, len(test_lexicon)])

                        

                    

#     return acc_list

# acc_list = train_with_file_list()

phrase Mode
Subject: dxf_right_gngram_phrase_downup_spl6_str3_lag100_var, pr: 1, session: 3, scenario: none
File dxf_right_gngram_phrase_downup_spl6_str3_lag100_var loaded!
Conflict: per cent of the and percent of the
No conflict in the lexicon!
Word to viseme done!
Input size is: 4
Model, optimizer and scheduler deleted!
125
38
Lexicon is written, its size is: 100
Type: <class 'numpy.ndarray'>
Shape: (400, 3)
Train data size: 300, Test data size: 100
First few items: [array([[20, 21, 13, 22, 31, 34, 11, 23, 18, 27, 12, 31, 17, 20, 37,  3,
         21, 12, 14,  1, 23, 31]])
 array(['CLAIM TO ANY PARTICULAR FONT'], dtype='<U28')
 array([[[-9.28214199e+00,  5.25925525e-01, -1.91974801e-01,
          -9.23803502e+00],
         [-3.99525938e+01,  4.31299646e+00, -1.24345385e+00,
          -3.97003462e+01],
         [-1.05546539e+02,  1.63171821e+01, -3.23437637e+00,
          -1.04964467e+02],
         ...,
         [-2.38270365e+00, -7.69486715e-01, -1.03815951e+00,
          -1.79156669e

In [44]:
import csv
print(acc_list)
# save the acc_list to a csv file
with open('../results/test.csv', 'w', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(['top1','top2','top3','top4','top5','top6','top7','top8','top9','file_name','scenario','pr','num_session','lex_size'])
    writer.writerows(acc_list)

[[0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 'dxf_right_gngram_phrase_downup_spl6_str3_lag100_var', 'leave1', 1, 3, 100], [0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 'dxf_right_gngram_phrase_downup_spl6_str3_lag100_var', 'music', 1, 3, 100], [0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 'dxf_right_gngram_phrase_downup_spl6_str3_lag100_var', 'noise', 1, 3, 100], [0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 'dxf_right_gngram_phrase_downup_spl6_str3_lag100_var', 'walk', 1, 3, 100], [0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 'dxf_right_gngram_phrase_downup_spl6_str3_lag100_var', 'none', 1, 3, 100]]
