In [None]:
!pip install -q fasttext

import math
import json
import copy
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
from sklearn.metrics import f1_score
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from transformers import AutoTokenizer, BertModel
from functools import partial
import fasttext
import gc
import warnings
warnings.filterwarnings("ignore")

from google.colab import drive
drive.mount('/content/drive')
dir = '/content/drive/MyDrive/fewshot_medical'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Call device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initiate fixed random seed, for consistent and reproducible output.
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Load named_entity_to_type_id, type_ids, generated_text_span.
with open(f'{dir}/named_entity_to_type_id.json', 'r') as f:
    named_entity_to_type_id = json.load(f)
with open(f'{dir}/type_ids.json', 'r') as f:
    type_ids = json.load(f)
with open(f'{dir}/generated_text_span.json', 'r') as f:
    generated_text_span = json.load(f)

In [None]:
# Define model class.
class EPNet(nn.Module):
    def __init__(self, hidden_dims, params, input_embedding_mode):
        super().__init__()

        # Receive hyperparameters.
        self.num_classes = params['num_classes']
        self.additional_num_of_unknown_type = params['additional_num_of_unknown_type']
        self.dropout = params['dropout']
        self.input_dim = params['input_dim']
        self.projection_embedding_dim = params['projection_embedding_dim']
        self.prototype_train_learning_rate = params['prototype_train_learning_rate']
        self.prototype_train_epochs = params['prototype_train_epochs']
        self.prototype_train_patience = params['prototype_train_patience']
        self.max_span_length = params['max_span_length']
        self.length_embedding_dim = params['length_embedding_dim']
        self.tau = params['tau']
        self.shot_sample_number = params['shot_sample_number']
        self.validation_sample_number = params['validation_sample_number']
        self.adapt_learning_rate = params['adapt_learning_rate']
        self.adapt_epochs = params['adapt_epochs']
        self.adapt_patience = params['adapt_patience']
        self.batch_size = params['batch_size']
        self.input_embedding_mode = input_embedding_mode

        # Added weight for balancing losses
        self.span_loss_weight = params.get('span_loss_weight', 1.0)
        self.distance_loss_weight = params.get('distance_loss_weight', 0.5)

        # Added margin for better class separation
        self.margin = params.get('margin', 1.0)

        layers = []
        dims = [self.input_dim+self.length_embedding_dim]
        dims.extend(hidden_dims)

        # Hidden layers with batch normalization for better training
        for i in range(len(hidden_dims)):
            layers.append(nn.Linear(dims[i], dims[i+1]))
            layers.append(nn.BatchNorm1d(dims[i+1]))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(self.dropout))

        # Output layer with higher dimensionality for better representation
        layers.append(nn.Linear(dims[-1], self.projection_embedding_dim))
        layers.append(nn.LayerNorm(self.projection_embedding_dim))

        # Project to model.
        self.model = nn.Sequential(*layers)

        # Prototype vectors with improved initialization
        alpha = 20 # Increased for better separation
        self.prototypes = nn.Parameter(torch.randn(self.num_classes+self.additional_num_of_unknown_type, self.projection_embedding_dim))
        self.prototypes.data = F.normalize(self.prototypes.data, dim=1) * alpha

        # Span length embedding with improved initialization
        self.length_embeddings = nn.Parameter(torch.randn(self.max_span_length, self.length_embedding_dim))
        nn.init.xavier_uniform_(self.length_embeddings)

        # Initiate language model. The language model will not be fine-tuned to avoid overfitting.
        if input_embedding_mode == 'bert':
            # Load vanilla bert-base-cased model for input embedding.
            tokenizer = AutoTokenizer.from_pretrained(f'{dir}/language_model/{input_embedding_mode}/tokenizer')
            pretrained_model = BertModel.from_pretrained(f'{dir}/language_model/{input_embedding_mode}/transformer')
            pretrained_model.eval()
            pretrained_model.to(device)
            pretrained_model.resize_token_embeddings(tokenizer.vocab_size)
            self.tokenizer = tokenizer
            self.pretrained_model = pretrained_model
        elif input_embedding_mode == 'fasttext':
            # Load fasttext model for input embedding.
            pretrained_model = fasttext.load_model(f'{dir}/language_model/{input_embedding_mode}/cc.en.300.bin')
            self.tokenizer = None
            self.pretrained_model = pretrained_model

    def input_embed(self, text_span):
        length_embedding_tensor = self.length_embeddings[min(len(text_span.split(' '))-1, self.max_span_length-1)].unsqueeze(0)
        if self.input_embedding_mode == 'bert':
            encodings = self.tokenizer(
                text_span,
                max_length=128,  # Increased for better context
                return_tensors='pt',
                padding="max_length",
                truncation=True,
            )
            encodings = encodings.to(device)
            with torch.no_grad():
                outputs = self.pretrained_model(**encodings)
                cls_embedding = outputs.last_hidden_state[:, 0, :]
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                # Use average pooling alongside max pooling for more robust representation
                max_pool_tensor, _ = torch.max(outputs.last_hidden_state, dim=1)
                avg_pool_tensor = torch.mean(outputs.last_hidden_state, dim=1)
                combined_tensor = (max_pool_tensor + avg_pool_tensor) / 2
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                span_representation = torch.cat((combined_tensor, length_embedding_tensor), dim=-1)

        elif self.input_embedding_mode == 'fasttext':
            tokens = text_span.split(' ')
            normalized_embeddings = []
            for token in tokens:
                token_vector = self.pretrained_model.get_word_vector(token)
                norm = np.sqrt(np.sum(token_vector**2))
                if not norm == 0:
                    normalized_embeddings.append(token_vector/norm)
                else:
                    normalized_embeddings.append(token_vector)
            # Use weighted averaging for better representation
            if len(normalized_embeddings) > 0:
                weights = np.array([1.0 + 0.1 * (i - len(normalized_embeddings)/2)**2 for i in range(len(normalized_embeddings))])
                weights = weights / weights.sum()
                mean_vector = np.average(normalized_embeddings, axis=0, weights=weights)
            else:
                mean_vector = np.zeros(self.input_dim)
            mean_tensor = torch.from_numpy(mean_vector).unsqueeze(0).float()
            span_representation = torch.cat((mean_tensor, length_embedding_tensor), dim=-1)

        return span_representation

    def forward(self, input):
        projection = self.model.forward(input)
        # Normalize projections for cosine similarity
        projection = F.normalize(projection, p=2, dim=1)
        self.prototypes = self.prototypes.to(device)
        dists = torch.cdist(projection, self.prototypes, p=2) ** 2
        # Apply temperature scaling for sharper probability distribution
        temp = 0.1
        prediction = F.softmax(-dists/temp, dim=1)
        classification_result = torch.argmax(prediction, dim=1)

        return projection, prediction, classification_result

    # Improved distance loss function with margin
    def distance_loss(self):
        prototypes = self.prototypes.to(device)
        prototypes = F.normalize(prototypes, p=2, dim=1)
        num_types = prototypes.size(0)

        # Cosine similarity matrix
        similarity = torch.mm(prototypes, prototypes.t())

        # Remove diagonal elements (self-similarity)
        mask = torch.eye(num_types, device=device)
        similarity = similarity * (1 - mask)

        # Push apart with margin
        loss = F.relu(similarity - self.margin).mean()

        return loss

    # Improved span loss function with focal loss component
    def span_loss(self, y_hat_projection, y_classification_label):
        y_classification_label_copy = y_classification_label.copy()
        y_hat_projection = y_hat_projection.to(device)
        y_hat_projection = F.normalize(y_hat_projection, p=2, dim=1)

        # Compute distances to prototypes
        dist_matrix = torch.cdist(y_hat_projection, self.prototypes, p=2) ** 2

        # Assign unknown instances to nearest unknown prototype
        for i in range(len(y_classification_label_copy)):
            if y_classification_label_copy[i] == 0:
                closest = 0
                closest_dist = dist_matrix[i][0]
                for j in range(self.num_classes, self.num_classes+self.additional_num_of_unknown_type):
                    if dist_matrix[i][j] < closest_dist:
                        closest = j
                        closest_dist = dist_matrix[i][j]
                y_classification_label_copy[i] = closest

        # Calculate logits with margin
        logits = -dist_matrix
        y_classification_label_copy = torch.Tensor(y_classification_label_copy).long().to(device)

        # Calculate focal loss component - give more weight to hard examples
        ce_loss = F.cross_entropy(logits, y_classification_label_copy, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1-pt)**2 * ce_loss).mean()

        return focal_loss

    # First step with improved initialization
    def prototype_train(self):
        best_distance_loss = float('inf')
        increase_count = 0

        self.prototypes.to(device)
        optimizer = optim.Adam([self.prototypes], lr=self.prototype_train_learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=3)

        # Training loop
        for epoch in range(self.prototype_train_epochs):
            optimizer.zero_grad()
            loss = self.distance_loss()
            loss.backward()
            optimizer.step()
            scheduler.step(loss)

            if loss.item() < best_distance_loss:
                best_distance_loss = loss.item()
                best_prototypes = self.prototypes.detach().clone()
                increase_count = 0
            else:
                increase_count += 1
                if increase_count >= self.prototype_train_patience:
                    break

        self.prototypes.data = best_prototypes.to(device)
        # Normalize prototypes after training
        self.prototypes.data = F.normalize(self.prototypes.data, p=2, dim=1) * 10
        print(f"Result of initial prototypical network training loss: {best_distance_loss}")

        return

    # Modified Davies Bouldin Index (MBDI)
    def modified_davies_bouldin_index(self, key_projection_tensor):
        prototypes = self.prototypes.to(device)
        num_types = self.num_classes
        prototype_distances = torch.cdist(prototypes, prototypes, p=2)

        key_s_i = {}
        for key, projection_tensor in key_projection_tensor.items():
            if len(projection_tensor) == 0:
                key_s_i[key] = 0.0
                continue

            projection_tensor = projection_tensor.to(device)
            prototype_vector = prototypes[key].unsqueeze(0)
            euclidean_distances = torch.cdist(projection_tensor, prototype_vector, p=2)
            key_s_i[key] = euclidean_distances.mean().item()  # Using mean instead of std for more stability

        key_r_i = {}
        for key_i, s_i in key_s_i.items():
            if key_i == 0 or key_i >= self.num_classes or s_i == 0.0:
                continue

            max_r_i = 0
            for key_j, s_j in key_s_i.items():
                if key_i != key_j and key_j < self.num_classes and s_j > 0.0:
                    dist = max(prototype_distances[key_i, key_j].item(), 1e-5)  # Avoid division by zero
                    r_i_j = (s_i + s_j) / dist
                    if r_i_j > max_r_i:
                        max_r_i = r_i_j
            key_r_i[key_i] = max_r_i

        if len(key_r_i) == 0:
            return 0.0
        return sum(key_r_i.values()) / len(key_r_i)

    # Improved recognition with ensemble approach
    def recognize(self, query_set_type_id_list_of_text_span, input_embedding_mode):
        self.eval()

        # Test set
        y_test_classification_label = []
        for key in query_set_type_id_list_of_text_span.keys():
            for text_span in query_set_type_id_list_of_text_span[key]:
                y_test_classification_label.append(key)
        y_test_label = torch.tensor(y_test_classification_label)

        # For calculating MDBI
        key_projection_tensor = {}
        for i in range(self.prototypes.shape[0]):
            empty_tensor = torch.empty((0, self.projection_embedding_dim))
            empty_tensor = empty_tensor.to(device)
            key_projection_tensor[i] = empty_tensor

        # For calculating F1 score
        y_hat_test_classification_label = []
        test_batch_input = torch.zeros(self.batch_size, self.input_dim+self.length_embedding_dim)

        # Store all predictions for ensemble
        all_predictions = []

        for key in query_set_type_id_list_of_text_span.keys():
            intra_batch_count = 0
            for text_span in query_set_type_id_list_of_text_span[key]:
                test_batch_input[intra_batch_count] = self.input_embed(text_span)
                intra_batch_count += 1
                if intra_batch_count == self.batch_size:
                    test_batch_input = test_batch_input.to(device)
                    test_batch_projection, test_batch_prediction, test_batch_classification = self.forward(test_batch_input)
                    for i in range(intra_batch_count):
                        key_projection_tensor[key] = torch.cat((key_projection_tensor[key], test_batch_projection[i].unsqueeze(0)), dim=0)
                        all_predictions.append(test_batch_prediction[i].detach().cpu())
                        if test_batch_classification[i].item() >= self.num_classes:
                            y_hat_test_classification_label.append(0)
                        else:
                            y_hat_test_classification_label.append(test_batch_classification[i].item())
                    intra_batch_count = 0
                    test_batch_input = torch.zeros(self.batch_size, self.input_dim+self.length_embedding_dim)
            if intra_batch_count != 0:
                test_batch_input = test_batch_input[:intra_batch_count].to(device)
                test_batch_projection, test_batch_prediction, test_batch_classification = self.forward(test_batch_input)
                for i in range(intra_batch_count):
                    key_projection_tensor[key] = torch.cat((key_projection_tensor[key], test_batch_projection[i].unsqueeze(0)), dim=0)
                    all_predictions.append(test_batch_prediction[i].detach().cpu())
                    if test_batch_classification[i].item() >= self.num_classes:
                        y_hat_test_classification_label.append(0)
                    else:
                        y_hat_test_classification_label.append(test_batch_classification[i].item())

        y_hat_test_label = torch.tensor(y_hat_test_classification_label)

        # Calculate F1 score in test set.
        total_f1_score_value = f1_score(y_test_classification_label, y_hat_test_classification_label, average='macro')
        f1_score_per_type_id = {}
        compare_y_test_classification_per_type_id = {}
        compare_y_hat_test_classification_per_type_id = {}
        for i in range(len(type_ids)):
            compare_y_test_classification_per_type_id[type_ids[i]] = []
            compare_y_hat_test_classification_per_type_id[type_ids[i]] = []
        for i in range(len(y_test_label)):
            type_id_for_compare = y_test_label[i]
            compare_y_test_classification_per_type_id[type_ids[type_id_for_compare]].append(y_test_label[i])
            compare_y_hat_test_classification_per_type_id[type_ids[type_id_for_compare]].append(y_hat_test_label[i])
        for type_id in type_ids:
            type_id_y_test = torch.tensor(compare_y_test_classification_per_type_id[type_id])
            type_id_y_hat_test = torch.tensor(compare_y_hat_test_classification_per_type_id[type_id])
            type_id_f1_score = f1_score(type_id_y_test, type_id_y_hat_test, average='macro')
            f1_score_per_type_id[type_id] = type_id_f1_score

        # Calculate MDBI.
        mdbi = self.modified_davies_bouldin_index(key_projection_tensor)

        return total_f1_score_value, f1_score_per_type_id, mdbi

    # Model file name formatter
    def get_model_filename(self, hidden_dims):
        hidden_str = "-".join(map(str, hidden_dims))

        return f"model_{hidden_str}_tau_{self.tau}_improved.pt"

