In [1]:
import torch
import os

from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm, trange
from transformers.optimization import AdamW
from transformers import HubertModel, AutoConfig

from models.hubert_selective import HuBERTSelectiveNet
from utils.model_tools import *
from utils.selective_loss import SelectiveLoss

2023-11-26 16:55:46.966132: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
labels = np.load("data/queen_and_no_queen_labels.npy")
print(labels[:20])

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Loading in Vocal Imitation Data

In [3]:
data_dir = 'data/vocal_imitation-full'
model_id = "facebook/hubert-base-ls960"

from transformers import Wav2Vec2FeatureExtractor

def prepare_dataset(batch, feature_extractor):
    waveforms, labels = zip(*batch)
    waveforms = list(waveforms)
    features = feature_extractor(waveforms, return_tensors="pt", padding=True, sampling_rate=16000)
    labels = torch.tensor(labels)
    
    return features, labels

hubert_model = HubertModel.from_pretrained(model_id)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_id)

fold00 = VocalImitationDataset(data_dir, fold_name='fold00')
fold01 = VocalImitationDataset(data_dir, fold_name='fold01')
fold02 = VocalImitationDataset(data_dir, fold_name='fold02')

num_classes = len(fold00.vocab_list)
print(num_classes)
fold_set = set([fold00, fold01, fold02])

data/vocal_imitation-full/labelvocabulary.csv
data/vocal_imitation-full/fold00.json
data/vocal_imitation-full/labelvocabulary.csv
data/vocal_imitation-full/fold01.json
data/vocal_imitation-full/labelvocabulary.csv
data/vocal_imitation-full/fold02.json
302


In [4]:
from transformers import HubertForSequenceClassification, HubertConfig

class HuBERTSelectiveNet(torch.nn.Module):
    def __init__(self, hubert_model, num_classes:int,feature_size:int, init_weights=True):
        super(HuBERTSelectiveNet, self).__init__()
        self.hubert_model = hubert_model
        self.dim_features = feature_size  # This should be 768 based on your config
        self.num_classes = num_classes
        
        # Classifier represented as f() in the original paper
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(self.dim_features, self.num_classes)
        )

        self.selector = torch.nn.Sequential(
            torch.nn.Linear(self.dim_features, self.dim_features),
            torch.nn.ReLU(True),
            # Normalize across the feature dimension, which is the last dimension of the input
            torch.nn.BatchNorm1d(self.dim_features), # self.dim_features should be 768
            torch.nn.Linear(self.dim_features, 1),
            torch.nn.Sigmoid()
        )

        # Auxiliary classifier represented as h() in the original paper
        self.aux_classifier = torch.nn.Sequential(
            torch.nn.Linear(self.dim_features, self.num_classes)
        )

        #Initialize weights of heads if required
        if init_weights:
            self._initialize_weights(self.classifier)
            self._initialize_weights(self.selector)
            self._initialize_weights(self.aux_classifier)

    def forward(self, input_values):
        # Run input through HuBERT model
        outputs = self.hubert_model(input_values)

        # Extract the last hidden state (features)
        x = outputs.last_hidden_state  # Extracts the tensor

        # Perform mean pooling over the timesteps
        # Assuming x has shape [batch_size, num_timesteps, num_features]
        x = torch.mean(x, dim=1)  # Now x has shape [batch_size, num_features]

        # Pass the pooled features through the classifier and selector heads
        prediction_out = self.classifier(x)
        selection_out = self.selector(x)
        auxiliary_out = self.aux_classifier(x)

        return prediction_out, selection_out, auxiliary_out

    def _initialize_weights(self, module):
        for m in module.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, torch.nn.BatchNorm1d):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, torch.nn.Linear):
                torch.nn.init.normal_(m.weight, 0, 0.01)
                torch.nn.init.constant_(m.bias, 0)
    def save_model(model, save_path, model_config):
        # Save the model state
        torch.save(model.state_dict(), save_path + "/model_state.pt")

        # Save the configuration
        with open(save_path + "/config.json", 'w') as f:
            json.dump(model_config, f)
    def load_model(load_path, hubert_model_class):
        # Load the configuration
        with open(load_path + "/config.json", 'r') as f:
            model_config = json.load(f)

        # Recreate the HuBERTSelectiveNet instance
        hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960")
        model = hubert_model_class(hubert_model, model_config["num_classes"], model_config["feature_size"])

        # Load the model state
        model.load_state_dict(torch.load(load_path + "/model_state.pt"))
        model.eval()
        return model

