In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import random
import time
import datetime
import json
import pandas as pd
import numpy as np
import torch
import pickle
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from transformers import BertForSequenceClassification, BertTokenizer, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


import lightning as L
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint

In [142]:
from data import DataCollector, Tokenizer, PrototypicalBatchSampler, SequenceClassificationDataset,FewShotBatchSampler
from models import BiLSTMEncoder

In [143]:
train_labels_subset = [0,1,2,3]
dev_labels_subset = [1,2,3]
test_labels_subset = [4,5,6]

In [144]:
data_root_path = '../../data/bert_sentence_independent_embeddings/'

train_data_path = Path(data_root_path,'CL_train.pkl')
val_data_path = Path(data_root_path,'CL_train.pkl')
test_data_path = Path(data_root_path,'CL_train.pkl')


train_data = SequenceClassificationDataset(train_data_path, train_labels_subset)
val_data = SequenceClassificationDataset(val_data_path, dev_labels_subset)
test_set = SequenceClassificationDataset(test_data_path, test_labels_subset)

loaded data with 5803 sentence embedding with labels subset [0 1 2 3]
loaded data with 2523 sentence embedding with labels subset [1 2 3]
loaded data with 4737 sentence embedding with labels subset [4 5 6]


In [145]:
N = 4
K = 5
SAVE_PATH = './logs'

In [146]:
train_sampler = FewShotBatchSampler(torch.tensor(train_data.get_labels()), N, K,shuffle=True)
val_sampler = FewShotBatchSampler(torch.tensor(val_data.get_labels()), N, K, shuffle=False, shuffle_once=True)

In [147]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_sampler=train_sampler,  num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_sampler=val_sampler, num_workers=4)

In [148]:
def split_batch(batch):
    sentences, labels = batch['embeddings'], batch['labels']
    support, query = sentences.chunk(2, dim=0)
    support_targets, query_targets = labels.chunk(2, dim=0)
    return support, query, support_targets, query_targets

In [149]:
class ProtoNet(L.LightningModule):
    def __init__(self, input_size, hidden_size, lr):
        """Inputs.

        proto_dim - Dimensionality of prototype feature space
        lr - Learning rate of Adam optimizer
        """
        super().__init__()
        self.save_hyperparameters()
        self.model = BiLSTMEncoder(input_size, hidden_size)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[140, 180], gamma=0.1)
        return [optimizer], [scheduler]

    @staticmethod
    def calculate_prototypes(features, targets):
        # Given a stack of features vectors and labels, return class prototypes
        # features - shape [N, proto_dim], targets - shape [N]
        classes, _ = torch.unique(targets).sort()  # Determine which classes we have
        prototypes = []
        for c in classes:
            p = features[torch.where(targets == c)[0]].mean(dim=0)  # Average class feature vectors
            prototypes.append(p)
        prototypes = torch.stack(prototypes, dim=0)
        # Return the 'classes' tensor to know which prototype belongs to which class
        return prototypes, classes

    def classify_feats(self, prototypes, classes, feats, targets):
        # Classify new examples with prototypes and return classification error
        dist = torch.pow(prototypes[None, :] - feats[:, None], 2).sum(dim=2)  # Squared euclidean distance
        preds = F.log_softmax(-dist, dim=1)
        labels = (classes[None, :] == targets[:, None]).long().argmax(dim=-1)
        acc = (preds.argmax(dim=1) == labels).float().mean()
        
        from sklearn.metrics import f1_score
        # Calculate F1 score
        preds_flat = preds.argmax(dim=1).cpu().numpy()
        labels_flat = labels.cpu().numpy()
        f1 = f1_score(labels_flat, preds_flat, average='weighted')
        
        return preds, labels, acc, f1

    def calculate_loss(self, batch, mode):
        # Determine training loss for a given support and query set
        support_feats, query_feats, support_targets, query_targets = split_batch(batch)
        # Enocde the data
        support_feats = self.model(support_feats)
        query_feats = self.model(query_feats)
        
        prototypes, classes = ProtoNet.calculate_prototypes(support_feats, support_targets)
        
        preds, labels, acc, f1 = self.classify_feats(prototypes, classes, query_feats, query_targets)
        
        loss = F.cross_entropy(preds, labels)

        self.log("%s_loss" % mode, loss)
        self.log("%s_acc" % mode, acc)
        self.log("%s_f1" % mode, f1)
        return loss

    def training_step(self, batch, batch_idx):
        return self.calculate_loss(batch, mode="train")

    def validation_step(self, batch, batch_idx):
        self.calculate_loss(batch, mode="val")

