# Installation

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! nvidia-smi

Thu Jan 13 09:55:20 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   75C    P0    32W /  70W |   7946MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
! pip install --upgrade transformers
# ! pip3 install git+https://github.com/huggingface/transformers

Collecting transformers
  Downloading transformers-4.15.0-py3-none-any.whl (3.4 MB)
[K     |████████████████████████████████| 3.4 MB 5.0 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 6.5 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 55.5 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 73.4 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 57.4 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  A

In [None]:
! python -V

In [None]:
! pip install --upgrade torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-0.6.2-py3-none-any.whl (332 kB)
[?25l[K     |█                               | 10 kB 41.1 MB/s eta 0:00:01[K     |██                              | 20 kB 9.7 MB/s eta 0:00:01[K     |███                             | 30 kB 8.5 MB/s eta 0:00:01[K     |████                            | 40 kB 8.1 MB/s eta 0:00:01[K     |█████                           | 51 kB 4.8 MB/s eta 0:00:01[K     |██████                          | 61 kB 5.3 MB/s eta 0:00:01[K     |███████                         | 71 kB 5.2 MB/s eta 0:00:01[K     |███████▉                        | 81 kB 5.9 MB/s eta 0:00:01[K     |████████▉                       | 92 kB 4.6 MB/s eta 0:00:01[K     |█████████▉                      | 102 kB 5.1 MB/s eta 0:00:01[K     |██████████▉                     | 112 kB 5.1 MB/s eta 0:00:01[K     |███████████▉                    | 122 kB 5.1 MB/s eta 0:00:01[K     |████████████▉                   | 133 kB 5.1 MB/s eta 0:0

# Import

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils import rnn
from transformers import BertTokenizer
import torchaudio
import torchaudio.transforms as T
from sklearn import preprocessing
import pandas as pd
import numpy as np
import os

# For Configuration
from dataclasses import dataclass

# For Model
from transformers import AutoTokenizer, AutoModelForPreTraining, AutoModel
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

# For Train and Valid and Test
from tqdm.auto import tqdm
from torchmetrics import SpearmanCorrcoef

# To control which GPU to run on
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

# Configuration

Users only have to adjust the parameters here.

In [None]:
@dataclass
class BaseCFG:
    batch_size: int
    downstream_lr: float 
    audio_encoder_lr: float 
    weight_decay : float 
    audio_encoder_model : str 
    audio_embedding : int 
    hidden_dim : int 
    intent_dim : int 
    
    dataset : str
    task_type : str
    # Is it correct?
    max_length : int
    data_root : str

    # For Asrglue
    subtask : str = ""
    noise_level: str = ""
    
    trainable : bool = True
    num_of_workers : int = 2

    project_root : str = "/content/drive/Shareddrives/miulab/checkpoints/"
    patience : int = 1
    factor : float = 0.8

    device : torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    reduction : str = "mean"
    text_tokenizer : str = "distilbert-base-uncased"
    epochs : int = 30
    checkpoint_continue : bool = False

    # for random split ASRGLUE dataset
    torch.manual_seed(0)

@dataclass
class FluentSpeechCFG(BaseCFG):
    dataset = "fsc"
    task_type = "classification"
    max_length = 128
    data_root = "/content/drive/Shareddrives/miulab/fluent_speech_commands_dataset/"
    intent_dim = 31

@dataclass
class AsrglueCFG(BaseCFG):
    dataset = "asrglue"
    max_length = 512 # useless
    data_root = "/content/drive/Shareddrives/miulab/ASRGLUE/dev/"
 
@dataclass
class AsrglueStsbCFG(AsrglueCFG):
    task_type = "prediction"
    subtask = "sts-b"
    intent_dim = 1
    noise_level = "low"
    downstream_lr = 1e-5

@dataclass
class AsrglueSst2CFG(AsrglueCFG):
    task_type = "classification"
    subtask : str = "sst-2"
    intent_dim = 2
    noise_level = "low"
    downstream_lr = 1e-3

@dataclass
class AsrglueRteCFG(AsrglueCFG):
    task_type = "classification"
    subtask : str = "rte"
    intent_dim = 2
    noise_level = "low"
    downstream_lr = 1e-5

@dataclass
class AsrglueQqpCFG(AsrglueCFG):
    task_type = "classification"
    subtask : str = "qqp"
    intent_dim = 2
    noise_level = "low"
    downstream_lr = 1e-4

@dataclass
class AsrglueQnliCFG(AsrglueCFG):
    task_type = "classification"
    subtask : str = "qnli"
    intent_dim = 2
    noise_level = "low"
    downstream_lr = 1e-5

@dataclass
class AsrglueScitailCFG(AsrglueCFG):
    task_type = "classification"
    subtask : str = "scitail"
    intent_dim = 2
    noise_level = "low"
    downstream_lr = 1e-5

@dataclass
class HubertAsrglueStsbCFG(AsrglueStsbCFG):
    batch_size = 1
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/hubert-base-ls960'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "hubert_asrglue_stsb_test8_"

@dataclass
class Wav2Vec2AsrglueStsbCFG(AsrglueStsbCFG):
    batch_size = 2
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/wav2vec2-base-960h'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "wav2vec2_asrglue_stsb_"

@dataclass
class HubertAsrglueSst2CFG(AsrglueSst2CFG):
    batch_size = 2
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/hubert-base-ls960'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "hubert_asrglue_sst2_"

@dataclass
class Wav2Vec2AsrglueSst2CFG(AsrglueSst2CFG):
    batch_size = 1
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/wav2vec2-base-960h'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "wav2vec2_asrglue_sst2_"

@dataclass
class HubertAsrglueRteCFG(AsrglueRteCFG):
    batch_size = 1
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/hubert-base-ls960'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "hubert_asrglue_rte_"

@dataclass
class Wav2Vec2AsrglueRteCFG(AsrglueRteCFG):
    batch_size = 2
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/wav2vec2-base-960h'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "wav2vec2_asrglue_rte_"

@dataclass
class HubertAsrglueQqpCFG(AsrglueQqpCFG):
    batch_size = 1
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/hubert-base-ls960'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "hubert_asrglue_qqp_"

@dataclass
class Wav2Vec2AsrglueQqpCFG(AsrglueQqpCFG):
    batch_size = 2
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/wav2vec2-base-960h'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "wav2vec2_asrglue_qqp_"

@dataclass
class HubertAsrglueQnliCFG(AsrglueQnliCFG):
    batch_size = 1
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/hubert-base-ls960'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "hubert_asrglue_qnli_"

@dataclass
class Wav2Vec2AsrglueQnliCFG(AsrglueQnliCFG):
    batch_size = 2
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/wav2vec2-base-960h'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "wav2vec2_asrglue_qnli_"

@dataclass
class HubertAsrglueScitailCFG(AsrglueScitailCFG):
    batch_size = 2
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/hubert-base-ls960'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "hubert_asrglue_scitail_"

@dataclass
class Wav2Vec2AsrglueScitailCFG(AsrglueScitailCFG):
    batch_size = 2
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/wav2vec2-base-960h'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "wav2vec2_asrglue_scitail_"

@dataclass
class HubertFscCFG(FluentSpeechCFG): 
    batch_size = 4
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/hubert-base-ls960'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "hubert_fsc_"

@dataclass
class Wav2Vec2FscCFG(FluentSpeechCFG):
    batch_size = 4
    downstream_lr = 1e-4
    audio_encoder_lr = 1e-5
    weight_decay = 1e-2
    audio_encoder_model = 'facebook/wav2vec2-base-960h'
    audio_embedding = 768
    hidden_dim = 256
    project_root = BaseCFG.project_root + "wav2vec2_fsc_"


In [None]:
# curCFG = HubertAsrglueStsbCFG
curCFG = Wav2Vec2AsrglueScitailCFG
do_train = True
do_test = True

# Utils

In [None]:
class AverageMeter:

    def __init__(self):
        self.reset()

    def reset(self):
        self.sum_val = 0.0
        self.count = 0

    def update(self, values):
        self.sum_val += np.sum(values)
        self.count += len(values)

    def get(self):
        return self.sum_val / self.count

# Dataset

## BaseDataset

In [None]:
class BaseDataset(Dataset):
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        raise NotImplementedError

    def labels_list(self):
        raise NotImplementedError

## FSC Dataset

In [None]:
class FluentSpeechCommandsDataset(BaseDataset):

    def __init__(self, data_root, split='train', intent_encoder=None):
        assert split in ['train', 'test', 'valid'], 'Invalid split'
        print(f"init {split} dataset....")

        self.CFG = FluentSpeechCFG

        self.data_root = data_root
        self.df = pd.read_csv(os.path.join(self.data_root, 'data/', '{}_data.csv'.format(split)))
        self.df['intent'] = self.df[['action', 'object', 'location']].apply('-'.join, axis=1)

        if intent_encoder is None:
            intent_encoder = preprocessing.LabelEncoder()
            intent_encoder.fit(self.df['intent'])
        self.intent_encoder = intent_encoder
        self.df['intent_label'] = intent_encoder.transform(self.df['intent'])

        self.labels_set = set(self.df['intent_label'])
        self.label2idx = {}
        for label in self.labels_set:
            idx = np.where(self.df['intent_label'] == label)[0]
            self.label2idx[label] = idx

        self.distilbert_tokenizer = DistilBertTokenizer.from_pretrained(self.CFG.text_tokenizer)

    def load_audio(self, idx):
        df_row = self.df.iloc[idx]
        filename = os.path.join(self.data_root, df_row['path'])
        waveform, sample_rate = torchaudio.load(filename)
        intent = df_row['intent_label']
        encoding = self.distilbert_tokenizer(
            df_row['transcription'],
            padding=True,
            truncation=True,
            max_length=self.CFG.max_length
        )
        return waveform.squeeze(), intent, encoding, df_row['transcription']

    def get_dict(self, waveform, intent, encoding, transcription, suffix=''):
        ret_dict = {
            'waveform':waveform,
            'label':intent,
            'encoded_text':torch.tensor(encoding['input_ids']).flatten(),
            'text_length':torch.tensor(encoding['input_ids']).flatten().shape[0],
            'raw_text':transcription
        }
        ret_dict = {k+suffix:v for k,v in ret_dict.items()}
        return ret_dict

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        waveform, intent, encoding, transcription = self.load_audio(idx)
        ret = self.get_dict(waveform, intent, encoding, transcription)
        return ret

    def labels_list(self):
        return self.intent_encoder.classes_

## ASRGLUE Dataset

In [None]:
class ASRGLUEDataset(BaseDataset):

    def __init__(self, CFG, split='train', intent_encoder=None):
        assert split in ['train', 'test', 'valid'], 'Invalid split'

        self.CFG = CFG

        self.data_root = CFG.data_root
        self.df = pd.read_csv(os.path.join(self.data_root, '{}_{}_{}.csv'.format(self.CFG.subtask, self.CFG.noise_level, split)))

        if self.CFG.task_type == 'classification' :
            if intent_encoder is None:
                intent_encoder = preprocessing.LabelEncoder()
                intent_encoder.fit(self.df['label'])
            self.intent_encoder = intent_encoder
            self.df['label'] = intent_encoder.transform(self.df['label'])

            self.labels_set = set(self.df['label'])
            self.label2idx = {}
            for label in self.labels_set:
                idx = np.where(self.df['label'] == label)[0]
                self.label2idx[label] = idx

    def load_audio(self, idx):
        resample_rate = 16000
        df_row = self.df.iloc[idx]
        filename1 = os.path.join(self.data_root, self.CFG.subtask, df_row['path'])
        waveform1, sample_rate1 = torchaudio.load(filename1)
        if sample_rate1 != resample_rate:
            # print(f"sample rate1 {sample_rate1}")
            resampler = T.Resample(sample_rate1, resample_rate)
            resampled_waveform = resampler(waveform1)
            waveform1 = resampled_waveform
        waveform1 = waveform1.squeeze()
        label = df_row['label']

        if self.CFG.subtask in ['sts-b', 'rte', 'qqp', 'qnli', 'scitail'] :
            filename2 = os.path.join(self.data_root, self.CFG.subtask, df_row['path2'])
            waveform2, sample_rate2 = torchaudio.load(filename2)
            if sample_rate2 != resample_rate:
                # print(f"sample rate2 {sample_rate2}")
                resampler = T.Resample(sample_rate2, resample_rate)
                resampled_waveform = resampler(waveform1)
                waveform2 = resampled_waveform
            waveform2 = waveform2.squeeze()
            waveform = torch.tensor(np.concatenate([waveform1, np.zeros(10000), waveform2]))
        else :
            waveform = waveform1

        return waveform, df_row['label']

    def get_dict(self, waveform, label, suffix=''):
        ret_dict = {
            'waveform':waveform,
            'label':label,
        }
        ret_dict = {k+suffix:v for k,v in ret_dict.items()}
        # print(f"in get dict waveform len is {len(ret_dict['waveform'])}")
        return ret_dict

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        waveform, label = self.load_audio(idx)
        # print(f"in getitem waveform len is {len(waveform)}")
        ret = self.get_dict(waveform, label)
        # print(f"after get dict waveform len is {len(ret['waveform'])}")
        return ret

    def labels_list(self):
        assert self.CFG.task_type == 'classification', 'wrong task_type, required classification'
        return self.intent_encoder.classes_

## Dataloader

In [None]:
def fsc_collate_classifier(inputs):
    padded_waveforms = rnn.pad_sequence([data['waveform'] for data in inputs], batch_first=True)
    labels = torch.tensor([data['label'] for data in inputs], dtype=torch.long)
    padded_text = rnn.pad_sequence([data['encoded_text'] for data in inputs], batch_first=True)
    text_lengths = torch.tensor([data['text_length'] for data in inputs], dtype=torch.long)
    raw_text = [data['raw_text'] for data in inputs]

    return {
        'waveform' : padded_waveforms,
        'label':labels,
        'encoded_text':padded_text,
        'text_length':text_lengths,
        'raw_text':raw_text
    }

def asrglue_collate_classifier(inputs):
    padded_waveforms = rnn.pad_sequence([data['waveform'] for data in inputs], batch_first=True)
    if curCFG.task_type == "prediction" :
        labels = torch.tensor([data['label'] for data in inputs], dtype=torch.float)
    else :
        labels = torch.tensor([data['label'] for data in inputs], dtype=torch.long)

    return {
        'waveform' : padded_waveforms,
        'label':labels,
    }

def get_dataloaders(CFG, *args, **kwargs):
    data_root = CFG.data_root
    batch_size = CFG.batch_size 
    dataset = CFG.dataset
    num_workers = CFG.num_of_workers
    print(f"dataset: {dataset}")

    if dataset == 'fsc':
        train_dataset = FluentSpeechCommandsDataset(data_root, 'train', *args, **kwargs)
        val_dataset = FluentSpeechCommandsDataset(data_root, 'valid', train_dataset.intent_encoder, *args, **kwargs)
        test_dataset = FluentSpeechCommandsDataset(data_root, 'test', train_dataset.intent_encoder, *args, **kwargs)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=fsc_collate_classifier, shuffle=True, num_workers=num_workers)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=fsc_collate_classifier, num_workers=num_workers)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=fsc_collate_classifier, num_workers=num_workers)
        

    elif dataset == 'asrglue':
        print(f"subtask: {CFG.subtask} noise_level: {CFG.noise_level}")
        train_dataset = ASRGLUEDataset(CFG, 'train', *args, **kwargs)
        datalen = len(train_dataset)
        train_len = int(datalen * 0.8)
        
        train_dataset, val_dataset = random_split(train_dataset, [train_len, datalen - train_len])
       
        test_dataset = ASRGLUEDataset(CFG, 'test', *args, **kwargs)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=asrglue_collate_classifier, shuffle=True, drop_last=True, num_workers=num_workers)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=asrglue_collate_classifier, num_workers=num_workers)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=asrglue_collate_classifier, num_workers=num_workers)
      
    else:
        raise ValueError('Invalid dataset, check CFG.dataset')

    return train_loader, val_loader, test_loader

# Model

## Note
We use DistilBertTokenizer instead of AutoTokenizer

In [None]:
class AudioEncoder(torch.nn.Module):
      def __init__(
          self, 
          model_name,
          trainable = True
      ):
          super().__init__()
          self.model = AutoModel.from_pretrained(model_name)
          for p in self.model.parameters():
              p.requires_grad = trainable

      def forward(self, x):
          output = self.model(x)
          last_hidden_state = output.last_hidden_state
          return last_hidden_state[:,0,:]

class E2ESLU(torch.nn.Module):
    def __init__(
        self,
        CFG: BaseCFG
    ):
        super().__init__()
        
        model_name = CFG.audio_encoder_model;
        embedding = CFG.audio_embedding
        trainable = CFG.trainable
        intent_dim = CFG.intent_dim
        hidden_dim = CFG.hidden_dim

        self.audio_encoder = AudioEncoder(model_name, trainable)

        for p in self.audio_encoder.parameters():
            p.requires_grad = trainable

        self.final_classifier = torch.nn.Sequential(
            torch.nn.Linear(embedding, hidden_dim),
            torch.nn.LeakyReLU(inplace=True),
            
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.LeakyReLU(inplace=True),
            
            torch.nn.Linear(hidden_dim, 64),
            torch.nn.LeakyReLU(inplace=True),
            
            torch.nn.Linear(64, intent_dim),
           
        )

    def forward(self, x):
        output = self.audio_encoder(x)
        output = self.final_classifier(output)
        return output

# Train, Valid and Test

## Train

In [None]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, step, CFG):
    device = CFG.device
    task_type = CFG.task_type

    loss_meter = AverageMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    # """
    if task_type == "prediction":
        loss_fn = torch.nn.MSELoss(reduction='mean')
    elif task_type == "classification":
        loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
    # """

    losses = []
    train_total = 0
    train_acc = 0

    for batch in tqdm_object:
        print(batch)
        # output = model(batch['waveform'].to(device))
        output = model(batch['waveform'].to(device=device, dtype=torch.float))
        target = batch['label'].to(device)
        if task_type == "prediction":
            output = torch.transpose(output, 0, 1)
            target = target.view(1, -1)
            pred = output
        else:
            # classification
            pred = torch.argmax(output, dim=1)

        loss = loss_fn(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()

        count = batch['waveform'].size(0)

        # print(f"pred: {pred}")
        # print(f"target: {target}")

        train_total += count
        train_acc += (pred.to("cpu") == target.to("cpu")).sum().item()

        loss_meter.update([loss.item()])

        tqdm_object.set_postfix(train_loss=loss.item(), train_acc=train_acc/train_total)

        losses.append(loss_meter.get())

    return loss_meter


## Valid

In [None]:
def valid_epoch(model, valid_loader, CFG):
    device = CFG.device
    task_type = CFG.task_type
    
    loss_meter = AverageMeter()
   
    if task_type == "prediction":
        loss_fn = torch.nn.MSELoss(reduction='mean')
    else:
        # classification
        loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
   

    val_total = 0
    val_acc = 0

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        output = model(batch['waveform'].to(device=device, dtype=torch.float))
        # output = model(batch['waveform'].to(device))
        target = batch['label'].to(device)

        if task_type == "prediction":
            output = torch.transpose(output, 0, 1)
            target = target.view(1, -1)
            pred = output
        else:
            # classification
            pred = torch.argmax(output, dim=1)

        loss = loss_fn(output, target)

        count = batch["waveform"].size(0)

        val_total += count
        val_acc += (pred.to("cpu") == target.to("cpu")).sum().item()

        loss_meter.update([loss.item()])

        tqdm_object.set_postfix(valid_loss=loss_meter.get(), valid_acc=val_acc/val_total)
    return loss_meter


## Test

In [None]:
def test(model, test_loader, CFG):
    pretrain_path = CFG.project_root+"no_clap_pretrain_best.pt"
    print(f'start test. pretrain_path: {pretrain_path}')
    print(f'audio_encoder_lr: {CFG.audio_encoder_lr} downstream_lr: {CFG.downstream_lr}')
    
    model.load_state_dict(torch.load(pretrain_path))
    model.eval()
   
    test_total = 0
    test_acc = 0

    tqdm_object = tqdm(test_loader, total=len(test_loader))

    if CFG.task_type == "prediction":
        outputs = torch.tensor([]).to(CFG.device)
        targets = torch.tensor([]).to(CFG.device)
        for batch in tqdm_object:
            output = model(batch['waveform'].to(device=CFG.device, dtype=torch.float))
            # output = model(batch['waveform'].to(CFG.device))
            target = batch['label'].to(CFG.device)
            output = torch.transpose(output, 0, 1)
            output = output.view(-1)
            #target = target.view(1, -1)

            outputs = torch.cat((outputs, output), 0)
            targets = torch.cat((targets, target), 0)

        print(f"outputs: {outputs}")
        print(f"targets: {targets}")
        spearman = SpearmanCorrcoef()
        return spearman(outputs, targets).item()
    else:
        for batch in tqdm_object:
            output = model(batch['waveform'].to(device=CFG.device, dtype=torch.float))
            # output = model(batch['waveform'].to(CFG.device))
            target = batch['label'].to(CFG.device)
            
            # classification
            pred = torch.argmax(output, dim=1)
            count = batch["waveform"].size(0)

            test_total += count
            test_acc += (pred.to("cpu") == target.to("cpu")).sum().item()

            tqdm_object.set_postfix(test_acc=test_acc/test_total)

    return test_acc/test_total


# Main

In [None]:
# Prevent the acc from losting.
final_test_acc = 0.0

In [None]:
def train_and_valid_and_test(CFG: BaseCFG, do_train, do_test):
    train_loader, valid_loader, test_loader = get_dataloaders(CFG)
    print(f"audio_encoder_model: {CFG.audio_encoder_model}")
    models = E2ESLU(CFG).to(CFG.device, dtype=torch.float)

    if do_train:
        if CFG.checkpoint_continue:
          models.load_state_dict(torch.load(CFG.project_root+"no_clap_pretrain_best.pt"))

        params = [
            {"params": models.audio_encoder.parameters(), "lr": CFG.audio_encoder_lr},
            {"params": models.final_classifier.parameters(), "lr": CFG.downstream_lr},
        ]
        optimizer = torch.optim.AdamW(
            params=params,
            weight_decay=CFG.weight_decay,
        )
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode="min",
            patience=CFG.patience,
            factor=CFG.factor,
        )
        step = "epoch"

        best_loss = float("inf")

        for epoch in range(CFG.epochs):
            print(f"Epoch: {epoch+1}")
            models.train()
            train_loss = train_epoch(models, train_loader, optimizer, lr_scheduler, step, CFG)
            models.eval()
            with torch.no_grad():
                valid_loss = valid_epoch(models, valid_loader, CFG)

            valid_loss_avg = valid_loss.get()
            if valid_loss_avg < best_loss:
                best_loss = valid_loss_avg
                torch.save(models.state_dict(), CFG.project_root+"no_clap_pretrain_best.pt")
                print("Saved Best Model!")

            if (step == "epoch"):
                lr_scheduler.step(valid_loss_avg)
    if do_test:
        with torch.no_grad():
            final_test_acc = test(models, test_loader, CFG)
            print(f"Finish! Result: test_acc: {final_test_acc}")

In [None]:
torch.cuda.empty_cache()
train_and_valid_and_test(curCFG, do_train, do_test)