def selective_train(dataloader, model, selective_loss, optimizer, device) -> float:
    size = len(dataloader.dataset)
    train_loss = 0.0

    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        
        logits, selection_logits, auxiliary_logits = model(X)
        #auxiliary_logits = auxiliary_logits.mean(dim=1) # why were we doing this?
        labels = y.long()
        loss_dict = selective_loss(prediction_out=logits,
                                    selection_out=selection_logits,
                                    auxiliary_out=auxiliary_logits,
                                    target=labels)

        loss = loss_dict['loss']
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()

        # Append lists
        train_loss += loss.item()

        if batch % 1000 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return train_loss/len(dataloader)

class SelectiveLoss(torch.nn.Module):
    def __init__(self, loss_func, coverage:float, alpha:float=0.5, lm:float=32.0, device='cpu'):
        """
        Args:
            loss_func: base loss function. the shape of loss_func(x, target) shoud be (B). 
                       e.g.) torch.nn.CrossEntropyLoss(reduction=none) : classification
            coverage: target coverage.
            lm: Lagrange multiplier for coverage constraint. original experiment's value is 32. 
        """
        super(SelectiveLoss, self).__init__()
        assert 0.0 < coverage <= 1.0
        assert 0.0 < lm
        assert 0.0 < alpha <= 1.0

        self.loss_func = loss_func
        self.coverage = coverage
        self.lm = lm
        self.alpha = alpha

    def forward(self, prediction_out, selection_out, auxiliary_out, target, threshold=0.5, mode='train'):
        """
        Args:
            prediction_out: (B, num_classes)
            selection_out:  (B, 1)
            auxiliary_out:
            target:
            threshold:
            mode: str (train/test)
        """
        
        # selection is of size batch, features but prediction and aux should be of size batch, classes
        
        cross_entropy = torch.nn.CrossEntropyLoss()
        
        # compute empirical coverage (=phi^)
        empirical_coverage = selection_out.mean() 
        
        # compute empirical risk (=r^)
        empirical_risk = (self.loss_func(prediction_out, target) * selection_out.view(-1)).mean()
        empirical_risk = empirical_risk / empirical_coverage

        # compute penalty (=psi)
        coverage = torch.tensor([self.coverage], dtype=torch.float32, requires_grad=True, device=device) # can I just put this on the stupid device beforehand
        penalty = torch.max(coverage - empirical_coverage, 
                            torch.tensor([0.0], dtype=torch.float32, requires_grad=True, device=device)) ** 2
        penalty *= self.lm

        # compute selective loss (=L(f,g))
        selective_loss = empirical_risk + penalty
        
        # Assuming binary classification
        auxiliary_out_expanded = auxiliary_out#torch.stack([auxiliary_out, -auxiliary_out], dim=1)

        # Now compute the cross entropy loss
        ce_loss = cross_entropy(auxiliary_out_expanded, target)

        
        # total loss
        loss_pytorch = self.alpha * selective_loss + (1.0 - self.alpha) * ce_loss
        
        # compute coverage based on source implementation
        selective_head_coverage = self.get_coverage(selection_out, threshold)

        # compute selective accuracy based on source implementation
        selective_head_selective_acc = self.get_selective_acc(prediction_out, selection_out, target)

        # compute accuracy based on source implementation
        classification_head_acc = self.get_accuracy(auxiliary_out, target)
        
        # compute selective loss (=selective_head_loss) based on source implementation
        selective_head_loss = self.get_selective_loss(prediction_out, selection_out, target)

        # compute cross entropy loss (=classification_head_loss) based on source implementation
        classification_head_loss = cross_entropy(auxiliary_out_expanded, target)

        # compute loss
        loss = self.alpha * selective_head_loss + (1.0 - self.alpha) * classification_head_loss

        # empirical selective risk with rejection for test model
        if mode == 'test':
            test_selective_risk = self.get_selective_risk(prediction_out, selection_out, target, threshold) 

        # loss information dict 
        pref = ''
        if mode == 'validation':
            pref = 'val_'
        loss_dict={}
        loss_dict['{}empirical_coverage'.format(pref)] = empirical_coverage.detach().cpu().item()
        loss_dict['{}empirical_risk'.format(pref)] = empirical_risk.detach().cpu().item()
        loss_dict['{}penalty'.format(pref)] = penalty.detach().cpu().item()
        loss_dict['{}selective_loss'.format(pref)] = selective_loss.detach().cpu().item()
        loss_dict['{}ce_loss'.format(pref)] = ce_loss.detach().cpu().item()
        loss_dict['{}loss_pytorch'.format(pref)] = loss_pytorch
        loss_dict['{}selective_head_coverage'.format(pref)] = selective_head_coverage.detach().cpu().item() #coverage
        loss_dict['{}selective_head_selective_acc'.format(pref)] = selective_head_selective_acc.detach().cpu().item() #selective_accurcy
        loss_dict['{}classification_head_acc'.format(pref)] = classification_head_acc.detach().cpu().item() #calassification_accuracy
        loss_dict['{}selective_head_loss'.format(pref)] = selective_head_loss.detach().cpu().item() #selective_loss
        loss_dict['{}classification_head_loss'.format(pref)] = classification_head_loss.detach().cpu().item() #ce_loss
        loss_dict['{}loss'.format(pref)] = loss
        if mode == 'test':
            loss_dict['test_selective_risk'] = test_selective_risk.detach().cpu().item()

        return loss_dict

    # based on source implementation
    def get_selective_acc(self, prediction_out, selection_out, target):
        """
        Equivalent to selective_acc function of source implementation
        Args:
            prediction_out: (B,num_classes)
            selection_out:  (B, 1)
        """
        g = (selection_out.mean(dim=1).squeeze(-1) > 0.5).float()
        num = torch.dot(g, (torch.argmax(prediction_out, dim=-1) == target).float())
        return num / torch.sum(g)

    # based on source implementation
    def get_coverage(self, selection_out, threshold):
        """
        Equivalent to coverage function of source implementation
        Args:
            selection_out:  (B, 1)
        """
        g = (selection_out.squeeze(-1) >= threshold).float()
        return torch.mean(g)

    # based on source implementation
    def get_accuracy(self, auxiliary_out, target): #TODO: Check implementation with Lili
        """
        Equivalent to "accuracy" in Tensorflow
        Args:
            selection_out:  (B, 1)
        """ 
        num = torch.sum((torch.argmax(auxiliary_out, dim=-1) == target).float())
        return num / len(auxiliary_out)
    
    # based on source implementation
    def get_selective_loss(self, prediction_out, selection_out, target):
        """
        Equivalent to selective_loss function of source implementation
        Args:
            prediction_out: (B,num_classes)
            selection_out:  (B, 1)
        """
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        ce = self.loss_func(prediction_out, target)
        empirical_risk_variant = torch.mean(ce * selection_out.view(-1))
        empirical_coverage = selection_out.mean() 
        penalty = torch.max(self.coverage - empirical_coverage, torch.tensor([0.0], dtype=torch.float32, requires_grad=True, device=device))**2
        loss = empirical_risk_variant + self.lm * penalty
        return loss

    # selective risk in test mode
    def get_selective_risk(self, prediction_out, selection_out, target, threshold):
        g = (selection_out.squeeze(-1) >= threshold).float()
        empirical_coverage_rjc = torch.mean(g)
        empirical_risk_rjc = torch.mean(self.loss_func(prediction_out, target) * g.view(-1))
        empirical_risk_rjc /= empirical_coverage_rjc
        return empirical_risk_rjc

