In [None]:
!pip install -q fasttext

import math
import json
import copy
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
from sklearn.metrics import f1_score
from sklearn.manifold import MDS
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import AutoTokenizer, BertModel
import fasttext
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
import gc
import warnings
warnings.filterwarnings("ignore")

from google.colab import drive
drive.mount('/content/drive')

In [None]:
dir = ''

In [None]:
# Call device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initiate fixed random seed, for consistent and reproducible output.
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Load type_ids.
with open(f'{dir}/type_ids.json', 'r') as f:
    type_ids = json.load(f)
unknown_index = type_ids.index("UnknownType")

# Set where to ablate
ablations = ["hard_negative", "multiple_prototype", "contrastive_learning"]
ablation = ablations[2] ##

In [None]:
# Load language model.
def load_language_model(input_embedding_mode):
    # Initiate language model. The language model will not be fine-tuned to avoid overfitting.
    if input_embedding_mode == 'bert':
        # Load vanilla bert-base-cased model for input embedding.
        tokenizer = AutoTokenizer.from_pretrained(f'{dir}/language_model/{input_embedding_mode}/tokenizer')
        pretrained_model = BertModel.from_pretrained(f'{dir}/language_model/{input_embedding_mode}/transformer')
        pretrained_model.eval()
        pretrained_model.to(device)
        pretrained_model.resize_token_embeddings(tokenizer.vocab_size)

    elif input_embedding_mode == 'fasttext':
        # Load fasttext model for input embedding.
        pretrained_model = fasttext.load_model(f'{dir}/language_model/{input_embedding_mode}/cc.en.300.bin')
        tokenizer = None

    return (tokenizer, pretrained_model)