# Total training process with improved optimization
def adapt_train(model_class, params, support_type_id_list_of_text_span, validation_type_id_list_of_text_span, input_embedding_mode, hidden_dims):
    model = model_class(
        hidden_dims=hidden_dims,
        params=params,
        input_embedding_mode=input_embedding_mode
    )
    model.to(device)
    model.train()

    # First step
    model.prototype_train()

    # Second step
    # Cosine annealing scheduler for better optimization
    def cosine_annealing(epoch, total_epochs, eta_min=0):
        return eta_min + 0.5 * (model.adapt_learning_rate - eta_min) * (1 + math.cos(math.pi * epoch / total_epochs))

    logs = []
    patience_counter = 0

    # F1 score would be better benchmark.
    best_f1_score = 0
    best_model_state = None
    best_length_embeddings = None
    local_maxima_f1_score = 0
    local_maxima_model_state = None
    local_maxima_length_embeddings = None

    # Use different optimizer and scheduler
    ffn_optimizer = optim.AdamW(list(model.model.parameters()) + [model.length_embeddings], lr=model.adapt_learning_rate, weight_decay=0.01)
    ffn_scheduler = optim.lr_scheduler.CosineAnnealingLR(ffn_optimizer, T_max=model.adapt_epochs)
    patience_counter = 0

    for epoch in range(model.adapt_epochs):
        log = {}
        log['epoch'] = epoch+1
        log['tau'] = model.tau
        log['shot_sample_number'] = model.shot_sample_number
        log['validation_sample_number'] = model.validation_sample_number
        log['dropout'] = model.dropout
        log['num_classes'] = model.num_classes
        log['additional_num_of_unknown_type'] = model.additional_num_of_unknown_type
        log['input_dim'] = model.input_dim
        log['length_embedding_dim'] = model.length_embedding_dim
        log['projection_embedding_dim'] = model.projection_embedding_dim
        log['adapt_learning_rate'] = model.adapt_learning_rate

        model.train()
        ffn_optimizer.zero_grad()

        y_support_classification_label = []
        for key in support_type_id_list_of_text_span.keys():
            for text_span in support_type_id_list_of_text_span[key]:
                y_support_classification_label.append(key)

        y_hat_support_projection = torch.zeros(len(y_support_classification_label), model.projection_embedding_dim)
        intra_batch_count = 0
        multiple_count = 0
        support_batch_input = torch.zeros(model.batch_size, model.input_dim+model.length_embedding_dim)
        for key in support_type_id_list_of_text_span.keys():
            for text_span in support_type_id_list_of_text_span[key]:
                support_batch_input[intra_batch_count] = model.input_embed(text_span)
                intra_batch_count += 1
                if intra_batch_count == model.batch_size:
                    support_batch_input = support_batch_input.to(device)
                    support_batch_projection, _, _ = model.forward(support_batch_input)
                    for i in range(intra_batch_count):
                        y_hat_support_projection[multiple_count*model.batch_size+i] = support_batch_projection[i]
                    intra_batch_count = 0
                    multiple_count += 1
                    support_batch_input = torch.zeros(model.batch_size, model.input_dim+model.length_embedding_dim)
        if intra_batch_count != 0:
            support_batch_input = support_batch_input[:intra_batch_count].to(device)
            support_batch_projection, _, _ = model.forward(support_batch_input)
            for i in range(intra_batch_count):
                y_hat_support_projection[multiple_count*model.batch_size+i] = support_batch_projection[i]

        # Calculate support loss with weighted components
        supp_loss_d = model.distance_loss() * model.distance_loss_weight
        supp_loss_s = model.span_loss(y_hat_support_projection, y_support_classification_label) * model.span_loss_weight
        supp_loss = supp_loss_d + supp_loss_s
        log[f'model_epoch_{epoch+1}_support_loss'] = supp_loss.item()
        log[f'model_epoch_{epoch+1}_support_loss_d'] = supp_loss_d.item()
        log[f'model_epoch_{epoch+1}_support_loss_s'] = supp_loss_s.item()
        print(f"Epoch {epoch+1}, Support loss {supp_loss:.4f}")

        # Backward propagation with gradient clipping
        log[f'model_epoch_{epoch+1}_task_ffn_learning_rate'] = ffn_optimizer.param_groups[0]['lr']
        supp_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Add gradient clipping
        ffn_optimizer.step()
        ffn_scheduler.step()

        # Task validation with more frequent evaluation
        if epoch % 1 == 0: # Evaluate every epoch
            model.eval()
            with torch.no_grad():
                y_validation_classification_label = []
                for key in validation_type_id_list_of_text_span.keys():
                    for text_span in validation_type_id_list_of_text_span[key]:
                        y_validation_classification_label.append(key)

                y_hat_validation_classification_label = []
                y_hat_validation_projection = torch.zeros(len(y_validation_classification_label), model.projection_embedding_dim)
                intra_batch_count = 0
                multiple_count = 0
                validation_batch_input = torch.zeros(model.batch_size, model.input_dim+model.length_embedding_dim)
                for key in validation_type_id_list_of_text_span.keys():
                    for text_span in validation_type_id_list_of_text_span[key]:
                        validation_batch_input[intra_batch_count] = model.input_embed(text_span)
                        intra_batch_count += 1
                        if intra_batch_count == model.batch_size:
                            validation_batch_input = validation_batch_input.to(device)
                            validation_batch_projection, _, validation_batch_classification = model.forward(validation_batch_input)
                            for i in range(intra_batch_count):
                                y_hat_validation_projection[multiple_count*model.batch_size+i] = validation_batch_projection[i]
                                if validation_batch_classification[i].item() >= model.num_classes:
                                    y_hat_validation_classification_label.append(0)
                                else:
                                    y_hat_validation_classification_label.append(validation_batch_classification[i].item())
                            intra_batch_count = 0
                            multiple_count += 1
                            validation_batch_input = torch.zeros(model.batch_size, model.input_dim+model.length_embedding_dim)
                if intra_batch_count != 0:
                    validation_batch_input = validation_batch_input[:intra_batch_count].to(device)
                    validation_batch_projection, _, validation_batch_classification = model.forward(validation_batch_input)
                    for i in range(intra_batch_count):
                        y_hat_validation_projection[multiple_count*model.batch_size+i] = validation_batch_projection[i]
                        if validation_batch_classification[i].item() >= model.num_classes:
                            y_hat_validation_classification_label.append(0)
                        else:
                            y_hat_validation_classification_label.append(validation_batch_classification[i].item())

                # Calculate validation F1 score with class weights for balance
                val_total_f1_score_value = f1_score(y_validation_classification_label, y_hat_validation_classification_label, average='macro')
                log['meta_model_validation_f1_score_percent'] = val_total_f1_score_value*100

                # Calculate validation loss.
                val_loss_d = model.distance_loss() * model.distance_loss_weight
                val_loss_s = model.span_loss(y_hat_validation_projection, y_validation_classification_label) * model.span_loss_weight
                val_loss = val_loss_d + val_loss_s

                log['model_validation_loss'] = val_loss.item()
                log['model_validation_loss_d'] = val_loss_d.item()
                log['model_validation_loss_s'] = val_loss_s.item()
                print(f"\n==== Epoch {epoch+1}, F1 score {val_total_f1_score_value*100:.4f}%, Validation loss {val_loss:.4f} ====\n")

                if val_total_f1_score_value > local_maxima_f1_score:
                    local_maxima_f1_score = val_total_f1_score_value
                    local_maxima_meta_model_state = model.model.state_dict()
                    local_maxima_meta_length_embeddings = model.length_embeddings.detach().clone()
                    patience_counter = 0
                    if local_maxima_f1_score > best_f1_score:
                        best_f1_score = local_maxima_f1_score
                        local_maxima_f1_score = 0
                        best_meta_model_state = local_maxima_meta_model_state
                        best_meta_length_embeddings = local_maxima_meta_length_embeddings.to(device)
                else:
                    patience_counter += 1
                    if patience_counter >= model.adapt_patience:
                        break

        logs.append(log)

    if best_meta_model_state is not None and best_meta_length_embeddings is not None:
        model.model.load_state_dict(best_meta_model_state)
        model.length_embeddings.data = best_meta_length_embeddings.to(device)
    else:
        # Fallback to latest model if no improvement found
        model.model.load_state_dict(local_maxima_meta_model_state)
        model.length_embeddings.data = local_maxima_meta_length_embeddings.to(device)

    log_df = pd.DataFrame(logs)
    hidden_dims_name = "-".join(map(str, hidden_dims))
    log_df.to_csv(f'{dir}/{model.shot_sample_number}_shot/{input_embedding_mode}_logs/improved_log_{hidden_dims_name}_tau_{model.tau}_validation_num_{model.validation_sample_number}.csv', index=False)

    return model

