In [17]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [1]:
import torch
from trainer import *
from torch.utils.data import DataLoader as TorchLoader
from transformers import BertModel, BertTokenizer
from model import MiniMoELoadWeights, MiniMoE

In [2]:
tokenizer = BertTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
base_model = BertModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
base_model.config.attention_probs_dropout_prob = 0.05
base_model.config.hidden_dropout_prob = 0.05
#base_model.max_position_embeddings = 2048 # this would be a good addition
loader = MiniMoELoadWeights(base_model=base_model, tokenizer=tokenizer, domains=['[COPD]', '[CVD]', '[HSP90]', '[TGFB]'])
model, tokenizer = loader.get_seeded_model()
mini = MiniMoE(model)

Some weights differ
All weights match
Model loaded in  0.01 minutes
44.0 million total parameters
36.9 million effective parameters
Approximately 0.16 GB of memory in fp32



In [21]:
#copd_dataset.push_to_hub('lhallee/abstract_domain_copd')
#cvd_dataset.push_to_hub('lhallee/abstract_domain_cvd')

Uploading the dataset shards:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/91 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/91 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

In [9]:
class config:
    data_paths = ['lhallee/abstract_domain_cvd', 'lhallee/abstract_domain_copd']
    epochs = 100
    domains = ['[CVD]', '[COPD]']
    batch_size = 20
    lr = 1e-4
    validate_interval = 100
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
train_dataset, valid_dataset, test_dataset = get_datasets(config.data_paths, tokenizer, config.domains)
train_loader = TorchLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
valid_loader = TorchLoader(valid_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = TorchLoader(test_dataset, batch_size=config.batch_size, shuffle=True)

In [15]:
def print_words(input_ids):
    words = tokenizer.decode(input_ids)
    print(words)

import random

for _ in range(25):
    i = random.randint(0, len(train_dataset) - 1)
    ex = train_dataset[i]
    print_words(ex[0].input_ids.squeeze())
    print_words(ex[1].input_ids.squeeze())
    print(ex[2])
    print(ex[3])
    print('')


[COPD] pulmonary rehabilitation for moderate chronic obstructive pulmonary disease in primary care could improve patients ’ quality of life., this study aimed to assess the efficacy of a 3 - month pulmonary rehabilitation ( pr ) program with a further 9 months of maintenance ( rhbm group ) compared with both pr for 3 months without further maintenance ( rhb group ) and usual care in improving the quality of life of patients with moderate copd., we conducted a parallel - group, randomized clinical trial in majorca primary health care in which 97 patients with moderate copd were assigned to the 3 groups., health outcomes were quality of life, exercise capacity, pulmonary function and exacerbations., we found statistically and clinically significant differences in the three groups at 3 months in the emotion dimension ( 0. 53 ; 95 % ci0. 06 - 1. 01 ) in the usual care group, ( 0. 72 ; 95 % ci0. 26 - 1. 18 ) the rhb group ( 0. 87 ; 95 % ci 0. 44 - 1. 30 ) and the rhbm group as well as in fa

In [7]:
def train_test(config, model, optimizer, train_loader, val_loader):
    best_val_f1 = float('inf')
    patience_counter = 0
    c_losses, r_losses, cos_sims, accuracies = [], [], [], []

    for epoch in range(config.epochs):
        model.train()
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for batch_idx, (batch1, batch2, c_labels, r_labels) in pbar:
            r_labels = r_labels.to(config.device)
            batch1 = {k:v.squeeze(1).to(config.device) for k, v in batch1.items()}
            batch2 = {k:v.squeeze(1).to(config.device) for k, v in batch2.items()}
            optimizer.zero_grad()
            emba, embb, router_logits, c_loss, r_loss = model(batch1, batch2, r_labels)
            loss = c_loss + r_loss
            loss.backward()
            optimizer.step()

            c_losses.append(c_loss.item())
            r_losses.append(r_loss.item())
            cos_sims.append(F.cosine_similarity(emba, embb).mean().item())
            avg_logits = torch.stack(router_logits, dim=2).transpose(1, 2).mean(dim=1) # batch_size, num_experts
            router_predictions = torch.argmax(avg_logits, dim=1)
            accuracy = (router_predictions == r_labels).float().mean().item()
            accuracies.append(accuracy)

            if len(c_losses) > 10:
                avg_c_loss = np.mean(c_losses[-10:])
                avg_r_loss = np.mean(r_losses[-10:])
                avg_cos_sim = np.mean(cos_sims[-10:])
                avg_accuracy = np.mean(accuracies[-10:])
                pbar.set_description(f'Epoch {epoch} C_Loss: {avg_c_loss:.4f} R_Loss: {avg_r_loss:.4f} Cosine Similarity: {avg_cos_sim:.4f} Accuracy: {avg_accuracy:.4f}')

            if batch_idx % config.validate_interval == 0 and batch_idx > 0:
                threshold, val_f1 = validate(config, model, val_loader)
                print(f'Epoch {epoch} Step {batch_idx} Threshold {threshold} Val F1 ', val_f1)
                if val_f1 < best_val_f1:
                    best_val_f1 = val_f1
                    patience_counter = 0
                    torch.save(model.state_dict(), 'best_model.pt')
                else:
                    patience_counter += 1
                    if patience_counter > config.patience:
                        print('Early stopping due to loss not improving')
                        model.load_state_dict(torch.load('best_model.pt'))
                        return model
    model.load_state_dict(torch.load('best_model.pt'))
    return model

In [10]:
optimizer = torch.optim.AdamW(mini.parameters(), lr=config.lr)
train_test(config, mini, optimizer, train_loader, valid_loader)

  0%|          | 0/156727 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [20]:
import pandas as pd
from datasets import Dataset, DatasetDict

train_copd = pd.read_csv('COPDDS.csv')
valid_copd = pd.read_csv('COPDValidDS.csv')
test_copd = pd.read_csv('COPDTestDS.csv')
train_copd = train_copd.rename(columns={'uid1abstract': 'a', 'uid2abstract': 'b'})
valid_copd = valid_copd.rename(columns={'uid1abstract': 'a', 'uid2abstract': 'b'})
test_copd = test_copd.rename(columns={'uid1abstract': 'a', 'uid2abstract': 'b'})
train_copd['label'] = 1
test_copd['label'] = 1
copd_dataset = DatasetDict({
    'train': Dataset.from_pandas(train_copd),
    'valid': Dataset.from_pandas(valid_copd),
    'test': Dataset.from_pandas(test_copd)
})
train_cvd = pd.read_csv('CVDDS.csv')
valid_cvd = pd.read_csv('CVDValidDS.csv')
test_cvd = pd.read_csv('CVDTestDS.csv')
train_cvd = train_cvd.rename(columns={'uid1abstract': 'a', 'uid2abstract': 'b'})
valid_cvd = valid_cvd.rename(columns={'uid1abstract': 'a', 'uid2abstract': 'b', 'Label': 'label'})
test_cvd = test_cvd.rename(columns={'uid1abstract': 'a', 'uid2abstract': 'b', 'Label': 'label'})
train_cvd['label'] = 1
test_cvd['label'] = 1
cvd_dataset = DatasetDict({
    'train': Dataset.from_pandas(train_cvd),
    'valid': Dataset.from_pandas(valid_cvd),
    'test': Dataset.from_pandas(test_cvd)
})