In [19]:
!pip install -q torchmetrics transformers datasets pytorch-lightning
!pip install -q --upgrade ipython ipykernel


import pytorch_lightning as pl
import json
from math import ceil
import random
from torch.utils.data import DataLoader
from torch.optim import Adam 
from torch.nn.functional import cross_entropy
from transformers import AutoTokenizer, AutoModelForSequenceClassification 
from torch.nn import Linear, Sigmoid
import torch
import torchmetrics
from google.colab import drive
from tqdm.auto import tqdm


drive.mount("/content/drive")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [20]:
BATCH_SIZE = 128
#Change paths to downloaded datasets from https://doi.org/10.7910/DVN/TTP7AO
GDRIVE_PATH = "/content/drive/MyDrive/Hackathon: European Patent Office (EPO)/data"
GDRIVE_OUT_PATH = "/content/drive/MyDrive/Hackathon: European Patent Office (EPO)/model_weights/supervised"
DATASET = "Y02A"

In [21]:
DEBUG = False
def log(x, transform=None):
    if DEBUG:
        print(
            transform(x) 
            if transform is not None 
            else x
        )
    return x
    

In [22]:
class DataModule(pl.LightningDataModule):
  def __init__(self, 
               dir_sus: str = f'{GDRIVE_PATH}ds_WO_{DATASET}_complete_final_new.json',
               dir_other: str = f'{GDRIVE_PATH}ds_WO_not_Y02_complete_final_new.json',
               tokenizer=AutoTokenizer.from_pretrained('dmis-lab/biobert-base-cased-v1.2')):
    super().__init__()
    self.dir_sus=dir_sus
    self.dir_other=dir_other
    self.dataset={}
    self.train_batch_size=BATCH_SIZE
    self.data = []
    self.tokenizer=tokenizer

  def prepare_data(self):
      with open(self.dir_sus, 'r') as file:
          dsY02W= json.load(file)
      with open(self.dir_other, 'r') as file:
          dsOTHER= json.load(file)

      data = [{"label": "dsY02W", "text": ''.join(dsY02W[i][1])} for i in range(0,len(dsY02W))]
      data.extend([{"label": "dsOTHER", "text": ''.join(dsOTHER[i][1])} for i in range(0,len(dsOTHER))])
      data = [
        {
            "label": torch.Tensor([1,0]) if label == "dsY02W" else torch.Tensor([0,1]),
            "tokens": encoding.input_ids,
            "attention_mask": encoding.attention_mask, 
        }
        for label, encoding in [
            log((entry["label"], self.tokenizer(entry["text"], max_length=500, truncation=True)))
            for entry in data
        ]
      ]
      random.shuffle(data)
      self.data = data


  def setup(self, stage):
      self.dataset["train"]= self.data[:ceil(len(self.data)*0.6)]
      self.dataset["val"]= self.data[ceil(len(self.data)*0.6):ceil(len(self.data)*0.8)]
      self.dataset["test"]= self.data[ceil(len(self.data)*0.8):]
  
  def train_dataloader(self):
      return DataLoader(self.dataset["train"], batch_size=self.train_batch_size)
  
  def val_dataloader(self):
    return DataLoader(self.dataset["val"], batch_size=self.train_batch_size)

In [23]:
class Model(pl.LightningModule):
  def __init__(self, module, lr, loss_fn):
    super().__init__()
    self.module = module
    self.lr = lr
    self.loss_fn = loss_fn
    self.accuracy = torchmetrics.Accuracy(task='binary', multidim_average='samplewise')
    self.recall = torchmetrics.Recall(task='binary', multidim_average='samplewise')
    self.f1 = torchmetrics.F1Score(task='binary', multidim_average='samplewise')
    self.batch_size = BATCH_SIZE

  def forward(self, tokens, mask):
    return self.module(tokens, mask).logits

  def training_step(self, batch, batch_index):
    tokens=torch.stack(batch['tokens'],1).to(self.device)
    label=batch['label'].to(self.device, dtype=torch.float)
    mask=torch.stack(batch['attention_mask'],1).to(self.device)
    
    #mask, label = map(lambda x: x.float(), [mask, label])
    label_pred = self(tokens, mask)
    loss = self.loss_fn(
        label_pred,
        label
    )

    accuracy = self.accuracy(label_pred, label)
    recall = self.recall(label_pred, label)
    f1 = self.f1(label_pred, label)

    self.log_dict({"train_loss": loss,
                   f"train_accuracy_{DATASET}": accuracy[0],
                   "train_accuracy_notY02": accuracy[1],
                   f"train_recall_{DATASET}": recall[0],
                   "train_recall_notY02": recall[1],
                   f"train_f1_{DATASET}": f1[0],
                   "train_f1_notY02": f1[1],
                   })
    return loss

  def validation_step(self, batch, batch_idx):
    tokens=torch.stack(batch['tokens'],1).to(self.device)
    label=batch['label'].to(self.device, dtype=torch.float)
    mask=torch.stack(batch['attention_mask'],1).to(self.device)

    label_pred = self(tokens, mask)
    loss = self.loss_fn(
        label_pred,
        label
    )

    accuracy = self.accuracy(label_pred, label)
    recall = self.recall(label_pred, label)
    f1 = self.f1(label_pred, label)

    self.log_dict({"val_loss": loss,
                   f"val_accuracy_{DATASET}": accuracy[0],
                   "val_accuracy_notY02": accuracy[1],
                   f"val_recall_{DATASET}": recall[0],
                   "val_recall_notY02": recall[1],
                   f"val_f1_{DATASET}": f1[0],
                   "val_f1_notY02": f1[1],
                   })

  def configure_optimizers(self):
    return Adam(self.parameters(), lr=self.lr)

In [24]:
for cls in tqdm(["C", "D", "E", "P", "T", "W", "A", "B"], desc="All datasets"):
    cls = f"Y02{cls}"

    model = AutoModelForSequenceClassification.from_pretrained(
        'dmis-lab/biobert-base-cased-v1.2',
    )
    # Freezing everything but the last 2 layers
    for p in [p for p in model.parameters()][:-4]:
        p.requires_grad = False

    pl.Trainer(
        logger=pl.loggers.TensorBoardLogger(
            save_dir=f"{GDRIVE_OUT_PATH}/logs/{cls}",
        ),
        # gpus=1,
        max_epochs=4,
        callbacks=[
            pl.callbacks.EarlyStopping(monitor="val_loss"),
            pl.callbacks.ModelCheckpoint(
                dirpath=f"{GDRIVE_OUT_PATH}/{cls}",
                save_top_k=2,
                monitor=f"val_loss",
            ),
        ],
        # fast_dev_run=True,
    ).fit(
        Model(
            model,
            lr=0.01,
            loss_fn=cross_entropy,
        ),
        datamodule=DataModule(
            dir_sus=f"{GDRIVE_PATH}/ds_WO_{cls}_complete_final_new.json",
            dir_other=f"{GDRIVE_PATH}/ds_WO_not_Y02_complete_final_new.json",
            tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
        ),
    )

All datasets:   0%|          | 0/8 [00:00<?, ?it/s]

Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.2 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification we

Sanity Checking: 0it [00:00, ?it/s]

RuntimeError: ignored