In [None]:
from google.colab import drive
drive.mount('/content/drive')

KeyboardInterrupt: 

In [None]:
# Independent variable:
    # shot number per each support set
    # types of input embedding
# Dependent variable:
    # modified davies bouldin index (MDBI)
    # F1 score
    # F1 score per each classification category.
# Control variable:
    # dropout (0.5)
    # activation function (gelu)
    # random seed for layer initiation (42)
    # output layer dimension (1024)
    # number of hidden layers (1)
# Train stop condition: when validation loss increases more than patience number

# Embedding mode to input dim map
embedding_mode_to_input_dim = {'bert': 768, 'fasttext': 300}

# Set hyperparameters.
adapt_epochs = 50
input_embedding_mode = 'bert' ##
shot_sample_number = 1 ##
validation_sample_number = 20
query_sample_number = 15 ## delete this later
number_of_hidden_layers = 1

# Split test set and task set.
total_type_id_list_of_text_span = {}
for i in range(len(type_ids)):
    total_type_id_list_of_text_span[i] = []
for text_span in generated_text_span:
    if text_span.strip() == "":
        continue
    type_id = 0
    try:
        type_id = type_ids.index(named_entity_to_type_id[text_span])
    except:
        type_id = type_ids.index('UnknownType')
    total_type_id_list_of_text_span[type_id].append(text_span)