In [150]:
from lightning.pytorch.loggers import TensorBoardLogger

def train_model(model_class, train_loader, val_loader, **kwargs):
    os.makedirs(SAVE_PATH, exist_ok=True)  # Create the directory if it doesn't exist
    
    trainer = L.Trainer(
        default_root_dir=SAVE_PATH,
        accelerator="auto",
        devices=1,
        max_epochs=50,
        logger=TensorBoardLogger(SAVE_PATH),
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
            LearningRateMonitor("epoch"),
        ],
        enable_progress_bar=True,
    )
    trainer.logger._default_hp_metric = None

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(SAVE_PATH, model_class.__name__ + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained model at %s, loading..." % pretrained_filename)
        # Automatically loads the model with the saved hyperparameters
        model = model_class.load_from_checkpoint(pretrained_filename)
    else:
        L.seed_everything(42)  # To be reproducable
        model = model_class(**kwargs)
        print(model)
        trainer.fit(model, train_loader, val_loader)
        model = model_class.load_from_checkpoint(
            trainer.checkpoint_callback.best_model_path
        )  # Load best checkpoint after training

    return model

In [151]:
import torch.optim as optim
input_size, hidden_size = 768, 384
lr = 2e-4

protonet_model = train_model(
    ProtoNet, train_loader=train_dataloader, val_loader=val_dataloader, input_size=input_size, hidden_size=hidden_size, lr=lr)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 42

  | Name  | Type          | Params
----------------------------------------
0 | model | BiLSTMEncoder | 1.5 M 
----------------------------------------
1.5 M     Trainable params
0         Non-trainable params
1.5 M     Total params
5.911     Total estimated model params size (MB)


ProtoNet(
  (model): BiLSTMEncoder(
    (bilstm): LSTM(768, 192, batch_first=True, bidirectional=True)
  )
)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x123088220>
Traceback (most recent call last):
  File "/Users/ahmed/Lnlp/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/Users/ahmed/Lnlp/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/ahmed/.pyenv/versions/3.11.3/lib/python3.11/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ahmed/.pyenv/versions/3.11.3/lib/python3.11/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ahmed/.pyenv/versions/3.11.3/lib/python3.11/multiprocessing/connection.py", line 930, in wait
    ready = selector.select(timeout)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/User

RuntimeError: DataLoader worker (pid(s) 7026, 7042, 7045, 7062) exited unexpectedly

In [152]:
protonet_model

ProtoNet(
  (model): BiLSTMEncoder(
    (bilstm): LSTM(768, 192, batch_first=True, bidirectional=True)
  )
)

In [285]:
pretrained_filename = os.path.join('./logs/lightning_logs/version_4/checkpoints',"epoch=87-step=47168.ckpt")

In [26]:
os.listdir('./logs/lightning_logs/version_4/checkpoints')

['epoch=87-step=47168.ckpt']

In [286]:
if os.path.exists(pretrained_filename):
    print("Found pretrained model at %s, loading..." % pretrained_filename)
    # Automatically loads the model with the saved hyperparameters
    model = ProtoNet.load_from_checkpoint(pretrained_filename)


Found pretrained model at ./logs/lightning_logs/version_4/checkpoints/epoch=87-step=47168.ckpt, loading...


In [30]:
model

ProtoNet(
  (model): BiLSTMEncoder(
    (bilstm): LSTM(768, 192, batch_first=True, bidirectional=True)
  )
)

In [287]:
device = model.device
device

device(type='cpu')

In [288]:
from data import PrototypicalBatchSampler

test_set = SequenceClassificationDataset(test_data_path, test_labels_subset)

loaded data with 4737 sentence embedding with labels subset [4 5 6]


In [289]:
from tqdm import tqdm

In [290]:
@torch.no_grad()
def test_proto_net(model, dataset, data_feats=None, k_shot=4):
    """Inputs.

    model - Pretrained ProtoNet model
    dataset - The dataset on which the test should be performed.
              Should be instance of ImageDataset
    data_feats - The encoded features of all images in the dataset.
                 If None, they will be newly calculated, and returned
                 for later usage.
    k_shot - Number of examples per class in the support set.
    """
    model = model.to(device)
    model.eval()
    num_classes = len(np.unique(dataset.get_labels()))
    exmps_per_class = len(dataset.get_labels()) // num_classes  # We assume uniform example distribution here

    # The encoder network remains unchanged across k-shot settings. Hence, we only need
    # to extract the features for all images once.
    if data_feats is None:
        # Dataset preparation
        dataloader = DataLoader(dataset, batch_size=128, num_workers=4, shuffle=False, drop_last=False)

        sent_features = []
        sent_targets = []
        for batch in tqdm(dataloader, "Extracting image features", leave=False, position=0):
            sent, targets = batch['embeddings'],batch['labels']
            
            sent = sent.to(device)
            feats = model.model(sent)
            
            sent_features.append(feats.detach().cpu())
            sent_targets.append(targets)
        sent_features = torch.cat(sent_features, dim=0)
        sent_targets = torch.cat(sent_targets, dim=0)
        # Sort by classes, so that we obtain tensors of shape [num_classes, exmps_per_class, ...]
        # Makes it easier to process later
        sent_targets, sort_idx = sent_targets.sort()
        sent_targets = sent_targets.reshape(num_classes, exmps_per_class).transpose(0, 1)
        sent_features = sent_features[sort_idx].reshape(num_classes, exmps_per_class, -1).transpose(0, 1)
    else:
        sent_features, sent_targets = data_feats
        
    #return

    # We iterate through the full dataset in two manners. First, to select the k-shot batch.
    # Second, the evaluate the model on all other examples
    accuracies = []
    f1s = []
    for k_idx in tqdm(range(0, sent_features.shape[0], k_shot), "Evaluating prototype classification", leave=False, position=0):
        # Select support set and calculate prototypes
        k_sent_feats = sent_features[k_idx : k_idx + k_shot].flatten(0, 1)
        k_targets = sent_targets[k_idx : k_idx + k_shot].flatten(0, 1)
        prototypes, proto_classes = model.calculate_prototypes(k_sent_feats, k_targets)
        # Evaluate accuracy on the rest of the dataset
        batch_acc = 0
        batch_f1 = 0
        for e_idx in range(0, sent_features.shape[0], k_shot):
            if k_idx == e_idx:  # Do not evaluate on the support set examples
                continue
            e_sent_feats = sent_features[e_idx : e_idx + k_shot].flatten(0, 1)
            e_targets = sent_targets[e_idx : e_idx + k_shot].flatten(0, 1)
            _, _, acc, f1 = model.classify_feats(prototypes, proto_classes, e_sent_feats, e_targets)
            batch_acc += acc.item()
            batch_f1 += f1
        
        batch_acc /= sent_features.shape[0] // k_shot - 1
        batch_f1 /= sent_features.shape[0] // k_shot - 1
        accuracies.append(batch_acc)
        f1s.append(batch_f1)

    return (mean(accuracies), stdev(accuracies)), (mean(f1s), stdev(f1s)), (sent_features, sent_targets)

In [291]:
from statistics import mean, stdev
protonet_accuracies = dict()
protonet_f1s = dict()
data_feats = None
for k in [2, 4, 8, 16, 32]:
    protonet_accuracies[k], protonet_f1s[k], data_feats = test_proto_net(protonet_model, test_set, data_feats=data_feats, k_shot=k)

                                                                                

Accuracy for k=2: 51.15% (+-10.64%)


                                                                                

Accuracy for k=4: 51.11% (+-9.29%)


                                                                                

Accuracy for k=8: 50.66% (+-8.50%)


                                                                                

Accuracy for k=16: 50.79% (+-7.70%)


                                                                                

Accuracy for k=32: 50.98% (+-7.27%)




In [None]:
print(
    "Accuracy for k=%i: %4.2f%% (+-%4.2f%%)"
    % (k, 100.0 * protonet_accuracies[k][0], 100 * protonet_accuracies[k][1])
)

In [292]:
for k in [2, 4, 8, 16, 32]:
    print(
        "F1-score for k=%i: %4.2f%% (+-%4.2f%%)"
        % (k, 100.0 * protonet_f1s[k][0], 100 * protonet_f1s[k][1])
    )

F1-score for k=2: 49.23% (+-10.41%)
F1-score for k=4: 51.08% (+-9.09%)
F1-score for k=8: 51.74% (+-8.35%)
F1-score for k=16: 52.54% (+-7.76%)
F1-score for k=32: 53.31% (+-7.41%)