In [16]:
# Instantiate hubert model and make the full selectivenet

# I wonder if I will have a problem with this featuresize if I need to do padding.

inputs, labels = next(iter(fold00))
inputs = feature_extractor(inputs, return_tensors="pt", sampling_rate=16000)
outputs = hubert_model(inputs.input_values)
feature_size = outputs.last_hidden_state.shape[2]
print('features:', feature_size)

model = HuBERTSelectiveNet(hubert_model, num_classes=num_classes, feature_size=feature_size)
print(model.num_classes)

features: 768
302


In [19]:
loss_func = torch.nn.CrossEntropyLoss()
coverage = 0.8
alpha = 0.5
lm = 32.0
num_epochs = 10
batch_size = 10

loss_fn = SelectiveLoss(loss_func, coverage, alpha, lm)
    
# loss=loss
# Create the optimizer
optimizer = AdamW(model.parameters(), lr=1e-4)

model_file = 'models/selective-hubert-10ep-80c.pt' # 80c is 80% coverage

In [None]:
train_losses_file = 'logs/selective-hubert-10ep-80c-train.txt'
test_losses_file = 'logs/selective-hubert-10ep-80c-test.txt'

train_losses = []
test_losses = []
t = trange(num_epochs)

for fold in fold_set:
    off_folds = fold_set.difference([fold])
    off_concat = torch.utils.data.ConcatDataset(off_folds)
    
    train_loader = DataLoader(fold, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(off_concat, batch_size=batch_size, shuffle=True)

    for epoch in t:
        print(f"Epoch {epoch+1}\n-------------------------------")
        train_loss = selective_train(train_loader, model, loss_fn, optimizer, device)
        test_loss = selective_test(test_loader, model, device)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        
        torch.save(model.state_dict(), model_file)
        
with open(train_losses_file, 'w') as fp:
    for s in train_losses:
        fp.write("%s\n" % s)
        
with open(test_losses_file, 'w') as fp:
    for x in test_losses:
        fp.write("%s\n" % x)

print("Done!")

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1
-------------------------------