support_set_type_id_list_of_text_span = {}
validation_set_type_id_list_of_text_span = {}
query_set_type_id_list_of_text_span = {}

for key in total_type_id_list_of_text_span.keys():
    shuffled = total_type_id_list_of_text_span[key]
    random.shuffle(shuffled)
    support_set_type_id_list_of_text_span[key] = shuffled[:shot_sample_number]
    validation_set_type_id_list_of_text_span[key] = shuffled[shot_sample_number:shot_sample_number+validation_sample_number]
    query_set_type_id_list_of_text_span[key] = shuffled[shot_sample_number+validation_sample_number:shot_sample_number+validation_sample_number+query_sample_number]

In [None]:
layer_scale = [1024]*number_of_hidden_layers
results_list = []

tau = 5.0 ##

params = {
    'dropout': 0.5,
    'num_classes': len(type_ids),
    'additional_num_of_unknown_type': 10,
    'input_dim': embedding_mode_to_input_dim[input_embedding_mode],
    'projection_embedding_dim': 1024,
    'prototype_train_learning_rate': 0.2,
    'prototype_train_epochs': 1000,
    'prototype_train_patience': 5,
    'max_span_length': 10,
    'length_embedding_dim': 25,
    'shot_sample_number': shot_sample_number,
    'validation_sample_number': validation_sample_number,
    'adapt_learning_rate': 5e-3,
    'adapt_epochs': adapt_epochs,
    'adapt_patience': 100,
    'batch_size': 32,
    'tau': tau,
}