# Define model class.
class ReProCon(nn.Module):
    def __init__(self, params, ffn_hidden_dims, input_embedding_mode, type_ids):
        super().__init__()

        # Receive hyperparameters.
        self.train_test_split_ratio = params['train_test_split_ratio']
        self.num_classes = params['num_classes']
        self.dropout = params['dropout']
        self.input_dim = params['input_dim']
        self.max_input_tokens_length = params['max_input_tokens_length']
        self.positional_embedding_dim = params['positional_embedding_dim']
        self.bilstm_hidden_dim = params['bilstm_hidden_dim']
        self.bilstm_layers = params['bilstm_layers']
        self.lstm_embedding_dim = params['lstm_embedding_dim']
        self.projection_embedding_dim = params['projection_embedding_dim']
        self.temp = params['temp']
        self.prototype_train_learning_rate = params['prototype_train_learning_rate']
        self.prototype_train_epochs = params['prototype_train_epochs']
        self.prototype_train_patience = params['prototype_train_patience']
        self.prototype_num_per_class = params['prototype_num_per_class']
        self.shot_sample_number = params['shot_sample_number']
        self.meta_learning_rate = params['meta_learning_rate']
        self.task_learning_rate = params['task_learning_rate']
        self.meta_epochs = params['meta_epochs']
        self.task_epochs = params['task_epochs']
        self.adapt_patience = params['adapt_patience']
        self.batch_size = params['batch_size']
        self.input_embedding_mode = input_embedding_mode
        self.type_ids = type_ids
        self.prototype_loss_weight = params['prototype_loss_weight']

        # Positional embedding
        self.set_positional_embedding()

        # BiLSTM model
        self.bilstm_model = nn.LSTM(
            input_size=self.input_dim+self.positional_embedding_dim,
            hidden_size=self.bilstm_hidden_dim,
            num_layers=self.bilstm_layers,
            bias=True,
            batch_first=True,
            dropout=self.dropout,
            bidirectional=True,
            proj_size=self.lstm_embedding_dim, # Since it is bidirectional, it becomes *2
            device=device
        )

        # FFN model
        ffn_layers = []
        ffn_dims = [self.input_dim if self.input_embedding_mode == 'bert' else 2*self.lstm_embedding_dim]
        ffn_dims.extend(ffn_hidden_dims)
        for i in range(len(ffn_hidden_dims)):
            ffn_layers.append(nn.Linear(ffn_dims[i], ffn_dims[i+1]))
            ffn_layers.append(nn.BatchNorm1d(ffn_dims[i+1]))
            ffn_layers.append(nn.GELU())
            ffn_layers.append(nn.Dropout(self.dropout))
        ffn_layers.append(nn.Linear(ffn_dims[-1], self.projection_embedding_dim))
        ffn_layers.append(nn.LayerNorm(self.projection_embedding_dim))
        self.ffn_model = nn.Sequential(*ffn_layers)

        prototypes = np.random.rand(self.num_classes*self.prototype_num_per_class, self.projection_embedding_dim)
        self.prototypes = torch.tensor(prototypes, dtype=torch.float32)
        self.prototypes.data = F.normalize(self.prototypes.data, dim=1)
        self.prototypes.requires_grad_(True)

    def set_positional_embedding(self):
        pos = torch.arange(0, self.max_input_tokens_length).unsqueeze(1)
        cols = torch.arange(0, self.positional_embedding_dim).unsqueeze(0)
        position_tensor = pos / (torch.pow(10000, (2*(cols//2)) / self.positional_embedding_dim))

        position_tensor[:, 0::2] = torch.sin(position_tensor[:, 0::2])
        position_tensor[:, 1::2] = torch.cos(position_tensor[:, 1::2])

        self.positional_embedding = position_tensor.to(device)

    def input_embed(self, sample, language_model):
        mark_index = sample[0].index("[MARK_POSITION]")
        full_sentence = sample[0].copy()
        full_sentence[mark_index] = sample[1]

        if self.input_embedding_mode == 'bert':
            tokenizer, pretrained_model = language_model
            named_entity_start_index = 1
            named_entity_end_index = 1
            for i in range(len(full_sentence)):
                if i == mark_index:
                    named_entity_encodings = tokenizer.encode(full_sentence[i])
                    named_entity_tokens = tokenizer.convert_ids_to_tokens(named_entity_encodings)
                    named_entity_tokens.remove("[CLS]")
                    named_entity_tokens.remove("[SEP]")
                    named_entity_end_index = named_entity_start_index + len(named_entity_tokens)
                    break
                else:
                    not_named_entity_encodings = tokenizer.encode(full_sentence[i])
                    not_named_entity_tokens = tokenizer.convert_ids_to_tokens(not_named_entity_encodings)
                    not_named_entity_tokens.remove("[CLS]")
                    not_named_entity_tokens.remove("[SEP]")
                    named_entity_start_index += len(not_named_entity_tokens)

            full_sentence_text = ' '.join(full_sentence)
            full_sentence_encodings = tokenizer(
                full_sentence_text,
                max_length=self.max_input_tokens_length,
                return_tensors='pt',
                padding="max_length",
                truncation=True,
            )
            full_sentence_encodings = {k: v.to(device) for k, v in full_sentence_encodings.items()}
            with torch.no_grad():
                outputs = pretrained_model(**full_sentence_encodings)
                named_entity_embedding = outputs.last_hidden_state[:, named_entity_start_index:named_entity_end_index, :]
                if torch.cuda.is_available():
                    torch.cuda.synchronize()
                max_pool_tensor, _ = torch.max(named_entity_embedding, dim=1)
                mean_pool_tensor = torch.mean(named_entity_embedding, dim=1)
                span_representation = (max_pool_tensor+mean_pool_tensor)/2

                return span_representation, None

        elif self.input_embedding_mode == 'fasttext':
            _tokenizer, pretrained_model = language_model
            forwarded_embedding = torch.zeros(self.max_input_tokens_length, self.input_dim)
            for i in range(len(full_sentence)):
                tokens = full_sentence[i].split(' ')
                if len(tokens) > 1:
                    normalized_embeddings = []
                    for token in tokens:
                        token_vector = pretrained_model.get_word_vector(token)
                        norm = np.sqrt(np.sum(token_vector**2))
                        if not norm == 0:
                            normalized_embeddings.append(token_vector/norm)
                        else:
                            normalized_embeddings.append(token_vector)

                    # Use weighted averaging for better representation
                    if len(normalized_embeddings) > 0:
                        weights = np.array([1.0 + 0.1*(i-len(normalized_embeddings)/2)**2 for i in range(len(normalized_embeddings))])
                        weights = weights / weights.sum()
                        mean_vector = np.average(normalized_embeddings, axis=0, weights=weights)
                    else:
                        mean_vector = np.zeros(self.input_dim)
                    span_representation = torch.from_numpy(mean_vector).unsqueeze(0).float().to(device)
                    forwarded_embedding[i] = span_representation
                else:
                    token_vector = pretrained_model.get_word_vector(tokens[0])
                    norm = np.sqrt(np.sum(token_vector**2))
                    if not norm == 0:
                        token_vector = token_vector/norm
                    else:
                        token_vector = token_vector
                    span_representation = torch.from_numpy(token_vector).unsqueeze(0).float().to(device)
                    forwarded_embedding[i] = span_representation

            forwarded_embedding = forwarded_embedding.to(device)
            # This embedding will be forwarded to BiLSTM model. This is because fasttext itself does not reflect position information.
            # When embedding via BERT, adding positional embedding and forwarding BiLSTM will be skipped, since BERT model itself already includes this process.
            concatenated_embedding = torch.cat((forwarded_embedding, self.positional_embedding), dim=1)

            return concatenated_embedding, mark_index

    def forward(self, input, mark_index_list):
        projection = torch.empty((0, self.projection_embedding_dim))
        projection = projection.to(device)
        if self.input_embedding_mode == 'bert':
            projection = torch.cat((projection, self.ffn_model.forward(input)), dim=0)

        elif self.input_embedding_mode == 'fasttext':
            lstm_projection = torch.zeros(len(mark_index_list), self.lstm_embedding_dim*2)
            bilstm_forwarded_sequence, _ = self.bilstm_model(input)
            for i in range(len(mark_index_list)):
                mark_index = mark_index_list[i]
                projection_tensor = bilstm_forwarded_sequence[i][mark_index]
                lstm_projection[i] = projection_tensor
            lstm_projection = lstm_projection.to(device)

            projection = torch.cat((projection, self.ffn_model.forward(lstm_projection)), dim=0)

        self.prototypes = self.prototypes.to(device)

        projection = F.normalize(projection, p=2, dim=1)
        prototypes_normalized = self.prototypes
        similarity_matrix = torch.mm(projection, prototypes_normalized.t())
        square_similarity_matrix = -(1-similarity_matrix)*(1-similarity_matrix)
        logits = torch.softmax(square_similarity_matrix/self.temp, dim=1)
        classification_result_list = torch.argmax(logits, dim=1).tolist()
        for i in range(len(classification_result_list)):
            classification_result_list[i] = classification_result_list[i] // self.prototype_num_per_class

        return projection, logits, classification_result_list

    def prototype_loss(self):
        prototypes = self.prototypes.to(device)
        prototypes = F.normalize(prototypes, p=2, dim=1)
        num_types = prototypes.size(0)

        # Code from https://github.com/psmmettes/hpn
        # Use this to minimize mean maximum cosine similarity
        similarity = torch.mm(prototypes, prototypes.t())+1
        similarity -= 2.0 * torch.diag(torch.diag(similarity))
        loss = similarity.max(dim=1)[0].mean()

        return loss

    # Improved span loss function using supervised contrastive learning loss
    def span_loss(self, y_hat_projection, y_classification_label):
        y_classification_label_copy = y_classification_label.copy()
        y_classification_label_copy = torch.Tensor(y_classification_label_copy).long().to(device)
        # Code from https://arxiv.org/abs/2004.11362
        y_hat_projection_normalized = F.normalize(y_hat_projection, p=2, dim=1)
        y_hat_projection_normalized = y_hat_projection_normalized.to(device)
        prototypes_normalized = self.prototypes.to(device)
        similarity_matrix = torch.mm(prototypes_normalized, y_hat_projection_normalized.t())

        # Supervised Contrastive Learning
        if ablation != "contrastive_learning":
            # Code from https://arxiv.org/abs/1901.10514
            square_similarity_matrix = (1-similarity_matrix)*(1-similarity_matrix)
            # Minimum pooling per each category
            pooled_square_similarity_matrix = torch.zeros(self.num_classes, len(y_classification_label))
            for i in range(self.num_classes):
                square_similarity_matrix_per_class = square_similarity_matrix[i*self.prototype_num_per_class:(i+1)*self.prototype_num_per_class, :]
                pooled_square_similarity_matrix[i] = square_similarity_matrix_per_class.min(dim=0)[0]
            pooled_square_similarity_matrix = pooled_square_similarity_matrix.to(device)

            mask = torch.zeros(self.num_classes, len(y_classification_label), dtype=torch.bool)
            for i in range(len(y_classification_label)):
                mask[y_classification_label[i]][i] = True
            mask = mask.to(device)
            masked_square_similarity_matrix = pooled_square_similarity_matrix*mask

            positive_mask_counts = mask.sum(dim=1).float()

            # Since anchor is a prototype not a sample, skip removing auto-similarities
            similarity_sum = torch.sum(pooled_square_similarity_matrix, dim=1)
            positive_similarity_sum = torch.sum(masked_square_similarity_matrix, dim=1)
            positive_similarity_mean = positive_similarity_sum / positive_mask_counts
            contrastive_loss_per_class = positive_similarity_mean / similarity_sum
            loss = torch.sum(contrastive_loss_per_class, dim=0)
            loss = loss.to(device)
        else:
            for i in range(self.num_classes):
                class_sample_range = (self.shot_sample_number*i, self.shot_sample_number*(i+1))
                class_prototype_range = (self.prototype_num_per_class*i, self.prototype_num_per_class*(i+1))
                class_similarity_matrix = similarity_matrix[class_prototype_range[0]:class_prototype_range[1], class_sample_range[0]:class_sample_range[1]]
                maximum_similarity_indices = torch.argmax(class_similarity_matrix, dim=0)
                maximum_similarity_indices += class_sample_range[0]
                y_classification_label_copy[class_sample_range[0]:class_sample_range[1]] = maximum_similarity_indices
            logit = torch.softmax(similarity_matrix, dim=1).t()
            loss = F.cross_entropy(logit, y_classification_label_copy)

        return loss

    def select_hard_negatives(self, task_set_type_id_list_of_samples, language_model):
        self.eval()

        hard_negative_set_type_id_list_of_samples = {}

        for key, samples in tqdm(task_set_type_id_list_of_samples.items()):
            collect = []

            if self.input_embedding_mode == 'bert':
                test_batch_input = torch.zeros(
                    self.batch_size,
                    self.input_dim
                )
            elif self.input_embedding_mode == 'fasttext':
                test_batch_input = torch.zeros(
                    self.batch_size,
                    self.max_input_tokens_length,
                    self.input_dim+self.positional_embedding_dim
                )
            test_batch_mark_index_list = []
            intra_batch_count = 0
            multiple_count = 0
            for sample in task_set_type_id_list_of_samples[key]:
                input_embedding, mark_index = self.input_embed(sample, language_model)
                test_batch_input[intra_batch_count] = input_embedding
                if self.input_embedding_mode == 'fasttext':
                    test_batch_mark_index_list.append(mark_index)
                intra_batch_count += 1
                if intra_batch_count == self.batch_size:
                    test_batch_input = test_batch_input.to(device)
                    _, _, test_batch_classification = self.forward(test_batch_input, test_batch_mark_index_list)
                    for i in range(intra_batch_count):
                        if test_batch_classification[i] != int(key):
                            collect.append(samples[multiple_count*self.batch_size+i])

                    intra_batch_count = 0
                    multiple_count += 1

                    del test_batch_classification
                    gc.collect()

                    if self.input_embedding_mode == 'bert':
                        test_batch_input = torch.zeros(
                            self.batch_size,
                            self.input_dim
                        )
                    elif self.input_embedding_mode == 'fasttext':
                        test_batch_input = torch.zeros(
                            self.batch_size,
                            self.max_input_tokens_length,
                            self.input_dim+self.positional_embedding_dim
                        )
                    test_batch_mark_index_list = []
            if intra_batch_count != 0:
                test_batch_input = test_batch_input[:intra_batch_count].to(device)
                _, _, test_batch_classification = self.forward(test_batch_input, test_batch_mark_index_list)
                for i in range(intra_batch_count):
                    if test_batch_classification[i] != int(key):
                        collect.append(samples[multiple_count*self.batch_size+i])

            hard_negative_set_type_id_list_of_samples[int(key)] = collect

            del test_batch_classification
            gc.collect()

            # Empty cuda
            test_batch_input = test_batch_input.cpu()
            del test_batch_input
            gc.collect()
            torch.cuda.empty_cache()

        return hard_negative_set_type_id_list_of_samples

# Total training process.
def train(model_class, meta_model, params, ffn_hidden_dims, input_embedding_mode, type_ids, tasks, validation_set_type_id_list_of_samples, language_model, is_initial):
    meta_model.to(device)
    meta_model.train()

    # Assign random tasks.
    task_indices = list(range(len(tasks)))
    random.shuffle(task_indices)

    logs = []
    patience_counter = 0

    # F1 score would be better benchmark.
    best_f1_score = 0
    local_maxima_f1_score = 0
    best_meta_ffn_model_state = None
    local_maxima_meta_ffn_model_state = None
    if input_embedding_mode == 'fasttext':
        best_meta_lstm_model_state = None
        local_maxima_meta_lstm_model_state = None

    for meta_epoch in range(meta_model.meta_epochs):
        log = {}
        log['meta_epoch'] = meta_epoch+1

        # Bring task.
        support_set_type_id_list_of_samples = tasks[task_indices[meta_epoch]]

        # Clone meta model.
        task_model=model_class(
            params=params,
            ffn_hidden_dims=ffn_hidden_dims,
            input_embedding_mode=input_embedding_mode,
            type_ids=type_ids,
        )

        if input_embedding_mode == 'fasttext':
            task_model.bilstm_model.load_state_dict(meta_model.bilstm_model.state_dict())
        task_model.ffn_model.load_state_dict(meta_model.ffn_model.state_dict())
        task_model.prototypes.data = meta_model.prototypes.detach().clone().to(device)
        task_model.to(device)
        task_model.train()

        if input_embedding_mode == 'bert':
            task_model_optimizer = optim.AdamW(
                params=list(task_model.ffn_model.parameters())+[task_model.prototypes],
                lr=meta_model.task_learning_rate,
                weight_decay=0.01
            )
        if input_embedding_mode == 'fasttext':
            task_model_optimizer = optim.AdamW(
                params=list(task_model.bilstm_model.parameters())+list(task_model.ffn_model.parameters())+[task_model.prototypes],
                lr=meta_model.task_learning_rate,
                weight_decay=0.01
            )

        # Cosine annealing scheduler for better optimization.
        task_model_scheduler = optim.lr_scheduler.CosineAnnealingLR(task_model_optimizer, T_max=meta_model.task_epochs)

        if input_embedding_mode == 'fasttext':
            old_task_bilstm_model_state = copy.deepcopy(task_model.bilstm_model.state_dict())
        old_task_ffn_model_state = copy.deepcopy(task_model.ffn_model.state_dict())
        old_task_prototypes_state = meta_model.prototypes.detach().clone().to(device)

        for task_epoch in range(meta_model.task_epochs):
            task_model_optimizer.zero_grad()

            y_support_classification_label = []
            for key in support_set_type_id_list_of_samples.keys():
                y_support_classification_label.extend([key]*len(support_set_type_id_list_of_samples[key]))

            y_hat_support_projection = torch.zeros(
                len(y_support_classification_label),
                task_model.projection_embedding_dim
            )
            if input_embedding_mode == 'bert':
                support_batch_input = torch.zeros(
                    task_model.batch_size,
                    task_model.input_dim
                )
            elif input_embedding_mode == 'fasttext':
                support_batch_input = torch.zeros(
                    task_model.batch_size,
                    task_model.max_input_tokens_length,
                    task_model.input_dim+task_model.positional_embedding_dim
                )
            support_batch_mark_index_list = []
            intra_batch_count = 0
            multiple_count = 0
            for key in support_set_type_id_list_of_samples.keys():
                for sample in support_set_type_id_list_of_samples[key]:
                    input_embedding, mark_index = task_model.input_embed(sample, language_model)
                    support_batch_input[intra_batch_count] = input_embedding
                    if input_embedding_mode == 'fasttext':
                        support_batch_mark_index_list.append(mark_index)
                    intra_batch_count += 1
                    if intra_batch_count == task_model.batch_size:
                        support_batch_input = support_batch_input.to(device)
                        support_batch_projection, _, _ = task_model.forward(support_batch_input, support_batch_mark_index_list)
                        for i in range(intra_batch_count):
                            y_hat_support_projection[multiple_count*meta_model.batch_size+i] = support_batch_projection[i]
                        intra_batch_count = 0
                        multiple_count += 1
                        if input_embedding_mode == 'bert':
                            support_batch_input = torch.zeros(
                                task_model.batch_size,
                                task_model.input_dim
                            )
                        elif input_embedding_mode == 'fasttext':
                            support_batch_input = torch.zeros(
                                task_model.batch_size,
                                task_model.max_input_tokens_length,
                                task_model.input_dim+task_model.positional_embedding_dim
                            )
                        support_batch_mark_index_list = []
            if intra_batch_count != 0:
                support_batch_input = support_batch_input[:intra_batch_count].to(device)
                support_batch_projection, _, _ = task_model.forward(support_batch_input, support_batch_mark_index_list)
                for i in range(intra_batch_count):
                    y_hat_support_projection[multiple_count*meta_model.batch_size+i] = support_batch_projection[i]

            # Calculate support loss with weighted components
            supp_loss_d = task_model.prototype_loss()*task_model.prototype_loss_weight
            supp_loss_s = task_model.span_loss(y_hat_support_projection, y_support_classification_label)
            supp_loss = supp_loss_d+supp_loss_s
            log[f'task_model_epoch_{task_epoch+1}_support_loss'] = supp_loss.item()
            log[f'task_model_epoch_{task_epoch+1}_support_loss_d'] = supp_loss_d.item()
            log[f'task_model_epoch_{task_epoch+1}_support_loss_s'] = supp_loss_s.item()
            log[f'task_model_epoch_{task_epoch+1}_task_learning_rate'] = task_model_optimizer.param_groups[0]['lr']

            supp_loss.backward()
            torch.nn.utils.clip_grad_norm_(task_model.parameters(), 1.0) # Add gradient clipping
            task_model_optimizer.step()
            task_model_scheduler.step()
            print(f"Task epoch {task_epoch+1}, Support loss {supp_loss.item():.4f}")

        if input_embedding_mode == 'fasttext':
            new_task_bilstm_model_state = copy.deepcopy(task_model.bilstm_model.state_dict())
            new_meta_bilstm_model_state = meta_model.bilstm_model.state_dict()

            with torch.no_grad():
                for name in new_meta_bilstm_model_state.keys():
                    if name in new_task_bilstm_model_state.keys():
                        if name in [n for n, _ in meta_model.bilstm_model.named_parameters()]:
                            new_meta_bilstm_model_state[name] += (new_task_bilstm_model_state[name]-old_task_bilstm_model_state[name])*meta_model.meta_learning_rate
                        else:
                            new_meta_bilstm_model_state[name] = new_task_bilstm_model_state[name]

            meta_model.bilstm_model.load_state_dict(new_meta_bilstm_model_state)

        new_task_ffn_model_state = copy.deepcopy(task_model.ffn_model.state_dict())
        new_meta_ffn_model_state = meta_model.ffn_model.state_dict()

        with torch.no_grad():
            for name in new_meta_ffn_model_state.keys():
                if name in new_task_ffn_model_state.keys():
                    if name in [n for n, _ in meta_model.ffn_model.named_parameters()]:
                        new_meta_ffn_model_state[name] += (new_task_ffn_model_state[name]-old_task_ffn_model_state[name])*meta_model.meta_learning_rate
                    else:
                        new_meta_ffn_model_state[name] = new_task_ffn_model_state[name]

        meta_model.ffn_model.load_state_dict(new_meta_ffn_model_state)

        new_task_prototypes_state = task_model.prototypes.detach().clone().to(device)
        updated_prototypes = old_task_prototypes_state + (new_task_prototypes_state-old_task_prototypes_state)*meta_model.meta_learning_rate
        meta_model.prototypes.data = F.normalize(updated_prototypes, p=2, dim=1)

        # Task validation
        meta_model.eval()
        with torch.no_grad():
            y_validation_classification_label = []
            for key in validation_set_type_id_list_of_samples.keys():
                y_validation_classification_label.extend([int(key)]*len(validation_set_type_id_list_of_samples[key]))

            y_hat_validation_classification_label = []
            y_hat_validation_projection = torch.zeros(
                len(y_validation_classification_label),
                meta_model.projection_embedding_dim
            )
            if input_embedding_mode == 'bert':
                validation_batch_input = torch.zeros(
                    meta_model.batch_size,
                    meta_model.input_dim
                )
            elif input_embedding_mode == 'fasttext':
                validation_batch_input = torch.zeros(
                    meta_model.batch_size,
                    meta_model.max_input_tokens_length,
                    meta_model.input_dim+meta_model.positional_embedding_dim
                )
            validation_batch_mark_index_list = []
            intra_batch_count = 0
            multiple_count = 0
            for key in validation_set_type_id_list_of_samples.keys():
                for sample in validation_set_type_id_list_of_samples[key]:
                    input_embedding, mark_index = meta_model.input_embed(sample, language_model)
                    validation_batch_input[intra_batch_count] = input_embedding
                    if input_embedding_mode == 'fasttext':
                        validation_batch_mark_index_list.append(mark_index)
                    intra_batch_count += 1
                    if intra_batch_count == meta_model.batch_size:
                        validation_batch_input = validation_batch_input.to(device)
                        validation_batch_projection, _, validation_batch_classification = meta_model.forward(validation_batch_input, validation_batch_mark_index_list)
                        y_hat_validation_classification_label.extend(validation_batch_classification)
                        for i in range(intra_batch_count):
                            y_hat_validation_projection[multiple_count*meta_model.batch_size+i] = validation_batch_projection[i]
                        intra_batch_count = 0
                        multiple_count += 1
                        if input_embedding_mode == 'bert':
                            validation_batch_input = torch.zeros(
                                meta_model.batch_size,
                                meta_model.input_dim
                            )
                        elif input_embedding_mode == 'fasttext':
                            validation_batch_input = torch.zeros(
                                meta_model.batch_size,
                                meta_model.max_input_tokens_length,
                                meta_model.input_dim+meta_model.positional_embedding_dim
                            )
                        validation_batch_mark_index_list = []
            if intra_batch_count != 0:
                validation_batch_input = validation_batch_input[:intra_batch_count].to(device)
                validation_batch_projection, _, validation_batch_classification = meta_model.forward(validation_batch_input, validation_batch_mark_index_list)
                y_hat_validation_classification_label.extend(validation_batch_classification)
                for i in range(intra_batch_count):
                    y_hat_validation_projection[multiple_count*meta_model.batch_size+i] = validation_batch_projection[i]

            # Calculate validation F1 score.
            y_validation_classification_label_tensor = torch.tensor(y_validation_classification_label)
            y_hat_validation_classification_label_tensor = torch.tensor(y_hat_validation_classification_label)
            val_total_f1_score_value = f1_score(y_validation_classification_label_tensor, y_hat_validation_classification_label_tensor, average='macro')
            log['meta_model_validation_f1_score_percent'] = val_total_f1_score_value*100

            # Calculate validation loss.
            val_loss_d = meta_model.prototype_loss()*meta_model.prototype_loss_weight
            val_loss_s = meta_model.span_loss(y_hat_validation_projection, y_validation_classification_label)
            val_loss = val_loss_d+val_loss_s

            log['meta_model_validation_loss'] = val_loss.item()
            log['meta_model_validation_loss_d'] = val_loss_d.item()
            log['meta_model_validation_loss_s'] = val_loss_s.item()
            print(f"\n==== Meta epoch {meta_epoch+1}, Validation F1 score {val_total_f1_score_value*100:.4f}%, Validation loss {val_loss.item():.4f} ====\n")

            if val_total_f1_score_value > local_maxima_f1_score:
                local_maxima_f1_score = val_total_f1_score_value
                if input_embedding_mode == 'fasttext':
                    local_maxima_meta_lstm_model_state = meta_model.bilstm_model.state_dict()
                local_maxima_meta_ffn_model_state = meta_model.ffn_model.state_dict()
                patience_counter = 0
                if local_maxima_f1_score > best_f1_score:
                    best_f1_score = local_maxima_f1_score
                    local_maxima_f1_score = 0
                    if input_embedding_mode == 'fasttext':
                        best_meta_lstm_model_state = local_maxima_meta_lstm_model_state
                    best_meta_ffn_model_state = local_maxima_meta_ffn_model_state
                    if (1.0-local_maxima_f1_score) < 1e-8:
                        break
            else:
                patience_counter += 1
                if patience_counter >= meta_model.adapt_patience:
                    break

        logs.append(log)

    if input_embedding_mode == 'bert':
        if best_meta_ffn_model_state is not None:
            meta_model.ffn_model.load_state_dict(best_meta_ffn_model_state)
        else:
            meta_model.ffn_model.load_state_dict(local_maxima_meta_ffn_model_state)

    elif input_embedding_mode == 'fasttext':
        if best_meta_lstm_model_state is not None and best_meta_ffn_model_state is not None:
            meta_model.bilstm_model.load_state_dict(best_meta_lstm_model_state)
            meta_model.ffn_model.load_state_dict(best_meta_ffn_model_state)
        else:
            # Fallback to latest model if no improvement found
            meta_model.bilstm_model.load_state_dict(local_maxima_meta_lstm_model_state)
            meta_model.ffn_model.load_state_dict(local_maxima_meta_ffn_model_state)

    if input_embedding_mode == 'fasttext':
        del best_meta_lstm_model_state, local_maxima_meta_lstm_model_state

    del best_meta_ffn_model_state, local_maxima_meta_ffn_model_state, new_task_ffn_model_state, old_task_ffn_model_state, new_task_prototypes_state, old_task_prototypes_state
    gc.collect()

    # Clear every cache in scheduler
    for param in task_model_optimizer.state.values():
        if isinstance(param, torch.Tensor):
            param.data = param.data.cpu()
            if param._grad is not None:
                param._grad.data = param._grad.data.cpu()
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.cpu()
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.cpu()

    del task_model_optimizer, task_model_scheduler, task_model
    gc.collect()
    torch.cuda.empty_cache()

    log_df = pd.DataFrame(logs)

    return log_df

def make_task_set(task_set_type_id_list_of_samples, meta_epochs, shot_sample_number, num_classes):
    # Tasks are combinations of support set.
    type_id_combinations = {}
    for type_id, list_of_samples in tqdm(task_set_type_id_list_of_samples.items()):
        list_of_samples_copy = list_of_samples.copy()
        combinations = []
        while len(combinations) < meta_epochs:
            support_pairs = list(tuple(sorted(random.sample(list_of_samples_copy, shot_sample_number))))
            already_in_list = False
            for combination in combinations:
                if combination == support_pairs:
                    already_in_list = True
                    break
            if not already_in_list:
                combinations.append(support_pairs)
        type_id_combinations[int(type_id)] = combinations

    tasks = []
    for i in range(meta_epochs):
        support_set_type_id_list_of_samples = {}
        for j in range(num_classes):
            support_set_type_id_list_of_samples[j] = type_id_combinations[j][i]
        tasks.append(support_set_type_id_list_of_samples)

    return tasks

In [None]:
meta_epochs = 200
shot_sample_number = 5
task_sample_ratio = 0.3
input_embedding_mode = 'fasttext'
embedding_mode_to_input_dim = {'bert': 768, 'fasttext': 300}
ffn_hidden_dims = [1024]
language_model = load_language_model(input_embedding_mode)

# Load task_set_type_id_list_of_samples
with open(f'{dir}/ablation_study/task_set_type_id_list_of_samples.json', 'r') as f:
    task_set_type_id_list_of_samples = json.load(f)
# Load validation_set_type_id_list_of_samples
with open(f'{dir}/ablation_study/validation_set_type_id_list_of_samples.json', 'r') as f:
    validation_set_type_id_list_of_samples = json.load(f)

# Limit total number of samples to 30000, due to limitations of computing resources
for type_id, list_of_samples in task_set_type_id_list_of_samples.items():
    scrambled = list_of_samples.copy()
    random.shuffle(scrambled)
    task_set_type_id_list_of_samples[type_id] = scrambled[:30000]
# Limit total number of samples to 500, due to limitations of computing resources
for type_id, list_of_samples in validation_set_type_id_list_of_samples.items():
    scrambled = list_of_samples.copy()
    random.shuffle(scrambled)
    validation_set_type_id_list_of_samples[type_id] = scrambled[:500]

if input_embedding_mode == 'bert':
    params = {
        'train_test_split_ratio': task_sample_ratio,
        'num_classes': len(type_ids),
        'dropout': 0.5,
        'input_dim': embedding_mode_to_input_dim[input_embedding_mode],
        'max_input_tokens_length': 300,
        'positional_embedding_dim': 200,
        'bilstm_hidden_dim': 1024,
        'bilstm_layers': 1,
        'lstm_embedding_dim': 512,
        'projection_embedding_dim': 50,
        'temp': 0.1,
        'prototype_train_learning_rate': 0.2,
        'prototype_train_epochs': 1000,
        'prototype_train_patience': 5,
        'prototype_num_per_class': 10,
        'shot_sample_number': shot_sample_number,
        'meta_learning_rate': 0.5,
        'task_learning_rate': 5e-4,
        'meta_epochs': meta_epochs,
        'task_epochs': 3,
        'adapt_patience': 100,
        'batch_size': 256,
        'prototype_loss_weight': 1.0,
    }
elif input_embedding_mode == 'fasttext':
    params = {
        'train_test_split_ratio': task_sample_ratio,
        'num_classes': len(type_ids),
        'dropout': 0.5,
        'input_dim': embedding_mode_to_input_dim[input_embedding_mode],
        'max_input_tokens_length': 300,
        'positional_embedding_dim': 200,
        'bilstm_hidden_dim': 1024,
        'bilstm_layers': 1,
        'lstm_embedding_dim': 512,
        'projection_embedding_dim': 50,
        'temp': 0.1,
        'prototype_train_learning_rate': 0.2,
        'prototype_train_epochs': 1000,
        'prototype_train_patience': 5,
        'prototype_num_per_class': 10,
        'shot_sample_number': shot_sample_number,
        'meta_learning_rate': 0.4,
        'task_learning_rate': 1e-3,
        'meta_epochs': meta_epochs,
        'task_epochs': 5,
        'adapt_patience': 100,
        'batch_size': 256,
        'prototype_loss_weight': 1.0,
    }

if ablation == "multiple_prototype":
    params['prototype_num_per_class'] = 1

# Train
log_df = pd.DataFrame()

# Initialize model
model = ReProCon(
    params=params,
    ffn_hidden_dims=ffn_hidden_dims,
    input_embedding_mode=input_embedding_mode,
    type_ids=type_ids,
)

# Initial task set
print("Making initial task set...")
tasks = make_task_set(task_set_type_id_list_of_samples, params['meta_epochs'], params['shot_sample_number'], len(type_ids))

# Initial train
add_log_df = train(
    model_class=ReProCon,
    meta_model=model,
    params=params,
    ffn_hidden_dims=ffn_hidden_dims,
    input_embedding_mode=input_embedding_mode,
    type_ids=type_ids,
    tasks=tasks,
    validation_set_type_id_list_of_samples=validation_set_type_id_list_of_samples,
    language_model=language_model,
    is_initial=True
)
log_df = pd.concat([log_df, add_log_df], ignore_index=True, axis=0)

if ablation != "hard_negative":
    # Collect hard negatives
    hard_negative_set_type_id_list_of_samples = model.select_hard_negatives(task_set_type_id_list_of_samples, language_model)

    # Hard negative task set
    print("Making hard negative task set...")
    tasks = make_task_set(hard_negative_set_type_id_list_of_samples, params['meta_epochs'], params['shot_sample_number'], len(type_ids))

    # Hard negative train
    add_log_df = train(
        model_class=ReProCon,
        meta_model=model,
        params=params,
        ffn_hidden_dims=ffn_hidden_dims,
        input_embedding_mode=input_embedding_mode,
        type_ids=type_ids,
        tasks=tasks,
        validation_set_type_id_list_of_samples=validation_set_type_id_list_of_samples,
        language_model=language_model,
        is_initial=False
    )
    log_df = pd.concat([log_df, add_log_df], ignore_index=True, axis=0)

# Save log
log_df.to_csv(f'{dir}/ablation_study/{ablation}/train_log.csv', index=False)

# Save model
if input_embedding_mode == 'fasttext':
    torch.save(model.bilstm_model.state_dict(), f'{dir}/ablation_study/{ablation}/bilstm_model.pt')
torch.save(model.ffn_model.state_dict(), f'{dir}/ablation_study/{ablation}/ffn_model.pt')
torch.save(model.prototypes, f'{dir}/ablation_study/{ablation}/prototypes.pt')

# Draw heatmap of similarities between prototypes, and then save the heatmap image
sim_matrix = torch.mm(model.prototypes, model.prototypes.t())
sim_matrix = sim_matrix.cpu().detach().numpy()
colors = ["darkred", "white", "darkgreen"]
custom_cmap = LinearSegmentedColormap.from_list("red_white_green", colors, N=256)
plt.figure(figsize=(14, 12))
expanded_type_ids = []
for type_id in model.type_ids:
    expanded_type_ids.extend([type_id])
    expanded_type_ids.extend(['']*(model.prototype_num_per_class-1))
sns.heatmap(sim_matrix, annot=False, cmap=custom_cmap, vmin=-1, vmax=1, xticklabels=expanded_type_ids, yticklabels=expanded_type_ids)
plt.title('Similarity Matrix between Prototype Vectors')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0, ha='right')
plt.tight_layout()
plt.savefig(f'{dir}/ablation_study/{ablation}/prototype_similarity_heatmap.png')

# Empty cuda
model = model.cpu()
del model
gc.collect()
torch.cuda.empty_cache()

# Clear others
if ablation != "hard_negative":
    del hard_negative_set_type_id_list_of_samples

del tasks, add_log_df, log_df, task_set_type_id_list_of_samples
gc.collect()

In [None]:
from google.colab import runtime

runtime.unassign()