# Train
model = adapt_train(
    model_class=EPNet,
    params=params,
    support_type_id_list_of_text_span=support_set_type_id_list_of_text_span,
    validation_type_id_list_of_text_span=validation_set_type_id_list_of_text_span,
    input_embedding_mode=input_embedding_mode,
    hidden_dims=layer_scale,
)

Result of initial prototypical network training loss: 0.0
Epoch 1, Support loss 2.9832

==== Epoch 1, F1 score 7.7806%, Validation loss 4.6021 ====

Epoch 2, Support loss 0.0043

==== Epoch 2, F1 score 8.3731%, Validation loss 4.8376 ====

Epoch 3, Support loss 0.3882

==== Epoch 3, F1 score 8.1585%, Validation loss 5.1661 ====

Epoch 4, Support loss 0.0000

==== Epoch 4, F1 score 11.5916%, Validation loss 5.1236 ====

Epoch 5, Support loss 0.0000

==== Epoch 5, F1 score 8.8469%, Validation loss 5.2875 ====

Epoch 6, Support loss 0.2844

==== Epoch 6, F1 score 9.9181%, Validation loss 5.2122 ====

Epoch 7, Support loss 0.0000

==== Epoch 7, F1 score 10.3230%, Validation loss 5.1252 ====

Epoch 8, Support loss 0.0075

==== Epoch 8, F1 score 11.1667%, Validation loss 5.1285 ====

Epoch 9, Support loss 0.0014

==== Epoch 9, F1 score 11.2769%, Validation loss 5.3986 ====

Epoch 10, Support loss 0.0001

==== Epoch 10, F1 score 9.7707%, Validation loss 5.7818 ====

Epoch 11, Support loss 0.0

In [None]:
# Recognizetotal_f1_score_value, f1_score_per_type_id, mdbi = model.recognize(query_set_type_id_list_of_text_span, input_embedding_mode)
total_f1_score_value, f1_score_per_type_id, mdbi = model.recognize(query_set_type_id_list_of_text_span, input_embedding_mode)
print(f"Total F1 score: {total_f1_score_value*100}%, Modified Davies Bouldin Index: {mdbi}")
print(f"F1 score per type: {f1_score_per_type_id}")

result = {
    'number_of_hidden_layers': number_of_hidden_layers,
    'hidden_dims': "-".join(map(str, layer_scale)),
    'total_f1_score_precent': total_f1_score_value*100,
    'mdbi': mdbi,
}
for type_id, f1_score_value in f1_score_per_type_id.items():
    result[f'f1_score_{type_id}_percent'] = f1_score_value*100
results_list.append(result)

Total F1 score: 10.001984045993057%, Modified Davies Bouldin Index: 1.4364369419588623
F1 score per type: {'UnknownType': 0.0, 'Idea or Concept': 0.0125, 'Substance': 0.09090909090909091, 'Natural Phenomenon or Process': 0.06015037593984962, 'Occupational Activity': 0.0, 'Anatomical Structure': 0.03361344537815126, 'Finding': 0.0, 'Organism': 0.08421052631578947, 'Intellectual Product': 0.0, 'Manufactured Object': 0.0, 'Group': 0.4230769230769231, 'Activity': 0.0, 'Organism Attribute': 0.05263157894736842, 'Phenomenon or Process': 0.0, 'Behavior': 0.0, 'Injury or Poisoning': 0.0, 'Organization': 0.2, 'Conceptual Entity': 0.0, 'Occupation or Discipline': 0.0}


In [None]:
# Save model.
model_filename = model.get_model_filename(layer_scale)
torch.save(model.model.state_dict(), f'{dir}/{model.shot_sample_number}_shot/{input_embedding_mode}_models/ablated_model_{model_filename}')
torch.save(model.length_embeddings, f'{dir}/{model.shot_sample_number}_shot/{input_embedding_mode}_models/ablated_length_embeddings_{model_filename}')
torch.save(model.prototypes, f'{dir}/{model.shot_sample_number}_shot/{input_embedding_mode}_models/ablated_prototypes_{model_filename}')

# Save performance results.
results = pd.DataFrame(results_list)
results.to_csv(f'{dir}/{model.shot_sample_number}_shot/{input_embedding_mode}_results/ablated_performance_result_tau_{tau}.csv', index=False)

In [None]:
from google.colab import runtime

runtime.unassign()